anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

detection.py (55944B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition and functions related to OFDM channel equalization"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 from sionna.utils import flatten_dims, split_dim, flatten_last_dims, expand_to_rank
     10 from sionna.ofdm import RemoveNulledSubcarriers
     11 from sionna.mimo import MaximumLikelihoodDetectorWithPrior as MaximumLikelihoodDetectorWithPrior_
     12 from sionna.mimo import MaximumLikelihoodDetector as MaximumLikelihoodDetector_
     13 from sionna.mimo import LinearDetector as LinearDetector_
     14 from sionna.mimo import KBestDetector as KBestDetector_
     15 from sionna.mimo import EPDetector as EPDetector_
     16 from sionna.mimo import MMSEPICDetector as MMSEPICDetector_
     17 from sionna.mapping import Constellation
     18 
     19 
     20 class OFDMDetector(Layer):
     21     # pylint: disable=line-too-long
     22     r"""OFDMDetector(detector, output, resource_grid, stream_management, dtype=tf.complex64, **kwargs)
     23 
     24     Layer that wraps a MIMO detector for use with the OFDM waveform.
     25 
     26     The parameter ``detector`` is a callable (e.g., a function) that
     27     implements a MIMO detection algorithm for arbitrary batch dimensions.
     28 
     29     This class pre-processes the received resource grid ``y`` and channel
     30     estimate ``h_hat``, and computes for each receiver the
     31     noise-plus-interference covariance matrix according to the OFDM and stream
     32     configuration provided by the ``resource_grid`` and
     33     ``stream_management``, which also accounts for the channel
     34     estimation error variance ``err_var``. These quantities serve as input to the detection
     35     algorithm that is implemented by ``detector``.
     36     Both detection of symbols or bits with either soft- or hard-decisions are supported.
     37 
     38     Note
     39     -----
     40     The callable ``detector`` must take as input a tuple :math:`(\mathbf{y}, \mathbf{h}, \mathbf{s})` such that:
     41 
     42     * **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals.
     43     * **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices.
     44     * **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices.
     45 
     46     It must generate one of following outputs depending on the value of ``output``:
     47 
     48     * **b_hat** ([..., num_streams_per_rx, num_bits_per_symbol], tf.float) -- LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
     49     * **x_hat** ([..., num_streams_per_rx, num_points], tf.float) or ([..., num_streams_per_rx], tf.int) -- Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. Hard-decisions correspond to the symbol indices.
     50 
     51     Parameters
     52     ----------
     53     detector : Callable
     54         Callable object (e.g., a function) that implements a MIMO detection
     55         algorithm for arbitrary batch dimensions. Either one of the existing detectors, e.g.,
     56         :class:`~sionna.mimo.LinearDetector`, :class:`~sionna.mimo.MaximumLikelihoodDetector`, or
     57         :class:`~sionna.mimo.KBestDetector` can be used, or a custom detector
     58         callable provided that has the same input/output specification.
     59 
     60     output : One of ["bit", "symbol"], str
     61         Type of output, either bits or symbols
     62 
     63     resource_grid : ResourceGrid
     64         Instance of :class:`~sionna.ofdm.ResourceGrid`
     65 
     66     stream_management : StreamManagement
     67         Instance of :class:`~sionna.mimo.StreamManagement`
     68 
     69     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
     70         The dtype of `y`. Defaults to tf.complex64.
     71         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
     72 
     73     Input
     74     ------
     75     (y, h_hat, err_var, no) :
     76         Tuple:
     77 
     78     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
     79         Received OFDM resource grid after cyclic prefix removal and FFT
     80 
     81     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
     82         Channel estimates for all streams from all transmitters
     83 
     84     err_var : [Broadcastable to shape of ``h_hat``], tf.float
     85         Variance of the channel estimation error
     86 
     87     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
     88         Variance of the AWGN
     89 
     90     Output
     91     ------
     92     One of:
     93 
     94     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
     95         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`
     96 
     97     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
     98         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
     99         Hard-decisions correspond to the symbol indices.
    100     """
    101     def __init__(self,
    102                  detector,
    103                  output,
    104                  resource_grid,
    105                  stream_management,
    106                  dtype=tf.complex64,
    107                  **kwargs):
    108         super().__init__(dtype=dtype, **kwargs)
    109         self._detector = detector
    110         self._resource_grid = resource_grid
    111         self._stream_management = stream_management
    112         self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid)
    113         self._output = output
    114 
    115         # Precompute indices to extract data symbols
    116         mask = resource_grid.pilot_pattern.mask
    117         num_data_symbols = resource_grid.pilot_pattern.num_data_symbols
    118         data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING")
    119         self._data_ind = data_ind[...,:num_data_symbols]
    120 
    121     def _preprocess_inputs(self, y, h_hat, err_var, no):
    122         """Pro-process the received signal and compute the
    123         noise-plus-interference covariance matrix"""
    124 
    125         # Remove nulled subcarriers from y (guards, dc). New shape:
    126         # [batch_size, num_rx, num_rx_ant, ...
    127         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    128         y_eff = self._removed_nulled_scs(y)
    129 
    130         ####################################################
    131         ### Prepare the observation y for MIMO detection ###
    132         ####################################################
    133         # Transpose y_eff to put num_rx_ant last. New shape:
    134         # [batch_size, num_rx, num_ofdm_symbols,...
    135         #  ..., num_effective_subcarriers, num_rx_ant]
    136         y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2])
    137         y_dt = tf.cast(y_dt, self._dtype)
    138 
    139         # Transpose y_eff to put num_rx_ant last. New shape:
    140         # [batch_size, num_rx, num_ofdm_symbols,...
    141         #  ..., num_effective_subcarriers, num_rx_ant]
    142         y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2])
    143         y_dt = tf.cast(y_dt, self._dtype)
    144 
    145         ##############################################
    146         ### Prepare the err_var for MIMO detection ###
    147         ##############################################
    148         # New shape is:
    149         # [batch_size, num_rx, num_ofdm_symbols,...
    150         #  ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams]
    151         err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat))
    152         err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4])
    153         err_var_dt = flatten_last_dims(err_var_dt, 2)
    154         err_var_dt = tf.cast(err_var_dt, self._dtype)
    155 
    156         ###############################
    157         ### Construct MIMO channels ###
    158         ###############################
    159 
    160         # Reshape h_hat for the construction of desired/interfering channels:
    161         # [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,...
    162         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    163         perm = [1, 3, 4, 0, 2, 5, 6]
    164         h_dt = tf.transpose(h_hat, perm)
    165 
    166         # Flatten first tthree dimensions:
    167         # [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ...
    168         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    169         h_dt = flatten_dims(h_dt, 3, 0)
    170 
    171         # Gather desired and undesired channels
    172         ind_desired = self._stream_management.detection_desired_ind
    173         ind_undesired = self._stream_management.detection_undesired_ind
    174         h_dt_desired = tf.gather(h_dt, ind_desired, axis=0)
    175         h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0)
    176 
    177         # Split first dimension to separate RX and TX:
    178         # [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ...
    179         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    180         h_dt_desired = split_dim(h_dt_desired,
    181                                  [self._stream_management.num_rx,
    182                                   self._stream_management.num_streams_per_rx],
    183                                  0)
    184         h_dt_undesired = split_dim(h_dt_undesired,
    185                                    [self._stream_management.num_rx, -1], 0)
    186 
    187         # Permutate dims to
    188         # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,..
    189         #  ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)]
    190         perm = [2, 0, 4, 5, 3, 1]
    191         h_dt_desired = tf.transpose(h_dt_desired, perm)
    192         h_dt_desired = tf.cast(h_dt_desired, self._dtype)
    193         h_dt_undesired = tf.transpose(h_dt_undesired, perm)
    194 
    195         ##################################
    196         ### Prepare the noise variance ###
    197         ##################################
    198         # no is first broadcast to [batch_size, num_rx, num_rx_ant]
    199         # then the rank is expanded to that of y
    200         # then it is transposed like y to the final shape
    201         # [batch_size, num_rx, num_ofdm_symbols,...
    202         #  ..., num_effective_subcarriers, num_rx_ant]
    203         no_dt = expand_to_rank(no, 3, -1)
    204         no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3])
    205         no_dt = expand_to_rank(no_dt, tf.rank(y), -1)
    206         no_dt = tf.transpose(no_dt, [0,1,3,4,2])
    207         no_dt = tf.cast(no_dt, self._dtype)
    208 
    209         ##################################################
    210         ### Compute the interference covariance matrix ###
    211         ##################################################
    212         # Covariance of undesired transmitters
    213         s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True)
    214 
    215         #Thermal noise
    216         s_no = tf.linalg.diag(no_dt)
    217 
    218         # Channel estimation errors
    219         # As we have only error variance information for each element,
    220         # we simply sum them across transmitters and build a
    221         # diagonal covariance matrix from this
    222         s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1))
    223 
    224         # Final covariance matrix
    225         s = s_inf + s_no + s_csi
    226         s = tf.cast(s, self._dtype)
    227 
    228         return y_dt, h_dt_desired, s
    229 
    230     def _extract_datasymbols(self, z):
    231         """Extract data symbols for all detected TX"""
    232 
    233         # If output is symbols with hard decision, the rank is 5 and not 6 as
    234         # for other cases. The tensor rank is therefore expanded with one extra
    235         # dimension, which is removed later.
    236         rank_extanded = len(z.shape) < 6
    237         z = expand_to_rank(z, 6, -1)
    238 
    239         # Transpose tensor to shape
    240         # [num_rx, num_streams_per_rx, num_ofdm_symbols,
    241         #    num_effective_subcarriers, num_bits_per_symbol or num_points,
    242         #       batch_size]
    243         z = tf.transpose(z, [1, 4, 2, 3, 5, 0])
    244 
    245         # Merge num_rx amd num_streams_per_rx
    246         # [num_rx * num_streams_per_rx, num_ofdm_symbols,
    247         #    num_effective_subcarriers, num_bits_per_symbol or num_points,
    248         #   batch_size]
    249         z = flatten_dims(z, 2, 0)
    250 
    251         # Put first dimension into the right ordering
    252         stream_ind = self._stream_management.stream_ind
    253         z = tf.gather(z, stream_ind, axis=0)
    254 
    255         # Reshape first dimensions to [num_tx, num_streams] so that
    256         # we can compare to the way the streams were created.
    257         # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,
    258         #     num_bits_per_symbol or num_points, batch_size]
    259         num_streams = self._stream_management.num_streams_per_tx
    260         num_tx = self._stream_management.num_tx
    261         z = split_dim(z, [num_tx, num_streams], 0)
    262 
    263         # Flatten resource grid dimensions
    264         # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarrier,
    265         #    num_bits_per_symbol or num_points, batch_size]
    266         z = flatten_dims(z, 2, 2)
    267 
    268         # Gather data symbols
    269         # [num_tx, num_streams, num_data_symbols,
    270         #    num_bits_per_symbol or num_points, batch_size]
    271         z = tf.gather(z, self._data_ind, batch_dims=2, axis=2)
    272 
    273         # Put batch_dim first
    274         # [batch_size, num_tx, num_streams,
    275         #     num_data_symbols, num_bits_per_symbol or num_points]
    276         z = tf.transpose(z, [4, 0, 1, 2, 3])
    277 
    278         # Reshape LLRs to
    279         # [batch_size, num_tx, num_streams,
    280         #     n = num_data_symbols*num_bits_per_symbol]
    281         # if output is LLRs on bits
    282         if self._output == 'bit':
    283             z = flatten_dims(z, 2, 3)
    284         # Remove dummy dimension if output is symbols with hard decision
    285         if rank_extanded:
    286             z = tf.squeeze(z, axis=-1)
    287 
    288         return z
    289 
    290     def call(self, inputs):
    291         y, h_hat, err_var, no = inputs
    292         # y has shape:
    293         # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size]
    294 
    295         # h_hat has shape:
    296         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,...
    297         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    298 
    299         # err_var has a shape that is broadcastable to h_hat
    300 
    301         # no has shape [batch_size, num_rx, num_rx_ant]
    302         # or just the first n dimensions of this
    303 
    304         ################################
    305         ### Pre-process the inputs
    306         ################################
    307         y_dt, h_dt_desired, s = self._preprocess_inputs(y, h_hat, err_var, no)
    308 
    309         #################################
    310         ### Detection
    311         #################################
    312         z = self._detector([y_dt, h_dt_desired, s])
    313 
    314         ##############################################
    315         ### Extract data symbols for all detected TX
    316         ##############################################
    317         z = self._extract_datasymbols(z)
    318 
    319         return z
    320 
    321 
    322 class OFDMDetectorWithPrior(OFDMDetector):
    323     # pylint: disable=line-too-long
    324     r"""OFDMDetectorWithPrior(detector, output, resource_grid, stream_management, constellation_type, num_bits_per_symbol, constellation, dtype=tf.complex64, **kwargs)
    325 
    326     Layer that wraps a MIMO detector that assumes prior knowledge of the bits or
    327     constellation points is available, for use with the OFDM waveform.
    328 
    329     The parameter ``detector`` is a callable (e.g., a function) that
    330     implements a MIMO detection algorithm with prior for arbitrary batch
    331     dimensions.
    332 
    333     This class pre-processes the received resource grid ``y``, channel
    334     estimate ``h_hat``, and the prior information ``prior``, and computes for each receiver the
    335     noise-plus-interference covariance matrix according to the OFDM and stream
    336     configuration provided by the ``resource_grid`` and
    337     ``stream_management``, which also accounts for the channel
    338     estimation error variance ``err_var``. These quantities serve as input to the detection
    339     algorithm that is implemented by ``detector``.
    340     Both detection of symbols or bits with either soft- or hard-decisions are supported.
    341 
    342     Note
    343     -----
    344     The callable ``detector`` must take as input a tuple :math:`(\mathbf{y}, \mathbf{h}, \mathbf{prior}, \mathbf{s})` such that:
    345 
    346     * **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals.
    347     * **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices.
    348     * **prior** ([...,num_streams_per_rx,num_bits_per_symbol] or [...,num_streams_per_rx,num_points], tf.float) -- Prior for the transmitted signals. If ``output`` equals "bit", then LLRs for the transmitted bits are expected. If ``output`` equals "symbol", then logits for the transmitted constellation points are expected.
    349     * **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices.
    350 
    351     It must generate one of the following outputs depending on the value of ``output``:
    352 
    353     * **b_hat** ([..., num_streams_per_rx, num_bits_per_symbol], tf.float) -- LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    354     * **x_hat** ([..., num_streams_per_rx, num_points], tf.float) or ([..., num_streams_per_rx], tf.int) -- Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. Hard-decisions correspond to the symbol indices.
    355 
    356     Parameters
    357     ----------
    358     detector : Callable
    359         Callable object (e.g., a function) that implements a MIMO detection
    360         algorithm with prior for arbitrary batch dimensions. Either the existing detector
    361         :class:`~sionna.mimo.MaximumLikelihoodDetectorWithPrior` can be used, or a custom detector
    362         callable provided that has the same input/output specification.
    363 
    364     output : One of ["bit", "symbol"], str
    365         Type of output, either bits or symbols
    366 
    367     resource_grid : ResourceGrid
    368         Instance of :class:`~sionna.ofdm.ResourceGrid`
    369 
    370     stream_management : StreamManagement
    371         Instance of :class:`~sionna.mimo.StreamManagement`
    372 
    373     constellation_type : One of ["qam", "pam", "custom"], str
    374         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    375         must be provided.
    376 
    377     num_bits_per_symbol : int
    378         Number of bits per constellation symbol, e.g., 4 for QAM16.
    379         Only required for ``constellation_type`` in ["qam", "pam"].
    380 
    381     constellation : Constellation
    382         Instance of :class:`~sionna.mapping.Constellation` or `None`.
    383         In the latter case, ``constellation_type``
    384         and ``num_bits_per_symbol`` must be provided.
    385 
    386     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    387         The dtype of `y`. Defaults to tf.complex64.
    388         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    389 
    390     Input
    391     ------
    392     (y, h_hat, prior, err_var, no) :
    393         Tuple:
    394 
    395     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    396         Received OFDM resource grid after cyclic prefix removal and FFT
    397 
    398     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    399         Channel estimates for all streams from all transmitters
    400 
    401     prior : [batch_size, num_tx, num_streams, num_data_symbols x num_bits_per_symbol] or [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float
    402         Prior of the transmitted signals.
    403         If ``output`` equals "bit", LLRs of the transmitted bits are expected.
    404         If ``output`` equals "symbol", logits of the transmitted constellation points are expected.
    405 
    406     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    407         Variance of the channel estimation error
    408 
    409     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    410         Variance of the AWGN
    411 
    412     Output
    413     ------
    414     One of:
    415 
    416     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
    417         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    418 
    419     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
    420         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    421         Hard-decisions correspond to the symbol indices.
    422     """
    423     def __init__(self,
    424                  detector,
    425                  output,
    426                  resource_grid,
    427                  stream_management,
    428                  constellation_type=None,
    429                  num_bits_per_symbol=None,
    430                  constellation=None,
    431                  dtype=tf.complex64,
    432                  **kwargs):
    433         super().__init__(detector=detector,
    434                          output=output,
    435                          resource_grid=resource_grid,
    436                          stream_management=stream_management,
    437                          dtype=dtype,
    438                          **kwargs)
    439 
    440         # Constellation object
    441         self._constellation = Constellation.create_or_check_constellation(
    442                                                         constellation_type,
    443                                                         num_bits_per_symbol,
    444                                                         constellation,
    445                                                         dtype=dtype)
    446 
    447         # Precompute indices to map priors to a resource grid
    448         rg_type = resource_grid.build_type_grid()
    449         # The nulled subcarriers (nulled DC and guard carriers) are removed to
    450         # get the correct indices of data-carrying resource elements.
    451         remove_nulled_sc = RemoveNulledSubcarriers(resource_grid)
    452         self._data_ind_scatter = tf.where(remove_nulled_sc(rg_type)==0)
    453 
    454     # Overwrite the call() method of baseclass `BaseDetector`
    455     def call(self, inputs):
    456         y, h_hat, prior, err_var, no = inputs
    457         # y has shape:
    458         # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size]
    459 
    460         # h_hat has shape:
    461         # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,...
    462         #  ..., num_ofdm_symbols, num_effective_subcarriers]
    463 
    464         # prior has shape
    465         # [batch_size, num_tx, num_streams,...
    466         #   ... num_data_symbols x num_bits_per_symbol]
    467         # or [batch_size, num_tx, num_streams, num_data_symbols, num_points]
    468 
    469         # err_var has a shape that is broadcastable to h_hat
    470 
    471         # no has shape [batch_size, num_rx, num_rx_ant]
    472         # or just the first n dimensions of this
    473 
    474         ################################
    475         ### Pre-process the inputs
    476         ################################
    477         y_dt, h_dt_desired, s = self._preprocess_inputs(y, h_hat, err_var, no)
    478 
    479         #########################
    480         ### Prepare the prior ###
    481         #########################
    482         # [batch_size, num_tx, num_streams_per_tx, num_data_symbols,
    483         #   ... num_bits_per_symbol/num_points]
    484         if self._output == 'bit':
    485             prior = split_dim(  prior,
    486                                 [   self._resource_grid.num_data_symbols,
    487                                     self._constellation.num_bits_per_symbol],
    488                                 3)
    489         # Create a zero template for the prior
    490         # [num_tx, num_streams_per_tx, num_ofdm_symbols,...
    491         #   ... num_effective_subcarriers, num_bits_per_symbol/num_points,
    492         #   ... batch_size]
    493         template = tf.zeros([   self._resource_grid.num_tx,
    494                                 self._resource_grid.num_streams_per_tx,
    495                                 self._resource_grid.num_ofdm_symbols,
    496                                 self._resource_grid.num_effective_subcarriers,
    497                                 tf.shape(prior)[-1],
    498                                 tf.shape(prior)[0]],
    499                             tf.as_dtype(self._dtype).real_dtype)
    500         # [num_tx, num_streams_per_tx, num_data_symbols,
    501         #   ... num_bits_per_symbol/num_points, batch_size]
    502         prior = tf.transpose(prior, [1, 2, 3, 4, 0])
    503         # [num_tx, num_streams_per_tx, num_ofdm_symbols,...
    504         #   ... num_effective_subcarriers, num_bits_per_symbol/num_points,...
    505         #   ... batch_size]
    506         prior = flatten_dims(prior, 3, 0)
    507         prior = tf.tensor_scatter_nd_update(template, self._data_ind_scatter,
    508                                                 prior)
    509         # [batch_size, num_ofdm_symbols, num_effective_subcarriers,...
    510         #  num_tx*num_streams_per_tx, num_bits_per_symbol/num_points]
    511         prior = tf.transpose(prior, [5, 2, 3, 0, 1, 4])
    512         prior = flatten_dims(prior, 2, 3)
    513         # Add the receive antenna dimension for broadcasting
    514         # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,...
    515         #  num_tx*num_streams_per_tx, num_bits_per_symbol/num_points]
    516         prior = tf.tile(tf.expand_dims(prior, axis=1),
    517                         [1, tf.shape(y)[1], 1, 1, 1, 1])
    518 
    519         #################################
    520         ### Maximum-likelihood detection
    521         #################################
    522         z = self._detector([y_dt, h_dt_desired, prior, s])
    523 
    524         ##############################################
    525         ### Extract data symbols for all detected TX
    526         ##############################################
    527         z = self._extract_datasymbols(z)
    528 
    529         return z
    530 
    531 
    532 class MaximumLikelihoodDetector(OFDMDetector):
    533     # pylint: disable=line-too-long
    534     r"""MaximumLikelihoodDetector(output, demapping_method, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
    535 
    536     Maximum-likelihood (ML) detection for OFDM MIMO transmissions.
    537 
    538     This layer implements maximum-likelihood (ML) detection
    539     for OFDM MIMO transmissions. Both ML detection of symbols or bits with either
    540     soft- or hard-decisions are supported. The OFDM and stream configuration are provided
    541     by a :class:`~sionna.ofdm.ResourceGrid` and
    542     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    543     actual detector is an instance of :class:`~sionna.mimo.MaximumLikelihoodDetector`.
    544 
    545     Parameters
    546     ----------
    547     output : One of ["bit", "symbol"], str
    548         Type of output, either bits or symbols. Whether soft- or
    549         hard-decisions are returned can be configured with the
    550         ``hard_out`` flag.
    551 
    552     demapping_method : One of ["app", "maxlog"], str
    553         Demapping method used
    554 
    555     resource_grid : ResourceGrid
    556         Instance of :class:`~sionna.ofdm.ResourceGrid`
    557 
    558     stream_management : StreamManagement
    559         Instance of :class:`~sionna.mimo.StreamManagement`
    560 
    561     constellation_type : One of ["qam", "pam", "custom"], str
    562         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    563         must be provided.
    564 
    565     num_bits_per_symbol : int
    566         Number of bits per constellation symbol, e.g., 4 for QAM16.
    567         Only required for ``constellation_type`` in ["qam", "pam"].
    568 
    569     constellation : Constellation
    570         Instance of :class:`~sionna.mapping.Constellation` or `None`.
    571         In the latter case, ``constellation_type``
    572         and ``num_bits_per_symbol`` must be provided.
    573 
    574     hard_out : bool
    575         If `True`, the detector computes hard-decided bit values or
    576         constellation point indices instead of soft-values.
    577         Defaults to `False`.
    578 
    579     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    580         The dtype of `y`. Defaults to tf.complex64.
    581         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    582 
    583     Input
    584     ------
    585     (y, h_hat, err_var, no) :
    586         Tuple:
    587 
    588     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    589         Received OFDM resource grid after cyclic prefix removal and FFT
    590 
    591     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    592         Channel estimates for all streams from all transmitters
    593 
    594     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    595         Variance of the channel estimation error
    596 
    597     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    598         Variance of the AWGN noise
    599 
    600     Output
    601     ------
    602     One of:
    603 
    604     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
    605         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    606 
    607     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
    608         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    609         Hard-decisions correspond to the symbol indices.
    610 
    611     Note
    612     ----
    613     If you want to use this layer in Graph mode with XLA, i.e., within
    614     a function that is decorated with ``@tf.function(jit_compile=True)``,
    615     you must set ``sionna.Config.xla_compat=true``.
    616     See :py:attr:`~sionna.Config.xla_compat`.
    617     """
    618 
    619     def __init__(self,
    620                  output,
    621                  demapping_method,
    622                  resource_grid,
    623                  stream_management,
    624                  constellation_type=None,
    625                  num_bits_per_symbol=None,
    626                  constellation=None,
    627                  hard_out=False,
    628                  dtype=tf.complex64,
    629                  **kwargs):
    630 
    631         # Instantiate the maximum-likelihood detector
    632         detector = MaximumLikelihoodDetector_(output=output,
    633                             demapping_method=demapping_method,
    634                             num_streams = stream_management.num_streams_per_rx,
    635                             constellation_type=constellation_type,
    636                             num_bits_per_symbol=num_bits_per_symbol,
    637                             constellation=constellation,
    638                             hard_out=hard_out,
    639                             dtype=dtype,
    640                             **kwargs)
    641 
    642         super().__init__(detector=detector,
    643                          output=output,
    644                          resource_grid=resource_grid,
    645                          stream_management=stream_management,
    646                          dtype=dtype,
    647                          **kwargs)
    648 
    649 
    650 class MaximumLikelihoodDetectorWithPrior(OFDMDetectorWithPrior):
    651     # pylint: disable=line-too-long
    652     r"""MaximumLikelihoodDetectorWithPrior(output, demapping_method, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
    653 
    654     Maximum-likelihood (ML) detection for OFDM MIMO transmissions, assuming prior
    655     knowledge of the bits or constellation points is available.
    656 
    657     This layer implements maximum-likelihood (ML) detection
    658     for OFDM MIMO transmissions assuming prior knowledge on the transmitted data is available.
    659     Both ML detection of symbols or bits with either
    660     soft- or hard-decisions are supported. The OFDM and stream configuration are provided
    661     by a :class:`~sionna.ofdm.ResourceGrid` and
    662     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    663     actual detector is an instance of :class:`~sionna.mimo.MaximumLikelihoodDetectorWithPrior`.
    664 
    665     Parameters
    666     ----------
    667     output : One of ["bit", "symbol"], str
    668         Type of output, either bits or symbols. Whether soft- or
    669         hard-decisions are returned can be configured with the
    670         ``hard_out`` flag.
    671 
    672     demapping_method : One of ["app", "maxlog"], str
    673         Demapping method used
    674 
    675     resource_grid : ResourceGrid
    676         Instance of :class:`~sionna.ofdm.ResourceGrid`
    677 
    678     stream_management : StreamManagement
    679         Instance of :class:`~sionna.mimo.StreamManagement`
    680 
    681     constellation_type : One of ["qam", "pam", "custom"], str
    682         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    683         must be provided.
    684 
    685     num_bits_per_symbol : int
    686         Number of bits per constellation symbol, e.g., 4 for QAM16.
    687         Only required for ``constellation_type`` in ["qam", "pam"].
    688 
    689     constellation : Constellation
    690         Instance of :class:`~sionna.mapping.Constellation` or `None`.
    691         In the latter case, ``constellation_type``
    692         and ``num_bits_per_symbol`` must be provided.
    693 
    694     hard_out : bool
    695         If `True`, the detector computes hard-decided bit values or
    696         constellation point indices instead of soft-values.
    697         Defaults to `False`.
    698 
    699     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    700         The dtype of `y`. Defaults to tf.complex64.
    701         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    702 
    703     Input
    704     ------
    705     (y, h_hat, prior, err_var, no) :
    706         Tuple:
    707 
    708     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    709         Received OFDM resource grid after cyclic prefix removal and FFT
    710 
    711     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    712         Channel estimates for all streams from all transmitters
    713 
    714     prior : [batch_size, num_tx, num_streams, num_data_symbols x num_bits_per_symbol] or [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float
    715         Prior of the transmitted signals.
    716         If ``output`` equals "bit", LLRs of the transmitted bits are expected.
    717         If ``output`` equals "symbol", logits of the transmitted constellation points are expected.
    718 
    719     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    720         Variance of the channel estimation error
    721 
    722     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    723         Variance of the AWGN noise
    724 
    725     Output
    726     ------
    727     One of:
    728 
    729     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
    730         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    731 
    732     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
    733         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    734         Hard-decisions correspond to the symbol indices.
    735 
    736     Note
    737     ----
    738     If you want to use this layer in Graph mode with XLA, i.e., within
    739     a function that is decorated with ``@tf.function(jit_compile=True)``,
    740     you must set ``sionna.Config.xla_compat=true``.
    741     See :py:attr:`~sionna.Config.xla_compat`.
    742     """
    743 
    744     def __init__(self,
    745                  output,
    746                  demapping_method,
    747                  resource_grid,
    748                  stream_management,
    749                  constellation_type=None,
    750                  num_bits_per_symbol=None,
    751                  constellation=None,
    752                  hard_out=False,
    753                  dtype=tf.complex64,
    754                  **kwargs):
    755 
    756         # Instantiate the maximum-likelihood detector
    757         detector = MaximumLikelihoodDetectorWithPrior_(output=output,
    758                             demapping_method=demapping_method,
    759                             num_streams = stream_management.num_streams_per_rx,
    760                             constellation_type=constellation_type,
    761                             num_bits_per_symbol=num_bits_per_symbol,
    762                             constellation=constellation,
    763                             hard_out=hard_out,
    764                             dtype=dtype,
    765                             **kwargs)
    766 
    767         super().__init__(detector=detector,
    768                          output=output,
    769                          resource_grid=resource_grid,
    770                          stream_management=stream_management,
    771                          constellation_type=constellation_type,
    772                          num_bits_per_symbol=num_bits_per_symbol,
    773                          constellation=constellation,
    774                          dtype=dtype,
    775                          **kwargs)
    776 
    777 
    778 class LinearDetector(OFDMDetector):
    779     # pylint: disable=line-too-long
    780     r"""LinearDetector(equalizer, output, demapping_method, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
    781 
    782     This layer wraps a MIMO linear equalizer and a :class:`~sionna.mapping.Demapper`
    783     for use with the OFDM waveform.
    784 
    785     Both detection of symbols or bits with either
    786     soft- or hard-decisions are supported. The OFDM and stream configuration are provided
    787     by a :class:`~sionna.ofdm.ResourceGrid` and
    788     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    789     actual detector is an instance of :class:`~sionna.mimo.LinearDetector`.
    790 
    791     Parameters
    792     ----------
    793     equalizer : str, one of ["lmmse", "zf", "mf"], or an equalizer function
    794         Equalizer to be used. Either one of the existing equalizers, e.g.,
    795         :func:`~sionna.mimo.lmmse_equalizer`, :func:`~sionna.mimo.zf_equalizer`, or
    796         :func:`~sionna.mimo.mf_equalizer` can be used, or a custom equalizer
    797         function provided that has the same input/output specification.
    798 
    799     output : One of ["bit", "symbol"], str
    800         Type of output, either bits or symbols. Whether soft- or
    801         hard-decisions are returned can be configured with the
    802         ``hard_out`` flag.
    803 
    804     demapping_method : One of ["app", "maxlog"], str
    805         Demapping method used
    806 
    807     resource_grid : ResourceGrid
    808         Instance of :class:`~sionna.ofdm.ResourceGrid`
    809 
    810     stream_management : StreamManagement
    811         Instance of :class:`~sionna.mimo.StreamManagement`
    812 
    813     constellation_type : One of ["qam", "pam", "custom"], str
    814         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    815         must be provided.
    816 
    817     num_bits_per_symbol : int
    818         Number of bits per constellation symbol, e.g., 4 for QAM16.
    819         Only required for ``constellation_type`` in ["qam", "pam"].
    820 
    821     constellation : Constellation
    822         Instance of :class:`~sionna.mapping.Constellation` or `None`.
    823         In the latter case, ``constellation_type``
    824         and ``num_bits_per_symbol`` must be provided.
    825 
    826     hard_out : bool
    827         If `True`, the detector computes hard-decided bit values or
    828         constellation point indices instead of soft-values.
    829         Defaults to `False`.
    830 
    831     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    832         The dtype of `y`. Defaults to tf.complex64.
    833         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    834 
    835     Input
    836     ------
    837     (y, h_hat, err_var, no) :
    838         Tuple:
    839 
    840     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    841         Received OFDM resource grid after cyclic prefix removal and FFT
    842 
    843     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    844         Channel estimates for all streams from all transmitters
    845 
    846     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    847         Variance of the channel estimation error
    848 
    849     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    850         Variance of the AWGN
    851 
    852     Output
    853     ------
    854     One of:
    855 
    856     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
    857         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    858 
    859     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
    860         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    861         Hard-decisions correspond to the symbol indices.
    862 
    863     Note
    864     ----
    865     If you want to use this layer in Graph mode with XLA, i.e., within
    866     a function that is decorated with ``@tf.function(jit_compile=True)``,
    867     you must set ``sionna.Config.xla_compat=true``.
    868     See :py:attr:`~sionna.Config.xla_compat`.
    869     """
    870 
    871     def __init__(self,
    872                  equalizer,
    873                  output,
    874                  demapping_method,
    875                  resource_grid,
    876                  stream_management,
    877                  constellation_type=None,
    878                  num_bits_per_symbol=None,
    879                  constellation=None,
    880                  hard_out=False,
    881                  dtype=tf.complex64,
    882                  **kwargs):
    883 
    884         # Instantiate the linear detector
    885         detector = LinearDetector_(equalizer=equalizer,
    886                                    output=output,
    887                                    demapping_method=demapping_method,
    888                                    constellation_type=constellation_type,
    889                                    num_bits_per_symbol=num_bits_per_symbol,
    890                                    constellation=constellation,
    891                                    hard_out=hard_out,
    892                                    dtype=dtype,
    893                                    **kwargs)
    894 
    895         super().__init__(detector=detector,
    896                          output=output,
    897                          resource_grid=resource_grid,
    898                          stream_management=stream_management,
    899                          dtype=dtype,
    900                          **kwargs)
    901 
    902 
    903 class KBestDetector(OFDMDetector):
    904     # pylint: disable=line-too-long
    905     r"""KBestDetector(output, num_streams, k, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, use_real_rep=False, list2llr=None, dtype=tf.complex64, **kwargs)
    906 
    907     This layer wraps the MIMO K-Best detector for use with the OFDM waveform.
    908 
    909     Both detection of symbols or bits with either
    910     soft- or hard-decisions are supported. The OFDM and stream configuration are provided
    911     by a :class:`~sionna.ofdm.ResourceGrid` and
    912     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
    913     actual detector is an instance of :class:`~sionna.mimo.KBestDetector`.
    914 
    915     Parameters
    916     ----------
    917     output : One of ["bit", "symbol"], str
    918         Type of output, either bits or symbols. Whether soft- or
    919         hard-decisions are returned can be configured with the
    920         ``hard_out`` flag.
    921 
    922     num_streams : tf.int
    923         Number of transmitted streams
    924 
    925     k : tf.int
    926         Number of paths to keep. Cannot be larger than the
    927         number of constellation points to the power of the number of
    928         streams.
    929 
    930     resource_grid : ResourceGrid
    931         Instance of :class:`~sionna.ofdm.ResourceGrid`
    932 
    933     stream_management : StreamManagement
    934         Instance of :class:`~sionna.mimo.StreamManagement`
    935 
    936     constellation_type : One of ["qam", "pam", "custom"], str
    937         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    938         must be provided.
    939 
    940     num_bits_per_symbol : int
    941         Number of bits per constellation symbol, e.g., 4 for QAM16.
    942         Only required for ``constellation_type`` in ["qam", "pam"].
    943 
    944     constellation : Constellation
    945         Instance of :class:`~sionna.mapping.Constellation` or `None`.
    946         In the latter case, ``constellation_type``
    947         and ``num_bits_per_symbol`` must be provided.
    948 
    949     hard_out : bool
    950         If `True`, the detector computes hard-decided bit values or
    951         constellation point indices instead of soft-values.
    952         Defaults to `False`.
    953 
    954     use_real_rep : bool
    955         If `True`, the detector use the real-valued equivalent representation
    956         of the channel. Note that this only works with a QAM constellation.
    957         Defaults to `False`.
    958 
    959     list2llr: `None` or instance of :class:`~sionna.mimo.List2LLR`
    960         The function to be used to compute LLRs from a list of candidate solutions.
    961         If `None`, the default solution :class:`~sionna.mimo.List2LLRSimple`
    962         is used.
    963 
    964     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    965         The dtype of `y`. Defaults to tf.complex64.
    966         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    967 
    968     Input
    969     ------
    970     (y, h_hat, err_var, no) :
    971         Tuple:
    972 
    973     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
    974         Received OFDM resource grid after cyclic prefix removal and FFT
    975 
    976     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
    977         Channel estimates for all streams from all transmitters
    978 
    979     err_var : [Broadcastable to shape of ``h_hat``], tf.float
    980         Variance of the channel estimation error
    981 
    982     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
    983         Variance of the AWGN
    984 
    985     Output
    986     ------
    987     One of:
    988 
    989     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
    990         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
    991 
    992     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
    993         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
    994         Hard-decisions correspond to the symbol indices.
    995 
    996     Note
    997     ----
    998     If you want to use this layer in Graph mode with XLA, i.e., within
    999     a function that is decorated with ``@tf.function(jit_compile=True)``,
   1000     you must set ``sionna.Config.xla_compat=true``.
   1001     See :py:attr:`~sionna.Config.xla_compat`.
   1002     """
   1003 
   1004     def __init__(self,
   1005                  output,
   1006                  num_streams,
   1007                  k,
   1008                  resource_grid,
   1009                  stream_management,
   1010                  constellation_type=None,
   1011                  num_bits_per_symbol=None,
   1012                  constellation=None,
   1013                  hard_out=False,
   1014                  use_real_rep=False,
   1015                  list2llr="default",
   1016                  dtype=tf.complex64,
   1017                  **kwargs):
   1018 
   1019         # Instantiate the K-Best detector
   1020         detector = KBestDetector_(output=output,
   1021                                   num_streams=num_streams,
   1022                                   k=k,
   1023                                   constellation_type=constellation_type,
   1024                                   num_bits_per_symbol=num_bits_per_symbol,
   1025                                   constellation=constellation,
   1026                                   hard_out=hard_out,
   1027                                   use_real_rep=use_real_rep,
   1028                                   list2llr=list2llr,
   1029                                   dtype=dtype,
   1030                                   **kwargs)
   1031 
   1032         super().__init__(detector=detector,
   1033                          output=output,
   1034                          resource_grid=resource_grid,
   1035                          stream_management=stream_management,
   1036                          dtype=dtype,
   1037                          **kwargs)
   1038 
   1039 
   1040 class EPDetector(OFDMDetector):
   1041     # pylint: disable=line-too-long
   1042     r"""EPDetector(output, resource_grid, stream_management, num_bits_per_symbol, hard_out=False, l=10, beta=0.9, dtype=tf.complex64, **kwargs)
   1043 
   1044     This layer wraps the MIMO EP detector for use with the OFDM waveform.
   1045 
   1046     Both detection of symbols or bits with either
   1047     soft- or hard-decisions are supported. The OFDM and stream configuration are provided
   1048     by a :class:`~sionna.ofdm.ResourceGrid` and
   1049     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
   1050     actual detector is an instance of :class:`~sionna.mimo.EPDetector`.
   1051 
   1052     Parameters
   1053     ----------
   1054     output : One of ["bit", "symbol"], str
   1055         Type of output, either bits or symbols. Whether soft- or
   1056         hard-decisions are returned can be configured with the
   1057         ``hard_out`` flag.
   1058 
   1059     resource_grid : ResourceGrid
   1060         Instance of :class:`~sionna.ofdm.ResourceGrid`
   1061 
   1062     stream_management : StreamManagement
   1063         Instance of :class:`~sionna.mimo.StreamManagement`
   1064 
   1065     num_bits_per_symbol : int
   1066         Number of bits per constellation symbol, e.g., 4 for QAM16.
   1067         Only required for ``constellation_type`` in ["qam", "pam"].
   1068 
   1069     hard_out : bool
   1070         If `True`, the detector computes hard-decided bit values or
   1071         constellation point indices instead of soft-values.
   1072         Defaults to `False`.
   1073 
   1074     l : int
   1075         Number of iterations. Defaults to 10.
   1076 
   1077     beta : float
   1078         Parameter :math:`\beta\in[0,1]` for update smoothing.
   1079         Defaults to 0.9.
   1080 
   1081     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1082         Precision used for internal computations. Defaults to ``tf.complex64``.
   1083         Especially for large MIMO setups, the precision can make a significant
   1084         performance difference.
   1085 
   1086     Input
   1087     ------
   1088     (y, h_hat, err_var, no) :
   1089         Tuple:
   1090 
   1091     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
   1092         Received OFDM resource grid after cyclic prefix removal and FFT
   1093 
   1094     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
   1095         Channel estimates for all streams from all transmitters
   1096 
   1097     err_var : [Broadcastable to shape of ``h_hat``], tf.float
   1098         Variance of the channel estimation error
   1099 
   1100     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
   1101         Variance of the AWGN
   1102 
   1103     Output
   1104     ------
   1105     One of:
   1106 
   1107     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
   1108         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
   1109 
   1110     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
   1111         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
   1112         Hard-decisions correspond to the symbol indices.
   1113 
   1114     Note
   1115     ----
   1116     For numerical stability, we do not recommend to use this function in Graph
   1117     mode with XLA, i.e., within a function that is decorated with
   1118     ``@tf.function(jit_compile=True)``.
   1119     However, it is possible to do so by setting
   1120     ``sionna.Config.xla_compat=true``.
   1121     See :py:attr:`~sionna.Config.xla_compat`.
   1122     """
   1123     def __init__(self,
   1124                  output,
   1125                  resource_grid,
   1126                  stream_management,
   1127                  num_bits_per_symbol=None,
   1128                  hard_out=False,
   1129                  l=10,
   1130                  beta=0.9,
   1131                  dtype=tf.complex64,
   1132                  **kwargs):
   1133 
   1134         # Instantiate the EP detector
   1135         detector = EPDetector_(output=output,
   1136                                num_bits_per_symbol=num_bits_per_symbol,
   1137                                hard_out=hard_out,
   1138                                l=l,
   1139                                beta=beta,
   1140                                dtype=dtype,
   1141                                **kwargs)
   1142 
   1143         super().__init__(detector=detector,
   1144                          output=output,
   1145                          resource_grid=resource_grid,
   1146                          stream_management=stream_management,
   1147                          dtype=dtype,
   1148                          **kwargs)
   1149 
   1150 class MMSEPICDetector(OFDMDetectorWithPrior):
   1151     # pylint: disable=line-too-long
   1152     r"""MMSEPICDetector(output, resource_grid, stream_management, demapping_method="maxlog", num_iter=1, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
   1153 
   1154     This layer wraps the MIMO MMSE PIC detector for use with the OFDM waveform.
   1155 
   1156     Both detection of symbols or bits with either
   1157     soft- or hard-decisions are supported. The OFDM and stream configuration are provided
   1158     by a :class:`~sionna.ofdm.ResourceGrid` and
   1159     :class:`~sionna.mimo.StreamManagement` instance, respectively. The
   1160     actual detector is an instance of :class:`~sionna.mimo.MMSEPICDetector`.
   1161 
   1162     Parameters
   1163     ----------
   1164     output : One of ["bit", "symbol"], str
   1165         Type of output, either bits or symbols. Whether soft- or
   1166         hard-decisions are returned can be configured with the
   1167         ``hard_out`` flag.
   1168 
   1169     resource_grid : ResourceGrid
   1170         Instance of :class:`~sionna.ofdm.ResourceGrid`
   1171 
   1172     stream_management : StreamManagement
   1173         Instance of :class:`~sionna.mimo.StreamManagement`
   1174 
   1175     demapping_method : One of ["app", "maxlog"], str
   1176         The demapping method used.
   1177         Defaults to "maxlog".
   1178 
   1179     num_iter : int
   1180         Number of MMSE PIC iterations.
   1181         Defaults to 1.
   1182 
   1183     constellation_type : One of ["qam", "pam", "custom"], str
   1184         For "custom", an instance of :class:`~sionna.mapping.Constellation`
   1185         must be provided.
   1186 
   1187     num_bits_per_symbol : int
   1188         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1189         Only required for ``constellation_type`` in ["qam", "pam"].
   1190 
   1191     constellation : Constellation
   1192         An instance of :class:`~sionna.mapping.Constellation` or `None`.
   1193         In the latter case, ``constellation_type``
   1194         and ``num_bits_per_symbol`` must be provided.
   1195 
   1196     hard_out : bool
   1197         If `True`, the detector computes hard-decided bit values or
   1198         constellation point indices instead of soft-values.
   1199         Defaults to `False`.
   1200 
   1201     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1202         Precision used for internal computations. Defaults to ``tf.complex64``.
   1203         Especially for large MIMO setups, the precision can make a significant
   1204         performance difference.
   1205 
   1206     Input
   1207     ------
   1208     (y, h_hat, prior, err_var, no) :
   1209         Tuple:
   1210 
   1211     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
   1212         Received OFDM resource grid after cyclic prefix removal and FFT
   1213 
   1214     h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
   1215         Channel estimates for all streams from all transmitters
   1216 
   1217     prior : [batch_size, num_tx, num_streams, num_data_symbols x num_bits_per_symbol] or [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float
   1218         Prior of the transmitted signals.
   1219         If ``output`` equals "bit", LLRs of the transmitted bits are expected.
   1220         If ``output`` equals "symbol", logits of the transmitted constellation points are expected.
   1221 
   1222     err_var : [Broadcastable to shape of ``h_hat``], tf.float
   1223         Variance of the channel estimation error
   1224 
   1225     no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
   1226         Variance of the AWGN
   1227 
   1228     Output
   1229     ------
   1230     One of:
   1231 
   1232     : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float
   1233         LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`.
   1234 
   1235     : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int
   1236         Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`.
   1237         Hard-decisions correspond to the symbol indices.
   1238 
   1239     Note
   1240     ----
   1241     For numerical stability, we do not recommend to use this function in Graph
   1242     mode with XLA, i.e., within a function that is decorated with
   1243     ``@tf.function(jit_compile=True)``.
   1244     However, it is possible to do so by setting
   1245     ``sionna.Config.xla_compat=true``.
   1246     See :py:attr:`~sionna.Config.xla_compat`.
   1247     """
   1248     def __init__(self,
   1249                  output,
   1250                  resource_grid,
   1251                  stream_management,
   1252                  demapping_method="maxlog",
   1253                  num_iter=1,
   1254                  constellation_type=None,
   1255                  num_bits_per_symbol=None,
   1256                  constellation=None,
   1257                  hard_out=False,
   1258                  dtype=tf.complex64,
   1259                  **kwargs):
   1260 
   1261         # Instantiate the EP detector
   1262         detector = MMSEPICDetector_(output=output,
   1263                                     demapping_method=demapping_method,
   1264                                     num_iter=num_iter,
   1265                                     constellation_type=constellation_type,
   1266                                     num_bits_per_symbol=num_bits_per_symbol,
   1267                                     constellation=constellation,
   1268                                     hard_out=hard_out,
   1269                                     dtype=dtype,
   1270                                     **kwargs)
   1271 
   1272         super().__init__(detector=detector,
   1273                          output=output,
   1274                          resource_grid=resource_grid,
   1275                          stream_management=stream_management,
   1276                          constellation_type=constellation_type,
   1277                          num_bits_per_symbol=num_bits_per_symbol,
   1278                          constellation=constellation,
   1279                          dtype=dtype,
   1280                          **kwargs)