anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

pusch_transmitter.py (9406B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """PUSCH Transmitter for the nr (5G) sub-package of the Sionna library.
      6 """
      7 
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 from sionna.mapping import Mapper
     11 from sionna.utils import BinarySource
     12 from sionna.ofdm import ResourceGrid, ResourceGridMapper, OFDMModulator
     13 from .config import Config
     14 from .pusch_config import PUSCHConfig, check_pusch_configs
     15 from .pusch_pilot_pattern import PUSCHPilotPattern
     16 from .pusch_precoder import PUSCHPrecoder
     17 from .tb_encoder import TBEncoder
     18 from .layer_mapping import LayerMapper
     19 
     20 class PUSCHTransmitter(Layer):
     21     # pylint: disable=line-too-long
     22     r"""PUSCHTransmitter(pusch_configs, return_bits=True, output_domain="freq", dtype=tf.complex64, verbose=False, **kwargs)
     23 
     24     This layer generates batches of 5G NR PUSCH slots for multiple transmitters
     25     with random or provided payloads. Frequency- or time-domain outputs can be generated.
     26 
     27     It combines multiple processing blocks into a single layer
     28     as shown in the following figure. Blocks with dashed lines are
     29     optional and depend on the configuration.
     30 
     31     .. figure:: ../figures/pusch_transmitter_block_diagram.png
     32         :scale: 30%
     33         :align: center
     34 
     35     Information bits :math:`\mathbf{b}` that are either randomly generated or
     36     provided as input are encoded into a transport block by the :class:`~sionna.nr.TBEncoder`.
     37     The encoded bits are then mapped to QAM constellation symbols by the :class:`~sionna.mapping.Mapper`.
     38     The :class:`~sionna.nr.LayerMapper` splits the modulated symbols into different layers
     39     which are then mapped onto OFDM resource grids by the :class:`~sionna.ofdm.ResourceGridMapper`.
     40     If precoding is enabled in the :class:`~sionna.nr.PUSCHConfig`, the resource grids
     41     are further precoded so that there is one for each transmitter and antenna port.
     42     If ``output_domain`` equals "freq", these are the outputs :math:`\mathbf{x}`.
     43     If ``output_domain`` is chosen to be "time", the resource grids are transformed into
     44     time-domain signals by the :class:`~sionna.ofdm.OFDMModulator`.
     45 
     46     Parameters
     47     ----------
     48     pusch_configs : instance or list of :class:`~sionna.nr.PUSCHConfig`
     49         PUSCH Configurations according to which the resource grid and pilot pattern
     50         will created. One configuration is needed for each transmitter.
     51 
     52     return_bits : bool
     53         If set to `True`, the layer generates random information bits
     54         to be transmitted and returns them together with the transmit signal.
     55         Defaults to `True`.
     56 
     57     output_domain : str, one of ["freq", "time"]
     58         The domain of the output. Defaults to "freq".
     59 
     60     dtype : One of [tf.complex64, tf.complex128]
     61         Dtype of inputs and outputs. Defaults to tf.complex64.
     62 
     63     verbose: bool
     64         If `True`, additional parameters are printed during initialization.
     65         Defaults to `False`.
     66 
     67     Input
     68     -----
     69     One of:
     70 
     71     batch_size : int
     72         Batch size of random transmit signals to be generated,
     73         if ``return_bits`` is `True`.
     74 
     75     b : [batch_size, num_tx, tb_size], tf.float
     76         Information bits to be transmitted,
     77         if ``return_bits`` is `False`.
     78 
     79     Output
     80     ------
     81     x : [batch size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size], tf.complex or [batch size, num_tx, num_tx_ant, num_time_samples], tf.complex
     82         Transmit signal in either frequency or time domain, depending on ``output_domain``.
     83 
     84     b : [batch_size, num_tx, tb_size], tf.float
     85         Transmitted information bits.
     86         Only returned if ``return_bits`` is `True`.
     87 
     88     Example
     89     -------
     90     >>> pusch_config = PUSCHConfig()
     91     >>> pusch_transmitter = PUSCHTransmitter(pusch_config)
     92     >>> x, b = pusch_transmitter(16)
     93     >>> print("Shape of x:", x.shape)
     94     Shape of x: (16, 1, 1, 14, 48)
     95     >>> print("Shape of b:", b.shape)
     96     Shape of b: (16, 1, 1352)
     97 
     98     """
     99     def __init__(self,
    100                  pusch_configs,
    101                  return_bits=True,
    102                  output_domain="freq",
    103                  dtype=tf.complex64,
    104                  verbose=False,
    105                  **kwargs):
    106 
    107         assert dtype in [tf.complex64, tf.complex128], \
    108             "dtype must be tf.complex64 or tf.complex128"
    109         super().__init__(dtype=dtype, **kwargs)
    110 
    111         # Validate inputs and extract parameters
    112         assert isinstance(return_bits, bool), "return_bits must be bool"
    113         self._return_bits = return_bits
    114 
    115         assert output_domain in ["time", "freq"], \
    116             "output_domain must be 'time' or 'freq'"
    117         self._output_domain = output_domain
    118 
    119         assert isinstance(verbose, bool), "verbose must be bool"
    120         self._verbose = verbose
    121 
    122         if isinstance(pusch_configs, PUSCHConfig):
    123             pusch_configs = [pusch_configs]
    124 
    125         params = check_pusch_configs(pusch_configs)
    126         for key, value in params.items():
    127             self.__setattr__(f"_{key}", value)
    128 
    129         self._pusch_configs = pusch_configs
    130 
    131         # (Optionally) Create BinarySource
    132         if self._return_bits:
    133             self._binary_source = BinarySource(dtype=dtype.real_dtype)
    134 
    135         # Create TBEncoder
    136         self._tb_encoder = TBEncoder(
    137                             target_tb_size=self._tb_size,
    138                             num_coded_bits=self._num_coded_bits,
    139                             target_coderate=self._target_coderate,
    140                             num_bits_per_symbol=self._num_bits_per_symbol,
    141                             num_layers=self._num_layers,
    142                             n_rnti=self._n_rnti,
    143                             n_id=self._n_id,
    144                             channel_type="PUSCH", # PUSCHTransmitter
    145                             codeword_index=0, # not supported for PUSCH
    146                             use_scrambler=True,
    147                             verbose=self._verbose,
    148                             output_dtype=dtype.real_dtype)
    149 
    150         # Create PUSCHLayerMapper
    151         self._layer_mapper = LayerMapper(
    152                                 num_layers=self._num_layers,
    153                                 dtype=dtype)
    154 
    155         # Create Mapper
    156         self._mapper = Mapper("qam",
    157                               self._num_bits_per_symbol,
    158                               dtype=dtype)
    159 
    160         # Create PUSCHPilotPattern
    161         self._pilot_pattern = PUSCHPilotPattern(self._pusch_configs,
    162                                                 dtype=dtype)
    163 
    164         # Create ResourceGrid
    165         self._resource_grid = ResourceGrid(
    166                             num_ofdm_symbols=self._num_ofdm_symbols,
    167                             fft_size=self._num_subcarriers,
    168                             subcarrier_spacing=self._subcarrier_spacing,
    169                             num_tx=self._num_tx,
    170                             num_streams_per_tx=self._num_layers,
    171                             cyclic_prefix_length=self._cyclic_prefix_length,
    172                             pilot_pattern=self._pilot_pattern,
    173                             dtype=dtype)
    174 
    175         # Create ResourceGridMapper
    176         self._resource_grid_mapper = ResourceGridMapper(self._resource_grid,
    177                                                         dtype=dtype)
    178 
    179         # (Optionally) Create PUSCHPrecoder
    180         if self._precoding=="codebook":
    181             self._precoder = PUSCHPrecoder(self._precoding_matrices,
    182                                            dtype=dtype)
    183 
    184         # (Optionally) Create OFDMModulator
    185         if self._output_domain=="time":
    186             self._ofdm_modulator = OFDMModulator(self._cyclic_prefix_length)
    187 
    188     #########################################
    189     # Public methods and properties
    190     #########################################
    191 
    192     @property
    193     def resource_grid(self):
    194         """OFDM resource grid underlying the PUSCH transmissions"""
    195         return self._resource_grid
    196 
    197     @property
    198     def pilot_pattern(self):
    199         """Aggregate pilot pattern of all transmitters"""
    200         return self._pilot_pattern
    201 
    202     def show(self):
    203         """Print all properties of the PUSCHConfig and children"""
    204         # CarrierConfig is always the same
    205         self._pusch_configs[0].carrier.show()
    206         Config.show(self._pusch_configs[0])
    207         for idx,p in enumerate(self._pusch_configs):
    208             print(f"---- UE {idx} ----")
    209             p.dmrs.show()
    210             p.tb.show()
    211 
    212     def call(self, inputs):
    213 
    214         if self._return_bits:
    215             # inputs defines batch_size
    216             batch_size = inputs
    217             b = self._binary_source([batch_size, self._num_tx, self._tb_size])
    218         else:
    219             b = inputs
    220 
    221         # Encode transport block
    222         c = self._tb_encoder(b)
    223 
    224         # Map to constellations
    225         x_map = self._mapper(c)
    226 
    227         # Map to layers
    228         x_layer = self._layer_mapper(x_map)
    229 
    230         # Apply resource grid mapping
    231         x_grid = self._resource_grid_mapper(x_layer)
    232 
    233         # (Optionally) apply PUSCH precoding
    234         if self._precoding=="codebook":
    235             x_pre = self._precoder(x_grid)
    236         else:
    237             x_pre = x_grid
    238 
    239         # (Optionally) apply OFDM modulation
    240         if self._output_domain=="time":
    241             x = self._ofdm_modulator(x_pre)
    242         else:
    243             x = x_pre
    244 
    245         if self._return_bits:
    246             return x, b
    247         else:
    248             return x
    249 
    250