detection.py (82648B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Classes and functions related to MIMO channel detection""" 6 7 import warnings 8 import numpy as np 9 import tensorflow as tf 10 from tensorflow.keras.layers import Layer 11 from sionna.utils import expand_to_rank, matrix_sqrt_inv, flatten_last_dims, flatten_dims, split_dim, insert_dims, hard_decisions 12 from sionna.mapping import Constellation, SymbolLogits2LLRs, LLRs2SymbolLogits, PAM2QAM, Demapper, SymbolDemapper, SymbolInds2Bits, DemapperWithPrior, SymbolLogits2Moments 13 from sionna.mimo.utils import complex2real_channel, whiten_channel, List2LLR, List2LLRSimple, complex2real_matrix, complex2real_vector, real2complex_vector 14 from sionna.mimo.equalization import lmmse_equalizer, zf_equalizer, mf_equalizer 15 16 class LinearDetector(Layer): 17 # pylint: disable=line-too-long 18 r"""LinearDetector(equalizer, output, demapping_method, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 19 20 Convenience class that combines an equalizer, 21 such as :func:`~sionna.mimo.lmmse_equalizer`, and a :class:`~sionna.mapping.Demapper`. 22 23 Parameters 24 ---------- 25 equalizer : str, one of ["lmmse", "zf", "mf"], or an equalizer function 26 The equalizer to be used. Either one of the existing equalizers 27 :func:`~sionna.mimo.lmmse_equalizer`, :func:`~sionna.mimo.zf_equalizer`, or 28 :func:`~sionna.mimo.mf_equalizer` can be used, or a custom equalizer 29 callable provided that has the same input/output specification. 30 31 output : One of ["bit", "symbol"], str 32 The type of output, either LLRs on bits or logits on constellation symbols. 33 34 demapping_method : One of ["app", "maxlog"], str 35 The demapping method used. 36 37 constellation_type : One of ["qam", "pam", "custom"], str 38 For "custom", an instance of :class:`~sionna.mapping.Constellation` 39 must be provided. 40 41 num_bits_per_symbol : int 42 The number of bits per constellation symbol, e.g., 4 for QAM16. 43 Only required for ``constellation_type`` in ["qam", "pam"]. 44 45 constellation : Constellation 46 An instance of :class:`~sionna.mapping.Constellation` or `None`. 47 In the latter case, ``constellation_type`` 48 and ``num_bits_per_symbol`` must be provided. 49 50 hard_out : bool 51 If `True`, the detector computes hard-decided bit values or 52 constellation point indices instead of soft-values. 53 Defaults to `False`. 54 55 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 56 The dtype of ``y``. Defaults to tf.complex64. 57 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 58 59 Input 60 ------ 61 (y, h, s) : 62 Tuple: 63 64 y : [...,M], tf.complex 65 1+D tensor containing the received signals 66 67 h : [...,M,num_streams], tf.complex 68 2+D tensor containing the channel matrices 69 70 s : [...,M,M], tf.complex 71 2+D tensor containing the noise covariance matrices 72 73 Output 74 ------ 75 One of: 76 77 : [..., num_streams, num_bits_per_symbol], tf.float 78 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` 79 80 : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int 81 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` 82 Hard-decisions correspond to the symbol indices. 83 84 Note 85 ---- 86 If you want to use this layer in Graph mode with XLA, i.e., within 87 a function that is decorated with ``@tf.function(jit_compile=True)``, 88 you might need to set ``sionna.Config.xla_compat=true``. This depends on the 89 chosen equalizer function. See :py:attr:`~sionna.Config.xla_compat`. 90 """ 91 def __init__(self, 92 equalizer, 93 output, 94 demapping_method, 95 constellation_type=None, 96 num_bits_per_symbol=None, 97 constellation=None, 98 hard_out=False, 99 dtype=tf.complex64, 100 **kwargs): 101 super().__init__(dtype=dtype, **kwargs) 102 self._output = output 103 self._hard_out = hard_out 104 105 # Determine the equalizer to use 106 if isinstance(equalizer, str): 107 assert equalizer in ["lmmse", "zf", "mf"], "Unknown equalizer." 108 if equalizer=="lmmse": 109 self._equalizer = lmmse_equalizer 110 elif equalizer=="zf": 111 self._equalizer = zf_equalizer 112 else: 113 self._equalizer = mf_equalizer 114 else: 115 self._equalizer = equalizer 116 117 assert output in ("bit", "symbol"), "Unknown output" 118 assert demapping_method in ("app","maxlog"), "Unknown demapping method" 119 120 constellation = Constellation.create_or_check_constellation( 121 constellation_type, 122 num_bits_per_symbol, 123 constellation, 124 dtype=dtype) 125 self._constellation = constellation 126 127 # Determine the demapper to use 128 if output=="bit": 129 self._demapper = Demapper(demapping_method, 130 constellation=constellation, 131 hard_out=hard_out, 132 dtype=dtype) 133 else: 134 self._demapper = SymbolDemapper(constellation=constellation, 135 hard_out=hard_out, 136 dtype=dtype) 137 138 def call(self, inputs): 139 x_hat, no_eff = self._equalizer(*inputs) 140 z = self._demapper([x_hat, no_eff]) 141 142 # Reshape to the expected output shape 143 num_streams = tf.shape(inputs[1])[-1] 144 if self._output == 'bit': 145 num_bits_per_symbol = self._constellation.num_bits_per_symbol 146 z = split_dim(z, [num_streams, num_bits_per_symbol], tf.rank(z)-1) 147 148 return z 149 150 class MaximumLikelihoodDetector(Layer): 151 # pylint: disable=line-too-long 152 r""" 153 MaximumLikelihoodDetector(output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, with_prior=False, dtype=tf.complex64, **kwargs) 154 155 MIMO maximum-likelihood (ML) detector. 156 If the ``with_prior`` flag is set, prior knowledge on the bits or constellation points is assumed to be available. 157 158 This layer implements MIMO maximum-likelihood (ML) detection assuming the 159 following channel model: 160 161 .. math:: 162 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 163 164 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 165 :math:`\mathbf{x}\in\mathcal{C}^K` is the vector of transmitted symbols which 166 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 167 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 168 and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. 169 It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and 170 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, 171 where :math:`\mathbf{S}` has full rank. 172 If the ``with_prior`` flag is set, it is assumed that prior information of the transmitted signal :math:`\mathbf{x}` is available, 173 provided either as LLRs on the bits mapped onto :math:`\mathbf{x}` or as logits on the individual 174 constellation points forming :math:`\mathbf{x}`. 175 176 Prior to demapping, the received signal is whitened: 177 178 .. math:: 179 \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\ 180 &= \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\ 181 &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}} 182 183 The layer can compute ML detection of symbols or bits with either 184 soft- or hard-decisions. Note that decisions are computed symbol-/bit-wise 185 and not jointly for the entire vector :math:`\textbf{x}` (or the underlying vector 186 of bits). 187 188 **\ML detection of bits:** 189 190 Soft-decisions on bits are called log-likelihood ratios (LLR). 191 With the “app” demapping method, the LLR for the :math:`i\text{th}` bit 192 of the :math:`k\text{th}` user is then computed according to 193 194 .. math:: 195 \begin{align} 196 LLR(k,i)&= \ln\left(\frac{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}\right)\\ 197 &=\ln\left(\frac{ 198 \sum_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \exp\left( 199 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 200 \right) \Pr\left( \mathbf{x} \right) 201 }{ 202 \sum_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \exp\left( 203 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 204 \right) \Pr\left( \mathbf{x} \right) 205 }\right) 206 \end{align} 207 208 where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the 209 sets of vectors of constellation points for which the :math:`i\text{th}` bit 210 of the :math:`k\text{th}` user is equal to 1 and 0, respectively. 211 :math:`\Pr\left( \mathbf{x} \right)` is the prior distribution of the vector of 212 constellation points :math:`\mathbf{x}`. Assuming that the constellation points and 213 bit levels are independent, it is computed from the prior of the bits according to 214 215 .. math:: 216 \Pr\left( \mathbf{x} \right) = \prod_{k=1}^K \prod_{i=1}^{I} \sigma \left( LLR_p(k,i) \right) 217 218 where :math:`LLR_p(k,i)` is the prior knowledge of the :math:`i\text{th}` bit of the 219 :math:`k\text{th}` user given as an LLR and which is set to :math:`0` if no prior knowledge is assumed to be available, 220 and :math:`\sigma\left(\cdot\right)` is the sigmoid function. 221 The definition of the LLR has been chosen such that it is equivalent with that of logit. This is 222 different from many textbooks in communications, where the LLR is 223 defined as :math:`LLR(k,i) = \ln\left(\frac{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}\right)`. 224 225 With the "maxlog" demapping method, the LLR for the :math:`i\text{th}` bit 226 of the :math:`k\text{th}` user is approximated like 227 228 .. math:: 229 \begin{align} 230 LLR(k,i) \approx&\ln\left(\frac{ 231 \max_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \exp\left( 232 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 233 \right) \Pr\left( \mathbf{x} \right) \right) 234 }{ 235 \max_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \exp\left( 236 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 237 \right) \Pr\left( \mathbf{x} \right) \right) 238 }\right)\\ 239 = &\min_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left(\Pr\left( \mathbf{x} \right) \right) \right) - 240 \min_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left( \Pr\left( \mathbf{x} \right) \right) \right). 241 \end{align} 242 243 **ML detection of symbols:** 244 245 Soft-decisions on symbols are called logits (i.e., unnormalized log-probability). 246 247 With the “app” demapping method, the logit for the 248 constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user is computed according to 249 250 .. math:: 251 \begin{align} 252 \text{logit}(k,c) &= \ln\left(\sum_{\mathbf{x} : x_k = c} \exp\left( 253 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 254 \right)\Pr\left( \mathbf{x} \right)\right). 255 \end{align} 256 257 With the "maxlog" demapping method, the logit for the constellation point :math:`c \in \mathcal{C}` 258 of the :math:`k\text{th}` user is approximated like 259 260 .. math:: 261 \text{logit}(k,c) \approx \max_{\mathbf{x} : x_k = c} \left( 262 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 + \ln \left( \Pr\left( \mathbf{x} \right) \right) 263 \right). 264 265 When hard decisions are requested, this layer returns for the :math:`k` th stream 266 267 .. math:: 268 \hat{c}_k = \underset{c \in \mathcal{C}}{\text{argmax}} \left( \sum_{\mathbf{x} : x_k = c} \exp\left( 269 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 270 \right)\Pr\left( \mathbf{x} \right) \right) 271 272 where :math:`\mathcal{C}` is the set of constellation points. 273 274 Parameters 275 ----------- 276 output : One of ["bit", "symbol"], str 277 The type of output, either LLRs on bits or logits on constellation symbols. 278 279 demapping_method : One of ["app", "maxlog"], str 280 The demapping method used. 281 282 num_streams : tf.int 283 Number of transmitted streams 284 285 constellation_type : One of ["qam", "pam", "custom"], str 286 For "custom", an instance of :class:`~sionna.mapping.Constellation` 287 must be provided. 288 289 num_bits_per_symbol : int 290 The number of bits per constellation symbol, e.g., 4 for QAM16. 291 Only required for ``constellation_type`` in ["qam", "pam"]. 292 293 constellation : Constellation 294 An instance of :class:`~sionna.mapping.Constellation` or `None`. 295 In the latter case, ``constellation_type`` 296 and ``num_bits_per_symbol`` must be provided. 297 298 hard_out : bool 299 If `True`, the detector computes hard-decided bit values or 300 constellation point indices instead of soft-values. 301 Defaults to `False`. 302 303 with_prior : bool 304 If `True`, it is assumed that prior knowledge on the bits or constellation points is available. 305 This prior information is given as LLRs (for bits) or log-probabilities (for constellation points) as an 306 additional input to the layer. 307 Defaults to `False`. 308 309 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 310 The dtype of ``y``. Defaults to tf.complex64. 311 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 312 313 Input 314 ------ 315 (y, h, s) or (y, h, prior, s) : 316 Tuple: 317 318 y : [...,M], tf.complex 319 1+D tensor containing the received signals. 320 321 h : [...,M,num_streams], tf.complex 322 2+D tensor containing the channel matrices. 323 324 prior : [...,num_streams,num_bits_per_symbol] or [...,num_streams,num_points], tf.float 325 Prior of the transmitted signals. 326 If ``output`` equals "bit", then LLRs of the transmitted bits are expected. 327 If ``output`` equals "symbol", then logits of the transmitted constellation points are expected. 328 Only required if the ``with_prior`` flag is set. 329 330 s : [...,M,M], tf.complex 331 2+D tensor containing the noise covariance matrices. 332 333 Output 334 ------ 335 One of: 336 337 : [..., num_streams, num_bits_per_symbol], tf.float 338 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 339 340 : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int 341 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 342 Hard-decisions correspond to the symbol indices. 343 344 Note 345 ---- 346 If you want to use this layer in Graph mode with XLA, i.e., within 347 a function that is decorated with ``@tf.function(jit_compile=True)``, 348 you must set ``sionna.Config.xla_compat=true``. 349 See :py:attr:`~sionna.Config.xla_compat`. 350 """ 351 352 def __init__(self, 353 output, 354 demapping_method, 355 num_streams, 356 constellation_type=None, 357 num_bits_per_symbol=None, 358 constellation=None, 359 hard_out=False, 360 with_prior=False, 361 dtype=tf.complex64, 362 **kwargs): 363 super().__init__(dtype=dtype, **kwargs) 364 365 assert dtype in [tf.complex64, tf.complex128],\ 366 "dtype must be tf.complex64 or tf.complex128" 367 368 assert output in ("bit", "symbol"), "Unknown output" 369 370 assert demapping_method in ("app","maxlog"), "Unknown demapping method" 371 372 self._output = output 373 self._demapping_method = demapping_method 374 self._hard_out = hard_out 375 self._with_prior = with_prior 376 377 # Determine the reduce function for LLR computation 378 if self._demapping_method == "app": 379 self._reduce = tf.reduce_logsumexp 380 else: 381 self._reduce = tf.reduce_max 382 383 # Create constellation object 384 self._constellation = Constellation.create_or_check_constellation( 385 constellation_type, 386 num_bits_per_symbol, 387 constellation, 388 dtype=dtype) 389 390 # Utility function to compute 391 # vecs : [num_vecs, num_streams] The list of all possible transmitted vectors. 392 # vecs_ind : [num_vecs, num_streams] The list of all possible transmitted vectors 393 # constellation indices 394 # c : [num_vecs/num_points, num_streams, num_points] Which is such that `c[:,k,s]` 395 # gives the symbol indices in the first dimension of `vecs` for which 396 # the `k`th stream transmitted the `s`th constellation point. 397 vecs, vecs_ind, c = self._build_vecs(num_streams) 398 self._vecs = tf.cast(vecs, dtype) 399 self._vecs_ind = tf.cast(vecs_ind, tf.int32) 400 self._c = tf.cast(c, tf.int32) 401 402 if output == 'bit': 403 num_bits_per_symbol = self._constellation.num_bits_per_symbol 404 self._logits2llr = SymbolLogits2LLRs( 405 method=demapping_method, 406 num_bits_per_symbol=num_bits_per_symbol, 407 hard_out=hard_out, 408 dtype=dtype.real_dtype, 409 **kwargs) 410 self._llrs2logits = LLRs2SymbolLogits( 411 num_bits_per_symbol=num_bits_per_symbol, 412 hard_out=False, 413 dtype=dtype.real_dtype, 414 **kwargs) 415 416 @property 417 def constellation(self): 418 return self._constellation 419 420 def _build_vecs(self, num_streams): 421 """ 422 Utility function for building the list of all possible transmitted 423 vectors of constellation points and the symbol indices corresponding to 424 all possibly transmitted constellation points for every stream. 425 426 Input 427 ------ 428 num_streams : int 429 Number of transmitted streams 430 431 Output 432 ------- 433 vecs : [num_vecs, K], tf.complex 434 List of all possible transmitted vectors. 435 436 c : [num_vecs/num_points, num_streams, num_points], int 437 `c[:,k,s]` gives the symbol indices in the first dimension of `vecs` 438 for which the `k`th stream transmitted the `s`th symbol. 439 """ 440 441 points = self._constellation.points 442 num_points = points.shape[0] 443 444 # Recursive function for generating all possible transmitted 445 # vector of symbols and indices 446 # `n` is the remaining number of stream to process 447 def _build_vecs_(n): 448 if n == 1: 449 # If there is a single stream, then the list of possibly 450 # transmitted vectors corresponds to the constellation points. 451 # No recusrion is needed. 452 vecs = np.expand_dims(points, axis=1) 453 vecs_ind = np.expand_dims(np.arange(num_points), axis=1) 454 else: 455 # If the number of streams is `n >= 2` streams, then the list 456 # of possibly transmitted vectors is 457 # 458 # [c_1 v , c_2 v, ..., c_N v] 459 # 460 # where `[c_1, ..., c_N]` is the constellation of size N, and 461 # `v` is the list of possible vectors for `n-1` streams. 462 # This list has therefore length `N x len(v)`. 463 # 464 # Building the list for `n-1` streams, recursively. 465 v, vi = _build_vecs_(n-1) 466 # Building the list of `n` streams by appending the 467 # constellation points. 468 vecs = [] 469 vecs_ind = [] 470 for i,p in enumerate(points): 471 vecs.append(np.concatenate([np.full([v.shape[0], 1], p), 472 v], axis=1)) 473 vecs_ind.append(np.concatenate([np.full([v.shape[0], 1], i), 474 vi], axis=1)) 475 vecs = np.concatenate(vecs, axis=0) 476 vecs_ind = np.concatenate(vecs_ind, axis=0) 477 return vecs, vecs_ind 478 479 # Building the list of possible vectors for the `k` streams. 480 # [num_vecs, K] 481 vecs, vecs_ind = _build_vecs_(num_streams) 482 483 tx_ind = np.arange(num_streams) 484 tx_ind = np.expand_dims(tx_ind, axis=0) 485 tx_ind = np.tile(tx_ind, [vecs_ind.shape[0], 1]) 486 vecs_ind = np.stack([tx_ind, vecs_ind], axis=-1) 487 488 # Compute symbol indices for every stream. 489 # For every constellation point `p` and for every stream `j`, we gather 490 # the list of vector indices from `vecs` corresponding the vectors for 491 # which the `jth` stream transmitted `p`. 492 # [num_vecs/num_points, num_streams, num_points] 493 c = [] 494 for p in points: 495 c_ = [] 496 for j in range(num_streams): 497 c_.append(np.where(vecs[:,j]==p)[0]) 498 c_ = np.stack(c_, axis=-1) 499 c.append(c_) 500 c = np.stack(c, axis=-1) 501 502 return vecs, vecs_ind, c 503 504 def call(self, inputs): 505 if self._with_prior: 506 y, h, prior, s = inputs 507 508 # If operating on bits, computes prior on symbols from the prior 509 # on bits 510 if self._output == 'bit': 511 # [..., K, num_points] 512 prior = self._llrs2logits(prior) 513 else: 514 y, h, s = inputs 515 516 # Compute square-root of interference covariance matrix 517 s_inv = matrix_sqrt_inv(s) 518 519 # Whiten the observation 520 y = tf.expand_dims(y, -1) 521 y = tf.squeeze(tf.matmul(s_inv, y), axis=-1) 522 523 # Compute channel after whitening 524 h = tf.matmul(s_inv, h) 525 526 # Add extra dims for broadcasting with the dimensions corresponding 527 # to all possible transmimtted vectors 528 # Shape: [..., 1, M, K] 529 h = tf.expand_dims(h, axis=-3) 530 531 # Add extra dims for broadcasting with the dimensions corresponding 532 # to all possible transmimtted vectors 533 # Shape: [..., 1, M] 534 y = tf.expand_dims(y, axis=-2) 535 536 # Reshape list of all possible vectors from 537 # [num_vecs, K] 538 # to 539 # [1,...,1, num_vecs, K, 1] 540 vecs = self._vecs 541 vecs = tf.expand_dims(vecs, axis=-1) 542 vecs = expand_to_rank(vecs, tf.rank(h), 0) 543 544 # Compute exponents 545 # [..., num_vecs] 546 diff = y - tf.squeeze(h@vecs, axis=-1) 547 exponents = -tf.reduce_sum(tf.square(tf.abs(diff)), axis=-1) 548 549 # Add prior 550 if self._with_prior: 551 # [..., num_vecs, K] 552 prior = expand_to_rank(prior, tf.rank(exponents), axis=0) 553 prior_rank = tf.rank(prior) 554 transpose_ind = tf.concat([[prior_rank-2, prior_rank-1], 555 tf.range(prior_rank-2)], axis=0) 556 prior = tf.transpose(prior, transpose_ind) 557 prior = tf.gather_nd(prior, self._vecs_ind) 558 transpose_ind = tf.concat([ tf.range(2, prior_rank), 559 [0, 1]], axis=0) 560 prior = tf.transpose(prior, transpose_ind) 561 # [..., num_vecs] 562 prior = tf.reduce_sum(prior, axis=-1) 563 exponents = exponents + prior 564 565 # Gather exponents for all symbols 566 # [..., num_vecs/num_points, K, num_points] 567 exp = tf.gather(exponents, self._c, axis=-1) 568 569 # Compute logits on constellation points 570 # [..., K, num_points] 571 logits = self._reduce(exp, axis=-3) 572 573 if self._output == 'bit': 574 # Compute LLRs or hard decisions 575 return self._logits2llr(logits) 576 else: 577 if self._hard_out: 578 return tf.argmax(logits, axis=-1, output_type=tf.int32) 579 else: 580 return logits 581 582 class MaximumLikelihoodDetectorWithPrior(MaximumLikelihoodDetector): 583 # pylint: disable=line-too-long 584 r""" 585 MaximumLikelihoodDetectorWithPrior(output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 586 587 MIMO maximum-likelihood (ML) detector, assuming prior 588 knowledge on the bits or constellation points is available. 589 590 This class is deprecated as the functionality has been integrated 591 into :class:`~sionna.mimo.MaximumLikelihoodDetector`. 592 593 This layer implements MIMO maximum-likelihood (ML) detection assuming the 594 following channel model: 595 596 .. math:: 597 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 598 599 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 600 :math:`\mathbf{x}\in\mathcal{C}^K` is the vector of transmitted symbols which 601 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 602 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 603 and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. 604 It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and 605 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, 606 where :math:`\mathbf{S}` has full rank. 607 It is assumed that prior information of the transmitted signal :math:`\mathbf{x}` is available, 608 provided either as LLRs on the bits modulated onto :math:`\mathbf{x}` or as logits on the individual 609 constellation points forming :math:`\mathbf{x}`. 610 611 Prior to demapping, the received signal is whitened: 612 613 .. math:: 614 \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\ 615 &= \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\ 616 &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}} 617 618 The layer can compute ML detection of symbols or bits with either 619 soft- or hard-decisions. Note that decisions are computed symbol-/bit-wise 620 and not jointly for the entire vector :math:`\textbf{x}` (or the underlying vector 621 of bits). 622 623 **\ML detection of bits:** 624 625 Soft-decisions on bits are called log-likelihood ratios (LLR). 626 With the “app” demapping method, the LLR for the :math:`i\text{th}` bit 627 of the :math:`k\text{th}` user is then computed according to 628 629 .. math:: 630 \begin{align} 631 LLR(k,i)&= \ln\left(\frac{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}\right)\\ 632 &=\ln\left(\frac{ 633 \sum_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \exp\left( 634 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 635 \right) \Pr\left( \mathbf{x} \right) 636 }{ 637 \sum_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \exp\left( 638 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 639 \right) \Pr\left( \mathbf{x} \right) 640 }\right) 641 \end{align} 642 643 where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the 644 sets of vectors of constellation points for which the :math:`i\text{th}` bit 645 of the :math:`k\text{th}` user is equal to 1 and 0, respectively. 646 :math:`\Pr\left( \mathbf{x} \right)` is the prior distribution of the vector of 647 constellation points :math:`\mathbf{x}`. Assuming that the constellation points and 648 bit levels are independent, it is computed from the prior of the bits according to 649 650 .. math:: 651 \Pr\left( \mathbf{x} \right) = \prod_{k=1}^K \prod_{i=1}^{I} \sigma \left( LLR_p(k,i) \right) 652 653 where :math:`LLR_p(k,i)` is the prior knowledge of the :math:`i\text{th}` bit of the 654 :math:`k\text{th}` user given as an LLR, and :math:`\sigma\left(\cdot\right)` is the sigmoid function. 655 The definition of the LLR has been chosen such that it is equivalent with that of logit. This is 656 different from many textbooks in communications, where the LLR is 657 defined as :math:`LLR(k,i) = \ln\left(\frac{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}\right)`. 658 659 With the "maxlog" demapping method, the LLR for the :math:`i\text{th}` bit 660 of the :math:`k\text{th}` user is approximated like 661 662 .. math:: 663 \begin{align} 664 LLR(k,i) \approx&\ln\left(\frac{ 665 \max_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \exp\left( 666 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 667 \right) \Pr\left( \mathbf{x} \right) \right) 668 }{ 669 \max_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \exp\left( 670 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 671 \right) \Pr\left( \mathbf{x} \right) \right) 672 }\right)\\ 673 = &\min_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left(\Pr\left( \mathbf{x} \right) \right) \right) - 674 \min_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left( \Pr\left( \mathbf{x} \right) \right) \right). 675 \end{align} 676 677 **ML detection of symbols:** 678 679 Soft-decisions on symbols are called logits (i.e., unnormalized log-probability). 680 681 With the “app” demapping method, the logit for the 682 constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user is computed according to 683 684 .. math:: 685 \begin{align} 686 \text{logit}(k,c) &= \ln\left(\sum_{\mathbf{x} : x_k = c} \exp\left( 687 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 688 \right)\Pr\left( \mathbf{x} \right)\right). 689 \end{align} 690 691 With the "maxlog" demapping method, the logit for the constellation point :math:`c \in \mathcal{C}` 692 of the :math:`k\text{th}` user is approximated like 693 694 .. math:: 695 \text{logit}(k,c) \approx \max_{\mathbf{x} : x_k = c} \left( 696 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 + \ln \left( \Pr\left( \mathbf{x} \right) \right) 697 \right). 698 699 When hard decisions are requested, this layer returns for the :math:`k` th stream 700 701 .. math:: 702 \hat{c}_k = \underset{c \in \mathcal{C}}{\text{argmax}} \left( \sum_{\mathbf{x} : x_k = c} \exp\left( 703 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 704 \right)\Pr\left( \mathbf{x} \right) \right) 705 706 where :math:`\mathcal{C}` is the set of constellation points. 707 708 Parameters 709 ----------- 710 output : One of ["bit", "symbol"], str 711 The type of output, either LLRs on bits or logits on constellation symbols. 712 713 demapping_method : One of ["app", "maxlog"], str 714 The demapping method used. 715 716 num_streams : tf.int 717 Number of transmitted streams 718 719 constellation_type : One of ["qam", "pam", "custom"], str 720 For "custom", an instance of :class:`~sionna.mapping.Constellation` 721 must be provided. 722 723 num_bits_per_symbol : int 724 The number of bits per constellation symbol, e.g., 4 for QAM16. 725 Only required for ``constellation_type`` in ["qam", "pam"]. 726 727 constellation : Constellation 728 An instance of :class:`~sionna.mapping.Constellation` or `None`. 729 In the latter case, ``constellation_type`` 730 and ``num_bits_per_symbol`` must be provided. 731 732 hard_out : bool 733 If `True`, the detector computes hard-decided bit values or 734 constellation point indices instead of soft-values. 735 Defaults to `False`. 736 737 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 738 The dtype of ``y``. Defaults to tf.complex64. 739 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 740 741 Input 742 ------ 743 (y, h, prior, s) : 744 Tuple: 745 746 y : [...,M], tf.complex 747 1+D tensor containing the received signals. 748 749 h : [...,M,num_streams], tf.complex 750 2+D tensor containing the channel matrices. 751 752 prior : [...,num_streams,num_bits_per_symbol] or [...,num_streams,num_points], tf.float 753 Prior of the transmitted signals. 754 If ``output`` equals "bit", then LLRs of the transmitted bits are expected. 755 If ``output`` equals "symbol", then logits of the transmitted constellation points are expected. 756 757 s : [...,M,M], tf.complex 758 2+D tensor containing the noise covariance matrices. 759 760 Output 761 ------ 762 One of: 763 764 : [..., num_streams, num_bits_per_symbol], tf.float 765 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 766 767 : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int 768 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 769 Hard-decisions correspond to the symbol indices. 770 771 Note 772 ---- 773 If you want to use this layer in Graph mode with XLA, i.e., within 774 a function that is decorated with ``@tf.function(jit_compile=True)``, 775 you must set ``sionna.Config.xla_compat=true``. 776 See :py:attr:`~sionna.Config.xla_compat`. 777 """ 778 779 def __init__(self, 780 output, 781 demapping_method, 782 num_streams, 783 constellation_type=None, 784 num_bits_per_symbol=None, 785 constellation=None, 786 hard_out=False, 787 dtype=tf.complex64, 788 **kwargs): 789 super().__init__( output=output, 790 demapping_method=demapping_method, 791 num_streams=num_streams, 792 constellation_type=constellation_type, 793 num_bits_per_symbol=num_bits_per_symbol, 794 constellation=constellation, 795 hard_out=hard_out, 796 with_prior=True, 797 dtype=dtype, 798 **kwargs) 799 800 class KBestDetector(Layer): 801 # pylint: disable=line-too-long 802 r"""KBestDetector(output, num_streams, k, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, use_real_rep=False, list2llr=None, dtype=tf.complex64) 803 804 MIMO K-Best detector 805 806 This layer implements K-Best MIMO detection as described 807 in (Eq. 4-5) [FT2015]_. It can either generate hard decisions (for symbols 808 or bits) or compute LLRs. 809 810 The algorithm operates in either the complex or real-valued domain. 811 Although both options produce identical results, the former has the advantage 812 that it can be applied to arbitrary non-QAM constellations. It also reduces 813 the number of streams (or depth) by a factor of two. 814 815 The way soft-outputs (i.e., LLRs) are computed is determined by the 816 ``list2llr`` function. The default solution 817 :class:`~sionna.mimo.List2LLRSimple` assigns a predetermined 818 value to all LLRs without counter-hypothesis. 819 820 This layer assumes the following channel model: 821 822 .. math:: 823 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 824 825 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 826 :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which 827 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 828 :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix, 829 and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. 830 It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and 831 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, 832 where :math:`\mathbf{S}` has full rank. 833 834 In a first optional step, the channel model is converted to its real-valued equivalent, 835 see :func:`~sionna.mimo.complex2real_channel`. We assume in the sequel the complex-valued 836 representation. Then, the channel is whitened using :func:`~sionna.mimo.whiten_channel`: 837 838 .. math:: 839 \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\ 840 &= \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\ 841 &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}}. 842 843 Next, the columns of :math:`\tilde{\mathbf{H}}` are sorted according 844 to their norm in descending order. Then, the QR decomposition of the 845 resulting channel matrix is computed: 846 847 .. math:: 848 \tilde{\mathbf{H}} = \mathbf{Q}\mathbf{R} 849 850 where :math:`\mathbf{Q}\in\mathbb{C}^{M\times S}` is unitary and 851 :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is upper-triangular. 852 The channel outputs are then pre-multiplied by :math:`\mathbf{Q}^{\mathsf{H}}`. 853 This leads to the final channel model on which the K-Best detection algorithm operates: 854 855 .. math:: 856 \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}} 857 858 where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S`, 859 :math:`\bar{\mathbf{x}}\in\mathbb{C}^S`, and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S` 860 with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and 861 :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`. 862 863 **LLR Computation** 864 865 The K-Best algorithm produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S` 866 and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2` 867 for :math:`k=1,\dots,K`. If the real-valued channel representation is used, the distance 868 metrics are scaled by 0.5 to account for the reduced noise power in each complex dimension. 869 A hard-decision is simply the candidate with the shortest distance. 870 Various ways to compute LLRs from this list (and possibly 871 additional side-information) are possible. The (sub-optimal) default solution 872 is :class:`~sionna.mimo.List2LLRSimple`. Custom solutions can be provided. 873 874 Parameters 875 ----------- 876 output : One of ["bit", "symbol"], str 877 The type of output, either bits or symbols. Whether soft- or 878 hard-decisions are returned can be configured with the 879 ``hard_out`` flag. 880 881 num_streams : tf.int 882 Number of transmitted streams 883 884 k : tf.int 885 The number of paths to keep. Cannot be larger than the 886 number of constellation points to the power of the number of 887 streams. 888 889 constellation_type : One of ["qam", "pam", "custom"], str 890 For "custom", an instance of :class:`~sionna.mapping.Constellation` 891 must be provided. 892 893 num_bits_per_symbol : int 894 The number of bits per constellation symbol, e.g., 4 for QAM16. 895 Only required for ``constellation_type`` in ["qam", "pam"]. 896 897 constellation : Constellation 898 An instance of :class:`~sionna.mapping.Constellation` or `None`. 899 In the latter case, ``constellation_type`` 900 and ``num_bits_per_symbol`` must be provided. 901 902 hard_out : bool 903 If `True`, the detector computes hard-decided bit values or 904 constellation point indices instead of soft-values. 905 Defaults to `False`. The detector cannot compute soft-symbols. 906 907 use_real_rep : bool 908 If `True`, the detector use the real-valued equivalent representation 909 of the channel. Note that this only works with a QAM constellation. 910 Defaults to `False`. 911 912 list2llr: `None` or instance of :class:`~sionna.mimo.List2LLR` 913 The function to be used to compute LLRs from a list of candidate solutions. 914 If `None`, the default solution :class:`~sionna.mimo.List2LLRSimple` 915 is used. 916 917 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 918 The dtype of ``y``. Defaults to tf.complex64. 919 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 920 921 Input 922 ----- 923 (y, h, s) : 924 Tuple: 925 926 y : [...,M], tf.complex 927 1+D tensor containing the received signals 928 929 h : [...,M,num_streams], tf.complex 930 2+D tensor containing the channel matrices 931 932 s : [...,M,M], tf.complex 933 2+D tensor containing the noise covariance matrices 934 935 Output 936 ------ 937 One of: 938 939 : [...,num_streams,num_bits_per_symbol], tf.float 940 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` 941 942 : [...,num_streams,2**num_points], tf.float or [...,num_streams], tf.int 943 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` 944 Hard-decisions correspond to the symbol indices. 945 946 Note 947 ---- 948 If you want to use this layer in Graph mode with XLA, i.e., within 949 a function that is decorated with ``@tf.function(jit_compile=True)``, 950 you must set ``sionna.Config.xla_compat=true``. 951 See :py:attr:`~sionna.Config.xla_compat`. 952 """ 953 def __init__(self, 954 output, 955 num_streams, 956 k, 957 constellation_type=None, 958 num_bits_per_symbol=None, 959 constellation=None, 960 hard_out=False, 961 use_real_rep=False, 962 list2llr="default", 963 dtype=tf.complex64, 964 **kwargs): 965 super().__init__(dtype=dtype, **kwargs) 966 assert dtype in [tf.complex64, tf.complex128],\ 967 "dtype must be tf.complex64 or tf.complex128." 968 969 assert output in ("bit", "symbol"), "Unknown output" 970 971 err_msg = "You must provide either constellation or " + \ 972 "constellation_type and num_bits_per_symbol." 973 if constellation is None: 974 assert constellation_type is not None and \ 975 num_bits_per_symbol is not None, err_msg 976 else: 977 assert constellation_type is None and \ 978 num_bits_per_symbol is None, err_msg 979 980 if constellation is not None: 981 assert constellation.points.dtype==dtype, \ 982 "Constellation has wrong dtype." 983 984 self._output = output 985 self._hard_out = hard_out 986 self._use_real_rep = use_real_rep 987 988 if self._use_real_rep: 989 # Real-valued representation is used 990 err_msg = "Only QAM can be used for the real-valued representation" 991 if constellation_type is not None: 992 assert constellation_type=="qam", err_msg 993 else: 994 assert constellation._constellation_type=="qam", err_msg 995 996 # Double the number of streams to dectect 997 self._num_streams = 2*num_streams 998 999 # Half the number of bits for the PAM constellation 1000 if num_bits_per_symbol is None: 1001 n = constellation.num_bits_per_symbol//2 1002 self._num_bits_per_symbol = n 1003 else: 1004 self._num_bits_per_symbol = num_bits_per_symbol//2 1005 1006 # Geerate a PAM constellation with 0.5 energy 1007 c = Constellation("pam", 1008 self._num_bits_per_symbol, 1009 normalize=False, 1010 dtype=dtype) 1011 c._points /= tf.cast(np.std(c._points)*np.sqrt(2), c._points.dtype) 1012 self._constellation = tf.cast(c.points, dtype.real_dtype) 1013 1014 self._pam2qam = PAM2QAM(2*self._num_bits_per_symbol) 1015 1016 else: 1017 # Complex-valued representation is used 1018 # Number of streams is equal to number of transmitters 1019 self._num_streams = num_streams 1020 1021 # Create constellation or take the one provided 1022 c = Constellation.create_or_check_constellation( 1023 constellation_type, 1024 num_bits_per_symbol, 1025 constellation, 1026 dtype=dtype) 1027 self._constellation = c.points 1028 self._num_bits_per_symbol = c.num_bits_per_symbol 1029 1030 # Number of constellation symbols 1031 self._num_symbols = self._constellation.shape[0] 1032 1033 # Number of best paths to keep 1034 self._k = np.minimum(k, self._num_symbols**self._num_streams) 1035 if self._k < k: 1036 msg = "KBestDetector: " + \ 1037 f"The provided value of k={k} is larger than " + \ 1038 "the possible maximum number of paths. " + \ 1039 f"It has been set to k={self._k}." 1040 warnings.warn(msg) 1041 1042 # Compute the number of previous paths a layer needs to consider 1043 num_paths = [1] # The first layer considers a single path 1044 for l in range(1, self._num_streams+1): 1045 # The lth layer considers min(k, num_symbols**l) paths 1046 num_paths.append(np.minimum(self._k, self._num_symbols**l)) 1047 self._num_paths = tf.constant(tf.stack(num_paths, 0), tf.int32) 1048 1049 # The symbols and indices for all paths will be stored in tensors 1050 # of shape [batch_size, k, num_streams]. However, only 1051 # a subset of the available entries are updated by each stream. 1052 # To enable XLA, we need to compute the relevant indices of the tensors 1053 # that will be updated through tf.tensor_scatter_nd_update. 1054 indices = np.zeros([self._num_streams, self._k*self._num_streams, 2], 1055 np.int32) 1056 for l in range(0, self._num_streams): 1057 ind = np.zeros([self._num_paths[l+1], self._num_streams]) 1058 ind[:, :l+1] = 1 1059 ind = np.stack(np.where(ind), -1) 1060 indices[l,:ind.shape[0],:ind.shape[1]] = ind 1061 self._indices = tf.constant(indices, dtype=tf.int32) 1062 1063 if self._output=="bit": 1064 if self._hard_out is False: 1065 if list2llr=="default": 1066 self.list2llr = List2LLRSimple(self._num_bits_per_symbol) 1067 else: 1068 self.list2llr = list2llr 1069 else: 1070 if self._use_real_rep: 1071 n = 2*self._num_bits_per_symbol 1072 else: 1073 n = self._num_bits_per_symbol 1074 self._symbolinds2bits = SymbolInds2Bits(n, 1075 dtype=dtype.real_dtype) 1076 else: 1077 assert self._hard_out is True, \ 1078 "Soft-symbols are not supported for this detector." 1079 1080 @property 1081 def list2llr(self): 1082 return self._list2llr 1083 1084 @list2llr.setter 1085 def list2llr(self, value): 1086 assert isinstance(value, List2LLR) 1087 self._list2llr = value 1088 1089 def _preprocessing(self, inputs): 1090 1091 y, h, s = inputs 1092 1093 # Convert to real-valued representation if desired 1094 if self._use_real_rep: 1095 y, h, s = complex2real_channel(y, h, s) 1096 1097 # Whiten channel 1098 y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=W0632 1099 1100 # Order columns of H in order of decreasing norm 1101 h_norm = tf.reduce_sum(tf.abs(h)**2, axis=1) 1102 column_order = tf.argsort(h_norm, axis=-1, direction="DESCENDING") 1103 h = tf.gather(h, column_order, axis=-1, batch_dims=1) 1104 1105 # Compute QR decomposition of sorted channel 1106 # r is upper triangular 1107 q, r = tf.linalg.qr(h) 1108 1109 # Project y on Q' 1110 y = tf.squeeze(tf.matmul(q, tf.expand_dims(y, -1), adjoint_a=True), 1111 -1) 1112 1113 return y, r, column_order 1114 1115 def _select_best_paths(self, dists, path_syms, path_inds): 1116 1117 # Determine the number of paths to keep (either all or k) 1118 num_paths = tf.shape(path_syms)[1] 1119 k = tf.minimum(num_paths, self._k) 1120 1121 # Get the k paths with the shortest distance 1122 dists, ind = tf.math.top_k(-dists, k=k, sorted=True) 1123 dists = -dists 1124 1125 # Select the same best paths for the symbols and symbol indices 1126 path_syms = tf.gather(path_syms, ind, axis=1, batch_dims=1) 1127 path_inds = tf.gather(path_inds, ind, axis=1, batch_dims=1) 1128 1129 return dists, path_syms, path_inds 1130 1131 def _next_layer(self, y, r, dists, path_syms, path_inds, stream): 1132 1133 batch_size = tf.shape(y)[0] 1134 1135 # Streams are processed in reverse order 1136 stream_ind = self._num_streams-1-stream 1137 1138 # Current number of considered paths 1139 num_paths = tf.gather(self._num_paths, stream) 1140 1141 # Store input tensors for scatter update later on 1142 dists_o = dists 1143 path_syms_o = path_syms 1144 path_inds_o = path_inds 1145 1146 # Extract relevant values from input tensor 1147 dists = dists[..., :num_paths] 1148 path_syms = path_syms[..., :num_paths, :stream] 1149 path_inds = path_inds[..., :num_paths, :stream] 1150 1151 # Each path creates num_symbols branches 1152 dists = tf.repeat(dists, repeats=self._num_symbols, axis=1) 1153 path_syms = tf.repeat(path_syms, repeats=self._num_symbols, axis=1) 1154 path_inds = tf.repeat(path_inds, repeats=self._num_symbols, axis=1) 1155 1156 # Append to each path the symbols corresponding to the branch 1157 syms = tf.reshape(self._constellation, [1,-1]) 1158 syms = tf.repeat(syms, self._k, 0) 1159 syms = tf.reshape(syms, [1, -1, 1]) 1160 syms = tf.repeat(syms, batch_size, 0) 1161 syms = syms[:,:num_paths*self._num_symbols] 1162 path_syms = tf.concat([path_syms, syms], axis=-1) 1163 1164 # Do the same for the symbol indices 1165 inds = tf.reshape(tf.range(0, self._num_symbols), [1, -1]) 1166 inds = tf.repeat(inds, self._k, 0) 1167 inds = tf.reshape(inds, [1, -1, 1]) 1168 inds = tf.repeat(inds, batch_size, 0) 1169 inds = inds[:,:num_paths*self._num_symbols] 1170 path_inds = tf.concat([path_inds, inds], axis=-1) 1171 1172 # Compute partial distances 1173 # Extract the row of r corresponding to layer and reverse the order 1174 y = tf.expand_dims(y[:, stream_ind], axis=-1) 1175 r = tf.expand_dims(tf.reverse(r[:, stream_ind, stream_ind:], [-1]), 1) 1176 delta = tf.pow(tf.abs(y - tf.reduce_sum(r*path_syms, axis=-1)), 2) 1177 1178 # Update distances 1179 dists += delta 1180 1181 # Get k best paths 1182 dists, path_syms, path_inds = self._select_best_paths(dists, path_syms, path_inds) 1183 1184 # Scatter updates of dists 1185 tensor = tf.transpose(dists_o, perm=[1, 0]) 1186 updates = tf.transpose(dists, perm=[1, 0]) 1187 indices = tf.expand_dims(tf.range(tf.shape(updates)[0], dtype=tf.int32), -1) 1188 dists = tf.tensor_scatter_nd_update(tensor, indices, updates) 1189 dists = tf.transpose(dists, perm=[1, 0]) 1190 1191 # Scatter update of path_syms 1192 tensor = tf.transpose(path_syms_o, [1, 2, 0]) 1193 updates = tf.transpose(path_syms, [1, 2, 0]) 1194 updates = tf.reshape(updates, [-1, batch_size]) 1195 indices = self._indices[stream, :self._num_paths[stream+1]*(stream+1)] 1196 path_syms = tf.tensor_scatter_nd_update(tensor, indices, updates) 1197 path_syms = tf.transpose(path_syms, perm=[2, 0, 1]) 1198 1199 # Scatter update of path_inds 1200 tensor = tf.transpose(path_inds_o, [1, 2, 0]) 1201 updates = tf.transpose(path_inds, [1, 2, 0]) 1202 updates = tf.reshape(updates, [-1, batch_size]) 1203 path_inds = tf.tensor_scatter_nd_update(tensor, indices, updates) 1204 path_inds = tf.transpose(path_inds, perm=[2, 0, 1]) 1205 1206 return dists, path_syms, path_inds 1207 1208 def _unsort(self, column_order, tensor, transpose=True): 1209 # Undo the column sorting 1210 # If transpose=True, the unsorting is done along the last dimension 1211 # Otherwise, sorting is done along the second-last index 1212 unsort_inds = tf.argsort(column_order, axis=-1) 1213 if transpose: 1214 tensor = tf.transpose(tensor, perm=[0, 2, 1]) 1215 tensor = tf.gather(tensor, unsort_inds, axis=-2, batch_dims=1) 1216 if transpose: 1217 tensor = tf.transpose(tensor, perm=[0, 2, 1]) 1218 return tensor 1219 1220 def build(self, input_shape): 1221 assert input_shape[1][-2]>=input_shape[1][-1], \ 1222 "The number of receive antennas cannot be smaller \ 1223 than the number of streams" 1224 1225 def call(self, inputs): 1226 1227 # Flatten the batch dimensions 1228 y, h, s = inputs 1229 batch_shape = tf.shape(y)[:-1] 1230 num_batch_dims = len(batch_shape) 1231 if num_batch_dims > 1: 1232 y = flatten_dims(y, num_batch_dims, 0) 1233 h = flatten_dims(h, num_batch_dims, 0) 1234 s = flatten_dims(s, num_batch_dims, 0) 1235 inputs = (y,h,s) 1236 1237 # Initialization 1238 # (i) (optional) Convert to real-valued representation 1239 # (ii) Whiten channel 1240 # (iii) Sort columns of H by decreasing column norm 1241 # (iv) QR Decomposition of H 1242 # (v) Project y onto Q' 1243 y, r, column_order = self._preprocessing(inputs) 1244 1245 batch_size = tf.shape(y)[0] 1246 1247 # Tensor to keep track of the aggregate distances of all paths 1248 dists = tf.zeros([batch_size, self._k], y.dtype.real_dtype) 1249 1250 # Tensor to store constellation symbols of all paths 1251 path_syms = tf.zeros([batch_size, self._k, self._num_streams], y.dtype) 1252 1253 # Tensor to store constellation symbol indices of all paths 1254 path_inds = tf.zeros([batch_size, self._k, self._num_streams],tf.int32) 1255 1256 # Sequential K-Best algorithm 1257 for stream in range(0, self._num_streams): 1258 dists, path_syms, path_inds = self._next_layer(y, 1259 r, 1260 dists, 1261 path_syms, 1262 path_inds, 1263 stream) 1264 1265 # Reverse order as detection started with the last symbol first 1266 path_syms = tf.reverse(path_syms, axis=[-1]) 1267 path_inds = tf.reverse(path_inds, axis=[-1]) 1268 1269 # Processing for hard-decisions 1270 if self._hard_out: 1271 path_inds = self._unsort(column_order, path_inds) 1272 hard_dec = path_inds[:,0,:] 1273 1274 # Real-valued representation 1275 if self._use_real_rep: 1276 hard_dec = \ 1277 self._pam2qam(hard_dec[...,:self._num_streams//2], 1278 hard_dec[...,self._num_streams//2:]) 1279 1280 # Hard decisions on bits 1281 if self._output=="bit": 1282 hard_dec = self._symbolinds2bits(hard_dec) 1283 1284 # Reshape batch dimensions 1285 if num_batch_dims > 1: 1286 hard_dec = split_dim(hard_dec, batch_shape, 0) 1287 1288 return hard_dec 1289 1290 # Processing for soft-decisions 1291 else: 1292 # Real-valued representation 1293 if self._use_real_rep: 1294 llr = self.list2llr([y, r, dists, path_inds, path_syms]) 1295 llr = self._unsort(column_order, llr, transpose=False) 1296 1297 # Combine LLRs from PAM symbols in the correct order 1298 llr1 = llr[:,:self._num_streams//2] 1299 llr2 = llr[:,self._num_streams//2:] 1300 llr1 = tf.expand_dims(llr1, -1) 1301 llr2 = tf.expand_dims(llr2, -1) 1302 llr = tf.concat([llr1, llr2], -1) 1303 llr = tf.reshape(llr, [-1, self._num_streams//2, 1304 2*self._num_bits_per_symbol]) 1305 1306 # Complex-valued representation 1307 else: 1308 llr = self.list2llr([y, r, dists, path_inds, path_syms]) 1309 llr = self._unsort(column_order, llr, transpose=False) 1310 1311 # Reshape batch dimensions 1312 if num_batch_dims > 1: 1313 llr = split_dim(llr, batch_shape, 0) 1314 1315 return llr 1316 1317 class EPDetector(Layer): 1318 # pylint: disable=line-too-long 1319 r"""EPDetector(output, num_bits_per_symbol, hard_out=False, l=10, beta=0.9, dtype=tf.complex64) 1320 1321 MIMO Expectation Propagation (EP) detector 1322 1323 This layer implements Expectation Propagation (EP) MIMO detection as described 1324 in [EP2014]_. It can generate hard- or soft-decisions for symbols or bits. 1325 1326 This layer assumes the following channel model: 1327 1328 .. math:: 1329 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 1330 1331 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 1332 :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which 1333 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 1334 :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix, 1335 and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. 1336 It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and 1337 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, 1338 where :math:`\mathbf{S}` has full rank. 1339 1340 The channel model is first whitened using :func:`~sionna.mimo.whiten_channel` 1341 and then converted to its real-valued equivalent, 1342 see :func:`~sionna.mimo.complex2real_channel`, prior to MIMO detection. 1343 1344 The computation of LLRs is done by converting the symbol logits 1345 that naturally arise in the algorithm to LLRs using 1346 :func:`~sionna.mapping.PAM2QAM`. Custom conversions of symbol logits to LLRs 1347 can be implemented by using the soft-symbol output. 1348 1349 Parameters 1350 ----------- 1351 output : One of ["bit", "symbol"], str 1352 The type of output, either bits or symbols. Whether soft- or 1353 hard-decisions are returned can be configured with the 1354 ``hard_out`` flag. 1355 1356 num_bits_per_symbol : int 1357 The number of bits per QAM constellation symbol, e.g., 4 for QAM16. 1358 1359 hard_out : bool 1360 If `True`, the detector computes hard-decided bit values or 1361 constellation point indices instead of soft-values. 1362 Defaults to `False`. 1363 1364 l : int 1365 Number of iterations. Defaults to 10. 1366 1367 beta : float 1368 Parameter :math:`\beta\in[0,1]` for update smoothing. 1369 Defaults to 0.9. 1370 1371 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 1372 Precision used for internal computations. Defaults to ``tf.complex64``. 1373 Especially for large MIMO setups, the precision can make a significant 1374 performance difference. 1375 1376 Input 1377 ----- 1378 (y, h, s) : 1379 Tuple: 1380 1381 y : [...,M], tf.complex 1382 1+D tensor containing the received signals 1383 1384 h : [...,M,num_streams], tf.complex 1385 2+D tensor containing the channel matrices 1386 1387 s : [...,M,M], tf.complex 1388 2+D tensor containing the noise covariance matrices 1389 1390 Output 1391 ------ 1392 One of: 1393 1394 : [...,num_streams,num_bits_per_symbol], tf.float 1395 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` 1396 1397 : [...,num_streams,2**num_bits_per_symbol], tf.float or [...,num_streams], tf.int 1398 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` 1399 1400 Note 1401 ---- 1402 For numerical stability, we do not recommend to use this function in Graph 1403 mode with XLA, i.e., within a function that is decorated with 1404 ``@tf.function(jit_compile=True)``. 1405 However, it is possible to do so by setting 1406 ``sionna.Config.xla_compat=true``. 1407 See :py:attr:`~sionna.Config.xla_compat`. 1408 """ 1409 def __init__(self, 1410 output, 1411 num_bits_per_symbol, 1412 hard_out=False, 1413 l=10, 1414 beta=0.9, 1415 dtype=tf.complex64, 1416 **kwargs): 1417 super().__init__(dtype=dtype, **kwargs) 1418 assert dtype in [tf.complex64, tf.complex128], \ 1419 "Invalid dtype" 1420 self._cdtype = tf.dtypes.as_dtype(dtype) 1421 self._rdtype = self._cdtype.real_dtype 1422 1423 # Variable used to avoid numerical instabilities 1424 # See paragraph after Eq. (38) 1425 if self.dtype=="complex64": 1426 self._prec = 1e-6 1427 else: 1428 self._prec = 1e-12 1429 1430 assert output in ("bit", "symbol"), "Unknown output" 1431 self._output = output 1432 1433 self._hard_out = hard_out 1434 1435 if self._output=="symbol": 1436 self._pam2qam = PAM2QAM(num_bits_per_symbol, hard_out) 1437 else: 1438 self._symbollogits2llrs = SymbolLogits2LLRs("maxlog", 1439 num_bits_per_symbol//2, 1440 hard_out=hard_out) 1441 self._demapper = Demapper("maxlog", "pam", num_bits_per_symbol//2) 1442 1443 assert l>=1, "l must be a positive integer" 1444 self._l = l 1445 1446 assert 0.0<= beta <=1.0, "beta must be in [0,1]" 1447 self._beta = beta 1448 1449 # Create PAM constellations for real-valued detection 1450 self._num_bits_per_symbol = num_bits_per_symbol//2 1451 points = Constellation("pam", int(self._num_bits_per_symbol)).points 1452 1453 # Scale constellation points to half the energy because QAM is assumed 1454 self._points = tf.cast(points/np.sqrt(2.0), self._rdtype) 1455 1456 # Average symbol energy 1457 self._es = tf.constant(np.var(self._points), self._rdtype) 1458 1459 def compute_sigma_mu(self, h_t_h, h_t_y, no, lam, gam): 1460 """Equations (28) and (29)""" 1461 1462 # Prepare inputs 1463 lam = tf.linalg.diag(lam) 1464 gam = tf.expand_dims(gam, axis=-1) 1465 1466 # Computations 1467 sigma = tf.linalg.inv(h_t_h + no*lam) 1468 mu = tf.squeeze(tf.matmul(sigma, h_t_y + no*gam), axis=-1) 1469 sigma *= no 1470 sigma = tf.linalg.diag_part(sigma) 1471 1472 return sigma, mu 1473 1474 def compute_v_x_obs(self, sigma, mu, lam, gam): 1475 """Equations (31) and (32)""" 1476 1477 v_obs = tf.maximum(1/(1/sigma-lam), self._prec) 1478 x_obs = v_obs*(mu/sigma-gam) 1479 1480 return v_obs, x_obs 1481 1482 def compute_v_x(self, v_obs, x_obs): 1483 """Equation (33)""" 1484 1485 # Compute probability mass function for the symbols 1486 x_obs = tf.expand_dims(x_obs, -1) 1487 v_obs = tf.expand_dims(v_obs, -1) 1488 1489 points = expand_to_rank(self._points, tf.rank(x_obs), axis=0) 1490 logits = -tf.pow(x_obs-points, 2) / (tf.cast(2, self._rdtype)*v_obs) 1491 pmf = tf.math.softmax(logits) 1492 1493 # Compute mean and variance of all symbols 1494 x = tf.reduce_sum(points * pmf, axis=-1, keepdims=True) 1495 v = tf.reduce_sum((points-x)**2 * pmf, axis=-1) 1496 v = tf.maximum(v, self._prec) 1497 x = tf.squeeze(x, axis=-1) 1498 1499 return v, x, logits 1500 1501 def update_lam_gam(self, v, v_obs, x, x_obs, lam, gam): 1502 """Equations (35), (36), (37), (38)""" 1503 1504 # Save old values of lam, and gam 1505 lam_old = lam 1506 gam_old = gam 1507 1508 # Compute potential new values (35), (36) 1509 lam = 1/v - 1/v_obs 1510 gam = x/v - x_obs/v_obs 1511 1512 # Only update nonnegative values 1513 lam_new = tf.where(lam<0, lam_old, lam) 1514 gam_new = tf.where(lam<0, gam_old, gam) 1515 1516 # Damp updates (37), (38) 1517 lam_damp = (1-self._beta)*lam_new + self._beta*lam_old 1518 gam_damp = (1-self._beta)*gam_new + self._beta*gam_old 1519 1520 return lam_damp, gam_damp 1521 1522 def call(self, inputs): 1523 1524 # Flatten the batch dimensions 1525 y, h, s = inputs 1526 batch_shape = tf.shape(y)[:-1] 1527 num_batch_dims = len(batch_shape) 1528 if num_batch_dims > 1: 1529 y = flatten_dims(y, num_batch_dims, 0) 1530 h = flatten_dims(h, num_batch_dims, 0) 1531 s = flatten_dims(s, num_batch_dims, 0) 1532 inputs = (y,h,s) 1533 1534 # Number of transmit streams 1535 n_t = tf.shape(h)[-1] 1536 1537 # Whiten channel 1538 y, h, s = whiten_channel(y, h, s) 1539 1540 # Convert channel to real-valued representation 1541 y, h, s = complex2real_channel(y,h,s) 1542 1543 # Convert all inputs to desired dtypes 1544 y = tf.cast(y, self._rdtype) 1545 h = tf.cast(h, self._rdtype) 1546 no = tf.cast(0.5, self._rdtype) 1547 1548 # Gather relevant parameters 1549 batch_dims = tf.shape(y)[:-1] 1550 n_t_r = tf.shape(h)[-1] 1551 1552 # Initialize gamma and lambda (Paragraph after Eq. (29)) 1553 gam = tf.zeros(tf.concat([batch_dims, [n_t_r]], axis=0), y.dtype) 1554 lam = tf.ones(tf.concat([batch_dims, [n_t_r]], axis=0), y.dtype) 1555 lam /= tf.cast(self._es, y.dtype) 1556 1557 # Precompute values that are repeatedly needed 1558 h_t_h = tf.matmul(h, h, transpose_a=True) 1559 y = tf.expand_dims(y, axis=-1) 1560 h_t_y = tf.matmul(h, y, transpose_a=True) 1561 no = expand_to_rank(no, tf.rank(h), axis=-1) 1562 1563 for _ in range(self._l): 1564 sigma, mu = self.compute_sigma_mu(h_t_h, h_t_y, no, lam, gam) 1565 v_obs, x_obs = self.compute_v_x_obs(sigma, mu, lam, gam) 1566 v, x, logits = self.compute_v_x(v_obs, x_obs) 1567 lam, gam = self.update_lam_gam(v, v_obs, x, x_obs, lam, gam) 1568 1569 # Extract the logits for the 2 PAM constellations for each streams 1570 pam1_logits = logits[...,:n_t,:] 1571 pam2_logits = logits[...,n_t:,:] 1572 1573 if self._output=="symbol" and self._hard_out: 1574 # Take hard decisions on PAM symbol;s 1575 pam1_ind = tf.argmax(pam1_logits, axis=-1, output_type=tf.int32) 1576 pam2_ind = tf.argmax(pam2_logits, axis=-1, output_type=tf.int32) 1577 1578 # Transform to QAM indices 1579 qam_ind = self._pam2qam(pam1_ind, pam2_ind) 1580 1581 # Reshape batch dimensions 1582 if num_batch_dims > 1: 1583 qam_ind = split_dim(qam_ind, batch_shape, 0) 1584 1585 return qam_ind 1586 1587 elif self._output=="symbol" and not self._hard_out: 1588 qam_logits = self._pam2qam(pam1_logits, pam2_logits) 1589 1590 # Reshape batch dimensions 1591 if num_batch_dims > 1: 1592 qam_logits = split_dim(qam_logits, batch_shape, 0) 1593 1594 return qam_logits 1595 1596 elif self._output=="bit": 1597 # Compute LLRs for both PAM constellations 1598 llr1 = self._symbollogits2llrs(pam1_logits) 1599 llr2 = self._symbollogits2llrs(pam2_logits) 1600 1601 # Put LLRs in the correct order and shape 1602 llr = tf.stack([llr1, llr2], -1) 1603 llr = flatten_last_dims(llr) 1604 1605 # Reshape batch dimensions 1606 if num_batch_dims > 1: 1607 llr = split_dim(llr, batch_shape, 0) 1608 1609 return llr 1610 1611 class MMSEPICDetector(Layer): 1612 # pylint: disable=line-too-long 1613 r"""MMSEPICDetector(output, demapping_method="maxlog", num_iter=1, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 1614 1615 Minimum mean square error (MMSE) with parallel interference cancellation (PIC) detector 1616 1617 This layer implements the MMSE PIC detector, as proposed in [CST2011]_. 1618 For ``num_iter``>1, this implementation performs MMSE PIC self-iterations. 1619 MMSE PIC self-iterations can be understood as a concatenation of MMSE PIC 1620 detectors from [CST2011]_, which forward intrinsic LLRs to the next 1621 self-iteration. 1622 1623 Compared to [CST2011]_, this implementation also accepts priors on the 1624 constellation symbols as an alternative to priors on the bits. 1625 1626 This layer assumes the following channel model: 1627 1628 .. math:: 1629 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 1630 1631 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 1632 :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which 1633 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 1634 :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix, 1635 and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector. 1636 It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and 1637 :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`, 1638 where :math:`\mathbf{S}` has full rank. 1639 1640 The algorithm starts by computing the soft symbols 1641 :math:`\bar{x}_s=\mathbb{E}\left[ x_s \right]` and 1642 variances :math:`v_s=\mathbb{E}\left[ |e_s|^2\right]` from the priors, 1643 where :math:`e_s = x_s - \bar{x}_s`, for all :math:`s=1,\dots,S`. 1644 1645 Next, for each stream, the interference caused by all other streams is cancelled 1646 from the observation :math:`\mathbf{y}`, leading to 1647 1648 .. math:: 1649 \hat{\mathbf{y}}_s = \mathbf{y} - \sum_{j\neq s} \mathbf{h}_j x_j = \mathbf{h}_s x_s + \tilde{\mathbf{n}}_s,\quad s=1,\dots,S 1650 1651 where :math:`\tilde{\mathbf{n}}_s=\sum_{j\neq s} \mathbf{h}_j e_j + \mathbf{n}`. 1652 1653 Then, a linear MMSE filter :math:`\mathbf{w}_s` is computed to reduce the resdiual noise 1654 for each observation :math:`\hat{\mathbf{y}}_s`, which is given as 1655 1656 .. math:: 1657 \mathbf{w}_s = \mathbf{h}_s^{\mathsf{H}}\left( \mathbf{H} \mathbf{D}_s\mathbf{H}^{\mathsf{H}} +\mathbf{S} \right)^{-1} 1658 1659 where :math:`\mathbf{D}_s \in \mathbb{C}^{S\times S}` is diagonal with entries 1660 1661 .. math:: 1662 \left[\mathbf{D}_s\right]_{i,i} = \begin{cases} 1663 v_i & i\neq s \\ 1664 1 & i=s. 1665 \end{cases} 1666 1667 The filtered observations 1668 1669 .. math:: 1670 \tilde{z}_s = \mathbf{w}_s^{\mathsf{H}} \hat{\mathbf{y}}_s = \tilde{\mu}_s x_s + \mathbf{w}_s^{\mathsf{H}}\tilde{\mathbf{n}}_s 1671 1672 where :math:`\tilde{\mu}_s=\mathbf{w}_s^{\mathsf{H}} \mathbf{h}_s`, are then demapped to either symbol logits or LLRs, assuming that the remaining noise is Gaussian with variance 1673 1674 .. math:: 1675 \nu_s^2 = \mathop{\text{Var}}\left[\tilde{z}_s\right] = \mathbf{w}_s^{\mathsf{H}} \left(\sum_{j\neq s} \mathbf{h}_j \mathbf{h}_j^{\mathsf{H}} v_j +\mathbf{S} \right)\mathbf{w}_s. 1676 1677 The resulting soft-symbols can then be used for the next self-iteration of the algorithm. 1678 1679 Note that this algorithm can be substantially simplified as described in [CST2011]_ to avoid 1680 the computation of different matrix inverses for each stream. This is the version which is 1681 implemented. 1682 1683 Parameters 1684 ----------- 1685 output : One of ["bit", "symbol"], str 1686 The type of output, either LLRs on bits or logits on constellation 1687 symbols. 1688 1689 demapping_method : One of ["app", "maxlog"], str 1690 The demapping method used. 1691 Defaults to "maxlog". 1692 1693 num_iter : int 1694 Number of MMSE PIC iterations. 1695 Defaults to 1. 1696 1697 constellation_type : One of ["qam", "pam", "custom"], str 1698 For "custom", an instance of :class:`~sionna.mapping.Constellation` 1699 must be provided. 1700 1701 num_bits_per_symbol : int 1702 The number of bits per constellation symbol, e.g., 4 for QAM16. 1703 Only required for ``constellation_type`` in ["qam", "pam"]. 1704 1705 constellation : Constellation 1706 An instance of :class:`~sionna.mapping.Constellation` or `None`. 1707 In the latter case, ``constellation_type`` 1708 and ``num_bits_per_symbol`` must be provided. 1709 1710 hard_out : bool 1711 If `True`, the detector computes hard-decided bit values or 1712 constellation point indices instead of soft-values. 1713 Defaults to `False`. 1714 1715 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 1716 The dtype of ``y``. Defaults to tf.complex64. 1717 The output dtype is the corresponding real dtype 1718 (tf.float32 or tf.float64). 1719 1720 Input 1721 ----- 1722 (y, h, prior, s) : 1723 Tuple: 1724 1725 y : [...,M], tf.complex 1726 1+D tensor containing the received signals 1727 1728 h : [...,M,S], tf.complex 1729 2+D tensor containing the channel matrices 1730 1731 prior : [...,S,num_bits_per_symbol] or [...,S,num_points], tf.float 1732 Prior of the transmitted signals. 1733 If ``output`` equals "bit", then LLRs of the transmitted bits are expected. 1734 If ``output`` equals "symbol", then logits of the transmitted constellation points are expected. 1735 1736 s : [...,M,M], tf.complex 1737 2+D tensor containing the noise covariance matrices 1738 1739 Output 1740 ------ 1741 One of: 1742 1743 : [...,S,num_bits_per_symbol], tf.float 1744 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` 1745 1746 : [...,S,2**num_bits_per_symbol], tf.float or [...,S], tf.int 1747 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"` 1748 1749 Note 1750 ---- 1751 For numerical stability, we do not recommend to use this function in Graph 1752 mode with XLA, i.e., within a function that is decorated with 1753 ``@tf.function(jit_compile=True)``. 1754 However, it is possible to do so by setting 1755 ``sionna.Config.xla_compat=true``. 1756 See :py:attr:`~sionna.Config.xla_compat`. 1757 """ 1758 def __init__(self, 1759 output, 1760 demapping_method="maxlog", 1761 num_iter=1, 1762 constellation_type=None, 1763 num_bits_per_symbol=None, 1764 constellation=None, 1765 hard_out=False, 1766 dtype=tf.complex64, 1767 **kwargs): 1768 super().__init__(dtype=dtype, **kwargs) 1769 1770 assert isinstance(num_iter, int), "num_iter must be an integer" 1771 assert output in ("bit", "symbol"), "Unknown output" 1772 assert demapping_method in ("app", "maxlog"), "Unknown demapping method" 1773 1774 assert dtype in [tf.complex64, tf.complex128], \ 1775 "dtype must be tf.complex64 or tf.complex128" 1776 1777 self._num_iter = num_iter 1778 self._output = output 1779 self._epsilon = 1e-4 1780 self._realdtype = dtype.real_dtype 1781 self._demapping_method = demapping_method 1782 self._hard_out = hard_out 1783 1784 # Create constellation object 1785 self._constellation = Constellation.create_or_check_constellation( 1786 constellation_type, 1787 num_bits_per_symbol, 1788 constellation, 1789 dtype=dtype) 1790 1791 # Soft symbol mapping 1792 self._llr_2_symbol_logits = LLRs2SymbolLogits( 1793 self._constellation.num_bits_per_symbol, 1794 dtype=self._realdtype) 1795 1796 if self._output == "symbol": 1797 self._llr_2_symbol_logits_output = LLRs2SymbolLogits( 1798 self._constellation.num_bits_per_symbol, 1799 dtype=self._realdtype, 1800 hard_out=hard_out) 1801 self._symbol_logits_2_llrs = SymbolLogits2LLRs( 1802 method=demapping_method, 1803 num_bits_per_symbol=self._constellation.num_bits_per_symbol) 1804 self._symbol_logits_2_moments = SymbolLogits2Moments( 1805 constellation=self._constellation, 1806 dtype=self._realdtype) 1807 1808 # soft output demapping 1809 self._bit_demapper = DemapperWithPrior( 1810 demapping_method=demapping_method, 1811 constellation=self._constellation, 1812 dtype=dtype) 1813 1814 1815 def call(self, inputs): 1816 y, h, prior, s = inputs 1817 # y is unwhitened receive signal 1818 # [..., M] 1819 # h the channel estimate 1820 # [..., M, K] 1821 # prior is either the soft input LLRs 1822 # [..., K, num_bits_per_symbol] or symbol logits [..., K, Q] 1823 # s the noise covariance matrix 1824 # [..., M, M] 1825 1826 ## Preprocessing 1827 # Whiten channel 1828 # y : [..., M] 1829 # s : [..., M, M] 1830 y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=unbalanced-tuple-unpacking 1831 1832 # matched filtering of y 1833 # [..., K, 1] 1834 y_mf = insert_dims(tf.linalg.matvec(h, y, adjoint_a=True), 1835 num_dims=1, axis=-1) 1836 1837 ## Step 1: compute Gramm matrix 1838 # [..., K, K] 1839 g = tf.matmul(h, h, adjoint_a=True) 1840 1841 # For XLA compatibility, this implementation performs the MIMO 1842 # equalization in the real-valued domain 1843 # [..., 2M, 2K] 1844 hr = complex2real_matrix(h) 1845 # [..., 2K, 2K] 1846 gr = tf.matmul(hr, hr, adjoint_a=True) 1847 1848 # Compute a priori LLRs 1849 if self._output == "symbol": 1850 llr_a = self._symbol_logits_2_llrs(prior) 1851 else: 1852 llr_a = prior 1853 # llr_a is [..., K, num_bits_per_symbol] 1854 llr_shape = tf.shape(llr_a) 1855 1856 def mmse_pic_self_iteration(llr_d, llr_a, it): 1857 # MMSE PIC takes in a priori LLRs 1858 llr_a = llr_d 1859 1860 # Step 2: compute soft symbol estimates and variances 1861 # x_hat, var_x : [..., K] 1862 x_logits = self._llr_2_symbol_logits(llr_a) 1863 x_hat, var_x = self._symbol_logits_2_moments(x_logits) 1864 1865 # Step 3: perform parallel interference cancellation 1866 # H^H y_hat_i = y_mf - sum_j!=i gj x_hat_j = y + g_i x_hat_i 1867 # - sum_j g_j x_hat_j 1868 # [..., K, K] 1869 y_mf_pic = y_mf + g * insert_dims(x_hat, num_dims=1, axis=-2) \ 1870 - tf.linalg.matmul(g, insert_dims(x_hat, num_dims=1, axis=-1)) 1871 1872 # Step 4: compute A^-1 matrix 1873 # Calculate MMSE Filter (efficiently) 1874 # W^H = A^-1 H^H 1875 # A = H^H H \Lambda + N_0 I_Mt 1876 # \Lambda_ii is a diagonal matrix with \Lambda_ii = E_i = error_var 1877 1878 # Stack error variances and make it real 1879 # Note: Imaginary part is zero 1880 var_x = tf.cast(tf.concat([var_x, var_x], axis=-1), 1881 dtype=self._realdtype) 1882 var_x_row_vec = insert_dims(var_x, num_dims=1, axis=-2) 1883 # [..., 2K, 2K] 1884 a = gr * var_x_row_vec 1885 1886 i = expand_to_rank(tf.eye(tf.shape(a)[-1], dtype=a.dtype), 1887 tf.rank(a), 0) 1888 a = a + i 1889 1890 # a is non-hermitian! that's why we can't use sn.utils.matrix_inv 1891 # XLA can't invert complex matrices, that's why we work with the 1892 # real valued domain 1893 a_inv = tf.linalg.inv(a) 1894 1895 # Step 5: compute unbiased MMSE filter and outputs, calculate A\H^H 1896 1897 # Calculate bias mu_i = diag(A^-1 H^H H) = diag(A^-1 G) 1898 # Diagonal elements of matrix matrix multiplication simplified 1899 # to sum and dot-product 1900 # [..., 2K] 1901 mu = tf.reduce_sum(a_inv * tf.linalg.matrix_transpose(gr), axis=-1) 1902 1903 # Make y_mf_pic columns real (after transposition, 1904 # the last dimension corresponds to vectors) 1905 # [..., K, 2K] 1906 y_mf_pic_trans = tf.linalg.matrix_transpose(y_mf_pic) 1907 y_mf_pic_trans = complex2real_vector(y_mf_pic_trans) 1908 # stack them such that y_mf_pic_trans has shape [..., 2K, 2K] 1909 y_mf_pic_trans = tf.concat([y_mf_pic_trans, y_mf_pic_trans], 1910 axis=-2) 1911 1912 # Efficient parallel equalization after PIC 1913 # z_i = i'th row of a_inv * y_MF_PIC_i 1914 # boils down to tf.reduce_sum(a_inv * y_mf_pic_trans, axis=-1) 1915 # divide by mu_i for unbiasedness 1916 # [..., K] 1917 x_hat = real2complex_vector(tf.reduce_sum(a_inv * y_mf_pic_trans, 1918 axis=-1) / tf.cast(mu, dtype=a_inv.dtype)) 1919 1920 # Compute post equalization signal error estimate: 1921 # rho_i = mu_i / (1 - var_x_i * mu_i) 1922 # 1 - var_x_i * mu_i can become numerically 0, or even slightly 1923 # smaller than zero due to limited numerical precision 1924 # [..., 2K] 1925 var_x = tf.divide(mu, tf.maximum(1 - var_x * mu, self._epsilon)) 1926 # real variances map to the same complex valued variances in this 1927 # model 1928 var_x, _ = tf.split(var_x, 2, -1) 1929 1930 no_eff = 1. / var_x 1931 1932 # Step 6: LLR demapping (extrinsic LLRs) 1933 # [..., K, num_bits_per_symbols] 1934 llr_d = tf.reshape(self._bit_demapper([x_hat, llr_a, no_eff]), 1935 llr_shape) 1936 1937 return llr_d, llr_a, it 1938 1939 # Stopping condition (required for tf.while_loop) 1940 def dec_stop(llr_d, llr_a, it): # pylint: disable=W0613 1941 return tf.less(it, self._num_iter) 1942 1943 # start decoding iterations 1944 it = tf.constant(0) 1945 null_prior = tf.zeros(llr_shape, dtype=self._realdtype) 1946 llr_d, llr_a, _ = tf.while_loop(dec_stop, 1947 mmse_pic_self_iteration, 1948 (llr_a, null_prior, it), 1949 parallel_iterations=1, 1950 maximum_iterations=self._num_iter) 1951 llr_e = llr_d - llr_a 1952 if self._output == "symbol": 1953 # convert back to symbols if requested. 1954 # output symbol logits computed on extrinsic LLRs 1955 out = self._llr_2_symbol_logits_output(llr_e) 1956 else: 1957 # output extrinsic LLRs 1958 out = llr_e 1959 if self._hard_out: 1960 out = hard_decisions(out) 1961 1962 return out