anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

decoding.py (37972B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layer for Convolutional Code Viterbi Decoding."""
      6 
      7 import numpy as np
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 from sionna.fec.utils import int2bin
     11 from sionna.fec.conv.utils import polynomial_selector, Trellis
     12 
     13 
     14 class ViterbiDecoder(Layer):
     15     # pylint: disable=line-too-long
     16     r"""ViterbiDecoder(encoder=None, gen_poly=None, rate=1/2, constraint_length=3, rsc=False, terminate=False, method='soft_llr', output_dtype=tf.float32, **kwargs)
     17 
     18     Implements the Viterbi decoding algorithm [Viterbi]_ that returns an
     19     estimate of the information bits for a noisy convolutional codeword.
     20     Takes as input either LLR values (`method` = `soft_llr`) or hard bit values
     21     (`method` = `hard`) and returns a hard decided estimation of the information
     22     bits.
     23 
     24     The class inherits from the Keras layer class and can be used as layer in
     25     a Keras model.
     26 
     27     Parameters
     28     ----------
     29     encoder: :class:`~sionna.fec.conv.encoding.ConvEncoder`
     30         If ``encoder`` is provided as input, the following input parameters
     31         are not required and will be ignored: ``gen_poly``, ``rate``,
     32         ``constraint_length``, ``rsc``, ``terminate``. They will be inferred
     33         from the ``encoder``  object itself. If ``encoder`` is `None`, the
     34         above parameters must be provided explicitly.
     35 
     36     gen_poly: tuple
     37          tuple of strings with each string being a 0, 1 sequence. If `None`,
     38         ``rate`` and ``constraint_length`` must be provided.
     39 
     40     rate: float
     41         Valid values are 1/3 and 0.5. Only required if ``gen_poly`` is `None`.
     42 
     43     constraint_length: int
     44         Valid values are between 3 and 8 inclusive. Only required if
     45         ``gen_poly`` is `None`.
     46 
     47     rsc: boolean
     48         Boolean flag indicating whether the encoder is recursive-systematic for
     49         given generator polynomials.
     50         `True` indicates encoder is recursive-systematic.
     51         `False` indicates encoder is feed-forward non-systematic.
     52 
     53     terminate: boolean
     54         Boolean flag indicating whether the codeword is terminated.
     55         `True` indicates codeword is terminated to all-zero state.
     56         `False` indicates codeword is not terminated.
     57 
     58     method: str
     59         Valid values are `soft_llr` or `hard`. In computing path
     60         metrics,
     61         `soft_llr` expects channel LLRs as input
     62         `hard` assumes a `binary symmetric channel` (BSC) with 0/1 values are
     63         inputs. In case of `hard`, `inputs` will be quantized to 0/1 values.
     64 
     65     output_dtype: tf.DType
     66         Defaults to tf.float32. Defines the output datatype of the layer.
     67 
     68     Input
     69     -----
     70         inputs: [...,n], tf.float32
     71             2+D tensor containing the (noisy) channel output symbols where `n`
     72             denotes the codeword length
     73 
     74     Output
     75     ------
     76         : [...,rate*n], tf.float32
     77             2+D tensor containing the estimates of the information bit tensor
     78 
     79     Note
     80     ----
     81     A full implementation of the decoder rather than a windowed approach
     82     is used. For a given codeword of duration `T`, the path metric is
     83     computed from time `0` to `T` and the path with optimal metric at time
     84     `T` is selected. The optimal path is then traced back from `T` to `0`
     85     to output the estimate of the information bit vector used to encode.
     86     For larger codewords, note that the current method is sub-optimal
     87     in terms of memory utilization and latency.
     88     """
     89 
     90     def __init__(self,
     91                  encoder=None,
     92                  gen_poly=None,
     93                  rate=1/2,
     94                  constraint_length=3,
     95                  rsc=False,
     96                  terminate=False,
     97                  method='soft_llr',
     98                  return_info_bits=True,
     99                  output_dtype=tf.float32,
    100                  **kwargs):
    101 
    102         super().__init__(**kwargs)
    103         if encoder is not None:
    104             self._gen_poly = encoder.gen_poly
    105             self._trellis = encoder.trellis
    106             self._terminate = encoder.terminate
    107         else:
    108             valid_rates = (1/2, 1/3)
    109             valid_constraint_length = (3, 4, 5, 6, 7, 8)
    110 
    111             if gen_poly is not None:
    112                 assert all(isinstance(poly, str) for poly in gen_poly), \
    113                 "Each polynomial must be a string."
    114                 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \
    115                 "Each polynomial must be of same length."
    116                 assert all(all(
    117                 char in ['0','1'] for char in poly) for poly in gen_poly),\
    118                 "Each polynomial must be a string of 0's and 1's."
    119                 self._gen_poly = gen_poly
    120             else:
    121                 valid_rates = (1/2, 1/3)
    122                 valid_constraint_length = (3, 4, 5, 6, 7, 8)
    123 
    124                 assert constraint_length in valid_constraint_length, \
    125                     "Constraint length must be between 3 and 8."
    126                 assert rate in valid_rates, \
    127                     "Rate must be 1/3 or 1/2."
    128                 self._gen_poly = polynomial_selector(rate, constraint_length)
    129 
    130             # init Trellis parameters
    131             self._trellis = Trellis(self.gen_poly, rsc=rsc)
    132             self._terminate = terminate
    133 
    134         self._coderate_desired = 1/len(self.gen_poly)
    135         self._mu = len(self._gen_poly[0])-1
    136         assert method in ('soft_llr', 'hard'), \
    137             "method must be `soft_llr` or `hard`."
    138 
    139         # conv_k denotes number of input bit streams
    140         # can only be 1 in current implementation
    141         self._conv_k = self._trellis.conv_k
    142 
    143         # conv_n denotes number of output bits for conv_k input bits
    144         self._conv_n = self._trellis.conv_n
    145 
    146         self._k = None
    147         self._n = None
    148         # num_syms denote number of encoding periods or state transitions.
    149         self._num_syms = None
    150 
    151         self._ni = 2**self._conv_k
    152         self._no = 2**self._conv_n
    153         self._ns = self._trellis.ns
    154 
    155         self._method = method
    156         self._return_info_bits = return_info_bits
    157         self.output_dtype = output_dtype
    158         # If i->j state transition emits symbol k, tf.gather with ipst_op_idx
    159         # gathers (i,k) element from input in row j.
    160         self.ipst_op_idx = self._mask_by_tonode()
    161 
    162     #########################################
    163     # Public methods and properties
    164     #########################################
    165 
    166     @property
    167     def gen_poly(self):
    168         """Generator polynomial used by the encoder"""
    169         return self._gen_poly
    170 
    171     @property
    172     def coderate(self):
    173         """Rate of the code used in the encoder"""
    174         if self.terminate and self._n is None:
    175             print("Note that, due to termination, the true coderate is lower "\
    176                   "than the returned design rate. "\
    177                   "The exact true rate is dependent on the value of n and "\
    178                   "hence cannot be computed before the first call().")
    179             self._coderate = self._coderate_desired
    180         elif self.terminate and self._n is not None:
    181             k = self._coderate_desired*self._n - self._mu
    182             self._coderate = k/self._n
    183         return self._coderate
    184 
    185     @property
    186     def trellis(self):
    187         """Trellis object used during encoding"""
    188         return self._trellis
    189 
    190     @property
    191     def terminate(self):
    192         """Indicates if the encoder is terminated during codeword generation"""
    193         return self._terminate
    194 
    195     @property
    196     def k(self):
    197         """Number of information bits per codeword"""
    198         if self._k is None:
    199             print("Note: The value of k cannot be computed before the first " \
    200                   "call().")
    201         return self._k
    202 
    203     @property
    204     def n(self):
    205         """Number of codeword bits"""
    206         if self._n is None:
    207             print("Note: The value of n cannot be computed before the first " \
    208                   "call().")
    209         return self._n
    210 
    211     #########################
    212     # Utility functions
    213     #########################
    214 
    215     def _mask_by_tonode(self):
    216         r"""
    217          _Ns x _No index matrix, each element of shape (2,)
    218          where num_ops = 2**conv_n
    219          When applied as tf.gather index on a Ns x  num_ops matrix
    220          ((i,j) denoting metric for prev_st=i and output=j)
    221          the output is matrix sorted by next_state. Row i in output
    222          denotes the 2 possible metrics for transition to state i.
    223         """
    224         cnst = self._ns * self._ni
    225         from_nodes_vec = tf.reshape(self._trellis.from_nodes,(cnst,))
    226         op_idx = tf.reshape(self._trellis.op_by_tonode, (cnst,))
    227         st_op_idx = tf.transpose(tf.stack([from_nodes_vec, op_idx]))
    228         st_op_idx = tf.reshape(st_op_idx[None,:,:],(self._ns, self._ni, 2))
    229 
    230         return st_op_idx
    231 
    232     def _update_fwd(self, init_cm, bm_mat):
    233         state_vec = tf.tile(tf.range(self._ns, dtype=tf.int32)[None,:],
    234                             [tf.shape(init_cm)[0], 1])
    235         ipst_op_mask = tf.tile(self.ipst_op_idx[None,:], [tf.shape(init_cm)[0], 1, 1, 1])
    236 
    237         cm_ta = tf.TensorArray(tf.float32, size=self._num_syms,
    238                                dynamic_size=False, clear_after_read=False)
    239         tb_ta = tf.TensorArray(tf.int32, size=self._num_syms,
    240                                dynamic_size=False, clear_after_read=False)
    241 
    242         prev_cm = init_cm
    243         for idx in tf.range(0, self._n, self._conv_n):
    244             sym = idx//self._conv_n
    245             metrics_t = bm_mat[..., sym]
    246             # Ns x No matrix- (s,j) is path_metric at state s with transition op=j
    247             sum_metric = prev_cm[:,:,None] + metrics_t[:,None,:]
    248             sum_metric_bytonode = tf.gather_nd(sum_metric, ipst_op_mask,
    249                                                batch_dims=1)
    250 
    251             tb_state_idx = tf.math.argmin(sum_metric_bytonode, axis=2)
    252             tb_state_idx = tf.cast(tb_state_idx, tf.int32)
    253 
    254             # Transition to states argmin state index
    255             from_st_idx = tf.transpose(tf.stack([state_vec, tb_state_idx]),
    256                                        perm=[1, 2, 0])
    257 
    258             tb_states = tf.gather_nd(self._trellis.from_nodes, from_st_idx)
    259             cum_t = tf.math.reduce_min(sum_metric_bytonode,axis=2)
    260 
    261             cm_ta = cm_ta.write(sym, cum_t)
    262             tb_ta = tb_ta.write(sym, tb_states)
    263 
    264             prev_cm = cum_t
    265 
    266         return cm_ta, tb_ta
    267 
    268 
    269     def _op_bits_path(self, paths):
    270         r"""
    271         Given a path, compute the input bit stream that results in the path.
    272         Used in call() where the input is optimal path (seq of states) such
    273         as the path returned by _return_optimal.
    274         """
    275         paths = tf.cast(paths, tf.int32)
    276         ip_bits = tf.TensorArray(tf.int32,
    277                                   size=paths.shape[-1]-1,
    278                                   dynamic_size=False,
    279                                   clear_after_read=False)
    280         dec_syms = tf.TensorArray(tf.int32,
    281                                   size=paths.shape[-1]-1,
    282                                   dynamic_size=False,
    283                                   clear_after_read=False)
    284         ni = self._trellis.ni
    285         ip_sym_mask = tf.range(ni)[None, :]
    286 
    287         for sym in tf.range(1, paths.shape[-1]):
    288 
    289             # gather index from paths to enable XLA
    290             # replaces p_idx = paths[:,sym-1:sym+1]
    291             p_idx = tf.gather(paths, [sym-1, sym], axis=-1)
    292             dec_ = tf.gather_nd(self._trellis.op_mat, p_idx)
    293 
    294             dec_syms = dec_syms.write(sym-1, value=dec_)
    295             # bs x ni boolean tensor. Each row has a True and False. True
    296             # corresponds to input_bit which produced the next state (t=sym)
    297             match_st = tf.math.equal(
    298                 tf.gather(self._trellis.to_nodes,paths[:, sym-1]),
    299                 tf.tile(paths[:, sym][:, None], [1, 2])
    300                 )
    301 
    302             # tf.boolean_mask throws error in XLA mode
    303             #ip_bit = tf.boolean_mask(ip_sym_mask, match_st)
    304 
    305             ip_bit_ = tf.where(match_st,
    306                                ip_sym_mask,
    307                                tf.zeros_like(ip_sym_mask))
    308             ip_bit = tf.reduce_sum(ip_bit_, axis=-1)
    309             ip_bits = ip_bits.write(sym-1, ip_bit)
    310 
    311         ip_bit_vec_est = tf.transpose(ip_bits.stack())
    312         ip_sym_vec_est = tf.transpose(dec_syms.stack())
    313 
    314         return ip_bit_vec_est, ip_sym_vec_est
    315 
    316     def _optimal_path(self, cm_, tb_):
    317         r"""
    318         Compute optimal path (state at each time t) given tensors cm_ & tb_
    319         of shapes (None, Ns, T). Output is of shape (None, T)
    320         cm_: cumulative metrics for each state at time t(0 to T)
    321         tb_: traceback state for each state at time t(0 to T)
    322         """
    323         # tb and ca are of shape (batch x self._ns x num_syms)
    324         assert(tb_.get_shape()[1] == self._ns), "Invalid shape."
    325         optst_ta = tf.TensorArray(tf.int32, size=tb_.shape[-1],
    326                                   dynamic_size=False,
    327                                   clear_after_read=False)
    328         if self._terminate:
    329             opt_term_state = tf.zeros((tf.shape(cm_)[0],), tf.int32)
    330         else:
    331             opt_term_state =tf.cast(tf.argmin(cm_[:, :, -1], axis=1), tf.int32)
    332         optst_ta = optst_ta.write(tb_.shape[-1]-1,opt_term_state)
    333 
    334         for sym in tf.range(tb_.shape[-1]-1, 0, -1):
    335             opt_st = optst_ta.read(sym)[:,None]
    336 
    337             idx_ = tf.concat([tf.range(tf.shape(cm_)[0])[:,None], opt_st],
    338                              axis=1)
    339             opt_st_tminus1 = tf.gather_nd(tb_[:, :, sym], idx_)
    340 
    341             optst_ta = optst_ta.write(sym-1, opt_st_tminus1)
    342 
    343         return tf.transpose(optst_ta.stack())
    344 
    345     def _bmcalc(self, y):
    346         """
    347         Calculate branch metrics for a given noisy codeword tensor.
    348         For each time period t, _bmcalc computes the distance of symbol
    349         vector y[t] from each possible output symbol.
    350         The distance metric is L2 distance if decoder parameter `method` is
    351         "soft".
    352 
    353         The distance metric is L1 distance if parameter `method` is "hard".
    354         """
    355 
    356         op_bits = np.stack(
    357                     [int2bin(op, self._conv_n) for op in range(self._no)])
    358         op_mat = tf.cast(tf.tile(op_bits, [1,self._num_syms]), tf.float32)
    359         op_mat = tf.expand_dims(op_mat, axis=0)
    360         y = tf.expand_dims(y, axis=1)
    361         if self._method=='soft_llr':
    362             op_mat_sign = 1 - 2.*op_mat
    363             llr_sign = -1. * tf.math.multiply(y, op_mat_sign)
    364             llr_sign = tf.reshape(llr_sign,
    365                 (-1, self._no, self._num_syms, self._conv_n))
    366             # Sum of LLR*(sign of bit) for each symbol
    367             bm = tf.math.reduce_sum(llr_sign, axis=-1)
    368 
    369         else: # method == 'hard'
    370             diffabs = tf.math.abs(y-op_mat)
    371             diffabs = tf.reshape(diffabs,
    372                                  (-1, self._no, self._num_syms, self._conv_n))
    373             # Manhattan distance of symbols
    374             bm = tf.math.reduce_sum(diffabs, axis=-1)
    375 
    376         return bm
    377 
    378     #########################
    379     # Keras layer functions
    380     #########################
    381 
    382     def build(self, input_shape):
    383         """Build layer and check dimensions."""
    384         # assert rank must be two
    385         tf.debugging.assert_greater_equal(len(input_shape), 2)
    386 
    387         self._n = input_shape[-1]
    388 
    389         divisible = tf.math.floormod(self._n, self._conv_n)
    390         assert divisible==0, 'length of codeword should be divisible by \
    391             number of output bits per symbol.'
    392 
    393         self._num_syms = int(self._n*self._coderate_desired)
    394 
    395         self._num_term_syms = self._mu if self.terminate else 0
    396         self._k = self._num_syms - self._num_term_syms
    397 
    398     def call(self, inputs):
    399         """
    400         Viterbi decoding function.
    401 
    402         inputs is the (noisy) codeword tensor where the last dimension should
    403         equal n. All the leading dimensions are assumed as batch dimensions.
    404 
    405         """
    406         LARGEDIST = 2.**20 # pylint: disable=invalid-name
    407 
    408         tf.debugging.assert_type(inputs, tf.float32,
    409                                  message="input must be tf.float32.")
    410         if self._method == 'hard':
    411             inputs = tf.math.floormod(tf.cast(inputs, tf.int32),2)
    412         elif self._method == 'soft_llr':
    413             inputs = -1. * inputs
    414 
    415         inputs = tf.cast(inputs, tf.float32)
    416 
    417         output_shape = inputs.get_shape().as_list()
    418         y_resh = tf.reshape(inputs, [-1, self._n])
    419         output_shape[0] = -1
    420         if self._return_info_bits:
    421             output_shape[-1] = self._k # assign k to the last dimension
    422         else:
    423             output_shape[-1] = self._n
    424         # Branch metrics matrix for a given y
    425         bm_mat = self._bmcalc(y_resh)
    426 
    427         init_cm_np = np.full((self._ns,), LARGEDIST)
    428         init_cm_np[0] = 0.0
    429         prev_cm_ = tf.convert_to_tensor(init_cm_np, dtype=tf.float32)
    430         prev_cm = tf.tile(prev_cm_[None,:], [tf.shape(y_resh)[0], 1])
    431 
    432         cm_ta, tb_ta = self._update_fwd(prev_cm, bm_mat)
    433 
    434         cm = tf.transpose(cm_ta.stack(), perm=[1,2,0])
    435         tb = tf.transpose(tb_ta.stack(),perm=[1,2,0])
    436         del cm_ta, tb_ta
    437 
    438         zero_st = tf.zeros((tf.shape(y_resh)[0], 1), tf.int32)
    439         opt_path = self._optimal_path(cm, tb)
    440         opt_path = tf.concat((zero_st, opt_path), axis=1)
    441         del cm, tb
    442         msghat, cwhat = self._op_bits_path(opt_path)
    443         if self._return_info_bits:
    444             msghat = msghat[...,:self._k]
    445             output = tf.cast(msghat, self.output_dtype)
    446         else:
    447             output = tf.cast(cwhat, self.output_dtype)
    448         output_reshaped = tf.reshape(output, output_shape)
    449 
    450         return output_reshaped
    451 
    452 
    453 class BCJRDecoder(Layer):
    454     # pylint: disable=line-too-long
    455     r"""BCJRDecoder(encoder=None, gen_poly=None, rate=1/2, constraint_length=3, rsc=False, terminate=False, hard_out=True, algorithm='map', output_dtype=tf.float32, **kwargs)
    456 
    457     Implements the BCJR decoding algorithm [BCJR]_ that returns an
    458     estimate of the information bits for a noisy convolutional codeword.
    459     Takes as input either channel LLRs or a tuple
    460     (channel LLRs, apriori LLRs). Returns an estimate of the information
    461     bits, either output LLRs ( ``hard_out`` = `False`) or hard decoded
    462     bits ( ``hard_out`` = `True`), respectively.
    463 
    464     The class inherits from the Keras layer class and can be used as layer in
    465     a Keras model.
    466 
    467     Parameters
    468     ----------
    469     encoder: :class:`~sionna.fec.conv.encoding.ConvEncoder`
    470         If ``encoder`` is provided as input, the following input parameters
    471         are not required and will be ignored: ``gen_poly``, ``rate``,
    472         ``constraint_length``, ``rsc``, ``terminate``. They will be inferred
    473         from the ``encoder``  object itself. If ``encoder`` is `None`, the
    474         above parameters must be provided explicitly.
    475 
    476     gen_poly: tuple
    477         tuple of strings with each string being a 0, 1 sequence. If `None`,
    478         ``rate`` and ``constraint_length`` must be provided.
    479 
    480     rate: float
    481         Valid values are 1/3 and 1/2. Only required if ``gen_poly`` is `None`.
    482 
    483     constraint_length: int
    484         Valid values are between 3 and 8 inclusive. Only required if
    485         ``gen_poly`` is `None`.
    486 
    487     rsc: boolean
    488         Boolean flag indicating whether the encoder is recursive-systematic for
    489         given generator polynomials. `True` indicates encoder is
    490         recursive-systematic. `False` indicates encoder is feed-forward non-systematic.
    491 
    492     terminate: boolean
    493         Boolean flag indicating whether the codeword is terminated.
    494         `True` indicates codeword is terminated to all-zero state.
    495         `False` indicates codeword is not terminated.
    496 
    497     hard_out: boolean
    498         Boolean flag indicating whether to output hard or soft decisions on
    499         the decoded information vector.
    500         `True` implies a hard-decoded information vector of 0/1's as output.
    501         `False` implies output is decoded LLR's of the information.
    502 
    503     algorithm: str
    504         Defaults to `map`. Indicates the implemented BCJR algorithm,
    505         where `map` denotes the exact MAP algorithm, `log` indicates the
    506         exact MAP implementation, but in log-domain, and
    507         `maxlog` indicates the approximated MAP implementation in log-domain,
    508         where :math:`\log(e^{a}+e^{b}) \sim \max(a,b)`.
    509 
    510     output_dtype: tf.DType
    511         Defaults to tf.float32. Defines the output datatype of the layer.
    512 
    513     Input
    514     -----
    515     llr_ch or (llr_ch, llr_a) :
    516         Tensor or Tuple:
    517 
    518     llr_ch: [...,n], tf.float32
    519         2+D tensor containing the (noisy) channel
    520         LLRs, where `n` denotes the codeword length
    521 
    522     llr_a: [...,k], tf.float32
    523         2+D tensor containing the a priori information of each information bit.
    524         Implicitly assumed to be 0 if only ``llr_ch`` is provided.
    525 
    526     Output
    527     ------
    528     : tf.float32
    529         2+D tensor of shape `[...,coderate*n]` containing the estimates of the
    530         information bit tensor
    531 
    532     """
    533 
    534     def __init__(self,
    535                  encoder=None,
    536                  gen_poly=None,
    537                  rate=1/2,
    538                  constraint_length=3,
    539                  rsc=False,
    540                  terminate=False,
    541                  hard_out=True,
    542                  algorithm='map',
    543                  output_dtype=tf.float32,
    544                  **kwargs):
    545 
    546         super().__init__(**kwargs)
    547         if encoder is not None:
    548             self._gen_poly = encoder.gen_poly
    549             self._trellis = encoder.trellis
    550             self._terminate = encoder.terminate
    551         else:
    552             if gen_poly is not None:
    553                 assert all(isinstance(poly, str) for poly in gen_poly), \
    554                     "Each polynomial must be a string."
    555                 assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \
    556                     "Each polynomial must be of same length."
    557                 assert all(all(
    558                     char in ['0','1'] for char in poly) for poly in gen_poly),\
    559                     "Each polynomial must be a string of 0's and 1's."
    560                 self._gen_poly = gen_poly
    561             else:
    562                 valid_rates = (1/2, 1/3)
    563                 valid_constraint_length = (3, 4, 5, 6, 7, 8)
    564 
    565                 assert constraint_length in valid_constraint_length, \
    566                     "Constraint length must be between 3 and 8."
    567                 assert rate in valid_rates, \
    568                     "Rate must be 1/3 or 1/2."
    569                 self._gen_poly = polynomial_selector(rate, constraint_length)
    570 
    571                 # init Trellis parameters
    572             self._trellis = Trellis(self.gen_poly, rsc=rsc)
    573             self._terminate = terminate
    574 
    575         valid_algorithms = ['map', 'log', 'maxlog']
    576         assert algorithm in valid_algorithms, \
    577             "algorithm must be one of map, log or maxlog"
    578 
    579         self._coderate_desired = 1/len(self._gen_poly)
    580         self._mu = len(self._gen_poly[0])-1
    581 
    582         self._num_term_bits = None
    583         self._num_term_syms = None
    584 
    585         # conv_k denotes number of input bit streams
    586         # can only be 1 in current implementation
    587         self._conv_k = self._trellis.conv_k
    588         assert self._conv_k ==  1
    589         self._mu = self._trellis._mu
    590         # conv_n denotes number of output bits for conv_k input bits
    591         self._conv_n = self._trellis.conv_n
    592 
    593         # Length of Info-bit vector. Equal to _num_syms if terminate=False,
    594         # else < _num_syms
    595         self._k = None
    596         # Length of Turbo codeword, including termination bits
    597         self._n = None
    598         # num_syms denote number of encoding periods or state transitions.
    599         self._num_syms = None
    600 
    601         self._ni = 2**self._conv_k
    602         self._no = 2**self._conv_n
    603         self._ns = self._trellis.ns
    604 
    605         self._hard_out = hard_out
    606         self._algorithm = algorithm
    607 
    608         self._output_dtype = output_dtype
    609         self.ipst_op_idx, self.ipst_ip_idx = self._mask_by_tonode()
    610 
    611     #########################################
    612     # Public methods and properties
    613     #########################################
    614 
    615     @property
    616     def gen_poly(self):
    617         """Generator polynomial used by the encoder"""
    618         return self._gen_poly
    619 
    620     @property
    621     def coderate(self):
    622         """Rate of the code used in the encoder"""
    623         if self.terminate and self._n is None:
    624             print("Note that, due to termination, the true coderate is lower "\
    625                   "than the returned design rate. "\
    626                   "The exact true rate is dependent on the value of n and "\
    627                   "hence cannot be computed before the first call().")
    628             self._coderate = self._coderate_desired
    629         elif self.terminate and self._n is not None:
    630             k = self._coderate_desired*self._n - self._mu
    631             self._coderate = k/self._n
    632         return self._coderate
    633 
    634     @property
    635     def trellis(self):
    636         """Trellis object used during encoding"""
    637         return self._trellis
    638 
    639     @property
    640     def terminate(self):
    641         """Indicates if the encoder is terminated during codeword generation"""
    642         return self._terminate
    643 
    644     @property
    645     def k(self):
    646         """Number of information bits per codeword"""
    647         if self._k is None:
    648             print("Note: The value of k cannot be computed before the first " \
    649                   "call().")
    650         return self._k
    651 
    652     @property
    653     def n(self):
    654         """Number of codeword bits"""
    655         if self._n is None:
    656             print("Note: The value of n cannot be computed before the first " \
    657                   "call().")
    658         return self._n
    659 
    660     #########################
    661     # Utility functions
    662     #########################
    663 
    664     def _mask_by_tonode(self):
    665         """
    666         Assume i->j a valid state transition given info-bit b & emits symbol k
    667         returns following two _ns x _no matrices, each element of shape (2,).
    668         - st_op_idx: jth row contains (i,k) tuples
    669         - st_ip_idx: jth row contains (i,b) tuples
    670 
    671         When applied as tf.gather on a _ns x _no matrix, the output is
    672         matrix sorted by next_state.
    673 
    674         For e.g., tf.gather when applied on "input" (shape _ns x _no), with mask
    675         - st_op_idx: gathers input[i][k] in row j,
    676         - st_ip_idx: gathers input[i][b] in row j.
    677         """
    678 
    679         cnst = self._ns * self._ni
    680         from_nodes_vec = tf.reshape(self._trellis.from_nodes,(cnst,))
    681         op_idx = tf.reshape(self._trellis.op_by_tonode, (cnst,))
    682         st_op_idx = tf.transpose(tf.stack([from_nodes_vec, op_idx]))
    683         st_op_idx = tf.reshape(st_op_idx[None,:,:],(self._ns, self._ni, 2))
    684 
    685         ip_idx = tf.reshape(self._trellis.ip_by_tonode, (cnst,))
    686         st_ip_idx = tf.transpose(tf.stack([from_nodes_vec, ip_idx]))
    687         st_ip_idx = tf.reshape(st_ip_idx[None,:,:],(self._ns, self._ni, 2))
    688 
    689         return st_op_idx, st_ip_idx
    690 
    691     def _bmcalc(self, llr_in):
    692         """
    693         Calculate branch gamma metrics for a given noisy codeword tensor.
    694         For each time period t, _bmcalc computes the "distance" of symbol
    695         vector y[t] from each possible output symbol i.e.,
    696         (2*Eb/N0)* sum_i x_y*y_i for i=1,2,...,conv_n
    697 
    698         The above metric is used in calculation of gamma.
    699         If the input is llr, which is nothing but 2*Eb*y/N0.
    700         """
    701         op_bits = np.stack(
    702                     [int2bin(op, self._conv_n) for op in range(self._no)])
    703         op_mat = tf.cast(tf.tile(op_bits, [1, self._num_syms]), tf.float32)
    704         op_mat = tf.expand_dims(op_mat, axis=0)
    705         llr_in = tf.expand_dims(llr_in, axis=1)
    706         op_mat_sign = 1. - 2. * op_mat
    707 
    708         llr_sign = tf.math.multiply(llr_in, op_mat_sign)
    709         half_llr_sign = tf.reshape(0.5 * llr_sign,
    710             (-1, self._no, self._num_syms, self._conv_n))
    711 
    712         if self._algorithm in ['log', 'maxlog']:
    713             bm = tf.math.reduce_sum(half_llr_sign, axis=-1)
    714         else:
    715             bm = tf.math.exp(tf.math.reduce_sum(half_llr_sign, axis=-1))
    716 
    717         return bm
    718 
    719     def _initialize(self, llr_ch):
    720         if self._algorithm in ['log', 'maxlog']:
    721             init_vals = -np.inf, 0.0
    722         else:
    723             init_vals = 0.0, 1.0
    724         alpha_init_np = np.full((self._ns,), init_vals[0])
    725         alpha_init_np[0] = init_vals[1]
    726 
    727         beta_init_np = alpha_init_np
    728         if not self._terminate:
    729             eq_prob = 1./self._ns
    730             if self._algorithm in ['log', 'maxlog']:
    731                 eq_prob = np.log(eq_prob)
    732             beta_init_np = np.full((self._ns,), eq_prob)
    733 
    734         alpha_init = tf.convert_to_tensor(alpha_init_np, dtype=tf.float32)
    735         alpha_init = tf.tile(alpha_init[None,:], [tf.shape(llr_ch)[0], 1])
    736         beta_init = tf.convert_to_tensor(beta_init_np, dtype=tf.float32)
    737         beta_init = tf.tile(beta_init[None,:], [tf.shape(llr_ch)[0], 1])
    738         return alpha_init, beta_init
    739 
    740     def _update_fwd(self, alph_init, bm_mat, llr):
    741         """
    742         Run forward update from time t=0 to t=k-1.
    743         At each time t, computes alpha_t using alpha_t-1 and gamma_t.
    744 
    745         Returns tensor array of alpha_t, t-0,1,2...,k-1
    746         """
    747         alph_ta = tf.TensorArray(tf.float32, size=self._num_syms+1,
    748                                   dynamic_size=False, clear_after_read=False)
    749         alph_prev = tf.cast(alph_init, tf.float32)
    750 
    751         # (bs, _Ns, _ni, 2) matrix
    752         ipst_ip_mask = tf.tile(
    753             self.ipst_ip_idx[None,:],[tf.shape(alph_init)[0],1,1,1])
    754         # (bs, _Ns, _ni) matrix, by from state
    755         op_mask = tf.tile(self.trellis.op_by_fromnode[None,:,:],
    756                           [tf.shape(alph_init)[0],1,1])
    757         ipbit_mat = tf.tile(tf.range(self._ni)[None, None, :],
    758                             [tf.shape(alph_init)[0], self._ns, 1])
    759         ipbitsign_mat = 1. - 2. * tf.cast(ipbit_mat, tf.float32)
    760         alph_ta = alph_ta.write(0, alph_prev)
    761         for t in tf.range(self._num_syms):
    762             bm_t = bm_mat[..., t]
    763             llr_t = 0.5 * llr[...,t][:, None,None]
    764 
    765             bm_byfromst = tf.gather(bm_t, op_mask, batch_dims=1)
    766             signed_half_llr = tf.math.multiply(
    767                 tf.tile(llr_t,[1, self._ns, self._ni]), ipbitsign_mat)
    768             if self._algorithm in ['log', 'maxlog']:
    769                 llr_byfromst = signed_half_llr
    770                 gamma_byfromst = llr_byfromst + bm_byfromst
    771                 alph_gam_prod = gamma_byfromst + alph_prev[:,:,None]
    772             else:
    773                 llr_byfromst = tf.math.exp(signed_half_llr)
    774                 gamma_byfromst = tf.multiply(llr_byfromst, bm_byfromst)
    775                 alph_gam_prod = tf.math.multiply(gamma_byfromst,
    776                                              alph_prev[:,:,None])
    777 
    778             alphgam_bytost = tf.gather_nd(alph_gam_prod,
    779                                           ipst_ip_mask,
    780                                           batch_dims=1)
    781             if self._algorithm =='map':
    782                 alph_t = tf.math.reduce_sum(alphgam_bytost, axis=-1)
    783                 alph_t_sum = tf.reduce_sum(alph_t, axis=-1)
    784                 alph_t = tf.divide(alph_t, tf.tile(alph_t_sum[:,None],[1,self._ns]))
    785             elif self._algorithm == 'log':
    786                 alph_t = tf.math.reduce_logsumexp(alphgam_bytost, axis=-1)
    787             else:  # self._algorithm = 'maxlog'
    788                 alph_t = tf.math.reduce_max(alphgam_bytost, axis=-1)
    789 
    790             alph_prev = alph_t
    791             alph_ta = alph_ta.write(t+1, alph_t)
    792         return alph_ta
    793 
    794     def _update_bwd(self, beta_init, bm_mat, llr, alpha_ta):
    795         """
    796         Run backward update from time t=k-1 to t=0.
    797         At each time t, computes beta_t-1 using beta_t and gamma_t.
    798 
    799         Returns llr for information bits for t=0,1,...,k-1
    800         """
    801 
    802         beta_next = beta_init
    803         llr_op_ta = tf.TensorArray(tf.float32,
    804                                   size=self._num_syms,
    805                                   dynamic_size=False,
    806                                   clear_after_read=False)
    807         beta_next = tf.cast(beta_next, tf.float32)
    808 
    809         # (bs, _Ns, _ni) matrix, by from state
    810         op_mask = tf.tile(self.trellis.op_by_fromnode[None,:,:],
    811                           [tf.shape(beta_init)[0],1,1])
    812         tonode_mask = tf.tile(self.trellis.to_nodes[None,:,:],
    813                               [tf.shape(beta_init)[0], 1, 1])
    814 
    815         ipbit_mat = tf.tile(tf.range(self._ni)[None, None, :],
    816                             [tf.shape(beta_init)[0], self._ns, 1])
    817         ipbitsign_mat = 1.0 - 2.0 * tf.cast(ipbit_mat, tf.float32)
    818 
    819         for t in tf.range(self._num_syms-1, -1, -1):
    820             bm_t = bm_mat[..., t]
    821             llr_t = 0.5 * llr[...,t][:, None,None]
    822             signed_half_llr = tf.math.multiply(
    823                 tf.tile(llr_t,[1, self._ns, self._ni]), ipbitsign_mat)
    824             bm_byfromst = tf.gather(bm_t, op_mask, batch_dims=1)
    825 
    826             if self._algorithm in ['log', 'maxlog']:
    827                 llr_byfromst = signed_half_llr
    828                 gamma_byfromst = tf.math.add(llr_byfromst, bm_byfromst)
    829             else:
    830                 llr_byfromst = tf.math.exp(signed_half_llr)
    831                 gamma_byfromst = tf.multiply(llr_byfromst, bm_byfromst)
    832 
    833             beta_bytonode = tf.gather(beta_next, tonode_mask, batch_dims=1)
    834 
    835             if self._algorithm not in ['log', 'maxlog']:
    836                 beta_gam_prod = tf.math.multiply(gamma_byfromst, beta_bytonode)
    837                 beta_t = tf.math.reduce_sum(beta_gam_prod, axis=-1)
    838                 beta_t_sum = tf.reduce_sum(beta_t, axis=-1)
    839                 beta_t = tf.divide(beta_t, tf.tile(beta_t_sum[:,None],[1,self._ns]))
    840             elif self._algorithm == 'log':
    841                 beta_gam_prod = gamma_byfromst + beta_bytonode
    842                 beta_t = tf.math.reduce_logsumexp(beta_gam_prod, axis=-1, keepdims=False)
    843             else: #self._algorithm = 'maxlog'
    844                 beta_gam_prod = gamma_byfromst + beta_bytonode
    845                 beta_t = tf.math.reduce_max(beta_gam_prod, axis=-1)
    846 
    847             alph_t = alpha_ta.read(t)
    848             if self._algorithm not in ['log', 'maxlog']:
    849                 llr_op_t0 = tf.math.multiply(
    850                                 tf.math.multiply(alph_t, gamma_byfromst[...,0]),
    851                                             beta_bytonode[...,0])
    852                 llr_op_t1 = tf.math.multiply(
    853                                 tf.math.multiply(alph_t,gamma_byfromst[...,1]),
    854                                             beta_bytonode[...,1])
    855                 llr_op_t = tf.math.log(tf.divide(tf.reduce_sum(llr_op_t0, axis=-1),
    856                                                 tf.reduce_sum(llr_op_t1,axis=-1)))
    857             else:
    858                 llr_op_t0 = alph_t + gamma_byfromst[...,0] + beta_bytonode[...,0]
    859                 llr_op_t1 = alph_t + gamma_byfromst[...,1] + beta_bytonode[...,1]
    860                 if self._algorithm == 'log':
    861                     llr_op_t = tf.math.subtract(
    862                         tf.math.reduce_logsumexp(llr_op_t0, axis=-1),
    863                         tf.math.reduce_logsumexp(llr_op_t1, axis=-1))
    864                 else:
    865                     llr_op_t = tf.math.subtract(
    866                         tf.math.reduce_max(llr_op_t0, axis=-1),
    867                         tf.math.reduce_max(llr_op_t1, axis=-1))
    868 
    869             llr_op_ta = llr_op_ta.write(t, llr_op_t)
    870             beta_next = beta_t
    871 
    872         llr_op = tf.transpose(llr_op_ta.stack())
    873         return llr_op
    874 
    875     #########################
    876     # Keras layer functions
    877     #########################
    878 
    879     def build(self, input_shape):
    880         """Build layer and check dimensions."""
    881         # assert rank must be two
    882         tf.debugging.assert_greater_equal(len(input_shape), 2)
    883 
    884         if isinstance(input_shape, tf.TensorShape):
    885             self._n = input_shape[-1]
    886         else:
    887             self._n = input_shape[0][-1]
    888 
    889         self._num_syms = int(self._n*self._coderate_desired)
    890 
    891         self._num_term_syms = self._mu if self._terminate else 0
    892         self._num_term_bits = int(self._num_term_syms/self._coderate_desired)
    893 
    894         self._k = self._num_syms - self._num_term_syms
    895 
    896     def call(self, inputs):
    897         """
    898         BCJR decoding function.
    899         inputs is the (noisy) codeword tensor where the last dimension should
    900         equal n. All the leading dimensions are assumed as batch dimensions.
    901         """
    902         if isinstance(inputs, (tuple, list)):
    903             assert(len(inputs)) == 2
    904             llr_ch, llr_apr = inputs
    905         else:
    906             tf.debugging.assert_greater(tf.rank(inputs), 1)
    907             llr_ch = inputs
    908             llr_apr = None
    909 
    910         tf.debugging.assert_type(llr_ch,
    911                                  tf.float32,
    912                                  message="input must be tf.float32.")
    913 
    914         output_shape = llr_ch.get_shape().as_list()
    915 
    916         # allow different codeword lengths in eager mode
    917         if output_shape[-1] != self._n:
    918             if isinstance(inputs, (tuple, list)):
    919                 self.build((inputs[0].get_shape(),
    920                             inputs[1].get_shape()))
    921             else:
    922                 self.build(llr_ch.get_shape().as_list())
    923 
    924         output_shape[0] = -1
    925         output_shape[-1] = self._k # assign k to the last dimension
    926         llr_ch = tf.reshape(llr_ch, [-1, self._n])
    927 
    928         if llr_apr is None:
    929             llr_apr = tf.zeros((tf.shape(llr_ch)[0], self._num_syms),
    930                                dtype=tf.float32)
    931         llr_ch = -1. * llr_ch
    932         llr_apr = -1. * llr_apr
    933 
    934         # Branch metrics matrix for a given y
    935         bm_mat = self._bmcalc(llr_ch)
    936         alpha_init, beta_init = self._initialize(llr_ch)
    937 
    938         alph_ta = self._update_fwd(alpha_init, bm_mat, llr_apr)
    939         llr_op = self._update_bwd(beta_init, bm_mat, llr_apr, alph_ta)
    940 
    941         msghat = -1. * llr_op[...,:self._k]
    942         if self._hard_out: # hard decide decoder output if required
    943             msghat = tf.less(0.0, msghat)
    944         msghat = tf.cast(msghat, self._output_dtype)
    945         msghat_reshaped = tf.reshape(msghat, output_shape)
    946 
    947         return msghat_reshaped