anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

encoding.py (16695B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layer for Turbo Code Encoding."""
      6 
      7 import math
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 from sionna.fec import interleaving
     11 from sionna.fec.utils import bin2int_tf, int2bin_tf
     12 from sionna.fec.conv.encoding import ConvEncoder
     13 from sionna.fec.conv.utils import Trellis
     14 from sionna.fec.turbo.utils import polynomial_selector, puncture_pattern, TurboTermination
     15 
     16 class TurboEncoder(Layer):
     17     # pylint: disable=line-too-long
     18     r"""TurboEncoder(gen_poly=None, constraint_length=3, rate=1/3, terminate=False, interleaver_type='3GPP', output_dtype=tf.float32, **kwargs)
     19 
     20     Performs encoding of information bits to a Turbo code codeword [Berrou]_.
     21     Implements the standard Turbo code framework [Berrou]_: Two identical
     22     rate-1/2 convolutional encoders :class:`~sionna.fec.conv.encoding.ConvEncoder`
     23     are combined to produce a rate-1/3 Turbo code. Further,
     24     puncturing to attain a rate-1/2 Turbo code is supported.
     25 
     26     The class inherits from the Keras layer class and can be used as layer in a
     27     Keras model.
     28 
     29     Parameters
     30     ----------
     31     gen_poly: tuple
     32         Tuple of strings with each string being a 0,1 sequence. If
     33         `None`, ``constraint_length`` must be provided.
     34 
     35     constraint_length: int
     36         Valid values are between 3 and 6 inclusive. Only required if
     37         ``gen_poly`` is `None`.
     38 
     39     rate: float
     40         Valid values are 1/3 and 1/2. Note that ``rate`` here denotes
     41         the `design` rate of the Turbo code. If ``terminate`` is `True`, a
     42         small rate-loss occurs.
     43 
     44     terminate: boolean
     45         Underlying convolutional encoders are terminated to all zero state
     46         if `True`. If terminated, the true rate of the code is slightly lower
     47         than ``rate``.
     48 
     49     interleaver_type: str
     50         Valid values are `"3GPP"` or `"random"`. Determines the choice of
     51         the interleaver to interleave the message bits before input to the
     52         second convolutional encoder. If `"3GPP"`, the Turbo code interleaver
     53         from the 3GPP LTE standard [3GPPTS36212_Turbo]_ is used. If `"random"`,
     54         a random interleaver is used.
     55 
     56     output_dtype: tf.DType
     57         Defaults to `tf.float32`. Defines the output datatype of the layer.
     58 
     59     Input
     60     -----
     61     inputs : [...,k], tf.float32
     62         2+D tensor of information bits where `k` is the information length
     63 
     64     Output
     65     ------
     66     : `[...,k/rate]`, tf.float32
     67         2+D tensor where `rate` is provided as input
     68         parameter. The output is the encoded codeword for the input
     69         information tensor. When ``terminate`` is `True`, the effective rate
     70         of the Turbo code is slightly less than ``rate``.
     71 
     72     Note
     73     ----
     74         Various notations are used in literature to represent the generator
     75         polynomials for convolutional codes. For simplicity
     76         :class:`~sionna.fec.turbo.encoding.TurboEncoder` only
     77         accepts the binary format, i.e., `10011`, for the ``gen_poly`` argument
     78         which corresponds to the polynomial :math:`1 + D^3 + D^4`.
     79 
     80         Note that Turbo codes require the underlying convolutional encoders
     81         to be recursive systematic encoders. Only then the channel output
     82         from the systematic part of the first encoder can be used to decode
     83         the second encoder.
     84 
     85         Also note that ``constraint_length`` and ``memory`` are two different
     86         terms often used to denote the strength of the convolutional code. In
     87         this sub-package we use ``constraint_length``. For example, the polynomial
     88         `10011` has a ``constraint_length`` of 5, however its ``memory`` is
     89         only 4.
     90 
     91         When ``terminate`` is `True`, the true rate of the Turbo code is
     92         slightly lower than ``rate``. It can be computed as
     93         :math:`\frac{k}{\frac{k}{r}+\frac{4\mu}{3r}}` where `r` denotes
     94         ``rate`` and :math:`\mu` is the ``constraint_length`` - 1. For example, in
     95         3GPP, ``constraint_length`` = 4, ``terminate`` = `True`, for
     96         ``rate`` = 1/3, true rate is equal to  :math:`\frac{k}{3k+12}` .
     97     """
     98 
     99     def __init__(self,
    100                  gen_poly=None,
    101                  constraint_length=3,
    102                  rate=1/3,
    103                  terminate=False,
    104                  interleaver_type='3GPP',
    105                  output_dtype=tf.float32,
    106                  **kwargs):
    107 
    108         super().__init__(**kwargs)
    109 
    110         if gen_poly is not None:
    111             assert all(isinstance(poly, str) for poly in gen_poly), \
    112                 "Each element of gen_poly must be a string."
    113             assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \
    114                 "Each polynomial must be of same length."
    115             assert all(all(
    116                 char in ['0','1'] for char in poly) for poly in gen_poly),\
    117                 "Each Polynomial must be a string of 0/1 s."
    118             assert len(gen_poly)==2, \
    119                 "Generator polynomials need to be of Rate-1/2 "
    120             self._gen_poly = gen_poly
    121         else:
    122             valid_constraint_length = (3, 4, 5, 6)
    123             assert constraint_length in valid_constraint_length, \
    124                 "Constraint length must be between 3 and 6."
    125             self._gen_poly = polynomial_selector(constraint_length)
    126 
    127         valid_rates = (1/2, 1/3)
    128         assert rate in valid_rates, "Invalid coderate."
    129         assert isinstance(terminate, bool), "terminate must be bool."
    130         assert interleaver_type in ('3GPP', 'random'),\
    131                                             "Invalid interleaver_type."
    132 
    133         self._coderate_desired = rate
    134         self._coderate = self._coderate_desired
    135         self._terminate = terminate
    136         self._interleaver_type = interleaver_type
    137         self.output_dtype = output_dtype
    138         # Underlying convolutional encoders to be rsc or not
    139         rsc = True
    140 
    141         self._coderate_conv = 1/len(self.gen_poly)
    142         self._punct_pattern = puncture_pattern(rate, self._coderate_conv)
    143 
    144         self._trellis = Trellis(self.gen_poly, rsc=rsc)
    145         self._mu = self.trellis._mu
    146 
    147         # conv_n denotes number of output bits for conv_k input bits.
    148         self._conv_k = self._trellis.conv_k
    149         self._conv_n = self._trellis.conv_n
    150 
    151         self._ni = 2**self._conv_k
    152         self._no  = 2**self._conv_n
    153         self._ns = self._trellis.ns
    154 
    155         self._k = None
    156         self._n = None
    157 
    158         if self.terminate:
    159             self.turbo_term = TurboTermination(self._mu+1, conv_n=self._conv_n)
    160 
    161         if self._interleaver_type == '3GPP':
    162             self.internal_interleaver = interleaving.Turbo3GPPInterleaver()
    163         else:
    164             self.internal_interleaver = interleaving.RandomInterleaver(
    165                                                     keep_batch_constant=True,
    166                                                     keep_state=True,
    167                                                     axis=-1)
    168 
    169         if self.punct_pattern is not None:
    170             self.punct_idx = tf.where(self.punct_pattern)
    171 
    172         self.convencoder = ConvEncoder(gen_poly=self._gen_poly,
    173                                        rsc=rsc,
    174                                        terminate=self._terminate)
    175 
    176     #########################################
    177     # Public methods and properties
    178     #########################################
    179 
    180     @property
    181     def gen_poly(self):
    182         """Generator polynomial used by the encoder"""
    183         return self._gen_poly
    184 
    185     @property
    186     def constraint_length(self):
    187         """Constraint length of the encoder"""
    188         return self._mu + 1
    189 
    190     @property
    191     def coderate(self):
    192         """Rate of the code used in the encoder"""
    193         if self.terminate and self._k is None:
    194             print("Note that, due to termination, the true coderate is lower "\
    195                   "than the returned design rate. "\
    196                   "The exact true rate is dependent on the value of k and "\
    197                   "hence cannot be computed before the first call().")
    198         elif self.terminate and self._k is not None:
    199             term_factor = 1+math.ceil(4*self._mu/3)/self._k
    200             self._coderate = self._coderate_desired/term_factor
    201         return self._coderate
    202 
    203     @property
    204     def trellis(self):
    205         """Trellis object used during encoding"""
    206         return self._trellis
    207 
    208     @property
    209     def terminate(self):
    210         """Indicates if the convolutional encoders are terminated"""
    211         return self._terminate
    212 
    213     @property
    214     def punct_pattern(self):
    215         """Puncturing pattern for the Turbo codeword"""
    216         return self._punct_pattern
    217 
    218     @property
    219     def k(self):
    220         """Number of information bits per codeword"""
    221         if self._k is None:
    222             print("Note: The value of k cannot be computed before the first " \
    223                   "call().")
    224         return self._k
    225 
    226     @property
    227     def n(self):
    228         """Number of codeword bits"""
    229         if self._n is None:
    230             print("Note: The value of n cannot be computed before the first " \
    231                   "call().")
    232         return self._n
    233 
    234     def _conv_enc(self, info_vec, terminate):
    235         """
    236         This method encodes the information tensor info_vec using the
    237         underlying convolutional encoder. Returns the encoded codeword tensor
    238         array ta, and the tensor array containing termination bits ta_term.
    239         If the terminate variable is False, ta_term is array of length 0.
    240         """
    241         msg = tf.cast(info_vec, tf.int32)
    242 
    243         msg_reshaped = tf.reshape(msg, [-1, self._k])
    244         term_syms = int(self._mu) if terminate else 0
    245 
    246         prev_st = tf.zeros([tf.shape(msg_reshaped)[0]], tf.int32)
    247         ta = tf.TensorArray(tf.int32, size=self.num_syms, dynamic_size=False)
    248 
    249         idx_offset = range(0, self._conv_k)
    250         for idx in tf.range(0, self._k, self._conv_k):
    251             msg_bits_idx = tf.gather(msg_reshaped,
    252                                      idx + idx_offset,
    253                                      axis=-1)
    254 
    255             #msg_bits_idx = tf.experimental.numpy.take_along_axis(msg_reshaped)
    256 
    257             msg_idx = bin2int_tf(msg_bits_idx)
    258 
    259             indices = tf.stack([prev_st, msg_idx], -1)
    260             new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices)
    261 
    262             idx_syms = tf.gather_nd(self._trellis.op_mat,
    263                                     tf.stack([prev_st, new_st], -1))
    264             idx_bits = int2bin_tf(idx_syms, self._conv_n)
    265             ta = ta.write(idx//self._conv_k, idx_bits)
    266             prev_st = new_st
    267 
    268         ta_term = tf.TensorArray(tf.int32, size=term_syms, dynamic_size=False)
    269         # Termination
    270         if terminate:
    271             fb_poly = tf.constant([int(x) for x in self.gen_poly[0][1:]])
    272             fb_poly_tiled = tf.tile(
    273                 tf.expand_dims(fb_poly,0),[tf.shape(prev_st)[0],1])
    274             for idx in tf.range(0, term_syms, self._conv_k):
    275                 prev_st_bits = int2bin_tf(prev_st, self._mu)
    276                 msg_idx = tf.math.reduce_sum(
    277                                     tf.multiply(fb_poly_tiled, prev_st_bits),-1)
    278                 msg_idx = tf.squeeze(int2bin_tf(msg_idx,1),-1)
    279 
    280                 indices = tf.stack([prev_st, msg_idx], -1)
    281                 new_st = tf.gather_nd(self._trellis.to_nodes, indices=indices)
    282                 idx_syms = tf.gather_nd(self._trellis.op_mat,
    283                                         tf.stack([prev_st, new_st], -1))
    284                 idx_bits = int2bin_tf(idx_syms, self._conv_n)
    285                 ta_term = ta_term.write(idx//self._conv_k, idx_bits)
    286                 prev_st = new_st
    287 
    288         return ta, ta_term
    289 
    290     def _puncture_cw(self, cw):
    291         """
    292         Given the codeword ``cw``, this method punctures ``cw`` using the
    293         puncturing pattern defined in self.punct_pattern. A simple tile
    294         operation of self.punct_pattern followed by tf.boolean_mask(cw, mask_)
    295         works. However this fails in XLA mode as the dimension of the above
    296         operation is unknown.
    297 
    298         Hence, idx is obtained from `tf.where(self.punct_pattern)` during
    299         initialization. This way the dimension of idx is known during graph
    300         creation. Then during the call(), idx is tiled followed by row offset
    301         addition to idx (the indices tensor) will achieve the same result as
    302         applying a tiled boolean_mask.
    303         """
    304         # cw shape: (bs, n, 3)- transpose to (n, 3, bs)
    305         cw = tf.transpose(cw, perm=[1, 2, 0])
    306         cw_n = cw.get_shape()[0]
    307 
    308         punct_period = self.punct_pattern.shape[0]
    309         mask_reps = cw_n//punct_period
    310         idx = tf.tile(self.punct_idx, [mask_reps, 1])
    311 
    312         idx_per_period = self.punct_idx.shape[0]
    313         idx_per_time = idx_per_period/punct_period
    314 
    315         # When tiling punct_pattern doesn't cover cw, delta_times > 0
    316         delta_times  = cw_n - (mask_reps * punct_period)
    317         delta_idx_rows = int(delta_times*idx_per_time)
    318 
    319         time_offset = punct_period * tf.range(mask_reps)[None,:]
    320         row_idx = tf.transpose(tf.tile(time_offset,[idx_per_period,1]))
    321         row_idx = tf.reshape(row_idx, (-1, 1))
    322 
    323         total_indices = mask_reps*idx_per_period + delta_idx_rows
    324         col_idx = tf.zeros((total_indices,1), tf.int32)
    325 
    326         if delta_times > 0:
    327             idx = tf.concat([idx, self.punct_idx[:delta_idx_rows]], axis=0)
    328             # Additional index row offsets if delta_times > 0
    329             time_n = punct_period*mask_reps
    330             row_idx_delta = tf.tile(
    331                                 tf.range(time_n, time_n+delta_times)[None, :],
    332                                 [delta_idx_rows, 1])
    333             row_idx = tf.concat([row_idx, row_idx_delta], axis=0)
    334 
    335         idx_offset = tf.cast(tf.concat([row_idx, col_idx], axis=1), tf.int64)
    336         idx = tf.add(idx, idx_offset)
    337 
    338         cw = tf.gather_nd(cw, idx)
    339         cw = tf.transpose(cw)
    340         return cw
    341 
    342     #########################
    343     # Keras layer functions
    344     #########################
    345 
    346     def build(self, input_shape):
    347         """Build layer and check dimensions.
    348 
    349         Args:
    350             input_shape: shape of input tensor (...,k).
    351         """
    352         self._k = input_shape[-1]
    353         self._n = int(self._k/self._coderate_desired)
    354         if self._interleaver_type == '3GPP':
    355             assert self._k <= 6144, '3GPP Turbo Codes define Interleavers only\
    356             upto frame lengths of 6144'
    357 
    358         # Num. of encoding periods/state transitions.
    359         # Not equal to _k if_conv_k>1.
    360         self.num_syms = int(self._k//self._conv_k)
    361 
    362     def call(self, inputs):
    363         """Turbo code encoding function.
    364         Args:
    365             inputs (tf.float32): Information tensor of shape `[...,k]`.
    366 
    367         Returns:
    368             `tf.float32`: Encoded codeword tensor of shape `[...,n]`.
    369         """
    370         tf.debugging.assert_greater(tf.rank(inputs), 1)
    371 
    372         if inputs.shape[-1] != self._k:
    373             self.build(inputs.shape)
    374 
    375         if self._terminate:
    376             num_term_bits_ = int(
    377                 self.turbo_term.get_num_term_syms()/self._coderate_conv)
    378             num_term_bits_punct = int(
    379                 num_term_bits_*self._coderate_conv/self._coderate_desired)
    380         else:
    381             num_term_bits_ = 0
    382             num_term_bits_punct = 0
    383 
    384         output_shape = inputs.get_shape().as_list()
    385         output_shape[0] = -1
    386         output_shape[-1] = self._n + num_term_bits_punct
    387 
    388         preterm_n = int(self._k/self._coderate_conv)
    389         msg = tf.cast(tf.reshape(inputs, [-1, self._k]), tf.int32)
    390         msg2 = self.internal_interleaver(msg)
    391 
    392         cw1_ = self.convencoder(msg)
    393         cw2_ = self.convencoder(msg2)
    394 
    395         cw1, term1 = cw1_[:, :preterm_n], cw1_[:, preterm_n:]
    396         cw2, term2 = cw2_[:, :preterm_n], cw2_[:, preterm_n:]
    397 
    398         # Gather parity stream from 2nd enc
    399         par_idx = tf.range(1, preterm_n, delta=self._conv_n)
    400         cw2_par = tf.gather(cw2, indices=par_idx, axis=-1)
    401 
    402         cw1 = tf.reshape(cw1,(-1, self._k, self._conv_n))
    403         cw2_par = tf.reshape(cw2_par, (-1, self._k, 1))
    404 
    405         # Concatenate 2nd enc parity to _conv_n streams from first encoder
    406         cw = tf.concat([cw1, cw2_par], axis=-1)
    407 
    408         if self.terminate:
    409             term_syms_turbo = self.turbo_term.termbits_conv2turbo(term1, term2)
    410             term_syms_turbo = tf.reshape(
    411                 term_syms_turbo, (-1, num_term_bits_//2, 3))
    412             cw = tf.concat([cw, term_syms_turbo], axis=-2)
    413 
    414         if self.punct_pattern is not None:
    415             cw = self._puncture_cw(cw)
    416 
    417         cw = tf.cast(cw, self.output_dtype)
    418         cw_reshaped = tf.reshape(cw, output_shape)
    419         return cw_reshaped