encoding.py (16695B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for Turbo Code Encoding.""" 6 7 import math 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 from sionna.fec import interleaving 11 from sionna.fec.utils import bin2int_tf, int2bin_tf 12 from sionna.fec.conv.encoding import ConvEncoder 13 from sionna.fec.conv.utils import Trellis 14 from sionna.fec.turbo.utils import polynomial_selector, puncture_pattern, TurboTermination 15 16 class TurboEncoder(Layer): 17 # pylint: disable=line-too-long 18 r"""TurboEncoder(gen_poly=None, constraint_length=3, rate=1/3, terminate=False, interleaver_type='3GPP', output_dtype=tf.float32, **kwargs) 19 20 Performs encoding of information bits to a Turbo code codeword [Berrou]_. 21 Implements the standard Turbo code framework [Berrou]_: Two identical 22 rate-1/2 convolutional encoders :class:`~sionna.fec.conv.encoding.ConvEncoder` 23 are combined to produce a rate-1/3 Turbo code. Further, 24 puncturing to attain a rate-1/2 Turbo code is supported. 25 26 The class inherits from the Keras layer class and can be used as layer in a 27 Keras model. 28 29 Parameters 30 ---------- 31 gen_poly: tuple 32 Tuple of strings with each string being a 0,1 sequence. If 33 `None`, ``constraint_length`` must be provided. 34 35 constraint_length: int 36 Valid values are between 3 and 6 inclusive. Only required if 37 ``gen_poly`` is `None`. 38 39 rate: float 40 Valid values are 1/3 and 1/2. Note that ``rate`` here denotes 41 the `design` rate of the Turbo code. If ``terminate`` is `True`, a 42 small rate-loss occurs. 43 44 terminate: boolean 45 Underlying convolutional encoders are terminated to all zero state 46 if `True`. If terminated, the true rate of the code is slightly lower 47 than ``rate``. 48 49 interleaver_type: str 50 Valid values are `"3GPP"` or `"random"`. Determines the choice of 51 the interleaver to interleave the message bits before input to the 52 second convolutional encoder. If `"3GPP"`, the Turbo code interleaver 53 from the 3GPP LTE standard [3GPPTS36212_Turbo]_ is used. If `"random"`, 54 a random interleaver is used. 55 56 output_dtype: tf.DType 57 Defaults to `tf.float32`. Defines the output datatype of the layer. 58 59 Input 60 ----- 61 inputs : [...,k], tf.float32 62 2+D tensor of information bits where `k` is the information length 63 64 Output 65 ------ 66 : `[...,k/rate]`, tf.float32 67 2+D tensor where `rate` is provided as input 68 parameter. The output is the encoded codeword for the input 69 information tensor. When ``terminate`` is `True`, the effective rate 70 of the Turbo code is slightly less than ``rate``. 71 72 Note 73 ---- 74 Various notations are used in literature to represent the generator 75 polynomials for convolutional codes. For simplicity 76 :class:`~sionna.fec.turbo.encoding.TurboEncoder` only 77 accepts the binary format, i.e., `10011`, for the ``gen_poly`` argument 78 which corresponds to the polynomial :math:`1 + D^3 + D^4`. 79 80 Note that Turbo codes require the underlying convolutional encoders 81 to be recursive systematic encoders. Only then the channel output 82 from the systematic part of the first encoder can be used to decode 83 the second encoder. 84 85 Also note that ``constraint_length`` and ``memory`` are two different 86 terms often used to denote the strength of the convolutional code. In 87 this sub-package we use ``constraint_length``. For example, the polynomial 88 `10011` has a ``constraint_length`` of 5, however its ``memory`` is 89 only 4. 90 91 When ``terminate`` is `True`, the true rate of the Turbo code is 92 slightly lower than ``rate``. It can be computed as 93 :math:`\frac{k}{\frac{k}{r}+\frac{4\mu}{3r}}` where `r` denotes 94 ``rate`` and :math:`\mu` is the ``constraint_length`` - 1. For example, in 95 3GPP, ``constraint_length`` = 4, ``terminate`` = `True`, for 96 ``rate`` = 1/3, true rate is equal to :math:`\frac{k}{3k+12}` . 97 """ 98 99 def __init__(self, 100 gen_poly=None, 101 constraint_length=3, 102 rate=1/3, 103 terminate=False, 104 interleaver_type='3GPP', 105 output_dtype=tf.float32, 106 **kwargs): 107 108 super().__init__(**kwargs) 109 110 if gen_poly is not None: 111 assert all(isinstance(poly, str) for poly in gen_poly), \ 112 "Each element of gen_poly must be a string." 113 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \ 114 "Each polynomial must be of same length." 115 assert all(all( 116 char in ['0','1'] for char in poly) for poly in gen_poly),\ 117 "Each Polynomial must be a string of 0/1 s." 118 assert len(gen_poly)==2, \ 119 "Generator polynomials need to be of Rate-1/2 " 120 self._gen_poly = gen_poly 121 else: 122 valid_constraint_length = (3, 4, 5, 6) 123 assert constraint_length in valid_constraint_length, \ 124 "Constraint length must be between 3 and 6." 125 self._gen_poly = polynomial_selector(constraint_length) 126 127 valid_rates = (1/2, 1/3) 128 assert rate in valid_rates, "Invalid coderate." 129 assert isinstance(terminate, bool), "terminate must be bool." 130 assert interleaver_type in ('3GPP', 'random'),\ 131 "Invalid interleaver_type." 132 133 self._coderate_desired = rate 134 self._coderate = self._coderate_desired 135 self._terminate = terminate 136 self._interleaver_type = interleaver_type 137 self.output_dtype = output_dtype 138 # Underlying convolutional encoders to be rsc or not 139 rsc = True 140 141 self._coderate_conv = 1/len(self.gen_poly) 142 self._punct_pattern = puncture_pattern(rate, self._coderate_conv) 143 144 self._trellis = Trellis(self.gen_poly, rsc=rsc) 145 self._mu = self.trellis._mu 146 147 # conv_n denotes number of output bits for conv_k input bits. 148 self._conv_k = self._trellis.conv_k 149 self._conv_n = self._trellis.conv_n 150 151 self._ni = 2**self._conv_k 152 self._no = 2**self._conv_n 153 self._ns = self._trellis.ns 154 155 self._k = None 156 self._n = None 157 158 if self.terminate: 159 self.turbo_term = TurboTermination(self._mu+1, conv_n=self._conv_n) 160 161 if self._interleaver_type == '3GPP': 162 self.internal_interleaver = interleaving.Turbo3GPPInterleaver() 163 else: 164 self.internal_interleaver = interleaving.RandomInterleaver( 165 keep_batch_constant=True, 166 keep_state=True, 167 axis=-1) 168 169 if self.punct_pattern is not None: 170 self.punct_idx = tf.where(self.punct_pattern) 171 172 self.convencoder = ConvEncoder(gen_poly=self._gen_poly, 173 rsc=rsc, 174 terminate=self._terminate) 175 176 ######################################### 177 # Public methods and properties 178 ######################################### 179 180 @property 181 def gen_poly(self): 182 """Generator polynomial used by the encoder""" 183 return self._gen_poly 184 185 @property 186 def constraint_length(self): 187 """Constraint length of the encoder""" 188 return self._mu + 1 189 190 @property 191 def coderate(self): 192 """Rate of the code used in the encoder""" 193 if self.terminate and self._k is None: 194 print("Note that, due to termination, the true coderate is lower "\ 195 "than the returned design rate. "\ 196 "The exact true rate is dependent on the value of k and "\ 197 "hence cannot be computed before the first call().") 198 elif self.terminate and self._k is not None: 199 term_factor = 1+math.ceil(4*self._mu/3)/self._k 200 self._coderate = self._coderate_desired/term_factor 201 return self._coderate 202 203 @property 204 def trellis(self): 205 """Trellis object used during encoding""" 206 return self._trellis 207 208 @property 209 def terminate(self): 210 """Indicates if the convolutional encoders are terminated""" 211 return self._terminate 212 213 @property 214 def punct_pattern(self): 215 """Puncturing pattern for the Turbo codeword""" 216 return self._punct_pattern 217 218 @property 219 def k(self): 220 """Number of information bits per codeword""" 221 if self._k is None: 222 print("Note: The value of k cannot be computed before the first " \ 223 "call().") 224 return self._k 225 226 @property 227 def n(self): 228 """Number of codeword bits""" 229 if self._n is None: 230 print("Note: The value of n cannot be computed before the first " \ 231 "call().") 232 return self._n 233 234 def _conv_enc(self, info_vec, terminate): 235 """ 236 This method encodes the information tensor info_vec using the 237 underlying convolutional encoder. Returns the encoded codeword tensor 238 array ta, and the tensor array containing termination bits ta_term. 239 If the terminate variable is False, ta_term is array of length 0. 240 """ 241 msg = tf.cast(info_vec, tf.int32) 242 243 msg_reshaped = tf.reshape(msg, [-1, self._k]) 244 term_syms = int(self._mu) if terminate else 0 245 246 prev_st = tf.zeros([tf.shape(msg_reshaped)[0]], tf.int32) 247 ta = tf.TensorArray(tf.int32, size=self.num_syms, dynamic_size=False) 248 249 idx_offset = range(0, self._conv_k) 250 for idx in tf.range(0, self._k, self._conv_k): 251 msg_bits_idx = tf.gather(msg_reshaped, 252 idx + idx_offset, 253 axis=-1) 254 255 #msg_bits_idx = tf.experimental.numpy.take_along_axis(msg_reshaped) 256 257 msg_idx = bin2int_tf(msg_bits_idx) 258 259 indices = tf.stack([prev_st, msg_idx], -1) 260 new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices) 261 262 idx_syms = tf.gather_nd(self._trellis.op_mat, 263 tf.stack([prev_st, new_st], -1)) 264 idx_bits = int2bin_tf(idx_syms, self._conv_n) 265 ta = ta.write(idx//self._conv_k, idx_bits) 266 prev_st = new_st 267 268 ta_term = tf.TensorArray(tf.int32, size=term_syms, dynamic_size=False) 269 # Termination 270 if terminate: 271 fb_poly = tf.constant([int(x) for x in self.gen_poly[0][1:]]) 272 fb_poly_tiled = tf.tile( 273 tf.expand_dims(fb_poly,0),[tf.shape(prev_st)[0],1]) 274 for idx in tf.range(0, term_syms, self._conv_k): 275 prev_st_bits = int2bin_tf(prev_st, self._mu) 276 msg_idx = tf.math.reduce_sum( 277 tf.multiply(fb_poly_tiled, prev_st_bits),-1) 278 msg_idx = tf.squeeze(int2bin_tf(msg_idx,1),-1) 279 280 indices = tf.stack([prev_st, msg_idx], -1) 281 new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices) 282 idx_syms = tf.gather_nd(self._trellis.op_mat, 283 tf.stack([prev_st, new_st], -1)) 284 idx_bits = int2bin_tf(idx_syms, self._conv_n) 285 ta_term = ta_term.write(idx//self._conv_k, idx_bits) 286 prev_st = new_st 287 288 return ta, ta_term 289 290 def _puncture_cw(self, cw): 291 """ 292 Given the codeword ``cw``, this method punctures ``cw`` using the 293 puncturing pattern defined in self.punct_pattern. A simple tile 294 operation of self.punct_pattern followed by tf.boolean_mask(cw, mask_) 295 works. However this fails in XLA mode as the dimension of the above 296 operation is unknown. 297 298 Hence, idx is obtained from `tf.where(self.punct_pattern)` during 299 initialization. This way the dimension of idx is known during graph 300 creation. Then during the call(), idx is tiled followed by row offset 301 addition to idx (the indices tensor) will achieve the same result as 302 applying a tiled boolean_mask. 303 """ 304 # cw shape: (bs, n, 3)- transpose to (n, 3, bs) 305 cw = tf.transpose(cw, perm=[1, 2, 0]) 306 cw_n = cw.get_shape()[0] 307 308 punct_period = self.punct_pattern.shape[0] 309 mask_reps = cw_n//punct_period 310 idx = tf.tile(self.punct_idx, [mask_reps, 1]) 311 312 idx_per_period = self.punct_idx.shape[0] 313 idx_per_time = idx_per_period/punct_period 314 315 # When tiling punct_pattern doesn't cover cw, delta_times > 0 316 delta_times = cw_n - (mask_reps * punct_period) 317 delta_idx_rows = int(delta_times*idx_per_time) 318 319 time_offset = punct_period * tf.range(mask_reps)[None,:] 320 row_idx = tf.transpose(tf.tile(time_offset,[idx_per_period,1])) 321 row_idx = tf.reshape(row_idx, (-1, 1)) 322 323 total_indices = mask_reps*idx_per_period + delta_idx_rows 324 col_idx = tf.zeros((total_indices,1), tf.int32) 325 326 if delta_times > 0: 327 idx = tf.concat([idx, self.punct_idx[:delta_idx_rows]], axis=0) 328 # Additional index row offsets if delta_times > 0 329 time_n = punct_period*mask_reps 330 row_idx_delta = tf.tile( 331 tf.range(time_n, time_n+delta_times)[None, :], 332 [delta_idx_rows, 1]) 333 row_idx = tf.concat([row_idx, row_idx_delta], axis=0) 334 335 idx_offset = tf.cast(tf.concat([row_idx, col_idx], axis=1), tf.int64) 336 idx = tf.add(idx, idx_offset) 337 338 cw = tf.gather_nd(cw, idx) 339 cw = tf.transpose(cw) 340 return cw 341 342 ######################### 343 # Keras layer functions 344 ######################### 345 346 def build(self, input_shape): 347 """Build layer and check dimensions. 348 349 Args: 350 input_shape: shape of input tensor (...,k). 351 """ 352 self._k = input_shape[-1] 353 self._n = int(self._k/self._coderate_desired) 354 if self._interleaver_type == '3GPP': 355 assert self._k <= 6144, '3GPP Turbo Codes define Interleavers only\ 356 upto frame lengths of 6144' 357 358 # Num. of encoding periods/state transitions. 359 # Not equal to _k if_conv_k>1. 360 self.num_syms = int(self._k//self._conv_k) 361 362 def call(self, inputs): 363 """Turbo code encoding function. 364 Args: 365 inputs (tf.float32): Information tensor of shape `[...,k]`. 366 367 Returns: 368 `tf.float32`: Encoded codeword tensor of shape `[...,n]`. 369 """ 370 tf.debugging.assert_greater(tf.rank(inputs), 1) 371 372 if inputs.shape[-1] != self._k: 373 self.build(inputs.shape) 374 375 if self._terminate: 376 num_term_bits_ = int( 377 self.turbo_term.get_num_term_syms()/self._coderate_conv) 378 num_term_bits_punct = int( 379 num_term_bits_*self._coderate_conv/self._coderate_desired) 380 else: 381 num_term_bits_ = 0 382 num_term_bits_punct = 0 383 384 output_shape = inputs.get_shape().as_list() 385 output_shape[0] = -1 386 output_shape[-1] = self._n + num_term_bits_punct 387 388 preterm_n = int(self._k/self._coderate_conv) 389 msg = tf.cast(tf.reshape(inputs, [-1, self._k]), tf.int32) 390 msg2 = self.internal_interleaver(msg) 391 392 cw1_ = self.convencoder(msg) 393 cw2_ = self.convencoder(msg2) 394 395 cw1, term1 = cw1_[:, :preterm_n], cw1_[:, preterm_n:] 396 cw2, term2 = cw2_[:, :preterm_n], cw2_[:, preterm_n:] 397 398 # Gather parity stream from 2nd enc 399 par_idx = tf.range(1, preterm_n, delta=self._conv_n) 400 cw2_par = tf.gather(cw2, indices=par_idx, axis=-1) 401 402 cw1 = tf.reshape(cw1,(-1, self._k, self._conv_n)) 403 cw2_par = tf.reshape(cw2_par, (-1, self._k, 1)) 404 405 # Concatenate 2nd enc parity to _conv_n streams from first encoder 406 cw = tf.concat([cw1, cw2_par], axis=-1) 407 408 if self.terminate: 409 term_syms_turbo = self.turbo_term.termbits_conv2turbo(term1, term2) 410 term_syms_turbo = tf.reshape( 411 term_syms_turbo, (-1, num_term_bits_//2, 3)) 412 cw = tf.concat([cw, term_syms_turbo], axis=-2) 413 414 if self.punct_pattern is not None: 415 cw = self._puncture_cw(cw) 416 417 cw = tf.cast(cw, self.output_dtype) 418 cw_reshaped = tf.reshape(cw, output_shape) 419 return cw_reshaped