pusch_transmitter.py (9406B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """PUSCH Transmitter for the nr (5G) sub-package of the Sionna library. 6 """ 7 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 from sionna.mapping import Mapper 11 from sionna.utils import BinarySource 12 from sionna.ofdm import ResourceGrid, ResourceGridMapper, OFDMModulator 13 from .config import Config 14 from .pusch_config import PUSCHConfig, check_pusch_configs 15 from .pusch_pilot_pattern import PUSCHPilotPattern 16 from .pusch_precoder import PUSCHPrecoder 17 from .tb_encoder import TBEncoder 18 from .layer_mapping import LayerMapper 19 20 class PUSCHTransmitter(Layer): 21 # pylint: disable=line-too-long 22 r"""PUSCHTransmitter(pusch_configs, return_bits=True, output_domain="freq", dtype=tf.complex64, verbose=False, **kwargs) 23 24 This layer generates batches of 5G NR PUSCH slots for multiple transmitters 25 with random or provided payloads. Frequency- or time-domain outputs can be generated. 26 27 It combines multiple processing blocks into a single layer 28 as shown in the following figure. Blocks with dashed lines are 29 optional and depend on the configuration. 30 31 .. figure:: ../figures/pusch_transmitter_block_diagram.png 32 :scale: 30% 33 :align: center 34 35 Information bits :math:`\mathbf{b}` that are either randomly generated or 36 provided as input are encoded into a transport block by the :class:`~sionna.nr.TBEncoder`. 37 The encoded bits are then mapped to QAM constellation symbols by the :class:`~sionna.mapping.Mapper`. 38 The :class:`~sionna.nr.LayerMapper` splits the modulated symbols into different layers 39 which are then mapped onto OFDM resource grids by the :class:`~sionna.ofdm.ResourceGridMapper`. 40 If precoding is enabled in the :class:`~sionna.nr.PUSCHConfig`, the resource grids 41 are further precoded so that there is one for each transmitter and antenna port. 42 If ``output_domain`` equals "freq", these are the outputs :math:`\mathbf{x}`. 43 If ``output_domain`` is chosen to be "time", the resource grids are transformed into 44 time-domain signals by the :class:`~sionna.ofdm.OFDMModulator`. 45 46 Parameters 47 ---------- 48 pusch_configs : instance or list of :class:`~sionna.nr.PUSCHConfig` 49 PUSCH Configurations according to which the resource grid and pilot pattern 50 will created. One configuration is needed for each transmitter. 51 52 return_bits : bool 53 If set to `True`, the layer generates random information bits 54 to be transmitted and returns them together with the transmit signal. 55 Defaults to `True`. 56 57 output_domain : str, one of ["freq", "time"] 58 The domain of the output. Defaults to "freq". 59 60 dtype : One of [tf.complex64, tf.complex128] 61 Dtype of inputs and outputs. Defaults to tf.complex64. 62 63 verbose: bool 64 If `True`, additional parameters are printed during initialization. 65 Defaults to `False`. 66 67 Input 68 ----- 69 One of: 70 71 batch_size : int 72 Batch size of random transmit signals to be generated, 73 if ``return_bits`` is `True`. 74 75 b : [batch_size, num_tx, tb_size], tf.float 76 Information bits to be transmitted, 77 if ``return_bits`` is `False`. 78 79 Output 80 ------ 81 x : [batch size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size], tf.complex or [batch size, num_tx, num_tx_ant, num_time_samples], tf.complex 82 Transmit signal in either frequency or time domain, depending on ``output_domain``. 83 84 b : [batch_size, num_tx, tb_size], tf.float 85 Transmitted information bits. 86 Only returned if ``return_bits`` is `True`. 87 88 Example 89 ------- 90 >>> pusch_config = PUSCHConfig() 91 >>> pusch_transmitter = PUSCHTransmitter(pusch_config) 92 >>> x, b = pusch_transmitter(16) 93 >>> print("Shape of x:", x.shape) 94 Shape of x: (16, 1, 1, 14, 48) 95 >>> print("Shape of b:", b.shape) 96 Shape of b: (16, 1, 1352) 97 98 """ 99 def __init__(self, 100 pusch_configs, 101 return_bits=True, 102 output_domain="freq", 103 dtype=tf.complex64, 104 verbose=False, 105 **kwargs): 106 107 assert dtype in [tf.complex64, tf.complex128], \ 108 "dtype must be tf.complex64 or tf.complex128" 109 super().__init__(dtype=dtype, **kwargs) 110 111 # Validate inputs and extract parameters 112 assert isinstance(return_bits, bool), "return_bits must be bool" 113 self._return_bits = return_bits 114 115 assert output_domain in ["time", "freq"], \ 116 "output_domain must be 'time' or 'freq'" 117 self._output_domain = output_domain 118 119 assert isinstance(verbose, bool), "verbose must be bool" 120 self._verbose = verbose 121 122 if isinstance(pusch_configs, PUSCHConfig): 123 pusch_configs = [pusch_configs] 124 125 params = check_pusch_configs(pusch_configs) 126 for key, value in params.items(): 127 self.__setattr__(f"_{key}", value) 128 129 self._pusch_configs = pusch_configs 130 131 # (Optionally) Create BinarySource 132 if self._return_bits: 133 self._binary_source = BinarySource(dtype=dtype.real_dtype) 134 135 # Create TBEncoder 136 self._tb_encoder = TBEncoder( 137 target_tb_size=self._tb_size, 138 num_coded_bits=self._num_coded_bits, 139 target_coderate=self._target_coderate, 140 num_bits_per_symbol=self._num_bits_per_symbol, 141 num_layers=self._num_layers, 142 n_rnti=self._n_rnti, 143 n_id=self._n_id, 144 channel_type="PUSCH", # PUSCHTransmitter 145 codeword_index=0, # not supported for PUSCH 146 use_scrambler=True, 147 verbose=self._verbose, 148 output_dtype=dtype.real_dtype) 149 150 # Create PUSCHLayerMapper 151 self._layer_mapper = LayerMapper( 152 num_layers=self._num_layers, 153 dtype=dtype) 154 155 # Create Mapper 156 self._mapper = Mapper("qam", 157 self._num_bits_per_symbol, 158 dtype=dtype) 159 160 # Create PUSCHPilotPattern 161 self._pilot_pattern = PUSCHPilotPattern(self._pusch_configs, 162 dtype=dtype) 163 164 # Create ResourceGrid 165 self._resource_grid = ResourceGrid( 166 num_ofdm_symbols=self._num_ofdm_symbols, 167 fft_size=self._num_subcarriers, 168 subcarrier_spacing=self._subcarrier_spacing, 169 num_tx=self._num_tx, 170 num_streams_per_tx=self._num_layers, 171 cyclic_prefix_length=self._cyclic_prefix_length, 172 pilot_pattern=self._pilot_pattern, 173 dtype=dtype) 174 175 # Create ResourceGridMapper 176 self._resource_grid_mapper = ResourceGridMapper(self._resource_grid, 177 dtype=dtype) 178 179 # (Optionally) Create PUSCHPrecoder 180 if self._precoding=="codebook": 181 self._precoder = PUSCHPrecoder(self._precoding_matrices, 182 dtype=dtype) 183 184 # (Optionally) Create OFDMModulator 185 if self._output_domain=="time": 186 self._ofdm_modulator = OFDMModulator(self._cyclic_prefix_length) 187 188 ######################################### 189 # Public methods and properties 190 ######################################### 191 192 @property 193 def resource_grid(self): 194 """OFDM resource grid underlying the PUSCH transmissions""" 195 return self._resource_grid 196 197 @property 198 def pilot_pattern(self): 199 """Aggregate pilot pattern of all transmitters""" 200 return self._pilot_pattern 201 202 def show(self): 203 """Print all properties of the PUSCHConfig and children""" 204 # CarrierConfig is always the same 205 self._pusch_configs[0].carrier.show() 206 Config.show(self._pusch_configs[0]) 207 for idx,p in enumerate(self._pusch_configs): 208 print(f"---- UE {idx} ----") 209 p.dmrs.show() 210 p.tb.show() 211 212 def call(self, inputs): 213 214 if self._return_bits: 215 # inputs defines batch_size 216 batch_size = inputs 217 b = self._binary_source([batch_size, self._num_tx, self._tb_size]) 218 else: 219 b = inputs 220 221 # Encode transport block 222 c = self._tb_encoder(b) 223 224 # Map to constellations 225 x_map = self._mapper(c) 226 227 # Map to layers 228 x_layer = self._layer_mapper(x_map) 229 230 # Apply resource grid mapping 231 x_grid = self._resource_grid_mapper(x_layer) 232 233 # (Optionally) apply PUSCH precoding 234 if self._precoding=="codebook": 235 x_pre = self._precoder(x_grid) 236 else: 237 x_pre = x_grid 238 239 # (Optionally) apply OFDM modulation 240 if self._output_domain=="time": 241 x = self._ofdm_modulator(x_pre) 242 else: 243 x = x_pre 244 245 if self._return_bits: 246 return x, b 247 else: 248 return x 249 250