decoding.py (16703B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for Turbo Decoding.""" 6 7 import numpy as np 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 from sionna.fec import interleaving 11 from sionna.fec.conv.decoding import BCJRDecoder 12 from sionna.fec.conv.utils import Trellis 13 from sionna.fec.turbo.utils import TurboTermination, polynomial_selector, puncture_pattern 14 15 class TurboDecoder(Layer): 16 # pylint: disable=line-too-long 17 r"""TurboDecoder(encoder=None, gen_poly=None, rate=1/3, constraint_length=None, interleaver='3GPP', terminate=False, num_iter=6, hard_out=True, algorithm='map', output_dtype=tf.float32,**kwargs) 18 19 Turbo code decoder based on BCJR component decoders [Berrou]_. 20 Takes as input LLRs and returns LLRs or hard decided bits, i.e., an 21 estimate of the information tensor. 22 23 This decoder is based on the :class:`~sionna.fec.conv.decoding.BCJRDecoder` 24 and, thus, internally instantiates two 25 :class:`~sionna.fec.conv.decoding.BCJRDecoder` layers. 26 27 The class inherits from the Keras layer class and can be used as layer in 28 a Keras model. 29 30 Parameters 31 ---------- 32 encoder: :class:`~sionna.fec.turbo.encoding.TurboEncoder` 33 If ``encoder`` is provided as input, the following input parameters 34 are not required and will be ignored: `gen_poly`, `rate`, 35 `constraint_length`, `terminate`, `interleaver`. They will be inferred 36 from the ``encoder`` object itself. 37 If ``encoder`` is `None`, the above parameters must be provided 38 explicitly. 39 40 gen_poly: tuple 41 Tuple of strings with each string being a 0, 1 sequence. If `None`, 42 ``rate`` and ``constraint_length`` must be provided. 43 44 rate: float 45 Rate of the Turbo code. Valid values are 1/3 and 1/2. Note that 46 ``gen_poly``, if provided, is used to encode the underlying 47 convolutional code, which traditionally has rate 1/2. 48 49 constraint_length: int 50 Valid values are between 3 and 6 inclusive. Only required if 51 ``encoder`` and ``gen_poly`` are `None`. 52 53 interleaver: str 54 `"3GPP"` or `"Random"`. If `"3GPP"`, the internal interleaver for Turbo 55 codes as specified in [3GPPTS36212_Turbo]_ will be used. Only required 56 if ``encoder`` is `None`. 57 58 terminate: bool 59 If `True`, the two underlying convolutional encoders are assumed 60 to have terminated to all zero state. 61 62 num_iter: int 63 Number of iterations for the Turbo decoding to run. Each iteration of 64 Turbo decoding entails one BCJR decoder for each of the underlying 65 convolutional code components. 66 67 hard_out: boolean 68 Defaults to `True` and indicates whether to output hard or soft 69 decisions on the decoded information vector. `True` implies a hard- 70 decoded information vector of 0/1's is output. `False` implies 71 decoded LLRs of the information is output. 72 73 algorithm: str 74 Defaults to `map`. Indicates the implemented BCJR algorithm, 75 where `map` denotes the exact MAP algorithm, `log` indicates the 76 exact MAP implementation, but in log-domain, and 77 `maxlog` indicates the approximated MAP implementation in log-domain, 78 where :math:`\log(e^{a}+e^{b}) \sim \max(a,b)`. 79 80 output_dtype: tf.DType 81 Defaults to `tf.float32`. Defines the output datatype of the layer. 82 83 Input 84 ----- 85 inputs: tf.float32 86 2+D tensor of shape `[...,n]` containing the (noisy) channel 87 output symbols where `n` is the codeword length 88 89 Output 90 ------ 91 : tf.float32 92 2+D tensor of shape `[...,coderate*n]` containing the estimates of the 93 information bit tensor 94 95 Note 96 ---- 97 For decoding, input `logits` defined as 98 :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for 99 compatibility with the rest of Sionna. Internally, 100 log-likelihood ratios (LLRs) with definition 101 :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. 102 """ 103 104 def __init__(self, 105 encoder=None, 106 gen_poly=None, 107 rate=1/3, 108 constraint_length=None, 109 interleaver='3GPP', 110 terminate=False, 111 num_iter=6, 112 hard_out=True, 113 algorithm='map', 114 output_dtype=tf.float32, 115 **kwargs): 116 117 super().__init__(**kwargs) 118 if encoder is not None: 119 self._coderate = encoder._coderate 120 self._gen_poly = encoder._gen_poly 121 self._terminate = encoder.terminate 122 self._trellis = encoder.trellis 123 assert self._trellis.rsc is True 124 self.rsc = True 125 self.internal_interleaver = encoder.internal_interleaver 126 else: 127 if gen_poly is not None: 128 assert all(isinstance(poly, str) for poly in gen_poly), \ 129 "Each polynomial must be a string." 130 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \ 131 "Each polynomial must be of same length." 132 assert all(all( 133 char in ['0','1'] for char in poly) for poly in gen_poly),\ 134 "Each polynomial must be a string of 0's and 1's." 135 self._gen_poly = gen_poly 136 else: 137 valid_constraint_length = (3, 4, 5, 6) 138 assert constraint_length in valid_constraint_length, \ 139 "Constraint length must be between 3 and 6." 140 self._gen_poly = polynomial_selector(constraint_length) 141 142 valid_rates = (1/2, 1/3) 143 assert rate in valid_rates 144 self._coderate = rate 145 146 tf.debugging.assert_type(terminate, tf.bool) 147 self._terminate = terminate 148 149 assert interleaver in ('3GPP', 'random') 150 if interleaver == '3GPP': 151 self.internal_interleaver = interleaving.Turbo3GPPInterleaver() 152 else: 153 self.internal_interleaver = interleaving.RandomInterleaver( 154 keep_batch_constant=True, 155 keep_state=True, 156 axis=-1) 157 158 self.rsc = True 159 self._trellis = Trellis(self._gen_poly, rsc=self.rsc) 160 161 assert isinstance(hard_out, bool), 'hard_out must be bool.' 162 163 self._coderate_conv = 1/len(self._gen_poly) 164 self._mu = len(self._gen_poly[0])-1 165 self.punct_pattern = puncture_pattern(self._coderate, 166 self._coderate_conv) 167 168 # num. of input bit streams, only 1 in current implementation 169 self._conv_k = self._trellis.conv_k 170 self._mu = self._trellis._mu 171 # num. of output bits for conv_k input bits 172 self._conv_n = self._trellis.conv_n 173 self._ni = 2**self._conv_k 174 self._no = 2**self._conv_n 175 self._ns = self._trellis.ns 176 177 assert self._conv_k == 1 178 assert self._conv_n == 2 179 180 self._k = None # Length of Info-bit vector 181 self._n = None # Length of Turbo codeword, including termination bits 182 183 if self._terminate: 184 self.turbo_term = TurboTermination(self._mu+1, 185 conv_n=self._conv_n) 186 self._num_term_bits = 3 * self.turbo_term.get_num_term_syms() 187 else: 188 self._num_term_bits = 0 189 190 self._output_dtype = output_dtype 191 self.num_iter = num_iter 192 self._hard_out = hard_out 193 194 self.bcjrdecoder = BCJRDecoder(gen_poly=self._gen_poly, 195 rsc=self.rsc, 196 hard_out=False, 197 terminate=self._terminate, 198 algorithm=algorithm) 199 200 ######################################### 201 # Public methods and properties 202 ######################################### 203 204 @property 205 def gen_poly(self): 206 """Generator polynomial used by the encoder""" 207 return self._gen_poly 208 209 @property 210 def constraint_length(self): 211 """Constraint length of the encoder""" 212 return self._mu + 1 213 214 @property 215 def coderate(self): 216 """Rate of the code used in the encoder""" 217 return self._coderate 218 219 @property 220 def trellis(self): 221 """Trellis object used during encoding""" 222 return self._trellis 223 224 @property 225 def k(self): 226 """Number of information bits per codeword""" 227 if self._k is None: 228 print("Note: The value of k cannot be computed before the first " \ 229 "call().") 230 return self._k 231 232 @property 233 def n(self): 234 """Number of codeword bits""" 235 if self._n is None: 236 print("Note: The value of n cannot be computed before the first " \ 237 "call().") 238 return self._n 239 240 ######################### 241 # Utility functions 242 ######################### 243 244 def depuncture(self, y): 245 """ 246 Given a tensor `y` of shape `[batch, n]`, depuncture() scatters `y` 247 elements into shape `[batch, 3*rate*n]` where the 248 extra elements are filled with 0. 249 250 For e.g., if input is `y`, rate is 1/2 and 251 `punct_pattern` is [1, 1, 0, 1, 0, 1], then the 252 output is [y[0], y[1], 0., y[2], 0., y[3], y[4], y[5], 0., ... ,]. 253 """ 254 255 y_depunct = tf.scatter_nd(self._punct_indices, 256 tf.transpose(y), 257 shape=(self._depunct_len, tf.shape(y)[0])) 258 y_depunct = tf.transpose(y_depunct) 259 return y_depunct 260 261 def _convenc_cws(self, y_turbo): 262 """ 263 _convenc_cws() re-arranges Turbo Codeword to the two Convolutional 264 codewords format. 265 Given the channel output of a Turbo codeword y_turbo, this method 266 re-arranges y_turbo such that y1_cw contains the symbols corresponding 267 to Conv. Encoder 1 & similarly y2_cw contains the symbols corresponding 268 to Conv. Encoder 2 269 """ 270 y_turbo = self.depuncture(y_turbo) 271 prepunct_n = int(self._n * 3 * self._coderate) 272 273 # Separate Pre-termination & Termination parts of Y 274 msg_idx = tf.range(0, prepunct_n - self._num_term_bits) 275 term_idx = tf.range(prepunct_n-self._num_term_bits, prepunct_n) 276 277 # Pre-termination & Termination parts of Y 278 y_cw = tf.gather(y_turbo, msg_idx, axis=-1) 279 y_term = tf.gather(y_turbo, term_idx, axis=-1) 280 281 # Gather Encoder1 corresp. from Y (pre-termination part) 282 enc1_sys_idx = tf.expand_dims(tf.range(0, self._k*3, delta=3), 1) 283 enc1_cw_idx = tf.stack([enc1_sys_idx, enc1_sys_idx+1], axis=1) 284 enc1_cw_idx = tf.squeeze(tf.reshape(enc1_cw_idx, (-1, 2*self._k))) 285 y1_cw = tf.gather(y_cw, enc1_cw_idx, axis=-1) 286 287 # Gather systematic part of codeword from encoder1 & Inverse-interleave 288 y1_sys_cw = tf.gather(y_cw, enc1_sys_idx, axis=-1) 289 y2_sys_cw = self.internal_interleaver( 290 tf.squeeze(y1_sys_cw, -1))[:,:,None] 291 292 # Using above, gather Encoder2 corresp. from Y (pre-termination part) 293 y2_nonsys_cw = tf.gather(y_cw, enc1_sys_idx+2, axis=-1) 294 y2_cw = tf.squeeze(tf.stack([y2_sys_cw, y2_nonsys_cw], axis=-2)) 295 y2_cw = tf.reshape(y2_cw, [-1, 2*self._k]) 296 297 # Separate Termination bits to encoders 1 & 2 298 if self._terminate: 299 term_vec1, term_vec2 = self.turbo_term.term_bits_turbo2conv(y_term) 300 y1_cw = tf.concat([y1_cw, term_vec1],axis=1) 301 y2_cw = tf.concat([y2_cw, term_vec2],axis=-1) 302 return y1_cw, y2_cw 303 304 ######################### 305 # Keras layer functions 306 ######################### 307 308 def build(self, input_shape): 309 """Build layer and check dimensions.""" 310 # assert rank must be two 311 tf.debugging.assert_greater_equal(len(input_shape), 2) 312 313 self._n = input_shape[-1] 314 if self.coderate == 1/2: 315 assert self._n%2 == 0, "Codeword length should be a multiple of 2" 316 317 codefactor = self.coderate * 3 318 turbo_n = int(self._n * codefactor) 319 turbo_n_preterm = turbo_n - self._num_term_bits 320 assert turbo_n_preterm%3 == 0, "Invalid codeword length for a terminated Turbo code" 321 322 self._k = int(turbo_n_preterm/3) 323 324 # num of symbols for the convolutional codes. 325 self._convenc_numsyms = self._k 326 if self._terminate: 327 self._convenc_numsyms += self._mu 328 329 # generate puncturing mask 330 rate_factor = 3. * self._coderate 331 332 self._depunct_len = int(rate_factor * self._n) 333 punct_size = np.prod(self.punct_pattern.get_shape().as_list()) 334 rep_times = int(self._depunct_len//punct_size) 335 336 mask_ = tf.tile(self.punct_pattern, [rep_times, 1]) 337 extra_bits = int(self._depunct_len - rep_times*punct_size) 338 if extra_bits > 0: 339 extra_periods = int(extra_bits/3) 340 mask_ = tf.concat([mask_, self.punct_pattern[:extra_periods,:]], 341 axis=0) 342 343 mask_ = tf.squeeze(tf.reshape(mask_, (-1, ))) 344 self._punct_indices = tf.cast(tf.where(mask_), tf.int32) 345 346 def call(self, inputs): 347 """ 348 Decoder for Turbo code. 349 350 Runs BCJR decoder on both the constituent convolutional codes 351 iteratively `num_iter` times. At the end, the resultant LLRs are 352 computed and the decoded message vector (termination bits are 353 excluded) is output. 354 """ 355 356 llr_max = 20. 357 tf.debugging.assert_type(inputs, tf.float32, 358 message="input must be tf.float32.") 359 360 output_shape = inputs.get_shape().as_list() 361 362 # allow different codeword lengths in eager mode 363 if output_shape[-1] != self._n: 364 self.build(output_shape) 365 366 llr_ch = tf.reshape(inputs, [-1, self._n]) 367 368 output_shape[0] = -1 369 output_shape[-1] = self._k # assign k to the last dimension 370 371 # llr's inside TurboDecoder are not sign-inverted after input, 372 # unlike BCJR & LDPC decoders. They represent P(x=1)/P(x=0) as 373 # convention in Sionna. 374 y1_cw, y2_cw = self._convenc_cws(llr_ch) 375 376 sys_idx = tf.expand_dims(tf.range(0, self._k*2, delta=2), 1) 377 llr_ch = tf.gather(y1_cw, sys_idx, axis=-1) 378 llr_ch = tf.squeeze(llr_ch, -1) 379 llr_ch2 = tf.gather(y2_cw, sys_idx, axis=-1) 380 llr_ch2 = tf.squeeze(llr_ch2, -1) 381 382 llr_1e = tf.zeros((tf.shape(llr_ch)[0], self._convenc_numsyms), 383 dtype=tf.float32) 384 # define zero LLR's for termination info bits 385 term_info_bits = self._mu if self._terminate else 0 386 llr_terminfo = tf.zeros( 387 (tf.shape(llr_ch)[0], term_info_bits), tf.float32) 388 389 # needs to be initialized for XLA before entering the loop 390 llr_2i = tf.zeros_like(llr_ch2) 391 392 # run decoding loop 393 for _ in tf.range(self.num_iter): 394 395 # run 1st component decoder 396 llr_1i = self.bcjrdecoder((y1_cw, llr_1e)) 397 llr_1i = llr_1i[...,:self._k] 398 llr_extr = llr_1i - llr_ch - llr_1e[...,:self._k] 399 #llr_extr = llr_1i - llr_1e[...,:self._k] 400 401 llr_2e = self.internal_interleaver(llr_extr) 402 llr_2e = tf.concat([llr_2e, llr_terminfo], axis=-1) 403 llr_2e = tf.clip_by_value(llr_2e, 404 clip_value_min=-llr_max, 405 clip_value_max=llr_max) 406 # run 2nd component decoder 407 llr_2i = self.bcjrdecoder((y2_cw, llr_2e)) 408 llr_2i = llr_2i[...,:self._k] 409 llr_extr = llr_2i - llr_2e[...,:self._k] - llr_ch2 410 #llr_extr = llr_2i - llr_2e[...,:self._k] 411 412 llr_1e = self.internal_interleaver.call_inverse(llr_extr) 413 414 llr_1e = tf.clip_by_value(llr_1e, 415 clip_value_min=-llr_max, 416 clip_value_max=llr_max) 417 418 llr_1e = tf.concat([llr_1e, llr_terminfo], axis=-1) 419 420 # use latest output of 2nd decoder 421 output = self.internal_interleaver.call_inverse(llr_2i) 422 423 if self._hard_out: # hard decide decoder output if required 424 output = tf.less(0.0, output) 425 output = tf.cast(output, self._output_dtype) 426 427 output_reshaped = tf.reshape(output, output_shape) 428 return output_reshaped