tb_encoder.py (16416B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Transport block encoding functions for the 5g NR sub-package of Sionna. 6 """ 7 8 import numpy as np 9 import tensorflow as tf 10 from tensorflow.keras.layers import Layer 11 from sionna.fec.crc import CRCEncoder 12 from sionna.fec.scrambling import TB5GScrambler 13 from sionna.fec.ldpc import LDPC5GEncoder 14 from sionna.nr.utils import calculate_tb_size 15 16 class TBEncoder(Layer): 17 # pylint: disable=line-too-long 18 r"""TBEncoder(target_tb_size,num_coded_bits,target_coderate,num_bits_per_symbol,num_layers=1,n_rnti=1,n_id=1,channel_type="PUSCH",codeword_index=0,use_scrambler=True,verbose=False,output_dtype=tf.float32,, **kwargs) 19 5G NR transport block (TB) encoder as defined in TS 38.214 20 [3GPP38214]_ and TS 38.211 [3GPP38211]_ 21 22 The transport block (TB) encoder takes as input a `transport block` of 23 information bits and generates a sequence of codewords for transmission. 24 For this, the information bit sequence is segmented into multiple codewords, 25 protected by additional CRC checks and FEC encoded. Further, interleaving 26 and scrambling is applied before a codeword concatenation generates the 27 final bit sequence. Fig. 1 provides an overview of the TB encoding 28 procedure and we refer the interested reader to [3GPP38214]_ and 29 [3GPP38211]_ for further details. 30 31 .. figure:: ../figures/tb_encoding.png 32 33 Fig. 1: Overview TB encoding (CB CRC does not always apply). 34 35 If ``n_rnti`` and ``n_id`` are given as list, the TBEncoder encodes 36 `num_tx = len(` ``n_rnti`` `)` parallel input streams with different 37 scrambling sequences per user. 38 39 The class inherits from the Keras layer class and can be used as layer in a 40 Keras model. 41 42 Parameters 43 ---------- 44 target_tb_size: int 45 Target transport block size, i.e., how many information bits are 46 encoded into the TB. Note that the effective TB size can be 47 slightly different due to quantization. If required, zero padding 48 is internally applied. 49 50 num_coded_bits: int 51 Number of coded bits after TB encoding. 52 53 target_coderate : float 54 Target coderate. 55 56 num_bits_per_symbol: int 57 Modulation order, i.e., number of bits per QAM symbol. 58 59 num_layers: int, 1 (default) | [1,...,8] 60 Number of transmission layers. 61 62 n_rnti: int or list of ints, 1 (default) | [0,...,65335] 63 RNTI identifier provided by higher layer. Defaults to 1 and must be 64 in range `[0, 65335]`. Defines a part of the random seed of the 65 scrambler. If provided as list, every list entry defines the RNTI 66 of an independent input stream. 67 68 n_id: int or list of ints, 1 (default) | [0,...,1023] 69 Data scrambling ID :math:`n_\text{ID}` related to cell id and 70 provided by higher layer. 71 Defaults to 1 and must be in range `[0, 1023]`. If provided as 72 list, every list entry defines the scrambling id of an independent 73 input stream. 74 75 channel_type: str, "PUSCH" (default) | "PDSCH" 76 Can be either "PUSCH" or "PDSCH". 77 78 codeword_index: int, 0 (default) | 1 79 Scrambler can be configured for two codeword transmission. 80 ``codeword_index`` can be either 0 or 1. Must be 0 for 81 ``channel_type`` = "PUSCH". 82 83 use_scrambler: bool, True (default) 84 If False, no data scrambling is applied (non standard-compliant). 85 86 verbose: bool, False (default) 87 If `True`, additional parameters are printed during initialization. 88 89 dtype: tf.float32 (default) 90 Defines the datatype for internal calculations and the output dtype. 91 92 Input 93 ----- 94 inputs: [...,target_tb_size] or [...,num_tx,target_tb_size], tf.float 95 2+D tensor containing the information bits to be encoded. If 96 ``n_rnti`` and ``n_id`` are a list of size `num_tx`, the input must 97 be of shape `[...,num_tx,target_tb_size]`. 98 99 Output 100 ------ 101 : [...,num_coded_bits], tf.float 102 2+D tensor containing the sequence of the encoded codeword bits of 103 the transport block. 104 105 Note 106 ---- 107 The parameters ``tb_size`` and ``num_coded_bits`` can be derived by the 108 :meth:`~sionna.nr.calculate_tb_size` function or 109 by accessing the corresponding :class:`~sionna.nr.PUSCHConfig` attributes. 110 """ 111 112 def __init__(self, 113 target_tb_size, 114 num_coded_bits, 115 target_coderate, 116 num_bits_per_symbol, 117 num_layers=1, 118 n_rnti=1, 119 n_id=1, 120 channel_type="PUSCH", 121 codeword_index=0, 122 use_scrambler=True, 123 verbose=False, 124 output_dtype=tf.float32, 125 **kwargs): 126 127 super().__init__(dtype=output_dtype, **kwargs) 128 129 assert isinstance(use_scrambler, bool), \ 130 "use_scrambler must be bool." 131 self._use_scrambler = use_scrambler 132 assert isinstance(verbose, bool), \ 133 "verbose must be bool." 134 self._verbose = verbose 135 136 # check input for consistency 137 assert channel_type in ("PDSCH", "PUSCH"), \ 138 "Unsupported channel_type." 139 self._channel_type = channel_type 140 141 assert(target_tb_size%1==0), "target_tb_size must be int." 142 self._target_tb_size = int(target_tb_size) 143 144 assert(num_coded_bits%1==0), "num_coded_bits must be int." 145 self._num_coded_bits = int(num_coded_bits) 146 147 assert(0.<target_coderate <= 948/1024), \ 148 "target_coderate must be in range(0,0.925)." 149 self._target_coderate = target_coderate 150 151 assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int." 152 self._num_bits_per_symbol = int(num_bits_per_symbol) 153 154 assert(num_layers%1==0), "num_layers must be int." 155 self._num_layers = int(num_layers) 156 157 if channel_type=="PDSCH": 158 assert(codeword_index in (0,1)), "codeword_index must be 0 or 1." 159 else: 160 assert codeword_index==0, 'codeword_index must be 0 for "PUSCH".' 161 self._codeword_index = int(codeword_index) 162 163 if isinstance(n_rnti, (list, tuple)): 164 assert isinstance(n_id, (list, tuple)), "n_id must be also a list." 165 assert (len(n_rnti)==len(n_id)), \ 166 "n_id and n_rnti must be of same length." 167 self._n_rnti = n_rnti 168 self._n_id = n_id 169 else: 170 self._n_rnti = [n_rnti] 171 self._n_id = [n_id] 172 173 for idx, n in enumerate(self._n_rnti): 174 assert(n%1==0), "n_rnti must be int." 175 self._n_rnti[idx] = int(n) 176 for idx, n in enumerate(self._n_id): 177 assert(n%1==0), "n_id must be int." 178 self._n_id[idx] = int(n) 179 180 self._num_tx = len(self._n_id) 181 182 tbconfig = calculate_tb_size(target_tb_size=self._target_tb_size, 183 num_coded_bits=self._num_coded_bits, 184 target_coderate=self._target_coderate, 185 modulation_order=self._num_bits_per_symbol, 186 num_layers=self._num_layers, 187 verbose=verbose) 188 self._tb_size = tbconfig[0] 189 self._cb_size = tbconfig[1] 190 self._num_cbs = tbconfig[2] 191 self._cw_lengths = tbconfig[3] 192 self._tb_crc_length = tbconfig[4] 193 self._cb_crc_length = tbconfig[5] 194 195 assert self._tb_size <= self._tb_crc_length + np.sum(self._cw_lengths),\ 196 "Invalid TB parameters." 197 198 # due to quantization, the tb_size can slightly differ from the 199 # target tb_size. 200 self._k_padding = self._tb_size - self._target_tb_size 201 if self._tb_size != self._target_tb_size: 202 print(f"Note: actual tb_size={self._tb_size} is slightly "\ 203 f"different than requested " \ 204 f"target_tb_size={self._target_tb_size} due to "\ 205 f"quantization. Internal zero padding will be applied.") 206 207 # calculate effective coderate (incl. CRC) 208 self._coderate = self._tb_size / self._num_coded_bits 209 210 # Remark: CRC16 is only used for k<3824 (otherwise CRC24) 211 if self._tb_crc_length==16: 212 self._tb_crc_encoder = CRCEncoder("CRC16") 213 else: 214 # CRC24A as defined in 7.2.1 215 self._tb_crc_encoder = CRCEncoder("CRC24A") 216 217 # CB CRC only if more than one CB is used 218 if self._cb_crc_length==24: 219 self._cb_crc_encoder = CRCEncoder("CRC24B") 220 else: 221 self._cb_crc_encoder = None 222 223 # scrambler can be deactivated (non-standard compliant) 224 if self._use_scrambler: 225 self._scrambler = TB5GScrambler(n_rnti=self._n_rnti, 226 n_id=self._n_id, 227 binary=True, 228 channel_type=channel_type, 229 codeword_index=codeword_index, 230 dtype=tf.float32,) 231 else: # required for TBDecoder 232 self._scrambler = None 233 234 # ---- Init LDPC encoder ---- 235 # remark: as the codeword length can be (slightly) different 236 # within a TB due to rounding, we initialize the encoder 237 # with the max length and apply puncturing if required. 238 # Thus, also the output interleaver cannot be applied in the encoder. 239 # The procedure is defined in in 5.4.2.1 38.212 240 self._encoder = LDPC5GEncoder(self._cb_size, 241 np.max(self._cw_lengths), 242 num_bits_per_symbol=1) #deact. interleaver 243 244 # ---- Init interleaver ---- 245 # remark: explicit interleaver is required as the rate matching from 246 # Sec. 5.4.2.1 38.212 could otherwise not be applied here 247 perm_seq_short, _ = self._encoder.generate_out_int( 248 np.min(self._cw_lengths), 249 num_bits_per_symbol) 250 perm_seq_long, _ = self._encoder.generate_out_int( 251 np.max(self._cw_lengths), 252 num_bits_per_symbol) 253 254 perm_seq = [] 255 perm_seq_punc = [] 256 257 # define one big interleaver that moves the punctured positions to the 258 # end of the TB 259 payload_bit_pos = 0 # points to current pos of payload bits 260 261 for l in self._cw_lengths: 262 if np.min(self._cw_lengths)==l: 263 perm_seq = np.concatenate([perm_seq, 264 perm_seq_short + payload_bit_pos]) 265 # move unused bit positions to the end of TB 266 # this simplifies the inverse permutation 267 r = np.arange(payload_bit_pos+np.min(self._cw_lengths), 268 payload_bit_pos+np.max(self._cw_lengths)) 269 perm_seq_punc = np.concatenate([perm_seq_punc, r]) 270 271 # update pointer 272 payload_bit_pos += np.max(self._cw_lengths) 273 elif np.max(self._cw_lengths)==l: 274 perm_seq = np.concatenate([perm_seq, 275 perm_seq_long + payload_bit_pos]) 276 # update pointer 277 payload_bit_pos += l 278 else: 279 raise ValueError("Invalid cw_lengths.") 280 281 # add punctured positions to end of sequence (only relevant for 282 # deinterleaving) 283 perm_seq = np.concatenate([perm_seq, perm_seq_punc]) 284 285 self._output_perm = tf.constant(perm_seq, tf.int32) 286 self._output_perm_inv = tf.argsort(perm_seq, axis=-1) 287 288 ######################################### 289 # Public methods and properties 290 ######################################### 291 292 @property 293 def tb_size(self): 294 r"""Effective number of information bits per TB. 295 Note that (if required) internal zero padding can be applied to match 296 the request exact ``target_tb_size``.""" 297 return self._tb_size 298 299 @property 300 def k(self): 301 r"""Number of input information bits. Equals `tb_size` except for zero 302 padding of the last positions if the ``target_tb_size`` is quantized.""" 303 return self._target_tb_size 304 305 @property 306 def k_padding(self): 307 """Number of zero padded bits at the end of the TB.""" 308 return self._k_padding 309 310 @property 311 def n(self): 312 "Total number of output bits." 313 return self._num_coded_bits 314 315 @property 316 def num_cbs(self): 317 "Number code blocks." 318 return self._num_cbs 319 320 @property 321 def coderate(self): 322 """Effective coderate of the TB after rate-matching including overhead 323 for the CRC.""" 324 return self._coderate 325 326 @property 327 def ldpc_encoder(self): 328 """LDPC encoder used for TB encoding.""" 329 return self._encoder 330 331 @property 332 def scrambler(self): 333 """Scrambler used for TB scrambling. `None` if no scrambler is used.""" 334 return self._scrambler 335 336 @property 337 def tb_crc_encoder(self): 338 """TB CRC encoder""" 339 return self._tb_crc_encoder 340 341 @property 342 def cb_crc_encoder(self): 343 """CB CRC encoder. `None` if no CB CRC is applied.""" 344 return self._cb_crc_encoder 345 346 @property 347 def num_tx(self): 348 """Number of independent streams""" 349 return self._num_tx 350 351 @property 352 def cw_lengths(self): 353 r"""Each list element defines the codeword length of each of the 354 codewords after LDPC encoding and rate-matching. The total number of 355 coded bits is :math:`\sum` `cw_lengths`.""" 356 return self._cw_lengths 357 358 @property 359 def output_perm_inv(self): 360 r"""Inverse interleaver pattern for output bit interleaver.""" 361 return self._output_perm_inv 362 363 ######################### 364 # Keras layer functions 365 ######################### 366 367 def build(self, input_shapes): 368 """Test input shapes for consistency.""" 369 370 assert input_shapes[-1]==self.k, \ 371 f"Invalid input shape. Expected TB length is {self.k}." 372 373 def call(self, inputs): 374 """Apply transport block encoding procedure.""" 375 376 # store shapes 377 input_shape = inputs.shape.as_list() 378 u = tf.cast(inputs, tf.float32) 379 380 # apply zero padding if tb_size is slightly different to target_tb_size 381 if self._k_padding>0: 382 s = tf.shape(u) 383 s = tf.concat((s[:-1], [self._k_padding]), axis=0) 384 u = tf.concat((u, tf.zeros(s, u.dtype)), axis=-1) 385 386 # apply TB CRC 387 u_crc = self._tb_crc_encoder(u) 388 389 # CB segmentation 390 u_cb = tf.reshape(u_crc, 391 (-1, self._num_tx, self._num_cbs, 392 self._cb_size-self._cb_crc_length)) 393 394 # if relevant apply CB CRC 395 if self._cb_crc_length==24: 396 u_cb_crc = self._cb_crc_encoder(u_cb) 397 else: 398 u_cb_crc = u_cb # no CRC applied if only one CB exists 399 400 c_cb = self._encoder(u_cb_crc) 401 402 # CB concatenation 403 c = tf.reshape(c_cb, 404 (-1, self._num_tx, 405 self._num_cbs*np.max(self._cw_lengths))) 406 407 # apply interleaver (done after CB concatenation) 408 c = tf.gather(c, self._output_perm, axis=-1) 409 # puncture last bits 410 c = c[:, :, :np.sum(self._cw_lengths)] 411 412 # scrambler 413 if self._use_scrambler: 414 c_scr = self._scrambler(c) 415 else: # disable scrambler (non-standard compliant) 416 c_scr = c 417 418 # cast to output dtype 419 c_scr = tf.cast(c_scr, self.dtype) 420 421 # ensure output shapes 422 output_shape = input_shape 423 output_shape[0] = -1 424 output_shape[-1] = np.sum(self._cw_lengths) 425 c_tb = tf.reshape(c_scr, output_shape) 426 427 return c_tb