anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

pusch_receiver.py (11471B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """PUSCH Receiver for the nr (5G) sub-package of the Sionna library.
      6 """
      7 import numpy as np
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 import sionna
     11 from sionna.mimo import StreamManagement
     12 from sionna.ofdm import OFDMDemodulator, LinearDetector
     13 from sionna.utils import insert_dims
     14 from sionna.channel import time_to_ofdm_channel
     15 
     16 class PUSCHReceiver(Layer):
     17     # pylint: disable=line-too-long
     18     r"""PUSCHReceiver(pusch_transmitter, channel_estimator=None, mimo_detector=None, tb_decoder=None, return_tb_crc_status=False, stream_management=None, input_domain="freq", l_min=None, dtype=tf.complex64, **kwargs)
     19 
     20     This layer implements a full receiver for batches of 5G NR PUSCH slots sent
     21     by multiple transmitters. Inputs can be in the time or frequency domain.
     22     Perfect channel state information can be optionally provided.
     23     Different channel estimatiors, MIMO detectors, and transport decoders
     24     can be configured.
     25 
     26     The layer combines multiple processing blocks into a single layer
     27     as shown in the following figure. Blocks with dashed lines are
     28     optional and depend on the configuration.
     29 
     30     .. figure:: ../figures/pusch_receiver_block_diagram.png
     31         :scale: 30%
     32         :align: center
     33 
     34     If the ``input_domain`` equals "time", the inputs :math:`\mathbf{y}` are first
     35     transformed to resource grids with the :class:`~sionna.ofdm.OFDMDemodulator`.
     36     Then channel estimation is performed, e.g., with the help of the
     37     :class:`~sionna.nr.PUSCHLSChannelEstimator`. If ``channel_estimator``
     38     is chosen to be "perfect", this step is skipped and the input :math:`\mathbf{h}`
     39     is used instead.
     40     Next, MIMO detection is carried out with an arbitrary :class:`~sionna.ofdm.OFDMDetector`.
     41     The resulting LLRs for each layer are then combined to transport blocks
     42     with the help of the :class:`~sionna.nr.LayerDemapper`.
     43     Finally, the transport blocks are decoded with the :class:`~sionna.nr.TBDecoder`.
     44 
     45     Parameters
     46     ----------
     47     pusch_transmitter : :class:`~sionna.nr.PUSCHTransmitter`
     48         Transmitter used for the generation of the transmit signals
     49 
     50     channel_estimator : :class:`~sionna.ofdm.BaseChannelEstimator`, "perfect", or `None`
     51         Channel estimator to be used.
     52         If `None`, the :class:`~sionna.nr.PUSCHLSChannelEstimator` with
     53         linear interpolation is used.
     54         If "perfect", no channel estimation is performed and the channel state information
     55         ``h`` must be provided as additional input.
     56         Defaults to `None`.
     57 
     58     mimo_detector : :class:`~sionna.ofdm.OFDMDetector` or `None`
     59         MIMO Detector to be used.
     60         If `None`, the :class:`~sionna.ofdm.LinearDetector` with
     61         LMMSE detection is used.
     62         Defaults to `None`.
     63 
     64     tb_decoder : :class:`~sionna.nr.TBDecoder` or `None`
     65         Transport block decoder to be used.
     66         If `None`, the :class:`~sionna.nr.TBDecoder` with its
     67         default settings is used.
     68         Defaults to `None`.
     69 
     70     return_tb_crc_status : bool
     71         If `True`, the status of the transport block CRC is returned
     72         as additional output.
     73         Defaults to `False`.
     74 
     75     stream_management : :class:`~sionna.mimo.StreamManagement` or `None`
     76         Stream management configuration to be used.
     77         If `None`, it is assumed that there is a single receiver
     78         which decodes all streams of all transmitters.
     79         Defaults to `None`.
     80 
     81     input_domain : str, one of ["freq", "time"]
     82         Domain of the input signal.
     83         Defaults to "freq".
     84 
     85     l_min : int or `None`
     86         Smallest time-lag for the discrete complex baseband channel.
     87         Only needed if ``input_domain`` equals "time".
     88         Defaults to `None`.
     89 
     90     dtype : tf.Dtype
     91         Datatype for internal calculations and the output dtype.
     92         Defaults to `tf.complex64`.
     93 
     94     Input
     95     -----
     96     (y, h, no) :
     97         Tuple:
     98 
     99     y : [batch size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex or [batch size, num_rx, num_rx_ant, num_time_samples + l_max - l_min], tf.complex
    100         Frequency- or time-domain input signal
    101 
    102     h : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, num_subcarriers], tf.complex or [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_max - l_min, l_max - l_min + 1], tf.complex
    103         Perfect channel state information in either frequency or time domain
    104         (depending on ``input_domain``) to be used for detection.
    105         Only required if ``channel_estimator`` equals "perfect".
    106 
    107     no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float
    108         Variance of the AWGN
    109 
    110     Output
    111     ------
    112     b_hat : [batch_size, num_tx, tb_size], tf.float
    113         Decoded information bits
    114 
    115     tb_crc_status : [batch_size, num_tx], tf.bool
    116         Transport block CRC status
    117 
    118     Example
    119     -------
    120     >>> pusch_config = PUSCHConfig()
    121     >>> pusch_transmitter = PUSCHTransmitter(pusch_config)
    122     >>> pusch_receiver = PUSCHReceiver(pusch_transmitter)
    123     >>> channel = AWGN()
    124     >>> x, b = pusch_transmitter(16)
    125     >>> no = 0.1
    126     >>> y = channel([x, no])
    127     >>> b_hat = pusch_receiver([x, no])
    128     >>> compute_ber(b, b_hat)
    129     <tf.Tensor: shape=(), dtype=float64, numpy=0.0>
    130     """
    131     def __init__(self,
    132                  pusch_transmitter,
    133                  channel_estimator=None,
    134                  mimo_detector=None,
    135                  tb_decoder=None,
    136                  return_tb_crc_status=False,
    137                  stream_management=None,
    138                  input_domain="freq",
    139                  l_min=None,
    140                  dtype=tf.complex64,
    141                  **kwargs):
    142         assert dtype in [tf.complex64, tf.complex128], \
    143             "dtype must be tf.complex64 or tf.complex128"
    144         super().__init__(dtype=dtype, **kwargs)
    145 
    146         assert input_domain in ["time", "freq"], \
    147             "input_domain must be 'time' or 'freq'"
    148         self._input_domain = input_domain
    149 
    150         self._return_tb_crc_status = return_tb_crc_status
    151 
    152         self._resource_grid = pusch_transmitter.resource_grid
    153 
    154         # (Optionally) Create OFDMDemodulator
    155         if self._input_domain=="time":
    156             assert l_min is not None, \
    157                 "l_min must be provided for input_domain==time"
    158             self._l_min = l_min
    159             self._ofdm_demodulator = OFDMDemodulator(
    160                 fft_size=pusch_transmitter._num_subcarriers,
    161                 l_min=self._l_min,
    162                 cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length)
    163 
    164         # Use or create default ChannelEstimator
    165         self._perfect_csi = False
    166         self._w = None
    167         if channel_estimator is None:
    168             # Default channel estimator
    169             self._channel_estimator = sionna.nr.PUSCHLSChannelEstimator(
    170                                 self.resource_grid,
    171                                 pusch_transmitter._dmrs_length,
    172                                 pusch_transmitter._dmrs_additional_position,
    173                                 pusch_transmitter._num_cdm_groups_without_data,
    174                                 interpolation_type='lin',
    175                                 dtype=dtype)
    176         elif channel_estimator=="perfect":
    177             # Perfect channel estimation
    178             self._perfect_csi = True
    179             if pusch_transmitter._precoding=="codebook":
    180                 self._w = pusch_transmitter._precoder._w
    181                 self._w = insert_dims(self._w, 2, 1)
    182         else:
    183             # User-provided channel estimator
    184             self._channel_estimator = channel_estimator
    185 
    186         # Use or create default StreamManagement
    187         if stream_management is None:
    188             # Default StreamManagement
    189             rx_tx_association = np.ones([1, pusch_transmitter._num_tx], bool)
    190             self._stream_management = StreamManagement(
    191                                         rx_tx_association,
    192                                         pusch_transmitter._num_layers)
    193         else:
    194             # User-provided StramManagement
    195             self._stream_management = stream_management
    196 
    197         # Use or create default MIMODetector
    198         if mimo_detector is None:
    199             # Default MIMO detector
    200             self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog",
    201                                         pusch_transmitter.resource_grid,
    202                                         self._stream_management,
    203                                         "qam",
    204                                         pusch_transmitter._num_bits_per_symbol,
    205                                         dtype=dtype)
    206         else:
    207             # User-provided MIMO detector
    208             self._mimo_detector = mimo_detector
    209 
    210         # Create LayerDemapper
    211         self._layer_demapper = sionna.nr.LayerDemapper(
    212                     pusch_transmitter._layer_mapper,
    213                     num_bits_per_symbol=pusch_transmitter._num_bits_per_symbol)
    214 
    215         # Use or create default TBDecoder
    216         if tb_decoder is None:
    217             # Default TBEncoder
    218             self._tb_decoder = sionna.nr.TBDecoder(
    219                                     pusch_transmitter._tb_encoder,
    220                                     output_dtype=dtype.real_dtype)
    221         else:
    222             # User-provided TBEncoder
    223             self._tb_decoder = tb_decoder
    224 
    225     #########################################
    226     # Public methods and properties
    227     #########################################
    228 
    229     @property
    230     def resource_grid(self):
    231         """OFDM resource grid underlying the PUSCH transmissions"""
    232         return self._resource_grid
    233 
    234     def call(self, inputs):
    235         if self._perfect_csi:
    236             y, h, no = inputs
    237         else:
    238             y, no = inputs
    239 
    240         # (Optional) OFDM Demodulation
    241         if self._input_domain=="time":
    242             y = self._ofdm_demodulator(y)
    243 
    244         # Channel estimation
    245         if self._perfect_csi:
    246 
    247             # Transform time-domain to frequency-domain channel
    248             if self._input_domain=="time":
    249                 h = time_to_ofdm_channel(h, self.resource_grid, self._l_min)
    250 
    251 
    252             if self._w is not None:
    253                 # Reshape h to put channel matrix dimensions last
    254                 # [batch size, num_rx, num_tx, num_ofdm_symbols,...
    255                 #  ...fft_size, num_rx_ant, num_tx_ant]
    256                 h = tf.transpose(h, perm=[0,1,3,5,6,2,4])
    257 
    258                 # Multiply by precoding matrices to compute effective channels
    259                 # [batch size, num_rx, num_tx, num_ofdm_symbols,...
    260                 #  ...fft_size, num_rx_ant, num_streams]
    261                 h = tf.matmul(h, self._w)
    262 
    263                 # Reshape
    264                 # [batch size, num_rx, num_rx_ant, num_tx, num_streams,...
    265                 #  ...num_ofdm_symbols, fft_size]
    266                 h = tf.transpose(h, perm=[0,1,5,2,6,3,4])
    267             h_hat = h
    268             err_var = tf.cast(0, dtype=h_hat.dtype.real_dtype)
    269         else:
    270             h_hat,err_var = self._channel_estimator([y, no])
    271 
    272         # MIMO Detection
    273         llr = self._mimo_detector([y, h_hat, err_var, no])
    274 
    275         # Layer demapping
    276         llr = self._layer_demapper(llr)
    277 
    278         # TB Decoding
    279         b_hat, tb_crc_status = self._tb_decoder(llr)
    280 
    281         if self._return_tb_crc_status:
    282             return b_hat, tb_crc_status
    283         else:
    284             return b_hat