encoding.py (11361B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layer for Convolutional Code Encoding.""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from sionna.fec.utils import bin2int_tf, int2bin_tf 10 from sionna.fec.conv.utils import polynomial_selector, Trellis 11 12 class ConvEncoder(Layer): 13 # pylint: disable=line-too-long 14 r"""ConvEncoder(gen_poly=None, rate= 1/2, constraint_length=3, rsc=False, terminate=False, output_dtype=tf.float32, **kwargs) 15 16 Encodes an information binary tensor to a convolutional codeword. Currently, 17 only generator polynomials for codes of rate=1/n for n=2,3,4,... are allowed. 18 19 The class inherits from the Keras layer class and can be used as layer in a 20 Keras model. 21 22 Parameters 23 ---------- 24 gen_poly: tuple 25 Sequence of strings with each string being a 0,1 sequence. If 26 `None`, ``rate`` and ``constraint_length`` must be provided. 27 28 rate: float 29 Valid values are 1/3 and 0.5. Only required if ``gen_poly`` is 30 `None`. 31 32 constraint_length: int 33 Valid values are between 3 and 8 inclusive. Only required if 34 ``gen_poly`` is `None`. 35 36 rsc: boolean 37 Boolean flag indicating whether the Trellis generated is recursive 38 systematic or not. If `True`, the encoder is recursive-systematic. 39 In this case first polynomial in ``gen_poly`` is used as the 40 feedback polynomial. Defaults to `False`. 41 42 terminate: boolean 43 Encoder is terminated to all zero state if `True`. 44 If terminated, the `true` rate of the code is slightly lower than 45 ``rate``. 46 47 output_dtype: tf.DType 48 Defaults to `tf.float32`. Defines the output datatype of the layer. 49 50 Input 51 ----- 52 inputs : [...,k], tf.float32 53 2+D tensor containing the information bits where `k` is the 54 information length 55 56 Output 57 ------ 58 : [...,k/rate], tf.float32 59 2+D tensor containing the encoded codeword for the given input 60 information tensor where `rate` is 61 :math:`\frac{1}{\textrm{len}\left(\textrm{gen_poly}\right)}` 62 (if ``gen_poly`` is provided). 63 64 Note 65 ---- 66 The generator polynomials from [Moon]_ are available for various 67 rate and constraint lengths. To select them, use the ``rate`` and 68 ``constraint_length`` arguments. 69 70 In addition, polynomials for any non-recursive convolutional encoder 71 can be given as input via ``gen_poly`` argument. Currently, only 72 polynomials with rate=1/n are supported. When the ``gen_poly`` argument 73 is given, the ``rate`` and ``constraint_length`` arguments are ignored. 74 75 Various notations are used in the literature to represent the generator 76 polynomials for convolutional codes. In [Moon]_, the octal digits 77 format is primarily used. In the octal format, the generator polynomial 78 `10011` corresponds to 46. Another widely used format 79 is decimal notation with MSB. In this notation, polynomial `10011` 80 corresponds to 19. For simplicity, the 81 :class:`~sionna.fec.conv.encoding.ConvEncoder` only accepts the bit 82 format i.e. `10011` as ``gen_poly`` argument. 83 84 Also note that ``constraint_length`` and ``memory`` are two different 85 terms often used to denote the strength of a convolutional code. In this 86 sub-package, we use ``constraint_length``. For example, the 87 polynomial `10011` has a ``constraint_length`` of 5, however its 88 ``memory`` is only 4. 89 90 When ``terminate`` is `True`, the true rate of the convolutional 91 code is slightly lower than ``rate``. It equals 92 :math:`\frac{r*k}{k+\mu}` where `r` denotes ``rate`` and 93 :math:`\mu` is ``constraint_length`` - 1. For example when 94 ``terminate`` is `True`, ``k=100``, 95 :math:`\mu=4` and ``rate`` =0.5, true rate equals 96 :math:`\frac{0.5*100}{104}=0.481`. 97 """ 98 99 def __init__(self, 100 gen_poly=None, 101 rate=1/2, 102 constraint_length=3, 103 rsc=False, 104 terminate=False, 105 output_dtype=tf.float32, 106 **kwargs): 107 108 super().__init__(**kwargs) 109 110 if gen_poly is not None: 111 assert all(isinstance(poly, str) for poly in gen_poly), \ 112 "Each element of gen_poly must be a string." 113 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \ 114 "Each polynomial must be of same length." 115 assert all(all( 116 char in ['0','1'] for char in poly) for poly in gen_poly),\ 117 "Each Polynomial must be a string of 0/1 s." 118 self._gen_poly = gen_poly 119 else: 120 valid_rates = (1/2, 1/3) 121 valid_constraint_length = (3, 4, 5, 6, 7, 8) 122 123 assert constraint_length in valid_constraint_length, \ 124 "Constraint length must be between 3 and 8." 125 assert rate in valid_rates, \ 126 "Rate must be 1/3 or 1/2." 127 self._gen_poly = polynomial_selector(rate, constraint_length) 128 129 self._rsc = rsc 130 self._terminate = terminate 131 132 self._coderate_desired = 1/len(self.gen_poly) 133 # Differ when terminate is True 134 self._coderate = self._coderate_desired 135 136 self._trellis = Trellis(self.gen_poly,rsc=self._rsc) 137 self._mu = self.trellis._mu 138 139 # conv_k denotes number of input bit streams. 140 # Only 1 allowed in current implementation 141 self._conv_k = self._trellis.conv_k 142 143 # conv_n denotes number of output bits for conv_k input bits 144 self._conv_n = self._trellis.conv_n 145 146 self._ni = 2**self._conv_k 147 self._no = 2**self._conv_n 148 self._ns = self._trellis.ns 149 150 self._k = None 151 self._n = None 152 self.output_dtype = output_dtype 153 154 ######################################### 155 # Public methods and properties 156 ######################################### 157 158 @property 159 def gen_poly(self): 160 """Generator polynomial used by the encoder""" 161 return self._gen_poly 162 163 @property 164 def coderate(self): 165 """Rate of the code used in the encoder""" 166 if self.terminate and self._k is None: 167 print("Note that, due to termination, the true coderate is lower "\ 168 "than the returned design rate. "\ 169 "The exact true rate is dependent on the value of k and "\ 170 "hence cannot be computed before the first call().") 171 elif self.terminate and self._k is not None: 172 term_factor = self._k/(self._k + self._mu) 173 self._coderate = self._coderate_desired*term_factor 174 return self._coderate 175 176 @property 177 def trellis(self): 178 """Trellis object used during encoding""" 179 return self._trellis 180 181 @property 182 def terminate(self): 183 """Indicates if the convolutional encoder is terminated""" 184 return self._terminate 185 186 @property 187 def k(self): 188 """Number of information bits per codeword""" 189 if self._k is None: 190 print("Note: The value of k cannot be computed before the first " \ 191 "call().") 192 return self._k 193 194 @property 195 def n(self): 196 """Number of codeword bits""" 197 if self._n is None: 198 print("Note: The value of n cannot be computed before the first " \ 199 "call().") 200 return self._n 201 202 ######################### 203 # Keras layer functions 204 ######################### 205 206 def build(self, input_shape): 207 """Build layer and check dimensions. 208 209 Args: 210 input_shape: shape of input tensor (...,k). 211 """ 212 self._k = input_shape[-1] 213 self._n = int(self._k/self.coderate) 214 215 # num_syms denote number of encoding periods or state transitions. 216 # different from _k when _conv_k > 1. 217 self.num_syms = int(self._k//self._conv_k) 218 219 def call(self, inputs): 220 """Convolutional code encoding function. 221 222 Args: 223 inputs (tf.float32): Information tensor of shape `[...,k]`. 224 225 Returns: 226 `tf.float32`: Encoded codeword tensor of shape `[...,n]`. 227 """ 228 tf.debugging.assert_greater(tf.rank(inputs), 1) 229 230 if inputs.shape[-1] != self._k: 231 self.build(inputs.shape) 232 233 msg = tf.cast(inputs, tf.int32) 234 output_shape = msg.get_shape().as_list() 235 output_shape[0] = -1 # overwrite batch dim (can be none in keras) 236 output_shape[-1] = self._n # assign n to the last dim 237 238 msg_reshaped = tf.reshape(msg, [-1, self._k]) 239 term_syms = int(self._mu) if self._terminate else 0 240 241 prev_st = tf.zeros([tf.shape(msg_reshaped)[0]], tf.int32) 242 ta = tf.TensorArray(tf.int32, size=self.num_syms, dynamic_size=False) 243 244 idx_offset = range(0, self._conv_k) 245 for idx in tf.range(0, self._k, self._conv_k): 246 msg_bits_idx = tf.gather(msg_reshaped, 247 idx + idx_offset, 248 axis=-1) 249 250 msg_idx = bin2int_tf(msg_bits_idx) 251 252 indices = tf.stack([prev_st, msg_idx], -1) 253 new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices) 254 255 idx_syms = tf.gather_nd(self._trellis.op_mat, 256 tf.stack([prev_st, new_st], -1)) 257 idx_bits = int2bin_tf(idx_syms, self._conv_n) 258 ta = ta.write(idx//self._conv_k, idx_bits) 259 prev_st = new_st 260 cw = tf.concat(tf.unstack(ta.stack()), axis=1) 261 262 ta_term = tf.TensorArray(tf.int32, size=term_syms, dynamic_size=False) 263 # Termination 264 if self._terminate: 265 if self._rsc: 266 fb_poly = tf.constant([int(x) for x in self.gen_poly[0][1:]]) 267 fb_poly_tiled = tf.tile( 268 tf.expand_dims(fb_poly,0),[tf.shape(prev_st)[0],1]) 269 270 for idx in tf.range(0, term_syms, self._conv_k): 271 prev_st_bits = int2bin_tf(prev_st, self._mu) 272 if self._rsc: 273 msg_idx = tf.math.reduce_sum( 274 tf.multiply(fb_poly_tiled, prev_st_bits),-1) 275 msg_idx = tf.squeeze(int2bin_tf(msg_idx,1),-1) 276 else: 277 msg_idx = tf.zeros((tf.shape(prev_st)[0],), dtype=tf.int32) 278 279 indices = tf.stack([prev_st, msg_idx], -1) 280 new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices) 281 idx_syms = tf.gather_nd(self._trellis.op_mat, 282 tf.stack([prev_st, new_st], -1)) 283 idx_bits = int2bin_tf(idx_syms, self._conv_n) 284 ta_term = ta_term.write(idx//self._conv_k, idx_bits) 285 prev_st = new_st 286 287 term_bits = tf.concat(tf.unstack(ta_term.stack()), axis=1) 288 cw = tf.concat([cw, term_bits], axis=-1) 289 290 cw = tf.cast(cw, self.output_dtype) 291 cw_reshaped = tf.reshape(cw, output_shape) 292 293 return cw_reshaped 294