anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

detection.py (82648B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Classes and functions related to MIMO channel detection"""
      6 
      7 import warnings
      8 import numpy as np
      9 import tensorflow as tf
     10 from tensorflow.keras.layers import Layer
     11 from sionna.utils import expand_to_rank, matrix_sqrt_inv, flatten_last_dims, flatten_dims, split_dim, insert_dims, hard_decisions
     12 from sionna.mapping import Constellation, SymbolLogits2LLRs, LLRs2SymbolLogits, PAM2QAM, Demapper, SymbolDemapper, SymbolInds2Bits, DemapperWithPrior, SymbolLogits2Moments
     13 from sionna.mimo.utils import complex2real_channel, whiten_channel, List2LLR, List2LLRSimple, complex2real_matrix, complex2real_vector, real2complex_vector
     14 from sionna.mimo.equalization import lmmse_equalizer, zf_equalizer, mf_equalizer
     15 
     16 class LinearDetector(Layer):
     17     # pylint: disable=line-too-long
     18     r"""LinearDetector(equalizer, output, demapping_method, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
     19 
     20     Convenience class that combines an equalizer,
     21     such as :func:`~sionna.mimo.lmmse_equalizer`, and a :class:`~sionna.mapping.Demapper`.
     22 
     23     Parameters
     24     ----------
     25     equalizer : str, one of ["lmmse", "zf", "mf"], or an equalizer function
     26         The equalizer to be used. Either one of the existing equalizers
     27         :func:`~sionna.mimo.lmmse_equalizer`, :func:`~sionna.mimo.zf_equalizer`, or
     28         :func:`~sionna.mimo.mf_equalizer` can be used, or a custom equalizer
     29         callable provided that has the same input/output specification.
     30 
     31     output : One of ["bit", "symbol"], str
     32         The type of output, either LLRs on bits or logits on constellation symbols.
     33 
     34     demapping_method : One of ["app", "maxlog"], str
     35         The demapping method used.
     36 
     37     constellation_type : One of ["qam", "pam", "custom"], str
     38         For "custom", an instance of :class:`~sionna.mapping.Constellation`
     39         must be provided.
     40 
     41     num_bits_per_symbol : int
     42         The number of bits per constellation symbol, e.g., 4 for QAM16.
     43         Only required for ``constellation_type`` in ["qam", "pam"].
     44 
     45     constellation : Constellation
     46         An instance of :class:`~sionna.mapping.Constellation` or `None`.
     47         In the latter case, ``constellation_type``
     48         and ``num_bits_per_symbol`` must be provided.
     49 
     50     hard_out : bool
     51         If `True`, the detector computes hard-decided bit values or
     52         constellation point indices instead of soft-values.
     53         Defaults to `False`.
     54 
     55     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
     56         The dtype of ``y``. Defaults to tf.complex64.
     57         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
     58 
     59     Input
     60     ------
     61     (y, h, s) :
     62         Tuple:
     63 
     64     y : [...,M], tf.complex
     65         1+D tensor containing the received signals
     66 
     67     h : [...,M,num_streams], tf.complex
     68         2+D tensor containing the channel matrices
     69 
     70     s : [...,M,M], tf.complex
     71         2+D tensor containing the noise covariance matrices
     72 
     73     Output
     74     ------
     75     One of:
     76 
     77     : [..., num_streams, num_bits_per_symbol], tf.float
     78         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`
     79 
     80     : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int
     81        Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`
     82        Hard-decisions correspond to the symbol indices.
     83 
     84     Note
     85     ----
     86     If you want to use this layer in Graph mode with XLA, i.e., within
     87     a function that is decorated with ``@tf.function(jit_compile=True)``,
     88     you might need to set ``sionna.Config.xla_compat=true``. This depends on the
     89     chosen equalizer function. See :py:attr:`~sionna.Config.xla_compat`.
     90     """
     91     def __init__(self,
     92                  equalizer,
     93                  output,
     94                  demapping_method,
     95                  constellation_type=None,
     96                  num_bits_per_symbol=None,
     97                  constellation=None,
     98                  hard_out=False,
     99                  dtype=tf.complex64,
    100                  **kwargs):
    101         super().__init__(dtype=dtype, **kwargs)
    102         self._output = output
    103         self._hard_out = hard_out
    104 
    105         # Determine the equalizer to use
    106         if isinstance(equalizer, str):
    107             assert equalizer in ["lmmse", "zf", "mf"], "Unknown equalizer."
    108             if equalizer=="lmmse":
    109                 self._equalizer = lmmse_equalizer
    110             elif equalizer=="zf":
    111                 self._equalizer = zf_equalizer
    112             else:
    113                 self._equalizer = mf_equalizer
    114         else:
    115             self._equalizer = equalizer
    116 
    117         assert output in ("bit", "symbol"), "Unknown output"
    118         assert demapping_method in ("app","maxlog"), "Unknown demapping method"
    119 
    120         constellation = Constellation.create_or_check_constellation(
    121                                                             constellation_type,
    122                                                             num_bits_per_symbol,
    123                                                             constellation,
    124                                                             dtype=dtype)
    125         self._constellation = constellation
    126 
    127         # Determine the demapper to use
    128         if output=="bit":
    129             self._demapper = Demapper(demapping_method,
    130                                       constellation=constellation,
    131                                       hard_out=hard_out,
    132                                       dtype=dtype)
    133         else:
    134             self._demapper = SymbolDemapper(constellation=constellation,
    135                                             hard_out=hard_out,
    136                                             dtype=dtype)
    137 
    138     def call(self, inputs):
    139         x_hat, no_eff = self._equalizer(*inputs)
    140         z = self._demapper([x_hat, no_eff])
    141 
    142         # Reshape to the expected output shape
    143         num_streams = tf.shape(inputs[1])[-1]
    144         if self._output == 'bit':
    145             num_bits_per_symbol = self._constellation.num_bits_per_symbol
    146             z = split_dim(z, [num_streams, num_bits_per_symbol], tf.rank(z)-1)
    147 
    148         return z
    149 
    150 class MaximumLikelihoodDetector(Layer):
    151     # pylint: disable=line-too-long
    152     r"""
    153     MaximumLikelihoodDetector(output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, with_prior=False, dtype=tf.complex64, **kwargs)
    154 
    155     MIMO maximum-likelihood (ML) detector.
    156     If the ``with_prior`` flag is set, prior knowledge on the bits or constellation points is assumed to be available.
    157 
    158     This layer implements MIMO maximum-likelihood (ML) detection assuming the
    159     following channel model:
    160 
    161     .. math::
    162         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    163 
    164     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    165     :math:`\mathbf{x}\in\mathcal{C}^K` is the vector of transmitted symbols which
    166     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
    167     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
    168     and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector.
    169     It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and
    170     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`,
    171     where :math:`\mathbf{S}` has full rank.
    172     If the ``with_prior`` flag is set, it is assumed that prior information of the transmitted signal :math:`\mathbf{x}` is available,
    173     provided either as LLRs on the bits mapped onto :math:`\mathbf{x}` or as logits on the individual
    174     constellation points forming :math:`\mathbf{x}`.
    175 
    176     Prior to demapping, the received signal is whitened:
    177 
    178     .. math::
    179         \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\
    180         &=  \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\
    181         &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}}
    182 
    183     The layer can compute ML detection of symbols or bits with either
    184     soft- or hard-decisions. Note that decisions are computed symbol-/bit-wise
    185     and not jointly for the entire vector :math:`\textbf{x}` (or the underlying vector
    186     of bits).
    187 
    188     **\ML detection of bits:**
    189 
    190     Soft-decisions on bits are called log-likelihood ratios (LLR).
    191     With the “app” demapping method, the LLR for the :math:`i\text{th}` bit
    192     of the :math:`k\text{th}` user is then computed according to
    193 
    194     .. math::
    195         \begin{align}
    196             LLR(k,i)&= \ln\left(\frac{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}\right)\\
    197                     &=\ln\left(\frac{
    198                     \sum_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \exp\left(
    199                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    200                         \right) \Pr\left( \mathbf{x} \right)
    201                     }{
    202                     \sum_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \exp\left(
    203                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    204                         \right) \Pr\left( \mathbf{x} \right)
    205                     }\right)
    206         \end{align}
    207 
    208     where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the
    209     sets of vectors of constellation points for which the :math:`i\text{th}` bit
    210     of the :math:`k\text{th}` user is equal to 1 and 0, respectively.
    211     :math:`\Pr\left( \mathbf{x} \right)` is the prior distribution of the vector of
    212     constellation points :math:`\mathbf{x}`. Assuming that the constellation points and
    213     bit levels are independent, it is computed from the prior of the bits according to
    214 
    215     .. math::
    216         \Pr\left( \mathbf{x} \right) = \prod_{k=1}^K \prod_{i=1}^{I} \sigma \left( LLR_p(k,i) \right)
    217 
    218     where :math:`LLR_p(k,i)` is the prior knowledge of the :math:`i\text{th}` bit of the
    219     :math:`k\text{th}` user given as an LLR and which is set to :math:`0` if no prior knowledge is assumed to be available,
    220     and :math:`\sigma\left(\cdot\right)` is the sigmoid function.
    221     The definition of the LLR has been chosen such that it is equivalent with that of logit. This is
    222     different from many textbooks in communications, where the LLR is
    223     defined as :math:`LLR(k,i) = \ln\left(\frac{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}\right)`.
    224 
    225     With the "maxlog" demapping method, the LLR for the :math:`i\text{th}` bit
    226     of the :math:`k\text{th}` user is approximated like
    227 
    228     .. math::
    229         \begin{align}
    230             LLR(k,i) \approx&\ln\left(\frac{
    231                 \max_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \exp\left(
    232                     -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    233                     \right) \Pr\left( \mathbf{x} \right) \right)
    234                 }{
    235                 \max_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \exp\left(
    236                     -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    237                     \right) \Pr\left( \mathbf{x} \right) \right)
    238                 }\right)\\
    239                 = &\min_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left(\Pr\left( \mathbf{x} \right) \right) \right) -
    240                     \min_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left( \Pr\left( \mathbf{x} \right) \right) \right).
    241             \end{align}
    242 
    243     **ML detection of symbols:**
    244 
    245     Soft-decisions on symbols are called logits (i.e., unnormalized log-probability).
    246 
    247     With the “app” demapping method, the logit for the
    248     constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user  is computed according to
    249 
    250     .. math::
    251         \begin{align}
    252             \text{logit}(k,c) &= \ln\left(\sum_{\mathbf{x} : x_k = c} \exp\left(
    253                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    254                         \right)\Pr\left( \mathbf{x} \right)\right).
    255         \end{align}
    256 
    257     With the "maxlog" demapping method, the logit for the constellation point :math:`c \in \mathcal{C}`
    258     of the :math:`k\text{th}` user  is approximated like
    259 
    260     .. math::
    261         \text{logit}(k,c) \approx \max_{\mathbf{x} : x_k = c} \left(
    262                 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 + \ln \left( \Pr\left( \mathbf{x} \right) \right)
    263                 \right).
    264 
    265     When hard decisions are requested, this layer returns for the :math:`k` th stream
    266 
    267     .. math::
    268         \hat{c}_k = \underset{c \in \mathcal{C}}{\text{argmax}} \left( \sum_{\mathbf{x} : x_k = c} \exp\left(
    269                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    270                         \right)\Pr\left( \mathbf{x} \right) \right)
    271 
    272     where :math:`\mathcal{C}` is the set of constellation points.
    273 
    274     Parameters
    275     -----------
    276     output : One of ["bit", "symbol"], str
    277         The type of output, either LLRs on bits or logits on constellation symbols.
    278 
    279     demapping_method : One of ["app", "maxlog"], str
    280         The demapping method used.
    281 
    282     num_streams : tf.int
    283         Number of transmitted streams
    284 
    285     constellation_type : One of ["qam", "pam", "custom"], str
    286         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    287         must be provided.
    288 
    289     num_bits_per_symbol : int
    290         The number of bits per constellation symbol, e.g., 4 for QAM16.
    291         Only required for ``constellation_type`` in ["qam", "pam"].
    292 
    293     constellation : Constellation
    294         An instance of :class:`~sionna.mapping.Constellation` or `None`.
    295         In the latter case, ``constellation_type``
    296         and ``num_bits_per_symbol`` must be provided.
    297 
    298     hard_out : bool
    299         If `True`, the detector computes hard-decided bit values or
    300         constellation point indices instead of soft-values.
    301         Defaults to `False`.
    302 
    303     with_prior : bool
    304         If `True`, it is assumed that prior knowledge on the bits or constellation points is available.
    305         This prior information is given as LLRs (for bits) or log-probabilities (for constellation points) as an
    306         additional input to the layer.
    307         Defaults to `False`.
    308 
    309     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    310         The dtype of ``y``. Defaults to tf.complex64.
    311         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    312 
    313     Input
    314     ------
    315     (y, h, s) or (y, h, prior, s) :
    316         Tuple:
    317 
    318     y : [...,M], tf.complex
    319         1+D tensor containing the received signals.
    320 
    321     h : [...,M,num_streams], tf.complex
    322         2+D tensor containing the channel matrices.
    323 
    324     prior : [...,num_streams,num_bits_per_symbol] or [...,num_streams,num_points], tf.float
    325         Prior of the transmitted signals.
    326         If ``output`` equals "bit", then LLRs of the transmitted bits are expected.
    327         If ``output`` equals "symbol", then logits of the transmitted constellation points are expected.
    328         Only required if the ``with_prior`` flag is set.
    329 
    330     s : [...,M,M], tf.complex
    331         2+D tensor containing the noise covariance matrices.
    332 
    333     Output
    334     ------
    335     One of:
    336 
    337     : [..., num_streams, num_bits_per_symbol], tf.float
    338         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    339 
    340     : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int
    341        Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    342        Hard-decisions correspond to the symbol indices.
    343 
    344     Note
    345     ----
    346     If you want to use this layer in Graph mode with XLA, i.e., within
    347     a function that is decorated with ``@tf.function(jit_compile=True)``,
    348     you must set ``sionna.Config.xla_compat=true``.
    349     See :py:attr:`~sionna.Config.xla_compat`.
    350     """
    351 
    352     def __init__(self,
    353                  output,
    354                  demapping_method,
    355                  num_streams,
    356                  constellation_type=None,
    357                  num_bits_per_symbol=None,
    358                  constellation=None,
    359                  hard_out=False,
    360                  with_prior=False,
    361                  dtype=tf.complex64,
    362                  **kwargs):
    363         super().__init__(dtype=dtype, **kwargs)
    364 
    365         assert dtype in [tf.complex64, tf.complex128],\
    366             "dtype must be tf.complex64 or tf.complex128"
    367 
    368         assert output in ("bit", "symbol"), "Unknown output"
    369 
    370         assert demapping_method in ("app","maxlog"), "Unknown demapping method"
    371 
    372         self._output = output
    373         self._demapping_method = demapping_method
    374         self._hard_out = hard_out
    375         self._with_prior = with_prior
    376 
    377         # Determine the reduce function for LLR computation
    378         if self._demapping_method == "app":
    379             self._reduce = tf.reduce_logsumexp
    380         else:
    381             self._reduce = tf.reduce_max
    382 
    383         # Create constellation object
    384         self._constellation = Constellation.create_or_check_constellation(
    385                                                         constellation_type,
    386                                                         num_bits_per_symbol,
    387                                                         constellation,
    388                                                         dtype=dtype)
    389 
    390         # Utility function to compute
    391         # vecs : [num_vecs, num_streams] The list of all possible transmitted vectors.
    392         # vecs_ind : [num_vecs, num_streams] The list of all possible transmitted vectors
    393         #   constellation indices
    394         # c : [num_vecs/num_points, num_streams, num_points] Which is such that `c[:,k,s]`
    395         #   gives the symbol indices in the first dimension of `vecs` for which
    396         #   the `k`th stream transmitted the `s`th constellation point.
    397         vecs, vecs_ind, c = self._build_vecs(num_streams)
    398         self._vecs = tf.cast(vecs, dtype)
    399         self._vecs_ind = tf.cast(vecs_ind, tf.int32)
    400         self._c = tf.cast(c, tf.int32)
    401 
    402         if output == 'bit':
    403             num_bits_per_symbol = self._constellation.num_bits_per_symbol
    404             self._logits2llr = SymbolLogits2LLRs(
    405                                     method=demapping_method,
    406                                     num_bits_per_symbol=num_bits_per_symbol,
    407                                     hard_out=hard_out,
    408                                     dtype=dtype.real_dtype,
    409                                     **kwargs)
    410             self._llrs2logits = LLRs2SymbolLogits(
    411                                     num_bits_per_symbol=num_bits_per_symbol,
    412                                     hard_out=False,
    413                                     dtype=dtype.real_dtype,
    414                                     **kwargs)
    415 
    416     @property
    417     def constellation(self):
    418         return self._constellation
    419 
    420     def _build_vecs(self, num_streams):
    421         """
    422         Utility function for building the list of all possible transmitted
    423         vectors of constellation points and the symbol indices corresponding to
    424         all possibly transmitted constellation points for every stream.
    425 
    426         Input
    427         ------
    428         num_streams : int
    429             Number of transmitted streams
    430 
    431         Output
    432         -------
    433         vecs : [num_vecs, K], tf.complex
    434             List of all possible transmitted vectors.
    435 
    436         c : [num_vecs/num_points, num_streams, num_points], int
    437             `c[:,k,s]` gives the symbol indices in the first dimension of `vecs`
    438             for which the `k`th stream transmitted the `s`th symbol.
    439         """
    440 
    441         points = self._constellation.points
    442         num_points = points.shape[0]
    443 
    444         # Recursive function for generating all possible transmitted
    445         # vector of symbols and indices
    446         # `n` is the remaining number of stream to process
    447         def _build_vecs_(n):
    448             if n == 1:
    449                 # If there is a single stream, then the list of possibly
    450                 # transmitted vectors corresponds to the constellation points.
    451                 # No recusrion is needed.
    452                 vecs = np.expand_dims(points, axis=1)
    453                 vecs_ind = np.expand_dims(np.arange(num_points), axis=1)
    454             else:
    455                 # If the number of streams is `n >= 2` streams, then the list
    456                 # of possibly transmitted vectors is
    457                 #
    458                 # [c_1 v , c_2 v, ..., c_N v]
    459                 #
    460                 # where `[c_1, ..., c_N]` is the constellation of size N, and
    461                 # `v` is the list of possible vectors for `n-1` streams.
    462                 # This list has therefore length `N x len(v)`.
    463                 #
    464                 # Building the list for `n-1` streams, recursively.
    465                 v, vi = _build_vecs_(n-1)
    466                 # Building the list of `n` streams by appending the
    467                 # constellation points.
    468                 vecs = []
    469                 vecs_ind = []
    470                 for i,p in enumerate(points):
    471                     vecs.append(np.concatenate([np.full([v.shape[0], 1], p),
    472                                                 v], axis=1))
    473                     vecs_ind.append(np.concatenate([np.full([v.shape[0], 1], i),
    474                                                 vi], axis=1))
    475                 vecs = np.concatenate(vecs, axis=0)
    476                 vecs_ind = np.concatenate(vecs_ind, axis=0)
    477             return vecs, vecs_ind
    478 
    479         # Building the list of possible vectors for the `k` streams.
    480         # [num_vecs, K]
    481         vecs, vecs_ind = _build_vecs_(num_streams)
    482 
    483         tx_ind = np.arange(num_streams)
    484         tx_ind = np.expand_dims(tx_ind, axis=0)
    485         tx_ind = np.tile(tx_ind, [vecs_ind.shape[0], 1])
    486         vecs_ind = np.stack([tx_ind, vecs_ind], axis=-1)
    487 
    488         # Compute symbol indices for every stream.
    489         # For every constellation point `p` and for every stream `j`, we gather
    490         # the list of vector indices from `vecs` corresponding the vectors for
    491         # which the `jth` stream transmitted `p`.
    492         # [num_vecs/num_points, num_streams, num_points]
    493         c = []
    494         for p in points:
    495             c_ = []
    496             for j in range(num_streams):
    497                 c_.append(np.where(vecs[:,j]==p)[0])
    498             c_ = np.stack(c_, axis=-1)
    499             c.append(c_)
    500         c = np.stack(c, axis=-1)
    501 
    502         return vecs, vecs_ind, c
    503 
    504     def call(self, inputs):
    505         if self._with_prior:
    506             y, h, prior, s = inputs
    507 
    508             # If operating on bits, computes prior on symbols from the prior
    509             # on bits
    510             if self._output == 'bit':
    511                 # [..., K, num_points]
    512                 prior = self._llrs2logits(prior)
    513         else:
    514             y, h, s = inputs
    515 
    516         # Compute square-root of interference covariance matrix
    517         s_inv = matrix_sqrt_inv(s)
    518 
    519         # Whiten the observation
    520         y = tf.expand_dims(y, -1)
    521         y = tf.squeeze(tf.matmul(s_inv, y), axis=-1)
    522 
    523         # Compute channel after whitening
    524         h = tf.matmul(s_inv, h)
    525 
    526         # Add extra dims for broadcasting with the dimensions corresponding
    527         # to all possible transmimtted vectors
    528         # Shape: [..., 1, M, K]
    529         h = tf.expand_dims(h, axis=-3)
    530 
    531         # Add extra dims for broadcasting with the dimensions corresponding
    532         # to all possible transmimtted vectors
    533         # Shape: [..., 1, M]
    534         y = tf.expand_dims(y, axis=-2)
    535 
    536         # Reshape list of all possible vectors from
    537         # [num_vecs, K]
    538         # to
    539         # [1,...,1, num_vecs, K, 1]
    540         vecs = self._vecs
    541         vecs = tf.expand_dims(vecs, axis=-1)
    542         vecs = expand_to_rank(vecs, tf.rank(h), 0)
    543 
    544         # Compute exponents
    545         # [..., num_vecs]
    546         diff = y - tf.squeeze(h@vecs, axis=-1)
    547         exponents = -tf.reduce_sum(tf.square(tf.abs(diff)), axis=-1)
    548 
    549         # Add prior
    550         if self._with_prior:
    551             # [..., num_vecs, K]
    552             prior = expand_to_rank(prior, tf.rank(exponents), axis=0)
    553             prior_rank = tf.rank(prior)
    554             transpose_ind = tf.concat([[prior_rank-2, prior_rank-1],
    555                                         tf.range(prior_rank-2)], axis=0)
    556             prior = tf.transpose(prior, transpose_ind)
    557             prior = tf.gather_nd(prior, self._vecs_ind)
    558             transpose_ind = tf.concat([ tf.range(2, prior_rank),
    559                                         [0, 1]], axis=0)
    560             prior = tf.transpose(prior, transpose_ind)
    561             # [..., num_vecs]
    562             prior = tf.reduce_sum(prior, axis=-1)
    563             exponents = exponents + prior
    564 
    565         # Gather exponents for all symbols
    566         # [..., num_vecs/num_points, K, num_points]
    567         exp = tf.gather(exponents, self._c, axis=-1)
    568 
    569         # Compute logits on constellation points
    570         # [..., K, num_points]
    571         logits = self._reduce(exp, axis=-3)
    572 
    573         if self._output == 'bit':
    574             # Compute LLRs or hard decisions
    575             return self._logits2llr(logits)
    576         else:
    577             if self._hard_out:
    578                 return tf.argmax(logits, axis=-1, output_type=tf.int32)
    579             else:
    580                 return logits
    581 
    582 class MaximumLikelihoodDetectorWithPrior(MaximumLikelihoodDetector):
    583     # pylint: disable=line-too-long
    584     r"""
    585     MaximumLikelihoodDetectorWithPrior(output, demapping_method, num_streams, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
    586 
    587     MIMO maximum-likelihood (ML) detector, assuming prior
    588     knowledge on the bits or constellation points is available.
    589 
    590     This class is deprecated as the functionality has been integrated
    591     into :class:`~sionna.mimo.MaximumLikelihoodDetector`.
    592 
    593     This layer implements MIMO maximum-likelihood (ML) detection assuming the
    594     following channel model:
    595 
    596     .. math::
    597         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    598 
    599     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    600     :math:`\mathbf{x}\in\mathcal{C}^K` is the vector of transmitted symbols which
    601     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
    602     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
    603     and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector.
    604     It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and
    605     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`,
    606     where :math:`\mathbf{S}` has full rank.
    607     It is assumed that prior information of the transmitted signal :math:`\mathbf{x}` is available,
    608     provided either as LLRs on the bits modulated onto :math:`\mathbf{x}` or as logits on the individual
    609     constellation points forming :math:`\mathbf{x}`.
    610 
    611     Prior to demapping, the received signal is whitened:
    612 
    613     .. math::
    614         \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\
    615         &=  \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\
    616         &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}}
    617 
    618     The layer can compute ML detection of symbols or bits with either
    619     soft- or hard-decisions. Note that decisions are computed symbol-/bit-wise
    620     and not jointly for the entire vector :math:`\textbf{x}` (or the underlying vector
    621     of bits).
    622 
    623     **\ML detection of bits:**
    624 
    625     Soft-decisions on bits are called log-likelihood ratios (LLR).
    626     With the “app” demapping method, the LLR for the :math:`i\text{th}` bit
    627     of the :math:`k\text{th}` user is then computed according to
    628 
    629     .. math::
    630         \begin{align}
    631             LLR(k,i)&= \ln\left(\frac{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}\right)\\
    632                     &=\ln\left(\frac{
    633                     \sum_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \exp\left(
    634                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    635                         \right) \Pr\left( \mathbf{x} \right)
    636                     }{
    637                     \sum_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \exp\left(
    638                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    639                         \right) \Pr\left( \mathbf{x} \right)
    640                     }\right)
    641         \end{align}
    642 
    643     where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the
    644     sets of vectors of constellation points for which the :math:`i\text{th}` bit
    645     of the :math:`k\text{th}` user is equal to 1 and 0, respectively.
    646     :math:`\Pr\left( \mathbf{x} \right)` is the prior distribution of the vector of
    647     constellation points :math:`\mathbf{x}`. Assuming that the constellation points and
    648     bit levels are independent, it is computed from the prior of the bits according to
    649 
    650     .. math::
    651         \Pr\left( \mathbf{x} \right) = \prod_{k=1}^K \prod_{i=1}^{I} \sigma \left( LLR_p(k,i) \right)
    652 
    653     where :math:`LLR_p(k,i)` is the prior knowledge of the :math:`i\text{th}` bit of the
    654     :math:`k\text{th}` user given as an LLR, and :math:`\sigma\left(\cdot\right)` is the sigmoid function.
    655     The definition of the LLR has been chosen such that it is equivalent with that of logit. This is
    656     different from many textbooks in communications, where the LLR is
    657     defined as :math:`LLR(k,i) = \ln\left(\frac{\Pr\left(b_{k,i}=0\lvert \mathbf{y},\mathbf{H}\right)}{\Pr\left(b_{k,i}=1\lvert \mathbf{y},\mathbf{H}\right)}\right)`.
    658 
    659     With the "maxlog" demapping method, the LLR for the :math:`i\text{th}` bit
    660     of the :math:`k\text{th}` user is approximated like
    661 
    662     .. math::
    663         \begin{align}
    664             LLR(k,i) \approx&\ln\left(\frac{
    665                 \max_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \exp\left(
    666                     -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    667                     \right) \Pr\left( \mathbf{x} \right) \right)
    668                 }{
    669                 \max_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \exp\left(
    670                     -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    671                     \right) \Pr\left( \mathbf{x} \right) \right)
    672                 }\right)\\
    673                 = &\min_{\mathbf{x}\in\mathcal{C}_{k,i,0}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left(\Pr\left( \mathbf{x} \right) \right) \right) -
    674                     \min_{\mathbf{x}\in\mathcal{C}_{k,i,1}} \left( \left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 - \ln \left( \Pr\left( \mathbf{x} \right) \right) \right).
    675             \end{align}
    676 
    677     **ML detection of symbols:**
    678 
    679     Soft-decisions on symbols are called logits (i.e., unnormalized log-probability).
    680 
    681     With the “app” demapping method, the logit for the
    682     constellation point :math:`c \in \mathcal{C}` of the :math:`k\text{th}` user  is computed according to
    683 
    684     .. math::
    685         \begin{align}
    686             \text{logit}(k,c) &= \ln\left(\sum_{\mathbf{x} : x_k = c} \exp\left(
    687                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    688                         \right)\Pr\left( \mathbf{x} \right)\right).
    689         \end{align}
    690 
    691     With the "maxlog" demapping method, the logit for the constellation point :math:`c \in \mathcal{C}`
    692     of the :math:`k\text{th}` user  is approximated like
    693 
    694     .. math::
    695         \text{logit}(k,c) \approx \max_{\mathbf{x} : x_k = c} \left(
    696                 -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2 + \ln \left( \Pr\left( \mathbf{x} \right) \right)
    697                 \right).
    698 
    699     When hard decisions are requested, this layer returns for the :math:`k` th stream
    700 
    701     .. math::
    702         \hat{c}_k = \underset{c \in \mathcal{C}}{\text{argmax}} \left( \sum_{\mathbf{x} : x_k = c} \exp\left(
    703                         -\left\lVert\tilde{\mathbf{y}}-\tilde{\mathbf{H}}\mathbf{x}\right\rVert^2
    704                         \right)\Pr\left( \mathbf{x} \right) \right)
    705 
    706     where :math:`\mathcal{C}` is the set of constellation points.
    707 
    708     Parameters
    709     -----------
    710     output : One of ["bit", "symbol"], str
    711         The type of output, either LLRs on bits or logits on constellation symbols.
    712 
    713     demapping_method : One of ["app", "maxlog"], str
    714         The demapping method used.
    715 
    716     num_streams : tf.int
    717         Number of transmitted streams
    718 
    719     constellation_type : One of ["qam", "pam", "custom"], str
    720         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    721         must be provided.
    722 
    723     num_bits_per_symbol : int
    724         The number of bits per constellation symbol, e.g., 4 for QAM16.
    725         Only required for ``constellation_type`` in ["qam", "pam"].
    726 
    727     constellation : Constellation
    728         An instance of :class:`~sionna.mapping.Constellation` or `None`.
    729         In the latter case, ``constellation_type``
    730         and ``num_bits_per_symbol`` must be provided.
    731 
    732     hard_out : bool
    733         If `True`, the detector computes hard-decided bit values or
    734         constellation point indices instead of soft-values.
    735         Defaults to `False`.
    736 
    737     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    738         The dtype of ``y``. Defaults to tf.complex64.
    739         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    740 
    741     Input
    742     ------
    743     (y, h, prior, s) :
    744         Tuple:
    745 
    746     y : [...,M], tf.complex
    747         1+D tensor containing the received signals.
    748 
    749     h : [...,M,num_streams], tf.complex
    750         2+D tensor containing the channel matrices.
    751 
    752     prior : [...,num_streams,num_bits_per_symbol] or [...,num_streams,num_points], tf.float
    753         Prior of the transmitted signals.
    754         If ``output`` equals "bit", then LLRs of the transmitted bits are expected.
    755         If ``output`` equals "symbol", then logits of the transmitted constellation points are expected.
    756 
    757     s : [...,M,M], tf.complex
    758         2+D tensor containing the noise covariance matrices.
    759 
    760     Output
    761     ------
    762     One of:
    763 
    764     : [..., num_streams, num_bits_per_symbol], tf.float
    765         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    766 
    767     : [..., num_streams, num_points], tf.float or [..., num_streams], tf.int
    768        Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    769        Hard-decisions correspond to the symbol indices.
    770 
    771     Note
    772     ----
    773     If you want to use this layer in Graph mode with XLA, i.e., within
    774     a function that is decorated with ``@tf.function(jit_compile=True)``,
    775     you must set ``sionna.Config.xla_compat=true``.
    776     See :py:attr:`~sionna.Config.xla_compat`.
    777     """
    778 
    779     def __init__(self,
    780                  output,
    781                  demapping_method,
    782                  num_streams,
    783                  constellation_type=None,
    784                  num_bits_per_symbol=None,
    785                  constellation=None,
    786                  hard_out=False,
    787                  dtype=tf.complex64,
    788                  **kwargs):
    789         super().__init__(   output=output,
    790                             demapping_method=demapping_method,
    791                             num_streams=num_streams,
    792                             constellation_type=constellation_type,
    793                             num_bits_per_symbol=num_bits_per_symbol,
    794                             constellation=constellation,
    795                             hard_out=hard_out,
    796                             with_prior=True,
    797                             dtype=dtype,
    798                             **kwargs)
    799 
    800 class KBestDetector(Layer):
    801     # pylint: disable=line-too-long
    802     r"""KBestDetector(output, num_streams, k, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, use_real_rep=False, list2llr=None, dtype=tf.complex64)
    803 
    804     MIMO K-Best detector
    805 
    806     This layer implements K-Best MIMO detection as described
    807     in (Eq. 4-5) [FT2015]_. It can either generate hard decisions (for symbols
    808     or bits) or compute LLRs.
    809 
    810     The algorithm operates in either the complex or real-valued domain.
    811     Although both options produce identical results, the former has the advantage
    812     that it can be applied to arbitrary non-QAM constellations. It also reduces
    813     the number of streams (or depth) by a factor of two.
    814 
    815     The way soft-outputs (i.e., LLRs) are computed is determined by the
    816     ``list2llr`` function. The default solution
    817     :class:`~sionna.mimo.List2LLRSimple` assigns a predetermined
    818     value to all LLRs without counter-hypothesis.
    819 
    820     This layer assumes the following channel model:
    821 
    822     .. math::
    823         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    824 
    825     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    826     :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which
    827     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
    828     :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix,
    829     and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector.
    830     It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and
    831     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`,
    832     where :math:`\mathbf{S}` has full rank.
    833 
    834     In a first optional step, the channel model is converted to its real-valued equivalent,
    835     see :func:`~sionna.mimo.complex2real_channel`. We assume in the sequel the complex-valued
    836     representation. Then, the channel is whitened using :func:`~sionna.mimo.whiten_channel`:
    837 
    838     .. math::
    839         \tilde{\mathbf{y}} &= \mathbf{S}^{-\frac{1}{2}}\mathbf{y}\\
    840         &=  \mathbf{S}^{-\frac{1}{2}}\mathbf{H}\mathbf{x} + \mathbf{S}^{-\frac{1}{2}}\mathbf{n}\\
    841         &= \tilde{\mathbf{H}}\mathbf{x} + \tilde{\mathbf{n}}.
    842 
    843     Next, the columns of :math:`\tilde{\mathbf{H}}` are sorted according
    844     to their norm in descending order. Then, the QR decomposition of the
    845     resulting channel matrix is computed:
    846 
    847     .. math::
    848         \tilde{\mathbf{H}} = \mathbf{Q}\mathbf{R}
    849 
    850     where :math:`\mathbf{Q}\in\mathbb{C}^{M\times S}` is unitary and
    851     :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is upper-triangular.
    852     The channel outputs are then pre-multiplied by :math:`\mathbf{Q}^{\mathsf{H}}`.
    853     This leads to the final channel model on which the K-Best detection algorithm operates:
    854 
    855     .. math::
    856         \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}}
    857 
    858     where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S`,
    859     :math:`\bar{\mathbf{x}}\in\mathbb{C}^S`, and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S`
    860     with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and
    861     :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`.
    862 
    863     **LLR Computation**
    864 
    865     The K-Best algorithm produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S`
    866     and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2`
    867     for :math:`k=1,\dots,K`. If the real-valued channel representation is used, the distance
    868     metrics are scaled by 0.5 to account for the reduced noise power in each complex dimension.
    869     A hard-decision is simply the candidate with the shortest distance.
    870     Various ways to compute LLRs from this list (and possibly
    871     additional side-information) are possible. The (sub-optimal) default solution
    872     is :class:`~sionna.mimo.List2LLRSimple`. Custom solutions can be provided.
    873 
    874     Parameters
    875     -----------
    876     output : One of ["bit", "symbol"], str
    877         The type of output, either bits or symbols. Whether soft- or
    878         hard-decisions are returned can be configured with the
    879         ``hard_out`` flag.
    880 
    881     num_streams : tf.int
    882         Number of transmitted streams
    883 
    884     k : tf.int
    885         The number of paths to keep. Cannot be larger than the
    886         number of constellation points to the power of the number of
    887         streams.
    888 
    889     constellation_type : One of ["qam", "pam", "custom"], str
    890         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    891         must be provided.
    892 
    893     num_bits_per_symbol : int
    894         The number of bits per constellation symbol, e.g., 4 for QAM16.
    895         Only required for ``constellation_type`` in ["qam", "pam"].
    896 
    897     constellation : Constellation
    898         An instance of :class:`~sionna.mapping.Constellation` or `None`.
    899         In the latter case, ``constellation_type``
    900         and ``num_bits_per_symbol`` must be provided.
    901 
    902     hard_out : bool
    903         If `True`, the detector computes hard-decided bit values or
    904         constellation point indices instead of soft-values.
    905         Defaults to `False`. The detector cannot compute soft-symbols.
    906 
    907     use_real_rep : bool
    908         If `True`, the detector use the real-valued equivalent representation
    909         of the channel. Note that this only works with a QAM constellation.
    910         Defaults to `False`.
    911 
    912     list2llr: `None` or instance of :class:`~sionna.mimo.List2LLR`
    913         The function to be used to compute LLRs from a list of candidate solutions.
    914         If `None`, the default solution :class:`~sionna.mimo.List2LLRSimple`
    915         is used.
    916 
    917     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    918         The dtype of ``y``. Defaults to tf.complex64.
    919         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    920 
    921     Input
    922     -----
    923     (y, h, s) :
    924         Tuple:
    925 
    926     y : [...,M], tf.complex
    927         1+D tensor containing the received signals
    928 
    929     h : [...,M,num_streams], tf.complex
    930         2+D tensor containing the channel matrices
    931 
    932     s : [...,M,M], tf.complex
    933         2+D tensor containing the noise covariance matrices
    934 
    935     Output
    936     ------
    937     One of:
    938 
    939     : [...,num_streams,num_bits_per_symbol], tf.float
    940         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`
    941 
    942     : [...,num_streams,2**num_points], tf.float or [...,num_streams], tf.int
    943        Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`
    944        Hard-decisions correspond to the symbol indices.
    945 
    946     Note
    947     ----
    948     If you want to use this layer in Graph mode with XLA, i.e., within
    949     a function that is decorated with ``@tf.function(jit_compile=True)``,
    950     you must set ``sionna.Config.xla_compat=true``.
    951     See :py:attr:`~sionna.Config.xla_compat`.
    952     """
    953     def __init__(self,
    954                  output,
    955                  num_streams,
    956                  k,
    957                  constellation_type=None,
    958                  num_bits_per_symbol=None,
    959                  constellation=None,
    960                  hard_out=False,
    961                  use_real_rep=False,
    962                  list2llr="default",
    963                  dtype=tf.complex64,
    964                  **kwargs):
    965         super().__init__(dtype=dtype, **kwargs)
    966         assert dtype in [tf.complex64, tf.complex128],\
    967             "dtype must be tf.complex64 or tf.complex128."
    968 
    969         assert output in ("bit", "symbol"), "Unknown output"
    970 
    971         err_msg = "You must provide either constellation or " + \
    972                   "constellation_type and num_bits_per_symbol."
    973         if constellation is None:
    974             assert constellation_type is not None and \
    975                    num_bits_per_symbol is not None, err_msg
    976         else:
    977             assert constellation_type is None and \
    978                    num_bits_per_symbol is None, err_msg
    979 
    980         if constellation is not None:
    981             assert constellation.points.dtype==dtype, \
    982                 "Constellation has wrong dtype."
    983 
    984         self._output = output
    985         self._hard_out = hard_out
    986         self._use_real_rep = use_real_rep
    987 
    988         if self._use_real_rep:
    989             # Real-valued representation is used
    990             err_msg = "Only QAM can be used for the real-valued representation"
    991             if constellation_type is not None:
    992                 assert constellation_type=="qam", err_msg
    993             else:
    994                 assert constellation._constellation_type=="qam", err_msg
    995 
    996             # Double the number of streams to dectect
    997             self._num_streams = 2*num_streams
    998 
    999             # Half the number of bits for the PAM constellation
   1000             if num_bits_per_symbol is None:
   1001                 n = constellation.num_bits_per_symbol//2
   1002                 self._num_bits_per_symbol = n
   1003             else:
   1004                 self._num_bits_per_symbol = num_bits_per_symbol//2
   1005 
   1006             # Geerate a PAM constellation with 0.5 energy
   1007             c = Constellation("pam",
   1008                                 self._num_bits_per_symbol,
   1009                                 normalize=False,
   1010                                 dtype=dtype)
   1011             c._points /= tf.cast(np.std(c._points)*np.sqrt(2), c._points.dtype)
   1012             self._constellation = tf.cast(c.points, dtype.real_dtype)
   1013 
   1014             self._pam2qam = PAM2QAM(2*self._num_bits_per_symbol)
   1015 
   1016         else:
   1017             # Complex-valued representation is used
   1018             # Number of streams is equal to number of transmitters
   1019             self._num_streams = num_streams
   1020 
   1021             # Create constellation or take the one provided
   1022             c = Constellation.create_or_check_constellation(
   1023                                                         constellation_type,
   1024                                                         num_bits_per_symbol,
   1025                                                         constellation,
   1026                                                         dtype=dtype)
   1027             self._constellation = c.points
   1028             self._num_bits_per_symbol = c.num_bits_per_symbol
   1029 
   1030         # Number of constellation symbols
   1031         self._num_symbols = self._constellation.shape[0]
   1032 
   1033         # Number of best paths to keep
   1034         self._k = np.minimum(k, self._num_symbols**self._num_streams)
   1035         if self._k < k:
   1036             msg = "KBestDetector: " + \
   1037                   f"The provided value of k={k} is larger than " + \
   1038                   "the possible maximum number of paths. " + \
   1039                   f"It has been set to k={self._k}."
   1040             warnings.warn(msg)
   1041 
   1042         # Compute the number of previous paths a layer needs to consider
   1043         num_paths = [1] # The first layer considers a single path
   1044         for l in range(1, self._num_streams+1):
   1045             # The lth layer considers min(k, num_symbols**l) paths
   1046             num_paths.append(np.minimum(self._k, self._num_symbols**l))
   1047         self._num_paths = tf.constant(tf.stack(num_paths, 0), tf.int32)
   1048 
   1049         # The symbols and indices for all paths will be stored in tensors
   1050         # of shape [batch_size, k, num_streams]. However, only
   1051         # a subset of the available entries are updated by each stream.
   1052         # To enable XLA, we need to compute the relevant indices of the tensors
   1053         # that will be updated through tf.tensor_scatter_nd_update.
   1054         indices = np.zeros([self._num_streams, self._k*self._num_streams, 2],
   1055                            np.int32)
   1056         for l in range(0, self._num_streams):
   1057             ind = np.zeros([self._num_paths[l+1], self._num_streams])
   1058             ind[:, :l+1] = 1
   1059             ind = np.stack(np.where(ind), -1)
   1060             indices[l,:ind.shape[0],:ind.shape[1]] = ind
   1061         self._indices = tf.constant(indices, dtype=tf.int32)
   1062 
   1063         if self._output=="bit":
   1064             if self._hard_out is False:
   1065                 if list2llr=="default":
   1066                     self.list2llr = List2LLRSimple(self._num_bits_per_symbol)
   1067                 else:
   1068                     self.list2llr = list2llr
   1069             else:
   1070                 if self._use_real_rep:
   1071                     n = 2*self._num_bits_per_symbol
   1072                 else:
   1073                     n = self._num_bits_per_symbol
   1074                 self._symbolinds2bits = SymbolInds2Bits(n,
   1075                                              dtype=dtype.real_dtype)
   1076         else:
   1077             assert self._hard_out is True, \
   1078                 "Soft-symbols are not supported for this detector."
   1079 
   1080     @property
   1081     def list2llr(self):
   1082         return self._list2llr
   1083 
   1084     @list2llr.setter
   1085     def list2llr(self, value):
   1086         assert isinstance(value, List2LLR)
   1087         self._list2llr = value
   1088 
   1089     def _preprocessing(self, inputs):
   1090 
   1091         y, h, s = inputs
   1092 
   1093         # Convert to real-valued representation if desired
   1094         if self._use_real_rep:
   1095             y, h, s = complex2real_channel(y, h, s)
   1096 
   1097         # Whiten channel
   1098         y, h = whiten_channel(y, h, s, return_s=False) # pylint: disable=W0632
   1099 
   1100         # Order columns of H in order of decreasing norm
   1101         h_norm = tf.reduce_sum(tf.abs(h)**2, axis=1)
   1102         column_order = tf.argsort(h_norm, axis=-1, direction="DESCENDING")
   1103         h = tf.gather(h, column_order, axis=-1, batch_dims=1)
   1104 
   1105         # Compute QR decomposition of sorted channel
   1106         # r is upper triangular
   1107         q, r = tf.linalg.qr(h)
   1108 
   1109         # Project y on Q'
   1110         y = tf.squeeze(tf.matmul(q, tf.expand_dims(y, -1), adjoint_a=True),
   1111                        -1)
   1112 
   1113         return y, r, column_order
   1114 
   1115     def _select_best_paths(self, dists, path_syms, path_inds):
   1116 
   1117         # Determine the number of paths to keep (either all or k)
   1118         num_paths = tf.shape(path_syms)[1]
   1119         k = tf.minimum(num_paths, self._k)
   1120 
   1121         # Get the k paths with the shortest distance
   1122         dists, ind = tf.math.top_k(-dists, k=k, sorted=True)
   1123         dists = -dists
   1124 
   1125         # Select the same best paths for the symbols and symbol indices
   1126         path_syms = tf.gather(path_syms, ind, axis=1, batch_dims=1)
   1127         path_inds = tf.gather(path_inds, ind, axis=1, batch_dims=1)
   1128 
   1129         return dists, path_syms, path_inds
   1130 
   1131     def _next_layer(self, y, r, dists, path_syms, path_inds, stream):
   1132 
   1133         batch_size = tf.shape(y)[0]
   1134 
   1135         # Streams are processed in reverse order
   1136         stream_ind = self._num_streams-1-stream
   1137 
   1138         # Current number of considered paths
   1139         num_paths = tf.gather(self._num_paths, stream)
   1140 
   1141         # Store input tensors for scatter update later on
   1142         dists_o = dists
   1143         path_syms_o = path_syms
   1144         path_inds_o = path_inds
   1145 
   1146         # Extract relevant values from input tensor
   1147         dists = dists[..., :num_paths]
   1148         path_syms = path_syms[..., :num_paths, :stream]
   1149         path_inds = path_inds[..., :num_paths, :stream]
   1150 
   1151         # Each path creates num_symbols branches
   1152         dists     = tf.repeat(dists,     repeats=self._num_symbols, axis=1)
   1153         path_syms = tf.repeat(path_syms, repeats=self._num_symbols, axis=1)
   1154         path_inds = tf.repeat(path_inds, repeats=self._num_symbols, axis=1)
   1155 
   1156         # Append to each path the symbols corresponding to the branch
   1157         syms = tf.reshape(self._constellation, [1,-1])
   1158         syms = tf.repeat(syms, self._k, 0)
   1159         syms = tf.reshape(syms, [1, -1, 1])
   1160         syms = tf.repeat(syms, batch_size, 0)
   1161         syms = syms[:,:num_paths*self._num_symbols]
   1162         path_syms = tf.concat([path_syms, syms], axis=-1)
   1163 
   1164         # Do the same for the symbol indices
   1165         inds = tf.reshape(tf.range(0, self._num_symbols), [1, -1])
   1166         inds = tf.repeat(inds, self._k, 0)
   1167         inds = tf.reshape(inds, [1, -1, 1])
   1168         inds = tf.repeat(inds, batch_size, 0)
   1169         inds = inds[:,:num_paths*self._num_symbols]
   1170         path_inds = tf.concat([path_inds, inds], axis=-1)
   1171 
   1172         # Compute partial distances
   1173         # Extract the row of r corresponding to layer and reverse the order
   1174         y = tf.expand_dims(y[:, stream_ind], axis=-1)
   1175         r = tf.expand_dims(tf.reverse(r[:, stream_ind, stream_ind:], [-1]), 1)
   1176         delta = tf.pow(tf.abs(y - tf.reduce_sum(r*path_syms, axis=-1)), 2)
   1177 
   1178         # Update distances
   1179         dists += delta
   1180 
   1181         # Get k best paths
   1182         dists, path_syms, path_inds = self._select_best_paths(dists, path_syms, path_inds)
   1183 
   1184         # Scatter updates of dists
   1185         tensor = tf.transpose(dists_o, perm=[1, 0])
   1186         updates = tf.transpose(dists, perm=[1, 0])
   1187         indices = tf.expand_dims(tf.range(tf.shape(updates)[0], dtype=tf.int32), -1)
   1188         dists = tf.tensor_scatter_nd_update(tensor, indices, updates)
   1189         dists = tf.transpose(dists, perm=[1, 0])
   1190 
   1191         # Scatter update of path_syms
   1192         tensor = tf.transpose(path_syms_o, [1, 2, 0])
   1193         updates = tf.transpose(path_syms, [1, 2, 0])
   1194         updates = tf.reshape(updates, [-1, batch_size])
   1195         indices = self._indices[stream, :self._num_paths[stream+1]*(stream+1)]
   1196         path_syms = tf.tensor_scatter_nd_update(tensor, indices, updates)
   1197         path_syms = tf.transpose(path_syms, perm=[2, 0, 1])
   1198 
   1199         # Scatter update of path_inds
   1200         tensor = tf.transpose(path_inds_o, [1, 2, 0])
   1201         updates = tf.transpose(path_inds, [1, 2, 0])
   1202         updates = tf.reshape(updates, [-1, batch_size])
   1203         path_inds = tf.tensor_scatter_nd_update(tensor, indices, updates)
   1204         path_inds = tf.transpose(path_inds, perm=[2, 0, 1])
   1205 
   1206         return dists, path_syms, path_inds
   1207 
   1208     def _unsort(self, column_order, tensor, transpose=True):
   1209         # Undo the column sorting
   1210         # If transpose=True, the unsorting is done along the last dimension
   1211         # Otherwise, sorting is done along the second-last index
   1212         unsort_inds = tf.argsort(column_order, axis=-1)
   1213         if transpose:
   1214             tensor = tf.transpose(tensor, perm=[0, 2, 1])
   1215         tensor = tf.gather(tensor, unsort_inds, axis=-2, batch_dims=1)
   1216         if transpose:
   1217             tensor = tf.transpose(tensor, perm=[0, 2, 1])
   1218         return tensor
   1219 
   1220     def build(self, input_shape):
   1221         assert input_shape[1][-2]>=input_shape[1][-1], \
   1222                 "The number of receive antennas cannot be smaller \
   1223                  than the number of streams"
   1224 
   1225     def call(self, inputs):
   1226 
   1227         # Flatten the batch dimensions
   1228         y, h, s = inputs
   1229         batch_shape = tf.shape(y)[:-1]
   1230         num_batch_dims = len(batch_shape)
   1231         if num_batch_dims > 1:
   1232             y = flatten_dims(y, num_batch_dims, 0)
   1233             h = flatten_dims(h, num_batch_dims, 0)
   1234             s = flatten_dims(s, num_batch_dims, 0)
   1235             inputs = (y,h,s)
   1236 
   1237         # Initialization
   1238         # (i) (optional) Convert to real-valued representation
   1239         # (ii) Whiten channel
   1240         # (iii) Sort columns of H by decreasing column norm
   1241         # (iv) QR Decomposition of H
   1242         # (v) Project y onto Q'
   1243         y, r, column_order = self._preprocessing(inputs)
   1244 
   1245         batch_size = tf.shape(y)[0]
   1246 
   1247         # Tensor to keep track of the aggregate distances of all paths
   1248         dists = tf.zeros([batch_size, self._k], y.dtype.real_dtype)
   1249 
   1250         # Tensor to store constellation symbols of all paths
   1251         path_syms = tf.zeros([batch_size, self._k, self._num_streams], y.dtype)
   1252 
   1253         # Tensor to store constellation symbol indices of all paths
   1254         path_inds = tf.zeros([batch_size, self._k, self._num_streams],tf.int32)
   1255 
   1256         # Sequential K-Best algorithm
   1257         for stream in range(0, self._num_streams):
   1258             dists, path_syms, path_inds = self._next_layer(y,
   1259                                                            r,
   1260                                                            dists,
   1261                                                            path_syms,
   1262                                                            path_inds,
   1263                                                            stream)
   1264 
   1265         # Reverse order as detection started with the last symbol first
   1266         path_syms = tf.reverse(path_syms, axis=[-1])
   1267         path_inds = tf.reverse(path_inds, axis=[-1])
   1268 
   1269         # Processing for hard-decisions
   1270         if self._hard_out:
   1271             path_inds = self._unsort(column_order, path_inds)
   1272             hard_dec = path_inds[:,0,:]
   1273 
   1274             # Real-valued representation
   1275             if self._use_real_rep:
   1276                 hard_dec = \
   1277                     self._pam2qam(hard_dec[...,:self._num_streams//2],
   1278                                   hard_dec[...,self._num_streams//2:])
   1279 
   1280             # Hard decisions on bits
   1281             if self._output=="bit":
   1282                 hard_dec = self._symbolinds2bits(hard_dec)
   1283 
   1284             # Reshape batch dimensions
   1285             if num_batch_dims > 1:
   1286                 hard_dec = split_dim(hard_dec, batch_shape, 0)
   1287 
   1288             return hard_dec
   1289 
   1290         # Processing for soft-decisions
   1291         else:
   1292             # Real-valued representation
   1293             if self._use_real_rep:
   1294                 llr = self.list2llr([y, r, dists, path_inds, path_syms])
   1295                 llr = self._unsort(column_order, llr, transpose=False)
   1296 
   1297                 # Combine LLRs from PAM symbols in the correct order
   1298                 llr1 = llr[:,:self._num_streams//2]
   1299                 llr2 = llr[:,self._num_streams//2:]
   1300                 llr1 = tf.expand_dims(llr1, -1)
   1301                 llr2 = tf.expand_dims(llr2, -1)
   1302                 llr = tf.concat([llr1, llr2], -1)
   1303                 llr = tf.reshape(llr, [-1, self._num_streams//2,
   1304                                    2*self._num_bits_per_symbol])
   1305 
   1306             # Complex-valued representation
   1307             else:
   1308                 llr = self.list2llr([y, r, dists, path_inds, path_syms])
   1309                 llr = self._unsort(column_order, llr, transpose=False)
   1310 
   1311             # Reshape batch dimensions
   1312             if num_batch_dims > 1:
   1313                 llr = split_dim(llr, batch_shape, 0)
   1314 
   1315             return llr
   1316 
   1317 class EPDetector(Layer):
   1318     # pylint: disable=line-too-long
   1319     r"""EPDetector(output, num_bits_per_symbol, hard_out=False, l=10, beta=0.9, dtype=tf.complex64)
   1320 
   1321     MIMO Expectation Propagation (EP) detector
   1322 
   1323     This layer implements Expectation Propagation (EP) MIMO detection as described
   1324     in [EP2014]_. It can generate hard- or soft-decisions for symbols or bits.
   1325 
   1326     This layer assumes the following channel model:
   1327 
   1328     .. math::
   1329         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
   1330 
   1331     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
   1332     :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which
   1333     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
   1334     :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix,
   1335     and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector.
   1336     It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and
   1337     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`,
   1338     where :math:`\mathbf{S}` has full rank.
   1339 
   1340     The channel model is first whitened using :func:`~sionna.mimo.whiten_channel`
   1341     and then converted to its real-valued equivalent,
   1342     see :func:`~sionna.mimo.complex2real_channel`, prior to MIMO detection.
   1343 
   1344     The computation of LLRs is done by converting the symbol logits
   1345     that naturally arise in the algorithm to LLRs using
   1346     :func:`~sionna.mapping.PAM2QAM`. Custom conversions of symbol logits to LLRs
   1347     can be implemented by using the soft-symbol output.
   1348 
   1349     Parameters
   1350     -----------
   1351     output : One of ["bit", "symbol"], str
   1352         The type of output, either bits or symbols. Whether soft- or
   1353         hard-decisions are returned can be configured with the
   1354         ``hard_out`` flag.
   1355 
   1356     num_bits_per_symbol : int
   1357         The number of bits per QAM constellation symbol, e.g., 4 for QAM16.
   1358 
   1359     hard_out : bool
   1360         If `True`, the detector computes hard-decided bit values or
   1361         constellation point indices instead of soft-values.
   1362         Defaults to `False`.
   1363 
   1364     l : int
   1365         Number of iterations. Defaults to 10.
   1366 
   1367     beta : float
   1368         Parameter :math:`\beta\in[0,1]` for update smoothing.
   1369         Defaults to 0.9.
   1370 
   1371     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1372         Precision used for internal computations. Defaults to ``tf.complex64``.
   1373         Especially for large MIMO setups, the precision can make a significant
   1374         performance difference.
   1375 
   1376     Input
   1377     -----
   1378     (y, h, s) :
   1379         Tuple:
   1380 
   1381     y : [...,M], tf.complex
   1382         1+D tensor containing the received signals
   1383 
   1384     h : [...,M,num_streams], tf.complex
   1385         2+D tensor containing the channel matrices
   1386 
   1387     s : [...,M,M], tf.complex
   1388         2+D tensor containing the noise covariance matrices
   1389 
   1390     Output
   1391     ------
   1392     One of:
   1393 
   1394     : [...,num_streams,num_bits_per_symbol], tf.float
   1395         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`
   1396 
   1397     : [...,num_streams,2**num_bits_per_symbol], tf.float or [...,num_streams], tf.int
   1398        Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`
   1399 
   1400     Note
   1401     ----
   1402     For numerical stability, we do not recommend to use this function in Graph
   1403     mode with XLA, i.e., within a function that is decorated with
   1404     ``@tf.function(jit_compile=True)``.
   1405     However, it is possible to do so by setting
   1406     ``sionna.Config.xla_compat=true``.
   1407     See :py:attr:`~sionna.Config.xla_compat`.
   1408     """
   1409     def __init__(self,
   1410                  output,
   1411                  num_bits_per_symbol,
   1412                  hard_out=False,
   1413                  l=10,
   1414                  beta=0.9,
   1415                  dtype=tf.complex64,
   1416                  **kwargs):
   1417         super().__init__(dtype=dtype, **kwargs)
   1418         assert dtype in [tf.complex64, tf.complex128], \
   1419             "Invalid dtype"
   1420         self._cdtype = tf.dtypes.as_dtype(dtype)
   1421         self._rdtype = self._cdtype.real_dtype
   1422 
   1423         # Variable used to avoid numerical instabilities
   1424         # See paragraph after Eq. (38)
   1425         if self.dtype=="complex64":
   1426             self._prec = 1e-6
   1427         else:
   1428             self._prec = 1e-12
   1429 
   1430         assert output in ("bit", "symbol"), "Unknown output"
   1431         self._output = output
   1432 
   1433         self._hard_out = hard_out
   1434 
   1435         if self._output=="symbol":
   1436             self._pam2qam = PAM2QAM(num_bits_per_symbol, hard_out)
   1437         else:
   1438             self._symbollogits2llrs = SymbolLogits2LLRs("maxlog",
   1439                                                         num_bits_per_symbol//2,
   1440                                                         hard_out=hard_out)
   1441             self._demapper = Demapper("maxlog", "pam", num_bits_per_symbol//2)
   1442 
   1443         assert l>=1, "l must be a positive integer"
   1444         self._l = l
   1445 
   1446         assert 0.0<= beta <=1.0, "beta must be in [0,1]"
   1447         self._beta = beta
   1448 
   1449         # Create PAM constellations for real-valued detection
   1450         self._num_bits_per_symbol = num_bits_per_symbol//2
   1451         points = Constellation("pam", int(self._num_bits_per_symbol)).points
   1452 
   1453         # Scale constellation points to half the energy because QAM is assumed
   1454         self._points = tf.cast(points/np.sqrt(2.0), self._rdtype)
   1455 
   1456         # Average symbol energy
   1457         self._es = tf.constant(np.var(self._points), self._rdtype)
   1458 
   1459     def compute_sigma_mu(self, h_t_h, h_t_y, no, lam, gam):
   1460         """Equations (28) and (29)"""
   1461 
   1462         # Prepare inputs
   1463         lam = tf.linalg.diag(lam)
   1464         gam = tf.expand_dims(gam, axis=-1)
   1465 
   1466         # Computations
   1467         sigma = tf.linalg.inv(h_t_h + no*lam)
   1468         mu = tf.squeeze(tf.matmul(sigma, h_t_y + no*gam), axis=-1)
   1469         sigma *= no
   1470         sigma = tf.linalg.diag_part(sigma)
   1471 
   1472         return sigma, mu
   1473 
   1474     def compute_v_x_obs(self, sigma, mu, lam, gam):
   1475         """Equations (31) and (32)"""
   1476 
   1477         v_obs = tf.maximum(1/(1/sigma-lam), self._prec)
   1478         x_obs = v_obs*(mu/sigma-gam)
   1479 
   1480         return v_obs, x_obs
   1481 
   1482     def compute_v_x(self, v_obs, x_obs):
   1483         """Equation (33)"""
   1484 
   1485         # Compute probability mass function for the symbols
   1486         x_obs = tf.expand_dims(x_obs, -1)
   1487         v_obs = tf.expand_dims(v_obs, -1)
   1488 
   1489         points = expand_to_rank(self._points, tf.rank(x_obs), axis=0)
   1490         logits = -tf.pow(x_obs-points, 2) / (tf.cast(2, self._rdtype)*v_obs)
   1491         pmf = tf.math.softmax(logits)
   1492 
   1493         # Compute mean and variance of all symbols
   1494         x = tf.reduce_sum(points * pmf, axis=-1, keepdims=True)
   1495         v = tf.reduce_sum((points-x)**2 * pmf, axis=-1)
   1496         v = tf.maximum(v, self._prec)
   1497         x = tf.squeeze(x, axis=-1)
   1498 
   1499         return v, x, logits
   1500 
   1501     def update_lam_gam(self, v, v_obs, x, x_obs, lam, gam):
   1502         """Equations (35), (36), (37), (38)"""
   1503 
   1504         # Save old values of lam, and gam
   1505         lam_old = lam
   1506         gam_old = gam
   1507 
   1508         # Compute potential new values (35), (36)
   1509         lam = 1/v - 1/v_obs
   1510         gam = x/v - x_obs/v_obs
   1511 
   1512         # Only update nonnegative values
   1513         lam_new = tf.where(lam<0, lam_old, lam)
   1514         gam_new = tf.where(lam<0, gam_old, gam)
   1515 
   1516         # Damp updates (37), (38)
   1517         lam_damp = (1-self._beta)*lam_new + self._beta*lam_old
   1518         gam_damp = (1-self._beta)*gam_new + self._beta*gam_old
   1519 
   1520         return lam_damp, gam_damp
   1521 
   1522     def call(self, inputs):
   1523 
   1524         # Flatten the batch dimensions
   1525         y, h, s = inputs
   1526         batch_shape = tf.shape(y)[:-1]
   1527         num_batch_dims = len(batch_shape)
   1528         if num_batch_dims > 1:
   1529             y = flatten_dims(y, num_batch_dims, 0)
   1530             h = flatten_dims(h, num_batch_dims, 0)
   1531             s = flatten_dims(s, num_batch_dims, 0)
   1532             inputs = (y,h,s)
   1533 
   1534         # Number of transmit streams
   1535         n_t = tf.shape(h)[-1]
   1536 
   1537         # Whiten channel
   1538         y, h, s = whiten_channel(y, h, s)
   1539 
   1540         # Convert channel to real-valued representation
   1541         y, h, s = complex2real_channel(y,h,s)
   1542 
   1543         # Convert all inputs to desired dtypes
   1544         y = tf.cast(y, self._rdtype)
   1545         h = tf.cast(h, self._rdtype)
   1546         no = tf.cast(0.5, self._rdtype)
   1547 
   1548         # Gather relevant parameters
   1549         batch_dims = tf.shape(y)[:-1]
   1550         n_t_r = tf.shape(h)[-1]
   1551 
   1552         # Initialize gamma and lambda (Paragraph after Eq. (29))
   1553         gam = tf.zeros(tf.concat([batch_dims, [n_t_r]], axis=0), y.dtype)
   1554         lam = tf.ones(tf.concat([batch_dims, [n_t_r]], axis=0), y.dtype)
   1555         lam /= tf.cast(self._es, y.dtype)
   1556 
   1557         # Precompute values that are repeatedly needed
   1558         h_t_h = tf.matmul(h, h, transpose_a=True)
   1559         y = tf.expand_dims(y, axis=-1)
   1560         h_t_y = tf.matmul(h, y, transpose_a=True)
   1561         no = expand_to_rank(no, tf.rank(h), axis=-1)
   1562 
   1563         for _ in range(self._l):
   1564             sigma, mu = self.compute_sigma_mu(h_t_h, h_t_y, no, lam, gam)
   1565             v_obs, x_obs = self.compute_v_x_obs(sigma, mu, lam, gam)
   1566             v, x, logits = self.compute_v_x(v_obs, x_obs)
   1567             lam, gam = self.update_lam_gam(v, v_obs, x, x_obs, lam, gam)
   1568 
   1569         # Extract the logits for the 2 PAM constellations for each streams
   1570         pam1_logits = logits[...,:n_t,:]
   1571         pam2_logits = logits[...,n_t:,:]
   1572 
   1573         if self._output=="symbol" and self._hard_out:
   1574             # Take hard decisions on PAM symbol;s
   1575             pam1_ind = tf.argmax(pam1_logits, axis=-1, output_type=tf.int32)
   1576             pam2_ind = tf.argmax(pam2_logits, axis=-1, output_type=tf.int32)
   1577 
   1578             # Transform to QAM indices
   1579             qam_ind = self._pam2qam(pam1_ind, pam2_ind)
   1580 
   1581             # Reshape batch dimensions
   1582             if num_batch_dims > 1:
   1583                 qam_ind = split_dim(qam_ind, batch_shape, 0)
   1584 
   1585             return qam_ind
   1586 
   1587         elif self._output=="symbol" and not self._hard_out:
   1588             qam_logits = self._pam2qam(pam1_logits, pam2_logits)
   1589 
   1590             # Reshape batch dimensions
   1591             if num_batch_dims > 1:
   1592                 qam_logits = split_dim(qam_logits, batch_shape, 0)
   1593 
   1594             return qam_logits
   1595 
   1596         elif self._output=="bit":
   1597             # Compute LLRs for both PAM constellations
   1598             llr1 = self._symbollogits2llrs(pam1_logits)
   1599             llr2 = self._symbollogits2llrs(pam2_logits)
   1600 
   1601             # Put LLRs in the correct order and shape
   1602             llr = tf.stack([llr1, llr2], -1)
   1603             llr = flatten_last_dims(llr)
   1604 
   1605             # Reshape batch dimensions
   1606             if num_batch_dims > 1:
   1607                 llr = split_dim(llr, batch_shape, 0)
   1608 
   1609             return llr
   1610 
   1611 class MMSEPICDetector(Layer):
   1612     # pylint: disable=line-too-long
   1613     r"""MMSEPICDetector(output, demapping_method="maxlog", num_iter=1, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
   1614 
   1615     Minimum mean square error (MMSE) with parallel interference cancellation (PIC) detector
   1616 
   1617     This layer implements the MMSE PIC detector, as proposed in [CST2011]_.
   1618     For ``num_iter``>1, this implementation performs MMSE PIC self-iterations.
   1619     MMSE PIC self-iterations can be understood as a concatenation of MMSE PIC
   1620     detectors from [CST2011]_, which forward intrinsic LLRs to the next
   1621     self-iteration.
   1622 
   1623     Compared to [CST2011]_, this implementation also accepts priors on the
   1624     constellation symbols as an alternative to priors on the bits.
   1625 
   1626     This layer assumes the following channel model:
   1627 
   1628     .. math::
   1629         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
   1630 
   1631     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
   1632     :math:`\mathbf{x}\in\mathcal{C}^S` is the vector of transmitted symbols which
   1633     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
   1634     :math:`\mathbf{H}\in\mathbb{C}^{M\times S}` is the known channel matrix,
   1635     and :math:`\mathbf{n}\in\mathbb{C}^M` is a complex Gaussian noise vector.
   1636     It is assumed that :math:`\mathbb{E}\left[\mathbf{n}\right]=\mathbf{0}` and
   1637     :math:`\mathbb{E}\left[\mathbf{n}\mathbf{n}^{\mathsf{H}}\right]=\mathbf{S}`,
   1638     where :math:`\mathbf{S}` has full rank.
   1639 
   1640     The algorithm starts by computing the soft symbols
   1641     :math:`\bar{x}_s=\mathbb{E}\left[ x_s \right]` and
   1642     variances :math:`v_s=\mathbb{E}\left[ |e_s|^2\right]` from the priors,
   1643     where :math:`e_s = x_s - \bar{x}_s`, for all :math:`s=1,\dots,S`.
   1644 
   1645     Next, for each stream, the interference caused by all other streams is cancelled
   1646     from the observation :math:`\mathbf{y}`, leading to
   1647 
   1648     .. math::
   1649         \hat{\mathbf{y}}_s = \mathbf{y} - \sum_{j\neq s} \mathbf{h}_j x_j = \mathbf{h}_s x_s + \tilde{\mathbf{n}}_s,\quad s=1,\dots,S
   1650 
   1651     where :math:`\tilde{\mathbf{n}}_s=\sum_{j\neq s} \mathbf{h}_j e_j + \mathbf{n}`.
   1652 
   1653     Then, a linear MMSE filter :math:`\mathbf{w}_s` is computed to reduce the resdiual noise
   1654     for each observation :math:`\hat{\mathbf{y}}_s`, which is given as
   1655 
   1656     .. math::
   1657         \mathbf{w}_s = \mathbf{h}_s^{\mathsf{H}}\left( \mathbf{H} \mathbf{D}_s\mathbf{H}^{\mathsf{H}} +\mathbf{S} \right)^{-1}
   1658 
   1659     where :math:`\mathbf{D}_s \in \mathbb{C}^{S\times S}` is diagonal with entries
   1660 
   1661     .. math::
   1662         \left[\mathbf{D}_s\right]_{i,i} = \begin{cases}
   1663                                             v_i & i\neq s \\
   1664                                             1 & i=s.
   1665                                           \end{cases}
   1666 
   1667     The filtered observations
   1668 
   1669     .. math::
   1670         \tilde{z}_s = \mathbf{w}_s^{\mathsf{H}} \hat{\mathbf{y}}_s = \tilde{\mu}_s x_s + \mathbf{w}_s^{\mathsf{H}}\tilde{\mathbf{n}}_s
   1671 
   1672     where :math:`\tilde{\mu}_s=\mathbf{w}_s^{\mathsf{H}} \mathbf{h}_s`, are then demapped to either symbol logits or LLRs, assuming that the remaining noise is Gaussian with variance
   1673 
   1674     .. math::
   1675         \nu_s^2 = \mathop{\text{Var}}\left[\tilde{z}_s\right] = \mathbf{w}_s^{\mathsf{H}} \left(\sum_{j\neq s} \mathbf{h}_j \mathbf{h}_j^{\mathsf{H}} v_j +\mathbf{S} \right)\mathbf{w}_s.
   1676 
   1677     The resulting soft-symbols can then be used for the next self-iteration of the algorithm.
   1678 
   1679     Note that this algorithm can be substantially simplified as described in [CST2011]_ to avoid
   1680     the computation of different matrix inverses for each stream. This is the version which is
   1681     implemented.
   1682 
   1683     Parameters
   1684     -----------
   1685     output : One of ["bit", "symbol"], str
   1686         The type of output, either LLRs on bits or logits on constellation
   1687         symbols.
   1688 
   1689     demapping_method : One of ["app", "maxlog"], str
   1690         The demapping method used.
   1691         Defaults to "maxlog".
   1692 
   1693     num_iter : int
   1694         Number of MMSE PIC iterations.
   1695         Defaults to 1.
   1696 
   1697     constellation_type : One of ["qam", "pam", "custom"], str
   1698         For "custom", an instance of :class:`~sionna.mapping.Constellation`
   1699         must be provided.
   1700 
   1701     num_bits_per_symbol : int
   1702         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1703         Only required for ``constellation_type`` in ["qam", "pam"].
   1704 
   1705     constellation : Constellation
   1706         An instance of :class:`~sionna.mapping.Constellation` or `None`.
   1707         In the latter case, ``constellation_type``
   1708         and ``num_bits_per_symbol`` must be provided.
   1709 
   1710     hard_out : bool
   1711         If `True`, the detector computes hard-decided bit values or
   1712         constellation point indices instead of soft-values.
   1713         Defaults to `False`.
   1714 
   1715     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1716         The dtype of ``y``. Defaults to tf.complex64.
   1717         The output dtype is the corresponding real dtype
   1718         (tf.float32 or tf.float64).
   1719 
   1720     Input
   1721     -----
   1722     (y, h, prior, s) :
   1723         Tuple:
   1724 
   1725     y : [...,M], tf.complex
   1726         1+D tensor containing the received signals
   1727 
   1728     h : [...,M,S], tf.complex
   1729         2+D tensor containing the channel matrices
   1730 
   1731     prior : [...,S,num_bits_per_symbol] or [...,S,num_points], tf.float
   1732         Prior of the transmitted signals.
   1733         If ``output`` equals "bit", then LLRs of the transmitted bits are expected.
   1734         If ``output`` equals "symbol", then logits of the transmitted constellation points are expected.
   1735 
   1736     s : [...,M,M], tf.complex
   1737         2+D tensor containing the noise covariance matrices
   1738 
   1739     Output
   1740     ------
   1741     One of:
   1742 
   1743     : [...,S,num_bits_per_symbol], tf.float
   1744         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`
   1745 
   1746     : [...,S,2**num_bits_per_symbol], tf.float or [...,S], tf.int
   1747        Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`
   1748 
   1749     Note
   1750     ----
   1751     For numerical stability, we do not recommend to use this function in Graph
   1752     mode with XLA, i.e., within a function that is decorated with
   1753     ``@tf.function(jit_compile=True)``.
   1754     However, it is possible to do so by setting
   1755     ``sionna.Config.xla_compat=true``.
   1756     See :py:attr:`~sionna.Config.xla_compat`.
   1757     """
   1758     def __init__(self,
   1759                  output,
   1760                  demapping_method="maxlog",
   1761                  num_iter=1,
   1762                  constellation_type=None,
   1763                  num_bits_per_symbol=None,
   1764                  constellation=None,
   1765                  hard_out=False,
   1766                  dtype=tf.complex64,
   1767                  **kwargs):
   1768         super().__init__(dtype=dtype, **kwargs)
   1769 
   1770         assert isinstance(num_iter, int), "num_iter must be an integer"
   1771         assert output in ("bit", "symbol"), "Unknown output"
   1772         assert demapping_method in ("app", "maxlog"), "Unknown demapping method"
   1773 
   1774         assert dtype in [tf.complex64, tf.complex128], \
   1775             "dtype must be tf.complex64 or tf.complex128"
   1776 
   1777         self._num_iter = num_iter
   1778         self._output = output
   1779         self._epsilon = 1e-4
   1780         self._realdtype = dtype.real_dtype
   1781         self._demapping_method = demapping_method
   1782         self._hard_out = hard_out
   1783 
   1784         # Create constellation object
   1785         self._constellation = Constellation.create_or_check_constellation(
   1786             constellation_type,
   1787             num_bits_per_symbol,
   1788             constellation,
   1789             dtype=dtype)
   1790 
   1791         # Soft symbol mapping
   1792         self._llr_2_symbol_logits = LLRs2SymbolLogits(
   1793                                         self._constellation.num_bits_per_symbol,
   1794                                         dtype=self._realdtype)
   1795 
   1796         if self._output == "symbol":
   1797             self._llr_2_symbol_logits_output = LLRs2SymbolLogits(
   1798                                     self._constellation.num_bits_per_symbol,
   1799                                     dtype=self._realdtype,
   1800                                     hard_out=hard_out)
   1801             self._symbol_logits_2_llrs = SymbolLogits2LLRs(
   1802                 method=demapping_method,
   1803                 num_bits_per_symbol=self._constellation.num_bits_per_symbol)
   1804         self._symbol_logits_2_moments = SymbolLogits2Moments(
   1805                                             constellation=self._constellation,
   1806                                             dtype=self._realdtype)
   1807 
   1808         # soft output demapping
   1809         self._bit_demapper = DemapperWithPrior(
   1810                                             demapping_method=demapping_method,
   1811                                             constellation=self._constellation,
   1812                                             dtype=dtype)
   1813 
   1814 
   1815     def call(self, inputs):
   1816         y, h, prior, s = inputs
   1817         # y is unwhitened receive signal
   1818         #   [..., M]
   1819         # h the channel estimate
   1820         #   [..., M, K]
   1821         # prior is either the soft input LLRs
   1822         #   [..., K, num_bits_per_symbol] or symbol logits [..., K, Q]
   1823         # s the noise covariance matrix
   1824         #   [..., M, M]
   1825 
   1826         ## Preprocessing
   1827         # Whiten channel
   1828         # y : [..., M]
   1829         # s : [..., M, M]
   1830         y, h = whiten_channel(y, h, s, return_s=False)  # pylint: disable=unbalanced-tuple-unpacking
   1831 
   1832         # matched filtering of y
   1833         # [..., K, 1]
   1834         y_mf = insert_dims(tf.linalg.matvec(h, y, adjoint_a=True),
   1835                             num_dims=1, axis=-1)
   1836 
   1837         ## Step 1: compute Gramm matrix
   1838         # [..., K, K]
   1839         g = tf.matmul(h, h, adjoint_a=True)
   1840 
   1841         # For XLA compatibility, this implementation performs the MIMO
   1842         # equalization in the real-valued domain
   1843         # [..., 2M, 2K]
   1844         hr = complex2real_matrix(h)
   1845         # [..., 2K, 2K]
   1846         gr = tf.matmul(hr, hr, adjoint_a=True)
   1847 
   1848         # Compute a priori LLRs
   1849         if self._output == "symbol":
   1850             llr_a = self._symbol_logits_2_llrs(prior)
   1851         else:
   1852             llr_a = prior
   1853         # llr_a is [..., K, num_bits_per_symbol]
   1854         llr_shape = tf.shape(llr_a)
   1855 
   1856         def mmse_pic_self_iteration(llr_d, llr_a, it):
   1857             # MMSE PIC takes in a priori LLRs
   1858             llr_a = llr_d
   1859 
   1860             # Step 2: compute soft symbol estimates and variances
   1861             # x_hat, var_x : [..., K]
   1862             x_logits = self._llr_2_symbol_logits(llr_a)
   1863             x_hat, var_x = self._symbol_logits_2_moments(x_logits)
   1864 
   1865             # Step 3: perform parallel interference cancellation
   1866             # H^H y_hat_i = y_mf - sum_j!=i gj x_hat_j = y + g_i x_hat_i
   1867             #               - sum_j g_j x_hat_j
   1868             # [..., K, K]
   1869             y_mf_pic = y_mf + g * insert_dims(x_hat, num_dims=1, axis=-2) \
   1870                 - tf.linalg.matmul(g, insert_dims(x_hat, num_dims=1, axis=-1))
   1871 
   1872             # Step 4: compute A^-1 matrix
   1873             # Calculate MMSE Filter (efficiently)
   1874             # W^H = A^-1 H^H
   1875             # A = H^H H \Lambda + N_0 I_Mt
   1876             # \Lambda_ii is a diagonal matrix with \Lambda_ii = E_i = error_var
   1877 
   1878             # Stack error variances and make it real
   1879             # Note: Imaginary part is zero
   1880             var_x = tf.cast(tf.concat([var_x, var_x], axis=-1),
   1881                             dtype=self._realdtype)
   1882             var_x_row_vec = insert_dims(var_x, num_dims=1, axis=-2)
   1883             # [..., 2K, 2K]
   1884             a = gr * var_x_row_vec
   1885 
   1886             i = expand_to_rank(tf.eye(tf.shape(a)[-1], dtype=a.dtype),
   1887                                 tf.rank(a), 0)
   1888             a = a + i
   1889 
   1890             # a is non-hermitian! that's why we can't use sn.utils.matrix_inv
   1891             # XLA can't invert complex matrices, that's why we work with the
   1892             # real valued domain
   1893             a_inv = tf.linalg.inv(a)
   1894 
   1895             # Step 5: compute unbiased MMSE filter and outputs, calculate A\H^H
   1896 
   1897             # Calculate bias mu_i = diag(A^-1 H^H H) = diag(A^-1 G)
   1898             # Diagonal elements of matrix matrix multiplication simplified
   1899             # to sum and dot-product
   1900             # [..., 2K]
   1901             mu = tf.reduce_sum(a_inv * tf.linalg.matrix_transpose(gr), axis=-1)
   1902 
   1903             # Make y_mf_pic columns real (after transposition,
   1904             # the last dimension corresponds to vectors)
   1905             # [..., K, 2K]
   1906             y_mf_pic_trans = tf.linalg.matrix_transpose(y_mf_pic)
   1907             y_mf_pic_trans = complex2real_vector(y_mf_pic_trans)
   1908             # stack them such that y_mf_pic_trans has shape [..., 2K, 2K]
   1909             y_mf_pic_trans = tf.concat([y_mf_pic_trans, y_mf_pic_trans],
   1910                                         axis=-2)
   1911 
   1912             # Efficient parallel equalization after PIC
   1913             # z_i = i'th row of a_inv * y_MF_PIC_i
   1914             # boils down to tf.reduce_sum(a_inv * y_mf_pic_trans, axis=-1)
   1915             # divide by mu_i for unbiasedness
   1916             # [..., K]
   1917             x_hat = real2complex_vector(tf.reduce_sum(a_inv * y_mf_pic_trans,
   1918                                     axis=-1) / tf.cast(mu, dtype=a_inv.dtype))
   1919 
   1920             # Compute post equalization signal error estimate:
   1921             # rho_i = mu_i / (1 - var_x_i * mu_i)
   1922             # 1 - var_x_i * mu_i can become numerically 0, or even slightly
   1923             # smaller than zero due to limited numerical precision
   1924             # [..., 2K]
   1925             var_x = tf.divide(mu, tf.maximum(1 - var_x * mu, self._epsilon))
   1926             # real variances map to the same complex valued variances in this
   1927             # model
   1928             var_x, _ = tf.split(var_x, 2, -1)
   1929 
   1930             no_eff = 1. / var_x
   1931 
   1932             # Step 6: LLR demapping (extrinsic LLRs)
   1933             # [..., K, num_bits_per_symbols]
   1934             llr_d = tf.reshape(self._bit_demapper([x_hat, llr_a, no_eff]),
   1935                                 llr_shape)
   1936 
   1937             return llr_d, llr_a, it
   1938 
   1939         # Stopping condition (required for tf.while_loop)
   1940         def dec_stop(llr_d, llr_a, it):  # pylint: disable=W0613
   1941             return tf.less(it, self._num_iter)
   1942 
   1943         # start decoding iterations
   1944         it = tf.constant(0)
   1945         null_prior = tf.zeros(llr_shape, dtype=self._realdtype)
   1946         llr_d, llr_a, _ = tf.while_loop(dec_stop,
   1947                                     mmse_pic_self_iteration,
   1948                                     (llr_a, null_prior, it),
   1949                                     parallel_iterations=1,
   1950                                     maximum_iterations=self._num_iter)
   1951         llr_e = llr_d - llr_a
   1952         if self._output == "symbol":
   1953             # convert back to symbols if requested.
   1954              # output symbol logits computed on extrinsic LLRs
   1955             out = self._llr_2_symbol_logits_output(llr_e)
   1956         else:
   1957             # output extrinsic LLRs
   1958             out = llr_e
   1959             if self._hard_out:
   1960                 out = hard_decisions(out)
   1961 
   1962         return out