anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

tb_decoder.py (7229B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Transport block decoding functions for the 5g NR sub-package of Sionna.
      6 """
      7 
      8 import numpy as np
      9 import tensorflow as tf
     10 from tensorflow.keras.layers import Layer
     11 from sionna.fec.crc import CRCDecoder
     12 from sionna.fec.scrambling import  Descrambler
     13 from sionna.fec.ldpc import LDPC5GDecoder
     14 from sionna.nr import TBEncoder
     15 
     16 class TBDecoder(Layer):
     17     # pylint: disable=line-too-long
     18     r"""TBDecoder(encoder, num_bp_iter=20, cn_type="boxplus-phi", output_dtype=tf.float32, **kwargs)
     19     5G NR transport block (TB) decoder as defined in TS 38.214
     20     [3GPP38214]_.
     21 
     22     The transport block decoder takes as input a sequence of noisy channel
     23     observations and reconstructs the corresponding `transport block` of
     24     information bits. The detailed procedure is described in TS 38.214
     25     [3GPP38214]_ and TS 38.211 [3GPP38211]_.
     26 
     27     The class inherits from the Keras layer class and can be used as layer in a
     28     Keras model.
     29 
     30     Parameters
     31     ----------
     32         encoder : :class:`~sionna.nr.TBEncoder`
     33             Associated transport block encoder used for encoding of the signal.
     34 
     35         num_bp_iter : int, 20 (default)
     36             Number of BP decoder iterations
     37 
     38         cn_type : str, "boxplus-phi" (default) | "boxplus" | "minsum"
     39             The check node processing function of the LDPC BP decoder.
     40             One of {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where
     41             '"boxplus"' implements the single-parity-check APP decoding rule.
     42             '"boxplus-phi"' implements the numerical more stable version of
     43             boxplus [Ryan]_.
     44             '"minsum"' implements the min-approximation of the CN update rule
     45             [Ryan]_.
     46 
     47         output_dtype : tf.float32 (default)
     48             Defines the datatype for internal calculations and the output dtype.
     49 
     50     Input
     51     -----
     52         inputs : [...,num_coded_bits], tf.float
     53             2+D tensor containing channel logits/llr values of the (noisy)
     54             channel observations.
     55 
     56     Output
     57     ------
     58         b_hat : [...,target_tb_size], tf.float
     59             2+D tensor containing hard decided bit estimates of all information
     60             bits of the transport block.
     61 
     62         tb_crc_status : [...], tf.bool
     63             Transport block CRC status indicating if a transport block was
     64             (most likely) correctly recovered. Note that false positives are
     65             possible.
     66     """
     67 
     68     def __init__(self,
     69                  encoder,
     70                  num_bp_iter=20,
     71                  cn_type="boxplus-phi",
     72                  output_dtype=tf.float32,
     73                  **kwargs):
     74 
     75         super().__init__(dtype=output_dtype, **kwargs)
     76 
     77         assert output_dtype in (tf.float16, tf.float32, tf.float64), \
     78                 "output_dtype must be (tf.float16, tf.float32, tf.float64)."
     79 
     80         assert isinstance(encoder, TBEncoder), "encoder must be TBEncoder."
     81         self._tb_encoder = encoder
     82 
     83         self._num_cbs = encoder.num_cbs
     84 
     85         # init BP decoder
     86         self._decoder = LDPC5GDecoder(encoder=encoder.ldpc_encoder,
     87                                       num_iter=num_bp_iter,
     88                                       cn_type=cn_type,
     89                                       hard_out=True, # TB operates on bit-level
     90                                       return_infobits=True,
     91                                       output_dtype=output_dtype)
     92 
     93         # init descrambler
     94         if encoder.scrambler is not None:
     95             self._descrambler = Descrambler(encoder.scrambler,
     96                                             binary=False)
     97         else:
     98             self._descrambler = None
     99 
    100         # init CRC Decoder for CB and TB
    101         self._tb_crc_decoder = CRCDecoder(encoder.tb_crc_encoder)
    102 
    103         if encoder.cb_crc_encoder is not None:
    104             self._cb_crc_decoder = CRCDecoder(encoder.cb_crc_encoder)
    105         else:
    106             self._cb_crc_decoder = None
    107 
    108     #########################################
    109     # Public methods and properties
    110     #########################################
    111 
    112     @property
    113     def tb_size(self):
    114         """Number of information bits per TB."""
    115         return self._tb_encoder.tb_size
    116 
    117     # required for
    118     @property
    119     def k(self):
    120         """Number of input information bits. Equals TB size."""
    121         return self._tb_encoder.tb_size
    122 
    123     @property
    124     def n(self):
    125         "Total number of output codeword bits."
    126         return self._tb_encoder.n
    127 
    128     #########################
    129     # Keras layer functions
    130     #########################
    131 
    132     def build(self, input_shapes):
    133         """Test input shapes for consistency."""
    134 
    135         assert input_shapes[-1]==self.n, \
    136             f"Invalid input shape. Expected input length is {self.n}."
    137 
    138     def call(self, inputs):
    139         """Apply transport block decoding."""
    140 
    141         # store shapes
    142         input_shape = inputs.shape.as_list()
    143         llr_ch = tf.cast(inputs, tf.float32)
    144 
    145         llr_ch = tf.reshape(llr_ch,
    146                             (-1, self._tb_encoder.num_tx, self._tb_encoder.n))
    147 
    148         # undo scrambling (only if scrambler was used)
    149         if self._descrambler is not None:
    150             llr_scr = self._descrambler(llr_ch)
    151         else:
    152             llr_scr = llr_ch
    153 
    154         # undo CB interleaving and puncturing
    155         num_fillers = self._tb_encoder.ldpc_encoder.n * self._tb_encoder.num_cbs - np.sum(self._tb_encoder.cw_lengths)
    156         llr_int = tf.concat([llr_scr,
    157                             tf.zeros([tf.shape(llr_scr)[0], self._tb_encoder.num_tx, num_fillers])], axis=-1)
    158         llr_int = tf.gather(llr_int, self._tb_encoder.output_perm_inv, axis=-1)
    159 
    160         # undo CB concatenation
    161         llr_cb = tf.reshape(llr_int,
    162                         (-1, self._tb_encoder.num_tx, self._num_cbs, self._tb_encoder.ldpc_encoder.n))
    163 
    164         # LDPC decoding
    165         u_hat_cb = self._decoder(llr_cb)
    166 
    167         # CB CRC removal (if relevant)
    168         if self._cb_crc_decoder is not None:
    169             # we are ignoring the CB CRC status for the moment
    170             # Could be combined with the TB CRC for even better estimates
    171             u_hat_cb_crc, _ = self._cb_crc_decoder(u_hat_cb)
    172         else:
    173             u_hat_cb_crc = u_hat_cb
    174 
    175         # undo CB segmentation
    176         u_hat_tb = tf.reshape(u_hat_cb_crc,
    177                 (-1, self._tb_encoder.num_tx, self.tb_size+self._tb_encoder.tb_crc_encoder.crc_length))
    178 
    179         # TB CRC removal
    180         u_hat, tb_crc_status = self._tb_crc_decoder(u_hat_tb)
    181 
    182         # restore input shape
    183         output_shape = input_shape
    184         output_shape[0] = -1
    185         output_shape[-1] = self.tb_size
    186         u_hat = tf.reshape(u_hat, output_shape)
    187         # also apply to tb_crc_status
    188         output_shape[-1] = 1 # but last dim is 1
    189         tb_crc_status = tf.reshape(tb_crc_status, output_shape)
    190 
    191         # remove if zero-padding was applied
    192         if self._tb_encoder.k_padding>0:
    193             u_hat = u_hat[...,:-self._tb_encoder.k_padding]
    194 
    195         # cast to output dtype
    196         u_hat = tf.cast(u_hat, self.dtype)
    197         tb_crc_status = tf.squeeze(tf.cast(tb_crc_status, tf.bool), axis=-1)
    198 
    199         return u_hat, tb_crc_status