pusch_receiver.py (11471B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """PUSCH Receiver for the nr (5G) sub-package of the Sionna library. 6 """ 7 import numpy as np 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 import sionna 11 from sionna.mimo import StreamManagement 12 from sionna.ofdm import OFDMDemodulator, LinearDetector 13 from sionna.utils import insert_dims 14 from sionna.channel import time_to_ofdm_channel 15 16 class PUSCHReceiver(Layer): 17 # pylint: disable=line-too-long 18 r"""PUSCHReceiver(pusch_transmitter, channel_estimator=None, mimo_detector=None, tb_decoder=None, return_tb_crc_status=False, stream_management=None, input_domain="freq", l_min=None, dtype=tf.complex64, **kwargs) 19 20 This layer implements a full receiver for batches of 5G NR PUSCH slots sent 21 by multiple transmitters. Inputs can be in the time or frequency domain. 22 Perfect channel state information can be optionally provided. 23 Different channel estimatiors, MIMO detectors, and transport decoders 24 can be configured. 25 26 The layer combines multiple processing blocks into a single layer 27 as shown in the following figure. Blocks with dashed lines are 28 optional and depend on the configuration. 29 30 .. figure:: ../figures/pusch_receiver_block_diagram.png 31 :scale: 30% 32 :align: center 33 34 If the ``input_domain`` equals "time", the inputs :math:`\mathbf{y}` are first 35 transformed to resource grids with the :class:`~sionna.ofdm.OFDMDemodulator`. 36 Then channel estimation is performed, e.g., with the help of the 37 :class:`~sionna.nr.PUSCHLSChannelEstimator`. If ``channel_estimator`` 38 is chosen to be "perfect", this step is skipped and the input :math:`\mathbf{h}` 39 is used instead. 40 Next, MIMO detection is carried out with an arbitrary :class:`~sionna.ofdm.OFDMDetector`. 41 The resulting LLRs for each layer are then combined to transport blocks 42 with the help of the :class:`~sionna.nr.LayerDemapper`. 43 Finally, the transport blocks are decoded with the :class:`~sionna.nr.TBDecoder`. 44 45 Parameters 46 ---------- 47 pusch_transmitter : :class:`~sionna.nr.PUSCHTransmitter` 48 Transmitter used for the generation of the transmit signals 49 50 channel_estimator : :class:`~sionna.ofdm.BaseChannelEstimator`, "perfect", or `None` 51 Channel estimator to be used. 52 If `None`, the :class:`~sionna.nr.PUSCHLSChannelEstimator` with 53 linear interpolation is used. 54 If "perfect", no channel estimation is performed and the channel state information 55 ``h`` must be provided as additional input. 56 Defaults to `None`. 57 58 mimo_detector : :class:`~sionna.ofdm.OFDMDetector` or `None` 59 MIMO Detector to be used. 60 If `None`, the :class:`~sionna.ofdm.LinearDetector` with 61 LMMSE detection is used. 62 Defaults to `None`. 63 64 tb_decoder : :class:`~sionna.nr.TBDecoder` or `None` 65 Transport block decoder to be used. 66 If `None`, the :class:`~sionna.nr.TBDecoder` with its 67 default settings is used. 68 Defaults to `None`. 69 70 return_tb_crc_status : bool 71 If `True`, the status of the transport block CRC is returned 72 as additional output. 73 Defaults to `False`. 74 75 stream_management : :class:`~sionna.mimo.StreamManagement` or `None` 76 Stream management configuration to be used. 77 If `None`, it is assumed that there is a single receiver 78 which decodes all streams of all transmitters. 79 Defaults to `None`. 80 81 input_domain : str, one of ["freq", "time"] 82 Domain of the input signal. 83 Defaults to "freq". 84 85 l_min : int or `None` 86 Smallest time-lag for the discrete complex baseband channel. 87 Only needed if ``input_domain`` equals "time". 88 Defaults to `None`. 89 90 dtype : tf.Dtype 91 Datatype for internal calculations and the output dtype. 92 Defaults to `tf.complex64`. 93 94 Input 95 ----- 96 (y, h, no) : 97 Tuple: 98 99 y : [batch size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex or [batch size, num_rx, num_rx_ant, num_time_samples + l_max - l_min], tf.complex 100 Frequency- or time-domain input signal 101 102 h : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, num_subcarriers], tf.complex or [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_max - l_min, l_max - l_min + 1], tf.complex 103 Perfect channel state information in either frequency or time domain 104 (depending on ``input_domain``) to be used for detection. 105 Only required if ``channel_estimator`` equals "perfect". 106 107 no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float 108 Variance of the AWGN 109 110 Output 111 ------ 112 b_hat : [batch_size, num_tx, tb_size], tf.float 113 Decoded information bits 114 115 tb_crc_status : [batch_size, num_tx], tf.bool 116 Transport block CRC status 117 118 Example 119 ------- 120 >>> pusch_config = PUSCHConfig() 121 >>> pusch_transmitter = PUSCHTransmitter(pusch_config) 122 >>> pusch_receiver = PUSCHReceiver(pusch_transmitter) 123 >>> channel = AWGN() 124 >>> x, b = pusch_transmitter(16) 125 >>> no = 0.1 126 >>> y = channel([x, no]) 127 >>> b_hat = pusch_receiver([x, no]) 128 >>> compute_ber(b, b_hat) 129 <tf.Tensor: shape=(), dtype=float64, numpy=0.0> 130 """ 131 def __init__(self, 132 pusch_transmitter, 133 channel_estimator=None, 134 mimo_detector=None, 135 tb_decoder=None, 136 return_tb_crc_status=False, 137 stream_management=None, 138 input_domain="freq", 139 l_min=None, 140 dtype=tf.complex64, 141 **kwargs): 142 assert dtype in [tf.complex64, tf.complex128], \ 143 "dtype must be tf.complex64 or tf.complex128" 144 super().__init__(dtype=dtype, **kwargs) 145 146 assert input_domain in ["time", "freq"], \ 147 "input_domain must be 'time' or 'freq'" 148 self._input_domain = input_domain 149 150 self._return_tb_crc_status = return_tb_crc_status 151 152 self._resource_grid = pusch_transmitter.resource_grid 153 154 # (Optionally) Create OFDMDemodulator 155 if self._input_domain=="time": 156 assert l_min is not None, \ 157 "l_min must be provided for input_domain==time" 158 self._l_min = l_min 159 self._ofdm_demodulator = OFDMDemodulator( 160 fft_size=pusch_transmitter._num_subcarriers, 161 l_min=self._l_min, 162 cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length) 163 164 # Use or create default ChannelEstimator 165 self._perfect_csi = False 166 self._w = None 167 if channel_estimator is None: 168 # Default channel estimator 169 self._channel_estimator = sionna.nr.PUSCHLSChannelEstimator( 170 self.resource_grid, 171 pusch_transmitter._dmrs_length, 172 pusch_transmitter._dmrs_additional_position, 173 pusch_transmitter._num_cdm_groups_without_data, 174 interpolation_type='lin', 175 dtype=dtype) 176 elif channel_estimator=="perfect": 177 # Perfect channel estimation 178 self._perfect_csi = True 179 if pusch_transmitter._precoding=="codebook": 180 self._w = pusch_transmitter._precoder._w 181 self._w = insert_dims(self._w, 2, 1) 182 else: 183 # User-provided channel estimator 184 self._channel_estimator = channel_estimator 185 186 # Use or create default StreamManagement 187 if stream_management is None: 188 # Default StreamManagement 189 rx_tx_association = np.ones([1, pusch_transmitter._num_tx], bool) 190 self._stream_management = StreamManagement( 191 rx_tx_association, 192 pusch_transmitter._num_layers) 193 else: 194 # User-provided StramManagement 195 self._stream_management = stream_management 196 197 # Use or create default MIMODetector 198 if mimo_detector is None: 199 # Default MIMO detector 200 self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog", 201 pusch_transmitter.resource_grid, 202 self._stream_management, 203 "qam", 204 pusch_transmitter._num_bits_per_symbol, 205 dtype=dtype) 206 else: 207 # User-provided MIMO detector 208 self._mimo_detector = mimo_detector 209 210 # Create LayerDemapper 211 self._layer_demapper = sionna.nr.LayerDemapper( 212 pusch_transmitter._layer_mapper, 213 num_bits_per_symbol=pusch_transmitter._num_bits_per_symbol) 214 215 # Use or create default TBDecoder 216 if tb_decoder is None: 217 # Default TBEncoder 218 self._tb_decoder = sionna.nr.TBDecoder( 219 pusch_transmitter._tb_encoder, 220 output_dtype=dtype.real_dtype) 221 else: 222 # User-provided TBEncoder 223 self._tb_decoder = tb_decoder 224 225 ######################################### 226 # Public methods and properties 227 ######################################### 228 229 @property 230 def resource_grid(self): 231 """OFDM resource grid underlying the PUSCH transmissions""" 232 return self._resource_grid 233 234 def call(self, inputs): 235 if self._perfect_csi: 236 y, h, no = inputs 237 else: 238 y, no = inputs 239 240 # (Optional) OFDM Demodulation 241 if self._input_domain=="time": 242 y = self._ofdm_demodulator(y) 243 244 # Channel estimation 245 if self._perfect_csi: 246 247 # Transform time-domain to frequency-domain channel 248 if self._input_domain=="time": 249 h = time_to_ofdm_channel(h, self.resource_grid, self._l_min) 250 251 252 if self._w is not None: 253 # Reshape h to put channel matrix dimensions last 254 # [batch size, num_rx, num_tx, num_ofdm_symbols,... 255 # ...fft_size, num_rx_ant, num_tx_ant] 256 h = tf.transpose(h, perm=[0,1,3,5,6,2,4]) 257 258 # Multiply by precoding matrices to compute effective channels 259 # [batch size, num_rx, num_tx, num_ofdm_symbols,... 260 # ...fft_size, num_rx_ant, num_streams] 261 h = tf.matmul(h, self._w) 262 263 # Reshape 264 # [batch size, num_rx, num_rx_ant, num_tx, num_streams,... 265 # ...num_ofdm_symbols, fft_size] 266 h = tf.transpose(h, perm=[0,1,5,2,6,3,4]) 267 h_hat = h 268 err_var = tf.cast(0, dtype=h_hat.dtype.real_dtype) 269 else: 270 h_hat,err_var = self._channel_estimator([y, no]) 271 272 # MIMO Detection 273 llr = self._mimo_detector([y, h_hat, err_var, no]) 274 275 # Layer demapping 276 llr = self._layer_demapper(llr) 277 278 # TB Decoding 279 b_hat, tb_crc_status = self._tb_decoder(llr) 280 281 if self._return_tb_crc_status: 282 return b_hat, tb_crc_status 283 else: 284 return b_hat