tb_decoder.py (7229B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Transport block decoding functions for the 5g NR sub-package of Sionna. 6 """ 7 8 import numpy as np 9 import tensorflow as tf 10 from tensorflow.keras.layers import Layer 11 from sionna.fec.crc import CRCDecoder 12 from sionna.fec.scrambling import Descrambler 13 from sionna.fec.ldpc import LDPC5GDecoder 14 from sionna.nr import TBEncoder 15 16 class TBDecoder(Layer): 17 # pylint: disable=line-too-long 18 r"""TBDecoder(encoder, num_bp_iter=20, cn_type="boxplus-phi", output_dtype=tf.float32, **kwargs) 19 5G NR transport block (TB) decoder as defined in TS 38.214 20 [3GPP38214]_. 21 22 The transport block decoder takes as input a sequence of noisy channel 23 observations and reconstructs the corresponding `transport block` of 24 information bits. The detailed procedure is described in TS 38.214 25 [3GPP38214]_ and TS 38.211 [3GPP38211]_. 26 27 The class inherits from the Keras layer class and can be used as layer in a 28 Keras model. 29 30 Parameters 31 ---------- 32 encoder : :class:`~sionna.nr.TBEncoder` 33 Associated transport block encoder used for encoding of the signal. 34 35 num_bp_iter : int, 20 (default) 36 Number of BP decoder iterations 37 38 cn_type : str, "boxplus-phi" (default) | "boxplus" | "minsum" 39 The check node processing function of the LDPC BP decoder. 40 One of {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where 41 '"boxplus"' implements the single-parity-check APP decoding rule. 42 '"boxplus-phi"' implements the numerical more stable version of 43 boxplus [Ryan]_. 44 '"minsum"' implements the min-approximation of the CN update rule 45 [Ryan]_. 46 47 output_dtype : tf.float32 (default) 48 Defines the datatype for internal calculations and the output dtype. 49 50 Input 51 ----- 52 inputs : [...,num_coded_bits], tf.float 53 2+D tensor containing channel logits/llr values of the (noisy) 54 channel observations. 55 56 Output 57 ------ 58 b_hat : [...,target_tb_size], tf.float 59 2+D tensor containing hard decided bit estimates of all information 60 bits of the transport block. 61 62 tb_crc_status : [...], tf.bool 63 Transport block CRC status indicating if a transport block was 64 (most likely) correctly recovered. Note that false positives are 65 possible. 66 """ 67 68 def __init__(self, 69 encoder, 70 num_bp_iter=20, 71 cn_type="boxplus-phi", 72 output_dtype=tf.float32, 73 **kwargs): 74 75 super().__init__(dtype=output_dtype, **kwargs) 76 77 assert output_dtype in (tf.float16, tf.float32, tf.float64), \ 78 "output_dtype must be (tf.float16, tf.float32, tf.float64)." 79 80 assert isinstance(encoder, TBEncoder), "encoder must be TBEncoder." 81 self._tb_encoder = encoder 82 83 self._num_cbs = encoder.num_cbs 84 85 # init BP decoder 86 self._decoder = LDPC5GDecoder(encoder=encoder.ldpc_encoder, 87 num_iter=num_bp_iter, 88 cn_type=cn_type, 89 hard_out=True, # TB operates on bit-level 90 return_infobits=True, 91 output_dtype=output_dtype) 92 93 # init descrambler 94 if encoder.scrambler is not None: 95 self._descrambler = Descrambler(encoder.scrambler, 96 binary=False) 97 else: 98 self._descrambler = None 99 100 # init CRC Decoder for CB and TB 101 self._tb_crc_decoder = CRCDecoder(encoder.tb_crc_encoder) 102 103 if encoder.cb_crc_encoder is not None: 104 self._cb_crc_decoder = CRCDecoder(encoder.cb_crc_encoder) 105 else: 106 self._cb_crc_decoder = None 107 108 ######################################### 109 # Public methods and properties 110 ######################################### 111 112 @property 113 def tb_size(self): 114 """Number of information bits per TB.""" 115 return self._tb_encoder.tb_size 116 117 # required for 118 @property 119 def k(self): 120 """Number of input information bits. Equals TB size.""" 121 return self._tb_encoder.tb_size 122 123 @property 124 def n(self): 125 "Total number of output codeword bits." 126 return self._tb_encoder.n 127 128 ######################### 129 # Keras layer functions 130 ######################### 131 132 def build(self, input_shapes): 133 """Test input shapes for consistency.""" 134 135 assert input_shapes[-1]==self.n, \ 136 f"Invalid input shape. Expected input length is {self.n}." 137 138 def call(self, inputs): 139 """Apply transport block decoding.""" 140 141 # store shapes 142 input_shape = inputs.shape.as_list() 143 llr_ch = tf.cast(inputs, tf.float32) 144 145 llr_ch = tf.reshape(llr_ch, 146 (-1, self._tb_encoder.num_tx, self._tb_encoder.n)) 147 148 # undo scrambling (only if scrambler was used) 149 if self._descrambler is not None: 150 llr_scr = self._descrambler(llr_ch) 151 else: 152 llr_scr = llr_ch 153 154 # undo CB interleaving and puncturing 155 num_fillers = self._tb_encoder.ldpc_encoder.n * self._tb_encoder.num_cbs - np.sum(self._tb_encoder.cw_lengths) 156 llr_int = tf.concat([llr_scr, 157 tf.zeros([tf.shape(llr_scr)[0], self._tb_encoder.num_tx, num_fillers])], axis=-1) 158 llr_int = tf.gather(llr_int, self._tb_encoder.output_perm_inv, axis=-1) 159 160 # undo CB concatenation 161 llr_cb = tf.reshape(llr_int, 162 (-1, self._tb_encoder.num_tx, self._num_cbs, self._tb_encoder.ldpc_encoder.n)) 163 164 # LDPC decoding 165 u_hat_cb = self._decoder(llr_cb) 166 167 # CB CRC removal (if relevant) 168 if self._cb_crc_decoder is not None: 169 # we are ignoring the CB CRC status for the moment 170 # Could be combined with the TB CRC for even better estimates 171 u_hat_cb_crc, _ = self._cb_crc_decoder(u_hat_cb) 172 else: 173 u_hat_cb_crc = u_hat_cb 174 175 # undo CB segmentation 176 u_hat_tb = tf.reshape(u_hat_cb_crc, 177 (-1, self._tb_encoder.num_tx, self.tb_size+self._tb_encoder.tb_crc_encoder.crc_length)) 178 179 # TB CRC removal 180 u_hat, tb_crc_status = self._tb_crc_decoder(u_hat_tb) 181 182 # restore input shape 183 output_shape = input_shape 184 output_shape[0] = -1 185 output_shape[-1] = self.tb_size 186 u_hat = tf.reshape(u_hat, output_shape) 187 # also apply to tb_crc_status 188 output_shape[-1] = 1 # but last dim is 1 189 tb_crc_status = tf.reshape(tb_crc_status, output_shape) 190 191 # remove if zero-padding was applied 192 if self._tb_encoder.k_padding>0: 193 u_hat = u_hat[...,:-self._tb_encoder.k_padding] 194 195 # cast to output dtype 196 u_hat = tf.cast(u_hat, self.dtype) 197 tb_crc_status = tf.squeeze(tf.cast(tb_crc_status, tf.bool), axis=-1) 198 199 return u_hat, tb_crc_status