anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

tb_encoder.py (16416B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Transport block encoding functions for the 5g NR sub-package of Sionna.
      6 """
      7 
      8 import numpy as np
      9 import tensorflow as tf
     10 from tensorflow.keras.layers import Layer
     11 from sionna.fec.crc import CRCEncoder
     12 from sionna.fec.scrambling import TB5GScrambler
     13 from sionna.fec.ldpc import LDPC5GEncoder
     14 from sionna.nr.utils import calculate_tb_size
     15 
     16 class TBEncoder(Layer):
     17     # pylint: disable=line-too-long
     18     r"""TBEncoder(target_tb_size,num_coded_bits,target_coderate,num_bits_per_symbol,num_layers=1,n_rnti=1,n_id=1,channel_type="PUSCH",codeword_index=0,use_scrambler=True,verbose=False,output_dtype=tf.float32,, **kwargs)
     19     5G NR transport block (TB) encoder as defined in TS 38.214
     20     [3GPP38214]_ and TS 38.211 [3GPP38211]_
     21 
     22     The transport block (TB) encoder takes as input a `transport block` of
     23     information bits and generates a sequence of codewords for transmission.
     24     For this, the information bit sequence is segmented into multiple codewords,
     25     protected by additional CRC checks and FEC encoded. Further, interleaving
     26     and scrambling is applied before a codeword concatenation generates the
     27     final bit sequence. Fig. 1 provides an overview of the TB encoding
     28     procedure and we refer the interested reader to [3GPP38214]_ and
     29     [3GPP38211]_ for further details.
     30 
     31     ..  figure:: ../figures/tb_encoding.png
     32 
     33         Fig. 1: Overview TB encoding (CB CRC does not always apply).
     34 
     35     If ``n_rnti`` and ``n_id`` are given as list, the TBEncoder encodes
     36     `num_tx = len(` ``n_rnti`` `)` parallel input streams with different
     37     scrambling sequences per user.
     38 
     39     The class inherits from the Keras layer class and can be used as layer in a
     40     Keras model.
     41 
     42     Parameters
     43     ----------
     44         target_tb_size: int
     45             Target transport block size, i.e., how many information bits are
     46             encoded into the TB. Note that the effective TB size can be
     47             slightly different due to quantization. If required, zero padding
     48             is internally applied.
     49 
     50         num_coded_bits: int
     51             Number of coded bits after TB encoding.
     52 
     53         target_coderate : float
     54             Target coderate.
     55 
     56         num_bits_per_symbol: int
     57             Modulation order, i.e., number of bits per QAM symbol.
     58 
     59         num_layers: int, 1 (default) | [1,...,8]
     60             Number of transmission layers.
     61 
     62         n_rnti: int or list of ints, 1 (default) | [0,...,65335]
     63             RNTI identifier provided by higher layer. Defaults to 1 and must be
     64             in range `[0, 65335]`. Defines a part of the random seed of the
     65             scrambler. If provided as list, every list entry defines the RNTI
     66             of an independent input stream.
     67 
     68         n_id: int or list of ints, 1 (default) | [0,...,1023]
     69             Data scrambling ID :math:`n_\text{ID}` related to cell id and
     70             provided by higher layer.
     71             Defaults to 1 and must be in range `[0, 1023]`. If provided as
     72             list, every list entry defines the scrambling id of an independent
     73             input stream.
     74 
     75         channel_type: str, "PUSCH" (default) | "PDSCH"
     76             Can be either "PUSCH" or "PDSCH".
     77 
     78         codeword_index: int, 0 (default) | 1
     79             Scrambler can be configured for two codeword transmission.
     80             ``codeword_index`` can be either 0 or 1. Must be 0 for
     81             ``channel_type`` = "PUSCH".
     82 
     83         use_scrambler: bool, True (default)
     84             If False, no data scrambling is applied (non standard-compliant).
     85 
     86         verbose: bool, False (default)
     87             If `True`, additional parameters are printed during initialization.
     88 
     89         dtype: tf.float32 (default)
     90             Defines the datatype for internal calculations and the output dtype.
     91 
     92     Input
     93     -----
     94         inputs: [...,target_tb_size] or [...,num_tx,target_tb_size], tf.float
     95             2+D tensor containing the information bits to be encoded. If
     96             ``n_rnti`` and ``n_id`` are a list of size `num_tx`, the input must
     97             be of shape `[...,num_tx,target_tb_size]`.
     98 
     99     Output
    100     ------
    101         : [...,num_coded_bits], tf.float
    102             2+D tensor containing the sequence of the encoded codeword bits of
    103             the transport block.
    104 
    105     Note
    106     ----
    107     The parameters ``tb_size`` and ``num_coded_bits`` can be derived by the
    108     :meth:`~sionna.nr.calculate_tb_size` function or
    109     by accessing the corresponding :class:`~sionna.nr.PUSCHConfig` attributes.
    110     """
    111 
    112     def __init__(self,
    113                  target_tb_size,
    114                  num_coded_bits,
    115                  target_coderate,
    116                  num_bits_per_symbol,
    117                  num_layers=1,
    118                  n_rnti=1,
    119                  n_id=1,
    120                  channel_type="PUSCH",
    121                  codeword_index=0,
    122                  use_scrambler=True,
    123                  verbose=False,
    124                  output_dtype=tf.float32,
    125                  **kwargs):
    126 
    127         super().__init__(dtype=output_dtype, **kwargs)
    128 
    129         assert isinstance(use_scrambler, bool), \
    130                                 "use_scrambler must be bool."
    131         self._use_scrambler = use_scrambler
    132         assert isinstance(verbose, bool), \
    133                                 "verbose must be bool."
    134         self._verbose = verbose
    135 
    136         # check input for consistency
    137         assert channel_type in ("PDSCH", "PUSCH"), \
    138                                 "Unsupported channel_type."
    139         self._channel_type = channel_type
    140 
    141         assert(target_tb_size%1==0), "target_tb_size must be int."
    142         self._target_tb_size = int(target_tb_size)
    143 
    144         assert(num_coded_bits%1==0), "num_coded_bits must be int."
    145         self._num_coded_bits = int(num_coded_bits)
    146 
    147         assert(0.<target_coderate <= 948/1024), \
    148                     "target_coderate must be in range(0,0.925)."
    149         self._target_coderate = target_coderate
    150 
    151         assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int."
    152         self._num_bits_per_symbol = int(num_bits_per_symbol)
    153 
    154         assert(num_layers%1==0), "num_layers must be int."
    155         self._num_layers = int(num_layers)
    156 
    157         if channel_type=="PDSCH":
    158             assert(codeword_index in (0,1)), "codeword_index must be 0 or 1."
    159         else:
    160             assert codeword_index==0, 'codeword_index must be 0 for "PUSCH".'
    161         self._codeword_index = int(codeword_index)
    162 
    163         if isinstance(n_rnti, (list, tuple)):
    164             assert isinstance(n_id, (list, tuple)), "n_id must be also a list."
    165             assert (len(n_rnti)==len(n_id)), \
    166                                 "n_id and n_rnti must be of same length."
    167             self._n_rnti = n_rnti
    168             self._n_id = n_id
    169         else:
    170             self._n_rnti = [n_rnti]
    171             self._n_id = [n_id]
    172 
    173         for idx, n in enumerate(self._n_rnti):
    174             assert(n%1==0), "n_rnti must be int."
    175             self._n_rnti[idx] = int(n)
    176         for idx, n in enumerate(self._n_id):
    177             assert(n%1==0), "n_id must be int."
    178             self._n_id[idx] = int(n)
    179 
    180         self._num_tx = len(self._n_id)
    181 
    182         tbconfig = calculate_tb_size(target_tb_size=self._target_tb_size,
    183                                      num_coded_bits=self._num_coded_bits,
    184                                      target_coderate=self._target_coderate,
    185                                      modulation_order=self._num_bits_per_symbol,
    186                                      num_layers=self._num_layers,
    187                                      verbose=verbose)
    188         self._tb_size = tbconfig[0]
    189         self._cb_size = tbconfig[1]
    190         self._num_cbs = tbconfig[2]
    191         self._cw_lengths = tbconfig[3]
    192         self._tb_crc_length = tbconfig[4]
    193         self._cb_crc_length = tbconfig[5]
    194 
    195         assert self._tb_size <= self._tb_crc_length + np.sum(self._cw_lengths),\
    196             "Invalid TB parameters."
    197 
    198         # due to quantization, the tb_size can slightly differ from the
    199         # target tb_size.
    200         self._k_padding = self._tb_size - self._target_tb_size
    201         if self._tb_size != self._target_tb_size:
    202             print(f"Note: actual tb_size={self._tb_size} is slightly "\
    203                   f"different than requested " \
    204                   f"target_tb_size={self._target_tb_size} due to "\
    205                   f"quantization. Internal zero padding will be applied.")
    206 
    207         # calculate effective coderate (incl. CRC)
    208         self._coderate = self._tb_size / self._num_coded_bits
    209 
    210         # Remark: CRC16 is only used for k<3824 (otherwise CRC24)
    211         if self._tb_crc_length==16:
    212             self._tb_crc_encoder = CRCEncoder("CRC16")
    213         else:
    214             # CRC24A as defined in 7.2.1
    215             self._tb_crc_encoder = CRCEncoder("CRC24A")
    216 
    217         # CB CRC only if more than one CB is used
    218         if self._cb_crc_length==24:
    219             self._cb_crc_encoder = CRCEncoder("CRC24B")
    220         else:
    221             self._cb_crc_encoder = None
    222 
    223         # scrambler can be deactivated (non-standard compliant)
    224         if self._use_scrambler:
    225             self._scrambler = TB5GScrambler(n_rnti=self._n_rnti,
    226                                             n_id=self._n_id,
    227                                             binary=True,
    228                                             channel_type=channel_type,
    229                                             codeword_index=codeword_index,
    230                                             dtype=tf.float32,)
    231         else: # required for TBDecoder
    232             self._scrambler = None
    233 
    234         # ---- Init LDPC encoder ----
    235         # remark: as the codeword length can be (slightly) different
    236         # within a TB due to rounding, we initialize the encoder
    237         # with the max length and apply puncturing if required.
    238         # Thus, also the output interleaver cannot be applied in the encoder.
    239         # The procedure is defined in in 5.4.2.1 38.212
    240         self._encoder = LDPC5GEncoder(self._cb_size,
    241                                       np.max(self._cw_lengths),
    242                                       num_bits_per_symbol=1) #deact. interleaver
    243 
    244         # ---- Init interleaver ----
    245         # remark: explicit interleaver is required as the rate matching from
    246         # Sec. 5.4.2.1 38.212 could otherwise not be applied here
    247         perm_seq_short, _ = self._encoder.generate_out_int(
    248                                             np.min(self._cw_lengths),
    249                                             num_bits_per_symbol)
    250         perm_seq_long, _ = self._encoder.generate_out_int(
    251                                             np.max(self._cw_lengths),
    252                                             num_bits_per_symbol)
    253 
    254         perm_seq = []
    255         perm_seq_punc = []
    256 
    257         # define one big interleaver that moves the punctured positions to the
    258         # end of the TB
    259         payload_bit_pos = 0 # points to current pos of payload bits
    260 
    261         for l in self._cw_lengths:
    262             if np.min(self._cw_lengths)==l:
    263                 perm_seq = np.concatenate([perm_seq,
    264                                            perm_seq_short + payload_bit_pos])
    265                 # move unused bit positions to the end of TB
    266                 # this simplifies the inverse permutation
    267                 r = np.arange(payload_bit_pos+np.min(self._cw_lengths),
    268                               payload_bit_pos+np.max(self._cw_lengths))
    269                 perm_seq_punc = np.concatenate([perm_seq_punc, r])
    270 
    271                 # update pointer
    272                 payload_bit_pos += np.max(self._cw_lengths)
    273             elif np.max(self._cw_lengths)==l:
    274                 perm_seq = np.concatenate([perm_seq,
    275                                            perm_seq_long + payload_bit_pos])
    276                 # update pointer
    277                 payload_bit_pos += l
    278             else:
    279                 raise ValueError("Invalid cw_lengths.")
    280 
    281         # add punctured positions to end of sequence (only relevant for
    282         # deinterleaving)
    283         perm_seq = np.concatenate([perm_seq, perm_seq_punc])
    284 
    285         self._output_perm = tf.constant(perm_seq, tf.int32)
    286         self._output_perm_inv = tf.argsort(perm_seq, axis=-1)
    287 
    288     #########################################
    289     # Public methods and properties
    290     #########################################
    291 
    292     @property
    293     def tb_size(self):
    294         r"""Effective number of information bits per TB.
    295         Note that (if required) internal zero padding can be applied to match
    296         the request exact ``target_tb_size``."""
    297         return self._tb_size
    298 
    299     @property
    300     def k(self):
    301         r"""Number of input information bits. Equals `tb_size` except for zero
    302         padding of the last positions if the ``target_tb_size`` is quantized."""
    303         return self._target_tb_size
    304 
    305     @property
    306     def k_padding(self):
    307         """Number of zero padded bits at the end of the TB."""
    308         return self._k_padding
    309 
    310     @property
    311     def n(self):
    312         "Total number of output bits."
    313         return self._num_coded_bits
    314 
    315     @property
    316     def num_cbs(self):
    317         "Number code blocks."
    318         return self._num_cbs
    319 
    320     @property
    321     def coderate(self):
    322         """Effective coderate of the TB after rate-matching including overhead
    323         for the CRC."""
    324         return self._coderate
    325 
    326     @property
    327     def ldpc_encoder(self):
    328         """LDPC encoder used for TB encoding."""
    329         return self._encoder
    330 
    331     @property
    332     def scrambler(self):
    333         """Scrambler used for TB scrambling. `None` if no scrambler is used."""
    334         return self._scrambler
    335 
    336     @property
    337     def tb_crc_encoder(self):
    338         """TB CRC encoder"""
    339         return self._tb_crc_encoder
    340 
    341     @property
    342     def cb_crc_encoder(self):
    343         """CB CRC encoder. `None` if no CB CRC is applied."""
    344         return self._cb_crc_encoder
    345 
    346     @property
    347     def num_tx(self):
    348         """Number of independent streams"""
    349         return self._num_tx
    350 
    351     @property
    352     def cw_lengths(self):
    353         r"""Each list element defines the codeword length of each of the
    354         codewords after LDPC encoding and rate-matching. The total number of
    355         coded bits is :math:`\sum` `cw_lengths`."""
    356         return self._cw_lengths
    357 
    358     @property
    359     def output_perm_inv(self):
    360         r"""Inverse interleaver pattern for output bit interleaver."""
    361         return self._output_perm_inv
    362 
    363     #########################
    364     # Keras layer functions
    365     #########################
    366 
    367     def build(self, input_shapes):
    368         """Test input shapes for consistency."""
    369 
    370         assert input_shapes[-1]==self.k, \
    371             f"Invalid input shape. Expected TB length is {self.k}."
    372 
    373     def call(self, inputs):
    374         """Apply transport block encoding procedure."""
    375 
    376         # store shapes
    377         input_shape = inputs.shape.as_list()
    378         u = tf.cast(inputs, tf.float32)
    379 
    380         # apply zero padding if tb_size is slightly different to target_tb_size
    381         if self._k_padding>0:
    382             s = tf.shape(u)
    383             s = tf.concat((s[:-1], [self._k_padding]), axis=0)
    384             u = tf.concat((u, tf.zeros(s, u.dtype)), axis=-1)
    385 
    386         # apply TB CRC
    387         u_crc = self._tb_crc_encoder(u)
    388 
    389         # CB segmentation
    390         u_cb = tf.reshape(u_crc,
    391                           (-1, self._num_tx, self._num_cbs,
    392                           self._cb_size-self._cb_crc_length))
    393 
    394         # if relevant apply CB CRC
    395         if self._cb_crc_length==24:
    396             u_cb_crc = self._cb_crc_encoder(u_cb)
    397         else:
    398             u_cb_crc = u_cb # no CRC applied if only one CB exists
    399 
    400         c_cb = self._encoder(u_cb_crc)
    401 
    402         # CB concatenation
    403         c = tf.reshape(c_cb,
    404                        (-1, self._num_tx,
    405                        self._num_cbs*np.max(self._cw_lengths)))
    406 
    407         # apply interleaver (done after CB concatenation)
    408         c = tf.gather(c, self._output_perm, axis=-1)
    409         # puncture last bits
    410         c = c[:, :, :np.sum(self._cw_lengths)]
    411 
    412         # scrambler
    413         if self._use_scrambler:
    414             c_scr = self._scrambler(c)
    415         else: # disable scrambler (non-standard compliant)
    416             c_scr = c
    417 
    418         # cast to output dtype
    419         c_scr = tf.cast(c_scr, self.dtype)
    420 
    421         # ensure output shapes
    422         output_shape = input_shape
    423         output_shape[0] = -1
    424         output_shape[-1] = np.sum(self._cw_lengths)
    425         c_tb = tf.reshape(c_scr, output_shape)
    426 
    427         return c_tb