anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

decoding.py (90204B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for Polar decoding such as successive cancellation (SC), successive
      6 cancellation list (SCL) and iterative belief propagation (BP) decoding."""
      7 
      8 import tensorflow as tf
      9 import numpy as np
     10 from numpy.core.numerictypes import issubdtype
     11 import warnings
     12 from tensorflow.keras.layers import Layer
     13 from sionna.fec.crc import CRCDecoder, CRCEncoder
     14 from sionna.fec.polar.encoding import Polar5GEncoder
     15 import numbers
     16 
     17 class PolarSCDecoder(Layer):
     18     """PolarSCDecoder(frozen_pos, n, output_dtype=tf.float32, **kwargs)
     19 
     20     Successive cancellation (SC) decoder [Arikan_Polar]_ for Polar codes and
     21     Polar-like codes.
     22 
     23     The class inherits from the Keras layer class and can be used as layer in a
     24     Keras model.
     25 
     26     Parameters
     27     ----------
     28         frozen_pos: ndarray
     29             Array of `int` defining the ``n-k`` indices of the frozen positions.
     30 
     31         n: int
     32             Defining the codeword length.
     33 
     34        output_dtype: tf.DType
     35         Defaults to tf.float32. Defines the output datatype of the layer
     36         (internal precision remains tf.float32).
     37 
     38     Input
     39     -----
     40         inputs: [...,n], tf.float32
     41             2+D tensor containing the channel LLR values (as logits).
     42 
     43     Output
     44     ------
     45         : [...,k], tf.float32
     46             2+D tensor  containing hard-decided estimations of all ``k``
     47             information bits.
     48 
     49     Raises
     50     ------
     51         AssertionError
     52             If ``n`` is not `int`.
     53 
     54         AssertionError
     55             If ``n`` is not a power of 2.
     56 
     57         AssertionError
     58             If the number of elements in ``frozen_pos`` is greater than ``n``.
     59 
     60         AssertionError
     61             If ``frozen_pos`` does not consists of `int`.
     62 
     63         ValueError
     64             If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}.
     65 
     66     Note
     67     ----
     68         This layer implements the SC decoder as described in
     69         [Arikan_Polar]_. However, the implementation follows the `recursive
     70         tree` [Gross_Fast_SCL]_ terminology and combines nodes for increased
     71         throughputs without changing the outcome of the algorithm.
     72 
     73         As commonly done, we assume frozen bits are set to `0`. Please note
     74         that - although its practical relevance is only little - setting frozen
     75         bits to `1` may result in `affine` codes instead of linear code as the
     76         `all-zero` codeword is not necessarily part of the code any more.
     77 
     78     """
     79 
     80     def __init__(self, frozen_pos, n, output_dtype=tf.float32, **kwargs):
     81 
     82         if output_dtype not in (tf.float16, tf.float32, tf.float64):
     83             raise ValueError(
     84                 'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
     85 
     86         if output_dtype is not tf.float32:
     87             print('Note: decoder uses tf.float32 for internal calculations.')
     88 
     89         super().__init__(dtype=output_dtype, **kwargs)
     90         self._output_dtype = output_dtype
     91 
     92         # assert error if r>1 or k, n are negativ
     93         assert isinstance(n, numbers.Number), "n must be a number."
     94         n = int(n) # n can be float (e.g. as result of n=k*r)
     95 
     96         assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int."
     97         assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \
     98             "be greater than n."
     99         assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
    100 
    101         # store internal attributes
    102         self._n = n
    103         self._frozen_pos = frozen_pos
    104         self._k = self._n - len(self._frozen_pos)
    105         self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
    106         assert self._k==len(self._info_pos), "Internal error: invalid " \
    107                                               "info_pos generated."
    108         self._llr_max = 30. # internal max LLR value (uncritical for SC dec)
    109         # and create a frozen bit vector for simpler encoding
    110         self._frozen_ind = np.zeros(self._n)
    111         self._frozen_ind[self._frozen_pos] = 1
    112 
    113         # enable graph pruning
    114         self._use_fast_sc = False
    115 
    116     #########################################
    117     # Public methods and properties
    118     #########################################
    119 
    120     @property
    121     def n(self):
    122         """Codeword length."""
    123         return self._n
    124 
    125     @property
    126     def k(self):
    127         """Number of information bits."""
    128         return self._k
    129 
    130     @property
    131     def frozen_pos(self):
    132         """Frozen positions for Polar decoding."""
    133         return self._frozen_pos
    134 
    135     @property
    136     def info_pos(self):
    137         """Information bit positions for Polar encoding."""
    138         return self._info_pos
    139 
    140     @property
    141     def llr_max(self):
    142         """Maximum LLR value for internal calculations."""
    143         return self._llr_max
    144 
    145     @property
    146     def output_dtype(self):
    147         """Output dtype of decoder."""
    148         return self._output_dtype
    149 
    150     #########################
    151     # Utility methods
    152     #########################
    153 
    154     def _cn_op_tf(self, x, y):
    155         """Check-node update (boxplus) for LLR inputs.
    156 
    157         Operations are performed element-wise.
    158 
    159         See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
    160         """
    161         x_in = tf.clip_by_value(x,
    162                                 clip_value_min=-self._llr_max,
    163                                 clip_value_max=self._llr_max)
    164         y_in = tf.clip_by_value(y,
    165                                 clip_value_min=-self._llr_max,
    166                                 clip_value_max=self._llr_max)
    167 
    168         # avoid division for numerical stability
    169         llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in))
    170         llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in))
    171 
    172         return llr_out
    173 
    174     def _vn_op_tf(self, x, y, u_hat):
    175         """VN update for LLR inputs."""
    176         return tf.multiply((1-2*u_hat), x) + y
    177 
    178     def _polar_decode_sc_tf(self, llr_ch, frozen_ind):
    179         """Recursive SC decoding function.
    180 
    181         Recursively branch decoding tree and split into decoding of `upper`
    182         and `lower` path until reaching a leaf node.
    183 
    184         The function returns the u_hat decisions at stage `0` and the bit
    185         decisions of the intermediate stage `s` (i.e., the re-encoded version of
    186         `u_hat` until the current stage `s`).
    187 
    188         Note:
    189             This decoder parallelizes over the batch-dimension, i.e., the tree
    190             is processed for all samples in the batch in parallel. This yields a
    191             higher throughput, but does not improve the latency.
    192         """
    193 
    194         # calculate current codeword length
    195         n = len(frozen_ind)
    196 
    197         # branch if leaf is not reached yet
    198         if n>1:
    199             if self._use_fast_sc:
    200                 if np.sum(frozen_ind)==n:
    201                     #print("rate-0 detected! Length: ", n)
    202                     u_hat = tf.zeros_like(llr_ch)
    203                     return u_hat, u_hat
    204 
    205             llr_ch1 = llr_ch[...,0:int(n/2)]
    206             llr_ch2 = llr_ch[...,int(n/2):]
    207             frozen_ind1 = frozen_ind[0:int(n/2)]
    208             frozen_ind2 = frozen_ind[int(n/2):]
    209 
    210             # upper path
    211             x_llr1_in = self._cn_op_tf(llr_ch1, llr_ch2)
    212 
    213             # and call the decoding function (with upper half)
    214             u_hat1, u_hat1_up = self._polar_decode_sc_tf(x_llr1_in, frozen_ind1)
    215 
    216             # lower path
    217             x_llr2_in = self._vn_op_tf(llr_ch1, llr_ch2, u_hat1_up)
    218             # and call the decoding function again (with lower half)
    219             u_hat2, u_hat2_up = self._polar_decode_sc_tf(x_llr2_in, frozen_ind2)
    220 
    221             # combine u_hat from both branches
    222             u_hat = tf.concat([u_hat1, u_hat2], -1)
    223 
    224             # calculate re-encoded version of u_hat at current stage
    225             # u_hat1_up = tf.math.mod(u_hat1_up + u_hat2_up, 2)
    226             # combine u_hat via bitwise_xor (more efficient than mod2)
    227             u_hat1_up_int = tf.cast(u_hat1_up, tf.int8)
    228             u_hat2_up_int = tf.cast(u_hat2_up, tf.int8)
    229             u_hat1_up_int = tf.bitwise.bitwise_xor(u_hat1_up_int,
    230                                                    u_hat2_up_int)
    231             u_hat1_up = tf.cast(u_hat1_up_int , tf.float32)
    232             u_hat_up = tf.concat([u_hat1_up, u_hat2_up], -1)
    233 
    234         else: # if leaf is reached perform basic decoding op (=decision)
    235 
    236             if frozen_ind==1: # position is frozen
    237                 u_hat = tf.expand_dims(tf.zeros_like(llr_ch[:,0]), axis=-1)
    238                 u_hat_up = u_hat
    239             else: # otherwise hard decide
    240                 u_hat = 0.5 * (1. - tf.sign(llr_ch))
    241                 #remove "exact 0 llrs" leading to u_hat=0.5
    242                 u_hat = tf.where(tf.equal(u_hat, 0.5),
    243                                  tf.ones_like(u_hat),
    244                                  u_hat)
    245                 u_hat_up = u_hat
    246         return u_hat, u_hat_up
    247 
    248     #########################
    249     # Keras layer functions
    250     #########################
    251 
    252     def build(self, input_shape):
    253         """Check if shape of input is invalid."""
    254         assert (input_shape[-1]==self._n), "Invalid input shape."
    255         assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
    256 
    257     def call(self, inputs):
    258         """Successive cancellation (SC) decoding function.
    259 
    260         Performs successive cancellation decoding and returns the estimated
    261         information bits.
    262 
    263         Args:
    264             inputs (tf.float32): Tensor of shape `[...,n]` containing the
    265                 channel LLR values (as logits).
    266 
    267         Returns:
    268             `tf.float32`: Tensor of shape `[...,k]` containing
    269             hard-decided estimations of all ``k`` information bits.
    270 
    271         Raises:
    272             ValueError: If ``inputs`` is not of shape `[..., n]`
    273                 or `dtype` is not `tf.float32`.
    274 
    275             InvalidArgumentError: When rank(``inputs``)<2.
    276 
    277         Note:
    278             This function recursively unrolls the SC decoding tree, thus,
    279             for larger values of ``n`` building the decoding graph can become
    280             time consuming.
    281         """
    282 
    283         tf.debugging.assert_type(inputs, self.dtype, 'Invalid input dtype.')
    284         # internal calculations still in tf.float32
    285         inputs = tf.cast(inputs, tf.float32)
    286 
    287         # last dim must be of length n
    288         tf.debugging.assert_equal(tf.shape(inputs)[-1],
    289                                   self._n,
    290                                   "Last input dimension must be of length n.")
    291 
    292         # Reshape inputs to [-1, n]
    293         tf.debugging.assert_greater(tf.rank(inputs), 1)
    294         input_shape = inputs.shape
    295         new_shape = [-1, self._n]
    296         llr_ch = tf.reshape(inputs, new_shape)
    297 
    298         llr_ch = -1. * llr_ch # logits are converted into "true" llrs
    299 
    300         # and decode
    301         u_hat_n, _ = self._polar_decode_sc_tf(llr_ch, self._frozen_ind)
    302 
    303         # and recover the k information bit positions
    304         u_hat = tf.gather(u_hat_n, self._info_pos, axis=1)
    305 
    306         # and reconstruct input shape
    307         output_shape = input_shape.as_list()
    308         output_shape[-1] = self.k
    309         output_shape[0] = -1 # first dim can be dynamic (None)
    310         u_hat_reshape = tf.reshape(u_hat, output_shape)
    311         return tf.cast(u_hat_reshape, self._output_dtype)
    312 
    313 class PolarSCLDecoder(Layer):
    314     # pylint: disable=line-too-long
    315     """PolarSCLDecoder(frozen_pos, n, list_size=8, crc_degree=None, use_hybrid_sc=False, use_fast_scl=True, cpu_only=False, use_scatter=False, ind_iil_inv=None, return_crc_status=False, output_dtype=tf.float32, **kwargs)
    316 
    317     Successive cancellation list (SCL) decoder [Tal_SCL]_ for Polar codes
    318     and Polar-like codes.
    319 
    320     The class inherits from the Keras layer class and can be used as layer in a
    321     Keras model.
    322 
    323     Parameters
    324     ----------
    325         frozen_pos: ndarray
    326             Array of `int` defining the ``n-k`` indices of the frozen positions.
    327 
    328         n: int
    329             Defining the codeword length.
    330 
    331         list_size: int
    332             Defaults to 8. Defines the list size of the decoder.
    333 
    334         crc_degree: str
    335             Defining the CRC polynomial to be used. Can be any value from
    336             `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`.
    337 
    338         use_hybrid_sc: bool
    339             Defaults to False. If True, SC decoding is applied and only the
    340             codewords with invalid CRC are decoded with SCL. This option
    341             requires an outer CRC specified via ``crc_degree``.
    342             Remark: hybrid_sc does not support XLA optimization, i.e.,
    343             `@tf.function(jit_compile=True)`.
    344 
    345         use_fast_scl: bool
    346             Defaults to True. If True, Tree pruning is used to
    347             reduce the decoding complexity. The output is equivalent to the
    348             non-pruned version (besides numerical differences).
    349 
    350         cpu_only: bool
    351             Defaults to False. If True, `tf.py_function` embedding
    352             is used and the decoder runs on the CPU. This option is usually
    353             slower, but also more memory efficient and, in particular,
    354             recommended for larger blocklengths. Remark: cpu_only does not
    355             support XLA optimization `@tf.function(jit_compile=True)`.
    356 
    357         use_scatter: bool
    358             Defaults to False. If True, `tf.tensor_scatter_update` is used for
    359             tensor updates. This option is usually slower, but more memory
    360             efficient.
    361 
    362         ind_iil_inv : None or [k+k_crc], int or tf.int
    363             Defaults to None. If not `None`, the sequence is used as inverse
    364             input bit interleaver before evaluating the CRC.
    365             Remark: this only effects the CRC evaluation but the output
    366             sequence is not permuted.
    367 
    368         return_crc_status: bool
    369             Defaults to False. If True, the decoder additionally returns the
    370             CRC status indicating if a codeword was (most likely) correctly
    371             recovered. This is only available if ``crc_degree`` is not None.
    372 
    373         output_dtype: tf.DType
    374             Defaults to tf.float32. Defines the output datatype of the layer
    375             (internal precision remains tf.float32).
    376 
    377     Input
    378     -----
    379         inputs: [...,n], tf.float32
    380             2+D tensor containing the channel LLR values (as logits).
    381 
    382     Output
    383     ------
    384         b_hat : [...,k], tf.float32
    385             2+D tensor containing hard-decided estimations of all `k`
    386             information bits.
    387 
    388         crc_status : [...], tf.bool
    389             CRC status indicating if a codeword was (most likely) correctly
    390             recovered. This is only returned if ``return_crc_status`` is True.
    391             Note that false positives are possible.
    392 
    393     Raises:
    394         AssertionError
    395             If ``n`` is not `int`.
    396 
    397         AssertionError
    398             If ``n`` is not a power of 2.
    399 
    400         AssertionError
    401             If the number of elements in ``frozen_pos`` is greater than ``n``.
    402 
    403         AssertionError
    404             If ``frozen_pos`` does not consists of `int`.
    405 
    406         AssertionError
    407             If ``list_size`` is not `int`.
    408 
    409         AssertionError
    410             If ``cpu_only`` is not `bool`.
    411 
    412         AssertionError
    413             If ``use_scatter`` is not `bool`.
    414 
    415         AssertionError
    416             If ``use_fast_scl`` is not `bool`.
    417 
    418         AssertionError
    419             If ``use_hybrid_sc`` is not `bool`.
    420 
    421         AssertionError
    422             If ``list_size`` is not a power of 2.
    423 
    424         ValueError
    425             If ``output_dtype`` is not {tf.float16, tf.float32, tf.
    426             float64}.
    427 
    428         ValueError
    429             If ``inputs`` is not of shape `[..., n]` or `dtype` is not
    430             correct.
    431 
    432         InvalidArgumentError
    433             When rank(``inputs``)<2.
    434 
    435     Note
    436     ----
    437         This layer implements the successive cancellation list (SCL) decoder
    438         as described in [Tal_SCL]_ but uses LLR-based message updates
    439         [Stimming_LLR]_. The implementation follows the notation from
    440         [Gross_Fast_SCL]_, [Hashemi_SSCL]_. If option `use_fast_scl` is active
    441         tree pruning is used and tree nodes are combined if possible (see
    442         [Hashemi_SSCL]_ for details).
    443 
    444         Implementing SCL decoding as TensorFlow graph is a difficult task that
    445         requires several design tradeoffs to match the TF constraints while
    446         maintaining a reasonable throughput. Thus, the decoder minimizes
    447         the `control flow` as much as possible, leading to a strong memory
    448         occupation (e.g., due to full path duplication after each decision).
    449         For longer code lengths, the complexity of the decoding graph becomes
    450         large and we recommend to use the `CPU_only` option that uses an
    451         embedded Numpy decoder. Further, this function recursively unrolls the
    452         SCL decoding tree, thus, for larger values of ``n`` building the
    453         decoding graph can become time consuming. Please consider the
    454         ``cpu_only`` option if building the graph takes to long.
    455 
    456         A hybrid SC/SCL decoder as proposed in [Cammerer_Hybrid_SCL]_ (using SC
    457         instead of BP) can be activated with option ``use_hybrid_sc`` iff an
    458         outer CRC is available. Please note that the results are not exactly
    459         SCL performance caused by the false positive rate of the CRC.
    460 
    461         As commonly done, we assume frozen bits are set to `0`. Please note
    462         that - although its practical relevance is only little - setting frozen
    463         bits to `1` may result in `affine` codes instead of linear code as the
    464         `all-zero` codeword is not necessarily part of the code any more.
    465     """
    466 
    467     def __init__(self,
    468                  frozen_pos,
    469                  n,
    470                  list_size=8,
    471                  crc_degree=None,
    472                  use_hybrid_sc=False,
    473                  use_fast_scl=True,
    474                  cpu_only=False,
    475                  use_scatter=False,
    476                  ind_iil_inv=None,
    477                  return_crc_status=False,
    478                  output_dtype=tf.float32,
    479                  **kwargs):
    480 
    481         if output_dtype not in (tf.float16, tf.float32, tf.float64):
    482             raise ValueError(
    483                 'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
    484 
    485         if output_dtype is not tf.float32:
    486             print('Note: decoder uses tf.float32 for internal calculations.')
    487 
    488         super().__init__(dtype=output_dtype, **kwargs)
    489         self._output_dtype = output_dtype
    490 
    491         # assert error if r>1 or k, n are negative
    492         assert isinstance(n, numbers.Number), "n must be a number."
    493         n = int(n) # n can be float (e.g. as result of n=k*r)
    494         assert isinstance(list_size, int), "list_size must be integer."
    495         assert isinstance(cpu_only, bool), "cpu_only must be bool."
    496         assert isinstance(use_scatter, bool), "use_scatter must be bool."
    497         assert isinstance(use_fast_scl, bool), "use_fast_scl must be bool."
    498         assert isinstance(use_hybrid_sc, bool), "use_hybrid_sc must be bool."
    499         assert isinstance(return_crc_status, bool), \
    500                                             "return_crc_status must be bool."
    501 
    502         assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int."
    503         assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \
    504             "be greater than n."
    505         assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
    506         assert np.log2(list_size)==int(np.log2(list_size)), \
    507                                     "list_size must be a power of 2."
    508 
    509         # CPU mode is recommended for larger values of n
    510         if n>128 and cpu_only is False and use_hybrid_sc is False:
    511             warnings.warn("Required resource allocation is large " \
    512             "for the selected blocklength. Consider option `cpu_only=True`.")
    513 
    514         # CPU mode is recommended for larger values of L
    515         if list_size>32 and cpu_only is False and use_hybrid_sc is False:
    516             warnings.warn("Resource allocation is high for the " \
    517             "selected list_size. Consider option `cpu_only=True`.")
    518 
    519         # internal decoder parameters
    520         self._use_fast_scl = use_fast_scl # optimize rate-0 and rep nodes
    521         self._use_scatter = use_scatter # slower but more memory friendly
    522         self._cpu_only = cpu_only # run numpy decoder
    523         self._use_hybrid_sc = use_hybrid_sc
    524 
    525         # store internal attributes
    526         self._n = n
    527         self._frozen_pos = frozen_pos
    528         self._k = self._n - len(self._frozen_pos)
    529         self._list_size = list_size
    530         self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
    531         self._llr_max = 30. # internal max LLR value (not very critical for SC)
    532         assert self._k==len(self._info_pos), "Internal error: invalid " \
    533                                              "info_pos generated."
    534         # create a frozen bit vector
    535         self._frozen_ind = np.zeros(self._n)
    536         self._frozen_ind[self._frozen_pos] = 1
    537         self._cw_ind = np.arange(self._n)
    538         self._n_stages = int(np.log2(self._n)) # number of decoding stages
    539 
    540         # init CRC check (if needed)
    541         if crc_degree is not None:
    542             self._use_crc = True
    543             self._crc_decoder = CRCDecoder(CRCEncoder(crc_degree))
    544             self._k_crc = self._crc_decoder.encoder.crc_length
    545         else:
    546             self._use_crc = False
    547             self._k_crc = 0
    548         assert self._k>=self._k_crc, "Value of k is too small for \
    549             given CRC_degree."
    550 
    551 
    552         if (crc_degree is None) and return_crc_status:
    553             self._return_crc_status = False
    554             raise ValueError("Returning CRC status requires given crc_degree.")
    555         else:
    556             self._return_crc_status = return_crc_status
    557 
    558 
    559         # store the inverse interleaver patter
    560         if ind_iil_inv is not None:
    561             assert (ind_iil_inv.shape[0]==self._k), \
    562                     "ind_int must be of length k+k_crc."
    563             self._ind_iil_inv = ind_iil_inv
    564             self._iil = True
    565         else:
    566             self._iil = False
    567 
    568         # use SC decoder first and use numpy-based SCL as "afterburner"
    569         if self._use_hybrid_sc:
    570             self._decoder_sc = PolarSCDecoder(frozen_pos, n)
    571             # Note: CRC required to detect SC success
    572             if not self._use_crc:
    573                 raise ValueError("Hybrid SC requires outer CRC.")
    574 
    575     #########################################
    576     # Public methods and properties
    577     #########################################
    578 
    579     @property
    580     def n(self):
    581         """Codeword length."""
    582         return self._n
    583 
    584     @property
    585     def k(self):
    586         """Number of information bits."""
    587         return self._k
    588 
    589     @property
    590     def k_crc(self):
    591         """Number of CRC bits."""
    592         return self._k_crc
    593 
    594     @property
    595     def frozen_pos(self):
    596         """Frozen positions for Polar decoding."""
    597         return self._frozen_pos
    598 
    599     @property
    600     def info_pos(self):
    601         """Information bit positions for Polar encoding."""
    602         return self._info_pos
    603 
    604     @property
    605     def llr_max(self):
    606         """Maximum LLR value for internal calculations."""
    607         return self._llr_max
    608 
    609     @property
    610     def list_size(self):
    611         """List size for SCL decoding."""
    612         return self._list_size
    613 
    614     @property
    615     def output_dtype(self):
    616         """Output dtype of decoder."""
    617         return self._output_dtype
    618 
    619     #####################################
    620     # Helper functions for the TF decoder
    621     #####################################
    622 
    623     def _update_rate0_code(self, msg_pm, msg_uhat, msg_llr, cw_ind):
    624         """Update rate-0 sub-code (i.e., all frozen) at pos ``cw_ind``.
    625 
    626         See eq. (26) in [Hashemi_SSCL]_.
    627 
    628         Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
    629         initialized with `0` already.
    630         """
    631         n = len(cw_ind)
    632         stage_ind = int(np.log2(n))
    633 
    634         llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2)
    635         llr_in = tf.clip_by_value(llr,
    636                                   clip_value_min=-self._llr_max,
    637                                   clip_value_max=self._llr_max)
    638 
    639         # update path metric for complete sub-block of length n
    640         pm_val = tf.math.softplus(-1.*llr_in)
    641         msg_pm += tf.reduce_sum(pm_val, axis=-1)
    642 
    643         return msg_pm, msg_uhat, msg_llr
    644 
    645     def _update_rep_code(self, msg_pm, msg_uhat, msg_llr, cw_ind):
    646         """Update rep. code (i.e., only rightmost bit is non-frozen)
    647         sub-code at position ``ind_u``.
    648 
    649         See Eq. (31) in [Hashemi_SSCL]_.
    650 
    651         Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
    652         initialized with `0` already.
    653         """
    654         n = len(cw_ind)
    655         stage_ind = int(np.log2(n))
    656 
    657         # update PM
    658         llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2)
    659         llr_in = tf.clip_by_value(llr,
    660                                   clip_value_min=-self._llr_max,
    661                                   clip_value_max=self._llr_max)
    662 
    663         # upper branch has negative llr values (bit is 1)
    664         llr_low =  llr_in[:, :self._list_size, :]
    665         llr_up = - llr_in[:, self._list_size:, :]
    666         llr_pm = tf.concat([llr_low, llr_up], 1)
    667         pm_val = tf.math.softplus(-1.*llr_pm)
    668         msg_pm += tf.reduce_sum(pm_val, axis=-1)
    669 
    670         msg_uhat1 = msg_uhat[:, :self._list_size, :, :]
    671         msg_uhat21 = tf.expand_dims(
    672                         msg_uhat[:, self._list_size:, stage_ind, :cw_ind[0]],
    673                         axis=2)
    674 
    675         msg_uhat22= tf.expand_dims(
    676                         msg_uhat[:, self._list_size:, stage_ind, cw_ind[-1]+1:],
    677                         axis=2)
    678         # ones to insert
    679         msg_ones = tf.ones([tf.shape(msg_uhat)[0], self._list_size, 1, n],
    680                             tf.float32)
    681 
    682         msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3)
    683         msg_uhat24_1 = msg_uhat[:, self._list_size:, :stage_ind, :]
    684         msg_uhat24_2 = msg_uhat[:, self._list_size:, stage_ind+1:, :]
    685 
    686         msg_uhat2 = tf.concat([msg_uhat24_1, msg_uhat23, msg_uhat24_2], 2)
    687         msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1)
    688 
    689         # branch last bit and update pm at pos cw_ind[-1]
    690         msg_uhat = self._update_single_bit([cw_ind[-1]], msg_uhat)
    691         msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
    692                                                         msg_uhat,
    693                                                         msg_llr)
    694         msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat,
    695                                                           msg_llr,
    696                                                           msg_pm)
    697         return msg_pm, msg_uhat, msg_llr
    698 
    699     def _update_single_bit(self, ind_u, msg_uhat):
    700         """Update single bit at position ``ind_u`` for all decoders.
    701 
    702         Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
    703         initialized with `0` already.
    704 
    705         Remark: Two versions are implemented (throughput vs. graph complexity):
    706         1.) use tensor_scatter_nd_update
    707         2.) explicitly split graph and concatenate again
    708         """
    709         # position is non-frozen
    710         if self._frozen_ind[ind_u[0]]==0:
    711 
    712             # msg_uhat[:, ind_up, 0, ind_u] = 1
    713             if self._use_scatter:
    714                 ind_dec = np.arange(self._list_size, 2*self._list_size, 1)
    715                 ind_stage = np.array([0])
    716 
    717                 # transpose such that batch dim can be broadcasted
    718                 msg_uhat_t = tf.transpose(msg_uhat, [1, 3, 2, 0])
    719 
    720                 # generate index grid
    721                 ind_u = tf.cast(ind_u, tf.int64)
    722                 grid = tf.meshgrid(ind_dec, ind_u, ind_stage)
    723                 ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 3])
    724 
    725                 updates = tf.ones([ind.shape[0], tf.shape(msg_uhat)[0]])
    726                 msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t,
    727                                                          ind,
    728                                                          updates)
    729                 # and restore original order
    730                 msg_uhat = tf.transpose(msg_uhat_s, [3, 0, 2, 1])
    731             else:
    732                 # alternative solution with split/concatenation of graph
    733                 msg_uhat1 = msg_uhat[:, :self._list_size, :, :]
    734                 msg_uhat21 = tf.expand_dims(
    735                                 msg_uhat[:, self._list_size:, 0, :ind_u[0]],
    736                                 axis=2)
    737 
    738                 msg_uhat22= tf.expand_dims(
    739                                 msg_uhat[:, self._list_size:, 0, ind_u[0]+1:],
    740                                 axis=2)
    741                 # ones to insert
    742                 msg_ones = tf.ones_like(tf.reshape(
    743                                 msg_uhat[:, self._list_size:, 0, ind_u[0]],
    744                                 [-1, self._list_size, 1, 1]))
    745 
    746                 msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3)
    747                 msg_uhat24 = msg_uhat[:, self._list_size:, 1:, :]
    748 
    749                 msg_uhat2 = tf.concat([msg_uhat23, msg_uhat24], 2)
    750                 msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1)
    751 
    752         return msg_uhat
    753 
    754     def _update_pm(self, ind_u, msg_uhat, msg_llr, msg_pm):
    755         """Update path metric of all decoders after updating bit_pos ``ind_u``.
    756 
    757         We implement (10) from [Stimming_LLR]_.
    758         """
    759         u_hat = msg_uhat[:, :, 0, ind_u[0]]
    760         llr = msg_llr[:, :, 0, ind_u[0]]
    761 
    762         llr_in = tf.clip_by_value(llr,
    763                                   clip_value_min=-self._llr_max,
    764                                   clip_value_max=self._llr_max)
    765 
    766         # Numerically more stable implementation of log(1 + exp(-x))
    767         msg_pm += tf.math.softplus(-tf.multiply((1 - 2*u_hat), llr_in))
    768         return msg_pm
    769 
    770     def _sort_decoders(self, msg_pm, msg_uhat, msg_llr):
    771         """Sort decoders according to their path metric."""
    772 
    773         ind = tf.argsort(msg_pm, axis=-1)
    774 
    775         msg_pm = tf.gather(msg_pm, ind, batch_dims=1, axis=None)
    776         msg_uhat = tf.gather(msg_uhat, ind, batch_dims=1, axis=None)
    777         msg_llr = tf.gather(msg_llr, ind, batch_dims=1, axis=None)
    778 
    779         return msg_pm, msg_uhat, msg_llr
    780 
    781     def _cn_op(self, x, y):
    782         """Check-node update (boxplus) for LLR inputs.
    783 
    784         Operations are performed element-wise.
    785 
    786         See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
    787         """
    788         x_in = tf.clip_by_value(x,
    789                                 clip_value_min=-self._llr_max,
    790                                 clip_value_max=self._llr_max)
    791         y_in = tf.clip_by_value(y,
    792                                 clip_value_min=-self._llr_max,
    793                                 clip_value_max=self._llr_max)
    794 
    795         # Avoid division for numerical stability
    796         # Implements log(1+e^(x+y))
    797         llr_out = tf.math.softplus((x_in + y_in))
    798         # Implements log(e^x+e^y)
    799         llr_out -= tf.math.reduce_logsumexp(tf.stack([x_in, y_in], axis=-1),
    800                                             axis=-1)
    801 
    802         return llr_out
    803 
    804     def _vn_op(self, x, y, u_hat):
    805         """Variable node update for LLR inputs.
    806 
    807         Operations are performed element-wise.
    808 
    809         See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
    810         """
    811         return tf.multiply((1 - 2*u_hat), x) + y
    812 
    813     def _duplicate_paths(self, msg_uhat, msg_llr, msg_pm):
    814         """Duplicate paths by copying the upper branch into the lower one.
    815         """
    816         msg_uhat = tf.tile(msg_uhat[:, :self._list_size, :, :], [1, 2, 1, 1])
    817         msg_llr = tf.tile(msg_llr[:, :self._list_size, :, :], [1, 2, 1, 1])
    818         msg_pm = tf.tile(msg_pm[:, :self._list_size], [1, 2])
    819 
    820         return msg_uhat, msg_llr, msg_pm
    821 
    822     def _update_left_branch(self, msg_llr, stage_ind, cw_ind_left,cw_ind_right):
    823         """Update messages of left branch.
    824 
    825         Remark: Two versions are implemented (throughput vs. graph complexity):
    826         1.) use tensor_scatter_nd_update
    827         2.) explicitly split graph and concatenate again
    828         """
    829 
    830         llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :],
    831                                 cw_ind_left,
    832                                 axis=2)
    833         llr_right_in = tf.gather(msg_llr[:, :, stage_ind, :],
    834                                  cw_ind_right,
    835                                  axis=2)
    836 
    837         llr_left_out = self._cn_op(llr_left_in, llr_right_in)
    838 
    839         if self._use_scatter:
    840             # self.msg_llr[:, :, stage_ind-1, cw_ind_left] = llr_left_out
    841 
    842             # transpose such that batch-dim can be broadcasted
    843             msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0])
    844             llr_left_out_s = tf.transpose(llr_left_out, [2, 1, 0])
    845 
    846             # generate index grid
    847             stage_ind = tf.cast(stage_ind, tf.int64)
    848             cw_ind_left = tf.cast(cw_ind_left, tf.int64)
    849             grid = tf.meshgrid(stage_ind-1, cw_ind_left)
    850             ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])
    851 
    852             # update values
    853             msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t,
    854                                                     ind,
    855                                                     llr_left_out_s)
    856 
    857             # and restore original order
    858             msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1])
    859         else:
    860             # alternative solution with split/concatenation of graph
    861             # llr_left = msg_llr[:, :, stage_ind, cw_ind_left]
    862             llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :],
    863                                   np.arange(0, cw_ind_left[0]),
    864                                   axis=2)
    865 
    866             llr_right = tf.gather(msg_llr[:, :, stage_ind-1, :],
    867                                   cw_ind_right,
    868                                   axis=2)
    869             llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :],
    870                                    np.arange(cw_ind_right[-1] +1, self._n),
    871                                    axis=2)
    872 
    873             llr_s = tf.concat([llr_left0,
    874                                llr_left_out,
    875                                llr_right,
    876                                llr_right1], 2)
    877 
    878             llr_s = tf.expand_dims(llr_s, axis=2)
    879 
    880             msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :]
    881             msg_llr2 = msg_llr[:, :, stage_ind:, :]
    882             msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2)
    883 
    884         return msg_llr
    885 
    886     def _update_right_branch(self, msg_llr, msg_uhat, stage_ind, cw_ind_left,
    887                              cw_ind_right):
    888         """Update messages for right branch.
    889 
    890         Remark: Two versions are implemented (throughput vs. graph complexity):
    891         1.) use tensor_scatter_nd_update
    892         2.) explicitly split graph and concatenate again
    893         """
    894         u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
    895                                   cw_ind_left,
    896                                   axis=2)
    897 
    898         llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :],
    899                                 cw_ind_left,
    900                                 axis=2)
    901 
    902         llr_right = tf.gather(msg_llr[:, :, stage_ind, :],
    903                               cw_ind_right,
    904                               axis=2)
    905 
    906         llr_right_out = self._vn_op(llr_left_in, llr_right, u_hat_left_up)
    907 
    908         if self._use_scatter:
    909             # transpose such that batch dim can be broadcasted
    910             msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0])
    911             llr_right_out_s = tf.transpose(llr_right_out, [2, 1, 0])
    912 
    913             # generate index grid
    914             stage_ind = tf.cast(stage_ind, tf.int64)
    915             cw_ind_left = tf.cast(cw_ind_right, tf.int64)
    916             grid = tf.meshgrid(stage_ind-1, cw_ind_right)
    917             ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])
    918 
    919             msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t,
    920                                                     ind,
    921                                                     llr_right_out_s)
    922 
    923             # and restore original order
    924             msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1])
    925         else:
    926             # alternative solution with split/concatenation of graph
    927             # llr_left = msg_llr[:, :, stage_ind, cw_ind_left]
    928             llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :],
    929                                   np.arange(0, cw_ind_left[0]),
    930                                   axis=2)
    931             llr_left = tf.gather(msg_llr[:, :, stage_ind-1, :],
    932                                  cw_ind_left,
    933                                  axis=2)
    934             llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :],
    935                                    np.arange(cw_ind_right[-1]+1, self._n),
    936                                    axis=2)
    937 
    938             llr_s = tf.concat([llr_left0, llr_left, llr_right_out,llr_right1],2)
    939             llr_s = tf.expand_dims(llr_s, axis=2)
    940 
    941             msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :]
    942             msg_llr2 = msg_llr[:, :, stage_ind:, :]
    943 
    944             msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2)
    945 
    946         return msg_llr
    947 
    948     def _update_branch_u(self, msg_uhat, stage_ind, cw_ind_left, cw_ind_right):
    949         """Update ``u_hat`` messages after executing both branches.
    950 
    951         Remark: Two versions are implemented (throughput vs. graph complexity):
    952         1.) use tensor_scatter_nd_update
    953         2.) explicitly split graph and concatenate again
    954         """
    955         u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
    956                                   cw_ind_left,
    957                                   axis=2)
    958 
    959         u_hat_right_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
    960                                    cw_ind_right,
    961                                    axis=2)
    962 
    963         # combine u_hat via bitwise_xor (more efficient than mod2)
    964         u_hat_left_up_int = tf.cast(u_hat_left_up, tf.int32)
    965         u_hat_right_up_int = tf.cast(u_hat_right_up, tf.int32)
    966         u_hat_left = tf.bitwise.bitwise_xor(u_hat_left_up_int,
    967                                             u_hat_right_up_int)
    968         u_hat_left = tf.cast(u_hat_left, tf.float32)
    969 
    970         if self._use_scatter:
    971             cw_ind = np.concatenate([cw_ind_left, cw_ind_right])
    972 
    973             u_hat = tf.concat([u_hat_left, u_hat_right_up], -1)
    974 
    975             # self.msg_llr[:, stage_ind-1, cw_ind_left] = llr_left_out
    976 
    977             # transpose such that batch dim can be broadcasted
    978             msg_uhat_t = tf.transpose(msg_uhat, [2, 3, 1, 0])
    979             u_hat_s = tf.transpose(u_hat, [2, 1, 0])
    980 
    981             # generate index grid
    982             stage_ind = tf.cast(stage_ind, tf.int64)
    983             cw_ind = tf.cast(cw_ind, tf.int64)
    984             grid = tf.meshgrid(stage_ind, cw_ind)
    985             ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])
    986 
    987             msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t,
    988                                                      ind,
    989                                                      u_hat_s)
    990 
    991             # and restore original order
    992             msg_uhat = tf.transpose(msg_uhat_s, [3, 2, 0, 1])
    993         else:
    994             # alternative solution with split/concatenation of graph
    995             u_hat_left_0 = tf.gather(msg_uhat[:, :, stage_ind, :],
    996                                      np.arange(0, cw_ind_left[0]),
    997                                      axis=2)
    998             u_hat_right_1 = tf.gather(msg_uhat[:, :, stage_ind, :],
    999                                       np.arange(cw_ind_right[-1]+1, self._n),
   1000                                       axis=2)
   1001 
   1002             u_hat = tf.concat([u_hat_left_0,
   1003                                u_hat_left,
   1004                                u_hat_right_up,
   1005                                u_hat_right_1], 2)
   1006 
   1007             # provide u_hat for next higher stage
   1008             msg_uhat1 = msg_uhat[:, :, 0:stage_ind, :]
   1009             msg_uhat2 = msg_uhat[:, :, stage_ind+1:, :]
   1010             u_hat = tf.expand_dims(u_hat, axis=2)
   1011 
   1012             msg_uhat = tf.concat([msg_uhat1, u_hat, msg_uhat2], 2)
   1013 
   1014         return msg_uhat
   1015 
   1016     def _polar_decode_scl(self, cw_ind, msg_uhat, msg_llr, msg_pm):
   1017         """Recursive decoding function for SCL decoding.
   1018 
   1019         We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_
   1020         and branch the messages into a `left` and `right` update paths until
   1021         reaching a leaf node.
   1022 
   1023         Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the
   1024         tree depth while maintaining the same output.
   1025         """
   1026         # current sub-code length and stage index (= tree depth)
   1027         n = len(cw_ind)
   1028         stage_ind = int(np.log2(n))
   1029 
   1030         # recursively branch through decoding tree
   1031         if n>1:
   1032             # prune tree if rate-0 subcode is detected
   1033             if self._use_fast_scl:
   1034                 if np.sum(self._frozen_ind[cw_ind])==n:
   1035                     msg_pm, msg_uhat, msg_llr = self._update_rate0_code(msg_pm,
   1036                                                                        msg_uhat,
   1037                                                                        msg_llr,
   1038                                                                        cw_ind)
   1039                     return msg_uhat, msg_llr, msg_pm
   1040 
   1041                 if (self._frozen_ind[cw_ind[-1]]==0 and
   1042                     np.sum(self._frozen_ind[cw_ind[:-1]])==n-1):
   1043                     msg_pm, msg_uhat, msg_llr, = self._update_rep_code(msg_pm,
   1044                                                                        msg_uhat,
   1045                                                                        msg_llr,
   1046                                                                        cw_ind)
   1047                     return msg_uhat, msg_llr, msg_pm
   1048 
   1049             # split index into left and right part
   1050             cw_ind_left = cw_ind[0:int(n/2)]
   1051             cw_ind_right = cw_ind[int(n/2):]
   1052 
   1053             # ----- left branch -----
   1054             msg_llr = self. _update_left_branch(msg_llr,
   1055                                                 stage_ind,
   1056                                                 cw_ind_left,
   1057                                                 cw_ind_right)
   1058 
   1059             # call sub-graph decoder of left branch
   1060             msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_left,
   1061                                                                msg_uhat,
   1062                                                                msg_llr,
   1063                                                                msg_pm)
   1064 
   1065             # ----- right branch -----
   1066             msg_llr = self._update_right_branch(msg_llr,
   1067                                                 msg_uhat,
   1068                                                 stage_ind,
   1069                                                 cw_ind_left,
   1070                                                 cw_ind_right)
   1071 
   1072             # call sub-graph decoder of right branch
   1073             msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_right,
   1074                                                                msg_uhat,
   1075                                                                msg_llr,
   1076                                                                msg_pm)
   1077             # update uhat at current stage
   1078             msg_uhat = self._update_branch_u(msg_uhat,
   1079                                              stage_ind,
   1080                                              cw_ind_left,
   1081                                              cw_ind_right)
   1082 
   1083         # if leaf is reached perform basic decoding op (=decision)
   1084         else:
   1085             # update bit value at current position
   1086             msg_uhat = self._update_single_bit(cw_ind, msg_uhat)
   1087 
   1088             # update PM
   1089             msg_pm = self._update_pm(cw_ind, msg_uhat, msg_llr, msg_pm)
   1090 
   1091             if self._frozen_ind[cw_ind]==0: # position is non-frozen
   1092                 # sort list
   1093                 msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
   1094                                                                 msg_uhat,
   1095                                                                 msg_llr)
   1096 
   1097                 # duplicate l best decoders to pos l:2*l (kill other decoders)
   1098                 msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat,
   1099                                                                   msg_llr,
   1100                                                                   msg_pm)
   1101 
   1102         return msg_uhat, msg_llr, msg_pm
   1103 
   1104     def _decode_tf(self, llr_ch):
   1105         """Main decoding function in TF.
   1106 
   1107         Initializes memory and calls recursive decoding function.
   1108         """
   1109 
   1110         batch_size = tf.shape(llr_ch)[0]
   1111 
   1112         # allocate memory for all 2*list_size decoders
   1113         msg_uhat = tf.zeros([batch_size,
   1114                              2*self._list_size,
   1115                              self._n_stages+1,
   1116                              self._n])
   1117         msg_llr = tf.zeros([batch_size,
   1118                             2*self._list_size,
   1119                             self._n_stages,
   1120                             self._n])
   1121         # init all 2*l decoders with same llr_ch
   1122         llr_ch = tf.reshape(llr_ch, [-1, 1, 1, self._n])
   1123         llr_ch = tf.tile(llr_ch,[1, 2*self._list_size, 1, 1])
   1124 
   1125         # init last stage with llr_ch
   1126         msg_llr = tf.concat([msg_llr, llr_ch], 2)
   1127 
   1128         # init all remaining L-1 decoders with high penalty
   1129         pm0 = tf.zeros([batch_size, 1])
   1130         pm1 = self._llr_max * tf.ones([batch_size, self._list_size-1])
   1131         msg_pm = tf.concat([pm0, pm1, pm0, pm1], 1)
   1132 
   1133         # and call recursive graph function
   1134         msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(self._cw_ind,
   1135                                                            msg_uhat,
   1136                                                            msg_llr,
   1137                                                            msg_pm)
   1138 
   1139         # and sort output
   1140         msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
   1141                                                         msg_uhat,
   1142                                                         msg_llr)
   1143         return [msg_uhat, msg_pm]
   1144 
   1145     ####################################
   1146     # Helper functions for Numpy decoder
   1147     ####################################
   1148 
   1149     def _update_rate0_code_np(self, cw_ind):
   1150         """Update rate-0 (i.e., all frozen) sub-code at pos ``cw_ind`` in Numpy.
   1151 
   1152         See Eq. (26) in [Hashemi_SSCL]_.
   1153         """
   1154         n = len(cw_ind)
   1155         stage_ind = int(np.log2(n))
   1156 
   1157         # update PM for each batch sample
   1158         ind = np.expand_dims(self._dec_pointer, axis=-1)
   1159         llr_in = np.take_along_axis(self.msg_llr[:, :, stage_ind, cw_ind],
   1160                                     ind,
   1161                                     axis=1)
   1162 
   1163         llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
   1164         pm_val = np.log(1 + np.exp(-llr_clip))
   1165         self.msg_pm += np.sum(pm_val, axis=-1)
   1166 
   1167     def _update_rep_code_np(self, cw_ind):
   1168         """Update rep. code (i.e., only rightmost bit is non-frozen)
   1169         sub-code at position ``ind_u`` in Numpy.
   1170 
   1171         See Eq. (31) in [Hashemi_SSCL]_.
   1172         """
   1173         n = len(cw_ind)
   1174         stage_ind = int(np.log2(n))
   1175         bs = self._dec_pointer.shape[0]
   1176 
   1177         # update PM
   1178         llr = np.zeros([bs, 2*self._list_size, n])
   1179         for i in range(bs):
   1180             llr_i = self.msg_llr[i, self._dec_pointer[i, :], stage_ind, :]
   1181             llr[i, :, :] = llr_i[:, cw_ind]
   1182 
   1183         # upper branch has negative llr values (bit is 1)
   1184         llr[:, self._list_size:, :] = - llr[:, self._list_size:, :]
   1185         llr_in = np.maximum(np.minimum(llr, self._llr_max), -self._llr_max)
   1186         pm_val = np.sum(np.log(1 + np.exp(-llr_in)), axis=-1)
   1187         self.msg_pm += pm_val
   1188 
   1189         for i in range(bs):
   1190             ind_dec = self._dec_pointer[i, self._list_size:]
   1191             for j in cw_ind:
   1192                 self.msg_uhat[i, ind_dec, stage_ind, j] = 1
   1193 
   1194         # branch last bit and update pm at pos cw_ind[-1]
   1195         self._update_single_bit_np([cw_ind[-1]])
   1196         self._sort_decoders_np()
   1197         self._duplicate_paths_np()
   1198 
   1199     def _update_single_bit_np(self, ind_u):
   1200         """Update single bit at position ``ind_u`` of all decoders in Numpy."""
   1201 
   1202         if self._frozen_ind[ind_u]==0: # position is non-frozen
   1203             ind_dec = np.expand_dims(self._dec_pointer[:, self._list_size:],
   1204                                      axis=-1)
   1205             uhat_slice = self.msg_uhat[:, :, 0, ind_u]
   1206             np.put_along_axis(uhat_slice, ind_dec, 1., axis=1)
   1207             self.msg_uhat[:, :, 0, ind_u] = uhat_slice
   1208 
   1209 
   1210     def _update_pm_np(self, ind_u):
   1211         """ Update path metric of all decoders at bit position ``ind_u`` in
   1212         Numpy.
   1213 
   1214         We apply Eq. (10) from [Stimming_LLR]_.
   1215         """
   1216         ind = np.expand_dims(self._dec_pointer, axis=-1)
   1217         u_hat = np.take_along_axis(self.msg_uhat[:, :, 0, ind_u], ind, axis=1)
   1218         u_hat = np.squeeze(u_hat, axis=-1)
   1219         llr_in = np.take_along_axis(self.msg_llr[:, :, 0, ind_u], ind, axis=1)
   1220         llr_in = np.squeeze(llr_in, axis=-1)
   1221 
   1222         llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
   1223         self.msg_pm += np.log(1 + np.exp(-np.multiply((1-2*u_hat), llr_clip)))
   1224 
   1225     def _sort_decoders_np(self):
   1226         """Sort decoders according to their path metric."""
   1227 
   1228         ind = np.argsort(self.msg_pm, axis=-1)
   1229         self.msg_pm = np.take_along_axis(self.msg_pm, ind, axis=1)
   1230         self._dec_pointer = np.take_along_axis(self._dec_pointer, ind, axis=1)
   1231 
   1232     def _cn_op_np(self, x, y):
   1233         """Check node update (boxplus) for LLRs in Numpy.
   1234 
   1235         See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
   1236         """
   1237         x_in = np.maximum(np.minimum(x, self._llr_max), -self._llr_max)
   1238         y_in = np.maximum(np.minimum(y, self._llr_max), -self._llr_max)
   1239 
   1240         # avoid division for numerical stability
   1241         llr_out = np.log(1 + np.exp(x_in + y_in))
   1242         llr_out -= np.log(np.exp(x_in) + np.exp(y_in))
   1243 
   1244         return llr_out
   1245 
   1246     def _vn_op_np(self, x, y, u_hat):
   1247         """Variable node update (boxplus) for LLRs in Numpy."""
   1248         return np.multiply((1-2*u_hat), x) + y
   1249 
   1250     def _duplicate_paths_np(self):
   1251         """Copy first ``list_size``/2 paths into lower part in Numpy.
   1252 
   1253         Decoder indices are encoded in ``self._dec_pointer``.
   1254         """
   1255         ind_low = self._dec_pointer[:, :self._list_size]
   1256         ind_up = self._dec_pointer[:, self._list_size:]
   1257 
   1258         for i in range(ind_up.shape[0]):
   1259             self.msg_uhat[i, ind_up[i,:], :, :] = self.msg_uhat[i,
   1260                                                                 ind_low[i,:],
   1261                                                                 :, :]
   1262             self.msg_llr[i, ind_up[i,:],:,:] = self.msg_llr[i, ind_low[i,:],:,:]
   1263 
   1264         # pm must be sorted directly (not accessed via pointer)
   1265         self.msg_pm[:, self._list_size:] = self.msg_pm[:, :self._list_size]
   1266 
   1267     def _polar_decode_scl_np(self, cw_ind):
   1268         """Recursive decoding function in Numpy.
   1269 
   1270         We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_
   1271         and branch the messages into a `left` and `right` update paths until
   1272         reaching a leaf node.
   1273 
   1274         Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the
   1275         tree depth while maintaining the same output.
   1276         """
   1277         n = len(cw_ind)
   1278         stage_ind = int(np.log2(n))
   1279 
   1280         # recursively branch through decoding tree
   1281         if n>1:
   1282             # prune tree if rate-0 subcode or rep-code is detected
   1283             if self._use_fast_scl:
   1284                 if np.sum(self._frozen_ind[cw_ind])==n:
   1285                     # rate0 code detected
   1286                     self._update_rate0_code_np(cw_ind)
   1287                     return
   1288                 if (self._frozen_ind[cw_ind[-1]]==0 and
   1289                     np.sum(self._frozen_ind[cw_ind[:-1]])==n-1):
   1290                     # rep code detected
   1291                     self._update_rep_code_np(cw_ind)
   1292                     return
   1293             cw_ind_left = cw_ind[0:int(n/2)]
   1294             cw_ind_right = cw_ind[int(n/2):]
   1295 
   1296             # ----- left branch -----
   1297             llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
   1298             llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]
   1299 
   1300             self.msg_llr[:, :, stage_ind-1, cw_ind_left] = self._cn_op_np(
   1301                                                                     llr_left,
   1302                                                                     llr_right)
   1303 
   1304             # call left branch decoder
   1305             self._polar_decode_scl_np(cw_ind_left)
   1306 
   1307             # ----- right branch -----
   1308             u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left]
   1309             llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
   1310             llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]
   1311 
   1312             self.msg_llr[:, :, stage_ind-1, cw_ind_right] = self._vn_op_np(
   1313                                                                 llr_left,
   1314                                                                 llr_right,
   1315                                                                 u_hat_left_up)
   1316 
   1317             # call right branch decoder
   1318             self._polar_decode_scl_np(cw_ind_right)
   1319 
   1320             # combine u_hat
   1321             u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left]
   1322             u_hat_right_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_right]
   1323 
   1324             # u_hat_left_up XOR u_hat_right_up
   1325             u_hat_left =  (u_hat_left_up != u_hat_right_up) + 0
   1326 
   1327             u_hat = np.concatenate([u_hat_left, u_hat_right_up], axis=-1)
   1328 
   1329             # provide u_hat for next higher stage
   1330             self.msg_uhat[:, :, stage_ind,  cw_ind] = u_hat
   1331 
   1332         else: # if leaf is reached perform basic decoding op (=decision)
   1333 
   1334             self._update_single_bit_np(cw_ind)
   1335 
   1336             # update PM
   1337             self._update_pm_np(cw_ind)
   1338 
   1339             # position is non-frozen
   1340             if self._frozen_ind[cw_ind]==0:
   1341                 # sort list
   1342                 self._sort_decoders_np()
   1343                 # duplicate the best list_size decoders
   1344                 self._duplicate_paths_np()
   1345         return
   1346 
   1347     def _decode_np_batch(self, llr_ch):
   1348         """Decode batch of ``llr_ch`` with Numpy decoder."""
   1349 
   1350         bs = llr_ch.shape[0]
   1351 
   1352         # allocate memory for all 2*list_size decoders
   1353         self.msg_uhat = np.zeros([bs,
   1354                                   2*self._list_size,
   1355                                   self._n_stages+1,
   1356                                   self._n])
   1357         self.msg_llr = np.zeros([bs,
   1358                                  2*self._list_size,
   1359                                  self._n_stages+1,
   1360                                  self._n])
   1361         self.msg_pm = np.zeros([bs,
   1362                                 2*self._list_size])
   1363 
   1364         # L-1 decoders start with high penalty
   1365         self.msg_pm[:,1:self._list_size] = self._llr_max
   1366         # same for the second half of the L-1 decoders
   1367         self.msg_pm[:,self._list_size+1:] = self._llr_max
   1368 
   1369         # use pointers to avoid in-memory sorting
   1370         self._dec_pointer = np.arange(2*self._list_size)
   1371         self._dec_pointer = np.tile(np.expand_dims(self._dec_pointer, axis=0),
   1372                                     [bs,1])
   1373 
   1374         # init llr_ch (broadcast via list dimension)
   1375         self.msg_llr[:, :, self._n_stages, :] = np.expand_dims(llr_ch, axis=1)
   1376 
   1377         # call recursive graph function
   1378         self._polar_decode_scl_np(self._cw_ind)
   1379 
   1380         # select most likely candidate
   1381         self._sort_decoders_np()
   1382 
   1383         # remove pointers
   1384         for ind in range(bs):
   1385             self.msg_uhat[ind, :, :, :] = self.msg_uhat[ind,
   1386                                                         self._dec_pointer[ind],
   1387                                                         :, :]
   1388         return self.msg_uhat, self.msg_pm
   1389 
   1390     def _decode_np_hybrid(self, llr_ch, u_hat_sc, crc_valid):
   1391         """Hybrid SCL decoding stage that decodes iff CRC from previous SC
   1392         decoding attempt failed.
   1393 
   1394         This option avoids the usage of the high-complexity SCL decoder in cases
   1395         where SC would be sufficient. For further details we refer to
   1396         [Cammerer_Hybrid_SCL]_ (we use SC instead of the proposed BP stage).
   1397 
   1398         Remark: This decoder does not exactly implement SCL as the CRC
   1399         can be false positive after the SC stage. However, in these cases
   1400         SCL+CRC may also yield the wrong results.
   1401 
   1402         Remark 2: Due to the excessive control flow (if/else) and the
   1403         varying batch-sizes, this function is only available as Numpy
   1404         decoder (i.e., runs on the CPU).
   1405         """
   1406 
   1407         bs = llr_ch.shape[0]
   1408         crc_valid = np.squeeze(crc_valid, axis=-1)
   1409         # index of codewords that need SCL decoding
   1410         ind_invalid = np.arange(bs)[np.invert(crc_valid)]
   1411 
   1412         # init SCL decoder for bs_hyb samples requiring SCL dec.
   1413         llr_ch_hyb = np.take(llr_ch, ind_invalid, axis=0)
   1414         msg_uhat_hyb, msg_pm_hyb = self._decode_np_batch(llr_ch_hyb)
   1415 
   1416         # merge results with previously decoded SC results
   1417         msg_uhat = np.zeros([bs, 2*self._list_size, 1, self._n])
   1418         msg_pm = np.ones([bs, 2*self._list_size]) * self._llr_max * self.k
   1419         msg_pm[:, 0] = 0
   1420 
   1421         # copy SC data
   1422         msg_uhat[:, 0, 0, self._info_pos] = u_hat_sc
   1423 
   1424         ind_hyb = 0
   1425         for ind in range(bs):
   1426             if not crc_valid[ind]:
   1427                 #copy data from SCL
   1428                 msg_uhat[ind, :, 0, :] = msg_uhat_hyb[ind_hyb, :, 0, :]
   1429                 msg_pm[ind, :] = msg_pm_hyb[ind_hyb, :]
   1430                 ind_hyb += 1
   1431 
   1432         return msg_uhat, msg_pm
   1433 
   1434     #########################
   1435     # Keras layer functions
   1436     #########################
   1437 
   1438     def build(self, input_shape):
   1439         """Build and check if shape of input is invalid."""
   1440         assert (input_shape[-1]==self._n), "Invalid input shape."
   1441         assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
   1442 
   1443     def call(self, inputs):
   1444         """Successive cancellation list (SCL) decoding function.
   1445 
   1446         This function performs successive cancellation list decoding
   1447         and returns the estimated information bits.
   1448 
   1449         An outer CRC can be applied optionally by setting ``crc_degree``.
   1450 
   1451         Args:
   1452             inputs (tf.float32): Tensor of shape `[...,n]` containing the
   1453                 channel LLR values (as logits).
   1454 
   1455         Returns:
   1456             `tf.float32`: Tensor of shape `[...,k]` containing
   1457             hard-decided estimations of all ``k`` information bits.
   1458 
   1459         Raises:
   1460             ValueError: If ``inputs`` is not of shape `[..., n]`
   1461                 or `dtype` is not `tf.float32`.
   1462 
   1463             InvalidArgumentError: When rank(``inputs``)<2.
   1464 
   1465         Note:
   1466             This function recursively unrolls the SCL decoding tree, thus,
   1467             for larger values of ``n`` building the decoding graph can become
   1468             time consuming. Please consider the ``cpu_only`` option instead.
   1469         """
   1470 
   1471         tf.debugging.assert_type(inputs, self._output_dtype,
   1472                                  "Invalid input dtype.")
   1473         # internal calculations still in tf.float32
   1474         inputs = tf.cast(inputs, tf.float32)
   1475 
   1476         # last dim must be of length n
   1477         tf.debugging.assert_equal(tf.shape(inputs)[-1],
   1478                                   self._n,
   1479                                   "Last input dimension must be of length n.")
   1480 
   1481         # Reshape inputs to [-1, n]
   1482         tf.debugging.assert_greater(tf.rank(inputs), 1)
   1483         input_shape = inputs.shape
   1484         new_shape = [-1, self._n]
   1485         llr_ch = tf.reshape(inputs, new_shape)
   1486 
   1487         llr_ch = -1. * llr_ch # logits are converted into "true" llrs
   1488 
   1489         # if activated use Numpy decoder
   1490         if self._use_hybrid_sc:
   1491             # use SC decoder to decode first
   1492             u_hat = self._decoder_sc(-llr_ch)
   1493             _, crc_valid = self._crc_decoder(u_hat)
   1494             msg_uhat, msg_pm = tf.py_function(func=self._decode_np_hybrid,
   1495                                               inp=[llr_ch, u_hat, crc_valid],
   1496                                               Tout=[tf.float32, tf.float32])
   1497             # note: return shape is only 1 in 3. dim (to avoid copy overhead)
   1498             msg_uhat = tf.reshape(msg_uhat, [-1, 2*self._list_size, 1, self._n])
   1499             msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size])
   1500         else:
   1501             if self._cpu_only:
   1502                 msg_uhat, msg_pm = tf.py_function(func=self._decode_np_batch,
   1503                                                   inp=[llr_ch],
   1504                                                   Tout=[tf.float32, tf.float32])
   1505                 # restore shape information
   1506                 msg_uhat = tf.reshape(msg_uhat,
   1507                             [-1, 2*self._list_size, self._n_stages+1, self._n])
   1508                 msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size])
   1509             else:
   1510                 msg_uhat, msg_pm = self._decode_tf(llr_ch)
   1511 
   1512         # check CRC (and remove CRC parity bits)
   1513         if self._use_crc:
   1514             u_hat_list = tf.gather(msg_uhat[:, :, 0, :],
   1515                                    self._info_pos,
   1516                                    axis=-1)
   1517             # undo input bit interleaving
   1518             # remark: the output is not interleaved for compatibility with SC
   1519             if self._iil:
   1520                 u_hat_list_crc = tf.gather(u_hat_list,
   1521                                            self._ind_iil_inv,
   1522                                            axis=-1)
   1523             else: # no interleaving applied
   1524                 u_hat_list_crc = u_hat_list
   1525 
   1526             _, crc_valid = self._crc_decoder(u_hat_list_crc)
   1527             # add penalty to pm if CRC fails
   1528             pm_penalty = ((1. - tf.cast(crc_valid, tf.float32))
   1529                        * self._llr_max * self.k)
   1530             msg_pm += tf.squeeze(pm_penalty, axis=2)
   1531 
   1532         # select most likely candidate
   1533         cand_ind = tf.argmin(msg_pm, axis=-1)
   1534         c_hat = tf.gather(msg_uhat[:, :, 0, :], cand_ind, axis=1, batch_dims=1)
   1535         u_hat = tf.gather(c_hat, self._info_pos, axis=-1)
   1536 
   1537         # and reconstruct input shape
   1538         output_shape = input_shape.as_list()
   1539         output_shape[-1] = self.k
   1540         output_shape[0] = -1 # first dim can be dynamic (None)
   1541         u_hat_reshape = tf.reshape(u_hat, output_shape)
   1542 
   1543         if self._return_crc_status:
   1544             # reconstruct CRC status
   1545             crc_status = tf.gather(crc_valid, cand_ind, axis=1, batch_dims=1)
   1546             # reconstruct shape
   1547             output_shape.pop() # remove last dimension
   1548             crc_status = tf.reshape(crc_status, output_shape)
   1549 
   1550             crc_status = tf.cast(crc_status, self._output_dtype)
   1551             # return info bits and CRC status
   1552             return tf.cast(u_hat_reshape, self._output_dtype), crc_status
   1553         else: # return only info bits
   1554             return tf.cast(u_hat_reshape, self._output_dtype)
   1555 
   1556 
   1557 class PolarBPDecoder(Layer):
   1558     # pylint: disable=line-too-long
   1559     """PolarBPDecoder(frozen_pos, n, num_iter=20, hard_out=True, output_dtype=tf.float32, **kwargs)
   1560 
   1561     Belief propagation (BP) decoder for Polar codes [Arikan_Polar]_ and
   1562     Polar-like codes based on [Arikan_BP]_ and [Forney_Graphs]_.
   1563 
   1564     The class inherits from the Keras layer class and can be used as layer in a
   1565     Keras model.
   1566 
   1567     Remark: The PolarBPDecoder does currently not support XLA.
   1568 
   1569     Parameters
   1570     ----------
   1571         frozen_pos: ndarray
   1572             Array of `int` defining the ``n-k`` indices of the frozen positions.
   1573 
   1574         n: int
   1575             Defining the codeword length.
   1576 
   1577         num_iter: int
   1578             Defining the number of decoder iterations (no early stopping used
   1579             at the moment).
   1580 
   1581         hard_out: bool
   1582             Defaults to True. If True, the decoder provides hard-decided
   1583             information bits instead of soft-values.
   1584 
   1585         output_dtype: tf.DType
   1586             Defaults to tf.float32. Defines the output datatype of the layer
   1587             (internal precision remains tf.float32).
   1588 
   1589     Input
   1590     -----
   1591         inputs: [...,n], tf.float32
   1592             2+D tensor containing the channel logits/llr values.
   1593 
   1594     Output
   1595     ------
   1596         : [...,k], tf.float32
   1597             2+D tensor containing bit-wise soft-estimates
   1598             (or hard-decided bit-values) of all ``k`` information bits.
   1599 
   1600     Raises
   1601     ------
   1602         AssertionError
   1603             If ``n`` is not `int`.
   1604 
   1605         AssertionError
   1606             If ``n`` is not a power of 2.
   1607 
   1608         AssertionError
   1609             If the number of elements in ``frozen_pos`` is greater than ``n``.
   1610 
   1611         AssertionError
   1612             If ``frozen_pos`` does not consists of `int`.
   1613 
   1614         AssertionError
   1615             If ``hard_out`` is not `bool`.
   1616 
   1617         ValueError
   1618             If ``output_dtype`` is not {tf.float16, tf.float32, tf.float64}.
   1619 
   1620         AssertionError
   1621             If ``num_iter`` is not `int`.
   1622 
   1623         AssertionError
   1624             If ``num_iter`` is not a positive value.
   1625 
   1626     Note
   1627     ----
   1628         This decoder is fully differentiable and, thus, well-suited for
   1629         gradient descent-based learning tasks such as `learned code design`
   1630         [Ebada_Design]_.
   1631 
   1632         As commonly done, we assume frozen bits are set to `0`. Please note
   1633         that - although its practical relevance is only little - setting frozen
   1634         bits to `1` may result in `affine` codes instead of linear code as the
   1635         `all-zero` codeword is not necessarily part of the code any more.
   1636 
   1637     """
   1638 
   1639     def __init__(self,
   1640                  frozen_pos,
   1641                  n,
   1642                  num_iter=20,
   1643                  hard_out=True,
   1644                  output_dtype=tf.float32,
   1645                  **kwargs):
   1646 
   1647         if output_dtype not in (tf.float16, tf.float32, tf.float64):
   1648             raise ValueError(
   1649                 'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
   1650 
   1651         if output_dtype is not tf.float32:
   1652             print('Note: decoder uses tf.float32 for internal calculations.')
   1653 
   1654         super().__init__(dtype=output_dtype, **kwargs)
   1655         self._output_dtype = output_dtype
   1656 
   1657         # assert error if r>1 or k, n are negative
   1658         assert isinstance(n, numbers.Number), "n must be a number."
   1659         n = int(n) # n can be float (e.g. as result of n=k*r)
   1660         assert issubdtype(frozen_pos.dtype, int), "frozen_pos contains non int."
   1661         assert len(frozen_pos)<=n, "Num. of elements in frozen_pos cannot " \
   1662             "be greater than n."
   1663         assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
   1664 
   1665         assert isinstance(hard_out, bool), "hard_out must be boolean."
   1666 
   1667         # store internal attributes
   1668         self._n = n
   1669         self._frozen_pos = frozen_pos
   1670         self._k = self._n - len(self._frozen_pos)
   1671         self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
   1672         assert self._k==len(self._info_pos), "Internal error: invalid " \
   1673                                              "info_pos generated."
   1674 
   1675         assert isinstance(num_iter, int), "num_iter must be integer."
   1676         assert num_iter>0, "num_iter must be a positive value."
   1677         self._num_iter = tf.constant(num_iter, dtype=tf.int32)
   1678 
   1679         self._llr_max = 19.3 # internal max LLR value
   1680         self._hard_out = hard_out
   1681 
   1682         # depth of decoding graph
   1683         self._n_stages = int(np.log2(self._n))
   1684 
   1685     #########################################
   1686     # Public methods and properties
   1687     #########################################
   1688 
   1689     @property
   1690     def n(self):
   1691         """Codeword length."""
   1692         return self._n
   1693 
   1694     @property
   1695     def k(self):
   1696         """Number of information bits."""
   1697         return self._k
   1698 
   1699     @property
   1700     def frozen_pos(self):
   1701         """Frozen positions for Polar decoding."""
   1702         return self._frozen_pos
   1703 
   1704     @property
   1705     def info_pos(self):
   1706         """Information bit positions for Polar encoding."""
   1707         return self._info_pos
   1708 
   1709     @property
   1710     def llr_max(self):
   1711         """Maximum LLR value for internal calculations."""
   1712         return self._llr_max
   1713 
   1714     @property
   1715     def num_iter(self):
   1716         """Number of decoding iterations."""
   1717         return self._num_iter
   1718 
   1719     @property
   1720     def hard_out(self):
   1721         """Indicates if decoder hard-decides outputs."""
   1722         return self._hard_out
   1723 
   1724     @property
   1725     def output_dtype(self):
   1726         """Output dtype of decoder."""
   1727         return self._output_dtype
   1728 
   1729     @num_iter.setter
   1730     def num_iter(self, num_iter):
   1731         "Number of decoding iterations."
   1732         assert isinstance(num_iter, int), 'num_iter must be int.'
   1733         assert num_iter>=0, 'num_iter cannot be negative.'
   1734         self._num_iter = tf.constant(num_iter, dtype=tf.int32)
   1735 
   1736     #########################
   1737     # Utility methods
   1738     #########################
   1739 
   1740     def _boxplus_tf(self, x, y):
   1741         """Check-node update (boxplus) for LLR inputs.
   1742 
   1743         Operations are performed element-wise.
   1744         """
   1745         x_in = tf.clip_by_value(x,
   1746                                 clip_value_min=-self._llr_max,
   1747                                 clip_value_max=self._llr_max)
   1748         y_in = tf.clip_by_value(y,
   1749                                 clip_value_min=-self._llr_max,
   1750                                 clip_value_max=self._llr_max)
   1751 
   1752         # avoid division for numerical stability
   1753         llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in))
   1754         llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in))
   1755 
   1756         return llr_out
   1757 
   1758     def _decode_bp(self, llr_ch, num_iter):
   1759         """Iterative BP decoding function with LLR-values.
   1760 
   1761         Args:
   1762             llr_ch (tf.float32): Tensor of shape `[batch_size, n]` containing
   1763                 the channel logits/llr values where `batch_size` denotes the
   1764                 batch-size.
   1765 
   1766             num_iter (int): Defining the number of decoder iteration
   1767                 (no early stopping used at the moment).
   1768         Returns:
   1769             `tf.float32`: Tensor of shape `[batch_size, k]` containing
   1770             bit-wise soft-estimates (or hard-decided bit-values) of all
   1771             information bits.
   1772         """
   1773 
   1774         bs = tf.shape(llr_ch)[0]
   1775 
   1776         # store intermediate Tensors in TensorArray
   1777         msg_l = tf.TensorArray(tf.float32,
   1778                                size=num_iter*(self._n_stages+1),
   1779                                dynamic_size=False,
   1780                                clear_after_read=False)
   1781 
   1782         msg_r = tf.TensorArray(tf.float32,
   1783                                size=num_iter*(self._n_stages+1),
   1784                                dynamic_size=False,
   1785                                clear_after_read=False)
   1786 
   1787         # init frozen positions with infinity
   1788         msg_r_in = np.zeros([1, self._n])
   1789         msg_r_in[:, self._frozen_pos] = self._llr_max
   1790         # copy for all batch-samples
   1791         msg_r_in = tf.tile(tf.constant(msg_r_in, tf.float32), [bs, 1])
   1792 
   1793         # perform decoding iterations
   1794         for ind_it in tf.range(self._num_iter):
   1795             # update left-to-right messages
   1796             for ind_s in range(self._n_stages):
   1797                 # calc indices
   1798                 ind_range = np.arange(int(self._n/2))
   1799                 ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
   1800                 ind_2 = ind_1 + 2**ind_s
   1801                 # simplify gather with concatenated outputs
   1802                 ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))
   1803 
   1804                 # load incoming l messages
   1805                 if ind_s==self._n_stages-1:
   1806                     l1_in = tf.gather(llr_ch, ind_1, axis=1)
   1807                     l2_in = tf.gather(llr_ch, ind_2, axis=1)
   1808                 elif ind_it==0:
   1809                     l1_in = tf.zeros([bs, int(self._n/2)])
   1810                     l2_in = tf.zeros([bs, int(self._n/2)])
   1811                 else:
   1812                     l_in = msg_l.read((ind_s+1) + (ind_it-1)*(self._n_stages+1))
   1813                     l1_in = tf.gather(l_in, ind_1, axis=1)
   1814                     l2_in = tf.gather(l_in, ind_2, axis=1)
   1815 
   1816                 # load incoming r messages
   1817                 if ind_s==0:
   1818                     r1_in = tf.gather(msg_r_in, ind_1, axis=1)
   1819                     r2_in = tf.gather(msg_r_in, ind_2, axis=1)
   1820                 else:
   1821                     r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1))
   1822                     r1_in = tf.gather(r_in, ind_1, axis=1)
   1823                     r2_in = tf.gather(r_in, ind_2, axis=1)
   1824 
   1825                 r1_out = self._boxplus_tf(r1_in, l2_in + r2_in)
   1826                 r2_out = self._boxplus_tf(r1_in, l1_in) + r2_in
   1827 
   1828                 # and re-concatenate output
   1829                 r_out = tf.concat([r1_out, r2_out], 1)
   1830                 r_out = tf.gather(r_out, ind_inv, axis=1)
   1831                 msg_r = msg_r.write((ind_s+1)
   1832                                      + ind_it*(self._n_stages+1), r_out)
   1833 
   1834             # update right-to-left messages
   1835             for ind_s in range(self._n_stages-1, -1, -1):
   1836                 ind_range = np.arange(int(self._n/2))
   1837                 ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
   1838                 ind_2 = ind_1 + 2**ind_s
   1839                 ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))
   1840 
   1841                 # load messages
   1842                 if ind_s==self._n_stages-1:
   1843                     l1_in = tf.gather(llr_ch, ind_1, axis=1)
   1844                     l2_in = tf.gather(llr_ch, ind_2, axis=1)
   1845                 else:
   1846                     l_in = msg_l.read((ind_s+1)+ind_it*(self._n_stages+1))
   1847                     l1_in = tf.gather(l_in, ind_1, axis=1)
   1848                     l2_in = tf.gather(l_in, ind_2, axis=1)
   1849 
   1850                 if ind_s==0:
   1851                     r1_in = tf.gather(msg_r_in, ind_1, axis=1)
   1852                     r2_in = tf.gather(msg_r_in, ind_2, axis=1)
   1853                 else:
   1854                     r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1))
   1855                     r1_in = tf.gather(r_in, ind_1, axis=1)
   1856                     r2_in = tf.gather(r_in, ind_2, axis=1)
   1857 
   1858                 # node update functions
   1859                 l1_out = self._boxplus_tf(l1_in, l2_in + r2_in)
   1860                 l2_out = self._boxplus_tf(r1_in, l1_in) + l2_in
   1861 
   1862                 l_out = tf.concat([l1_out, l2_out], 1)
   1863                 l_out = tf.gather(l_out, ind_inv, axis=1)
   1864                 msg_l = msg_l.write(ind_s + ind_it*(self._n_stages+1), l_out)
   1865 
   1866         # recover u_hat
   1867         u_hat = tf.gather(msg_l.read((num_iter-1)*(self._n_stages+1)),
   1868                           self._info_pos,
   1869                           axis=1)
   1870         # if active, hard-decide output bits
   1871         if self._hard_out:
   1872             u_hat = tf.where(u_hat>0, 0., 1.)
   1873         else: # re-transform soft output to logits (instead of llrs)
   1874             u_hat = -1. * u_hat
   1875         return u_hat
   1876 
   1877     #########################
   1878     # Keras layer functions
   1879     #########################
   1880 
   1881     def build(self, input_shape):
   1882         """Build and check if shape of input is invalid."""
   1883         assert (input_shape[-1]==self._n), "Invalid input shape"
   1884         assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
   1885 
   1886     def call(self, inputs):
   1887         """Iterative BP decoding function.
   1888 
   1889         This function performs `num_iter` belief propagation decoding iterations
   1890         and returns the estimated information bits.
   1891 
   1892         Args:
   1893             inputs (tf.float32): Tensor of shape `[...,n]` containing the
   1894                 channel logits/llr values.
   1895 
   1896         Returns:
   1897             `tf.float32`: Tensor of shape `[...,k]` containing
   1898                 bit-wise soft-estimates (or hard-decided bit-values) of all
   1899                 ``k`` information bits.
   1900 
   1901         Raises:
   1902             ValueError: If ``inputs`` is not of shape `[..., n]`
   1903                 or `dtype` is not `output_dtype`.
   1904 
   1905             InvalidArgumentError: When rank(``inputs``)<2.
   1906 
   1907         Note:
   1908             This function recursively unrolls the BP decoding graph, thus,
   1909             for larger values of ``n`` or more iterations, building the
   1910             decoding graph can become time and memory consuming.
   1911         """
   1912 
   1913         tf.debugging.assert_type(inputs, self._output_dtype,
   1914                                  "Invalid input dtype.")
   1915         # internal calculations still in tf.float32
   1916         inputs = tf.cast(inputs, tf.float32)
   1917 
   1918         # Reshape inputs to [-1, n]
   1919         input_shape = inputs.shape
   1920         new_shape = [-1, self._n]
   1921         llr_ch = tf.reshape(inputs, new_shape)
   1922 
   1923         llr_ch = -1. * llr_ch # logits are converted into "true" llrs
   1924 
   1925         # and decode
   1926         u_hat = self._decode_bp(llr_ch, self._num_iter)
   1927 
   1928         # and reconstruct input shape
   1929         output_shape = input_shape.as_list()
   1930         output_shape[-1] = self.k
   1931         output_shape[0] = -1 # first dim can be dynamic (None)
   1932         u_hat_reshape = tf.reshape(u_hat, output_shape)
   1933         return tf.cast(u_hat_reshape, self._output_dtype)
   1934 
   1935 
   1936 class Polar5GDecoder(Layer):
   1937     # pylint: disable=line-too-long
   1938     """Polar5GDecoder(enc_polar, dec_type="SC", list_size=8, num_iter=20,return_crc_status=False, output_dtype=tf.float32, **kwargs)
   1939 
   1940     Wrapper for 5G compliant decoding including rate-recovery and CRC removal.
   1941 
   1942     The class inherits from the Keras layer class and can be used as layer in a
   1943     Keras model.
   1944 
   1945     Parameters
   1946     ----------
   1947         enc_polar: Polar5GEncoder
   1948             Instance of the :class:`~sionna.fec.polar.encoding.Polar5GEncoder`
   1949             used for encoding including rate-matching.
   1950 
   1951         dec_type: str
   1952             Defaults to `"SC"`. Defining the decoder to be used.
   1953             Must be one of the following `{"SC", "SCL", "hybSCL", "BP"}`.
   1954 
   1955         list_size: int
   1956             Defaults to 8. Defining the list size `iff` list-decoding is used.
   1957             Only required for ``dec_types`` `{"SCL", "hybSCL"}`.
   1958 
   1959         num_iter: int
   1960             Defaults to 20. Defining the number of BP iterations. Only required
   1961             for ``dec_type`` `"BP"`.
   1962 
   1963         return_crc_status: bool
   1964             Defaults to False. If True, the decoder additionally returns the
   1965             CRC status indicating if a codeword was (most likely) correctly
   1966             recovered.
   1967 
   1968         output_dtype: tf.DType
   1969             Defaults to tf.float32. Defines the output datatype of the layer
   1970             (internal precision remains tf.float32).
   1971 
   1972     Input
   1973     -----
   1974         inputs: [...,n], tf.float32
   1975             2+D tensor containing the channel logits/llr values.
   1976 
   1977     Output
   1978     ------
   1979 
   1980         b_hat : [...,k], tf.float32
   1981             2+D tensor containing hard-decided estimations of all `k`
   1982             information bits.
   1983 
   1984         crc_status : [...], tf.bool
   1985             CRC status indicating if a codeword was (most likely) correctly
   1986             recovered. This is only returned if ``return_crc_status`` is True.
   1987             Note that false positives are possible.
   1988     Raises
   1989     ------
   1990         AssertionError
   1991             If ``enc_polar`` is not `Polar5GEncoder`.
   1992 
   1993         ValueError
   1994             If ``dec_type`` is not `{"SC", "SCL", "SCL8", "SCL32", "hybSCL",
   1995             "BP"}`.
   1996 
   1997         AssertionError
   1998             If ``dec_type`` is not `str`.
   1999 
   2000         ValueError
   2001             If ``inputs`` is not of shape `[..., n]` or `dtype` is not
   2002             the same as ``output_dtype``.
   2003 
   2004         InvalidArgumentError
   2005             When rank(``inputs``)<2.
   2006 
   2007     Note
   2008     ----
   2009         This layer supports the uplink and downlink Polar rate-matching scheme
   2010         without `codeword segmentation`.
   2011 
   2012         Although the decoding `list size` is not provided by 3GPP
   2013         [3GPPTS38212]_, the consortium has agreed on a `list size` of 8 for the
   2014         5G decoding reference curves [Bioglio_Design]_.
   2015 
   2016         All list-decoders apply `CRC-aided` decoding, however, the non-list
   2017         decoders (`"SC"` and `"BP"`) cannot materialize the CRC leading to an
   2018         effective rate-loss.
   2019 
   2020     """
   2021 
   2022     def __init__(self,
   2023                  enc_polar,
   2024                  dec_type="SC",
   2025                  list_size=8,
   2026                  num_iter=20,
   2027                  return_crc_status=False,
   2028                  output_dtype=tf.float32,
   2029                  **kwargs):
   2030 
   2031         if output_dtype not in (tf.float16, tf.float32, tf.float64):
   2032             raise ValueError(
   2033                 'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
   2034 
   2035         if output_dtype is not tf.float32:
   2036             print('Note: decoder uses tf.float32 for internal calculations.')
   2037         self._output_dtype = output_dtype
   2038 
   2039         super().__init__(dtype=output_dtype, **kwargs)
   2040 
   2041         assert isinstance(enc_polar, Polar5GEncoder), \
   2042                                     "enc_polar must be Polar5GEncoder."
   2043         assert isinstance(dec_type, str), "dec_type must be str."
   2044         # list_size and num_iter are not checked here (done during decoder init)
   2045 
   2046         # Store internal attributes
   2047         self._n_target = enc_polar.n_target
   2048         self._k_target = enc_polar.k_target
   2049         self._n_polar = enc_polar.n_polar
   2050         self._k_polar = enc_polar.k_polar
   2051         self._k_crc = enc_polar.enc_crc.crc_length
   2052         self._bil = enc_polar._channel_type == "uplink"
   2053         self._iil = enc_polar._channel_type == "downlink"
   2054         self._llr_max = 100 # Internal max LLR value (for punctured positions)
   2055         self._enc_polar = enc_polar
   2056         self._dec_type = dec_type
   2057 
   2058         # Initialize the de-interleaver patterns
   2059         self._init_interleavers()
   2060 
   2061         # Initialize decoder
   2062         if dec_type=="SC":
   2063             print("Warning: 5G Polar codes use an integrated CRC that " \
   2064                   "cannot be materialized with SC decoding and, thus, " \
   2065                   "causes a degraded performance. Please consider SCL " \
   2066                   "decoding instead.")
   2067             self._polar_dec = PolarSCDecoder(self._enc_polar.frozen_pos,
   2068                                              self._n_polar)
   2069         elif dec_type=="SCL":
   2070             self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos,
   2071                                 self._n_polar,
   2072                                 crc_degree=self._enc_polar.enc_crc.crc_degree,
   2073                                 list_size=list_size,
   2074                                 ind_iil_inv = self.ind_iil_inv)
   2075         elif dec_type=="hybSCL":
   2076             self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos,
   2077                                 self._n_polar,
   2078                                 crc_degree=self._enc_polar.enc_crc.crc_degree,
   2079                                 list_size=list_size,
   2080                                 use_hybrid_sc=True,
   2081                                 ind_iil_inv = self.ind_iil_inv)
   2082         elif dec_type=="BP":
   2083             print("Warning: 5G Polar codes use an integrated CRC that " \
   2084                   "cannot be materialized with BP decoding and, thus, " \
   2085                   "causes a degraded performance. Please consider SCL " \
   2086                   " decoding instead.")
   2087             assert isinstance(num_iter, int), "num_iter must be int."
   2088             assert num_iter > 0, "num_iter must be positive."
   2089             self._num_iter = num_iter
   2090             self._polar_dec = PolarBPDecoder(self._enc_polar.frozen_pos,
   2091                                              self._n_polar,
   2092                                              num_iter=num_iter,
   2093                                              hard_out=True)
   2094         else:
   2095             raise ValueError("Unknown value for dec_type.")
   2096 
   2097         assert isinstance(return_crc_status, bool), \
   2098                                             "return_crc_status must be bool."
   2099 
   2100         self._return_crc_status = return_crc_status
   2101         if self._return_crc_status: # init crc decoder
   2102             if dec_type in ("SCL", "hybSCL"):
   2103                 # re-use CRC decoder from list decoder
   2104                 self._dec_crc = self._polar_dec._crc_decoder
   2105             else: # init new CRC decoder for BP and SC
   2106                 self._dec_crc = CRCDecoder(self._enc_polar._enc_crc)
   2107 
   2108     #########################################
   2109     # Public methods and properties
   2110     #########################################
   2111 
   2112     @property
   2113     def k_target(self):
   2114         """Number of information bits including rate-matching."""
   2115         return self._k_target
   2116 
   2117     @property
   2118     def n_target(self):
   2119         """Codeword length including rate-matching."""
   2120         return self._n_target
   2121 
   2122     @property
   2123     def k_polar(self):
   2124         """Number of information bits of mother Polar code."""
   2125         return self._k_polar
   2126 
   2127     @property
   2128     def n_polar(self):
   2129         """Codeword length of mother Polar code."""
   2130         return self._n_polar
   2131 
   2132     @property
   2133     def frozen_pos(self):
   2134         """Frozen positions for Polar decoding."""
   2135         return self._frozen_pos
   2136 
   2137     @property
   2138     def info_pos(self):
   2139         """Information bit positions for Polar encoding."""
   2140         return self._info_pos
   2141 
   2142     @property
   2143     def llr_max(self):
   2144         """Maximum LLR value for internal calculations."""
   2145         return self._llr_max
   2146 
   2147     @property
   2148     def dec_type(self):
   2149         """Decoder type used for decoding as str."""
   2150         return self._dec_type
   2151 
   2152     @property
   2153     def polar_dec(self):
   2154         """Decoder instance used for decoding."""
   2155         return self._polar_dec
   2156 
   2157     @property
   2158     def output_dtype(self):
   2159         """Output dtype of decoder."""
   2160         return self._output_dtype
   2161 
   2162     #########################
   2163     # Utility methods
   2164     #########################
   2165 
   2166     def _init_interleavers(self):
   2167         """Initialize inverse interleaver patterns for rate-recovery."""
   2168 
   2169         # Channel interleaver
   2170         ind_ch_int = self._enc_polar.channel_interleaver(
   2171                                                 np.arange(self._n_target))
   2172         self.ind_ch_int_inv = np.argsort(ind_ch_int) # Find inverse perm
   2173 
   2174         # Sub-block interleaver
   2175         ind_sub_int = self._enc_polar.subblock_interleaving(
   2176                                                 np.arange(self._n_polar))
   2177         self.ind_sub_int_inv = np.argsort(ind_sub_int) # Find inverse perm
   2178 
   2179         # input bit interleaver
   2180         if self._iil:
   2181             self.ind_iil_inv = np.argsort(self._enc_polar.input_interleaver(
   2182                                                 np.arange(self._k_polar)))
   2183         else:
   2184             self.ind_iil_inv = None
   2185     #########################
   2186     # Keras layer functions
   2187     #########################
   2188 
   2189     def build(self, input_shape):
   2190         """Build and check if shape of input is invalid."""
   2191         assert (input_shape[-1]==self._n_target), "Invalid input shape."
   2192         assert (len(input_shape)>=2), 'Inputs must have at least 2 dimensions.'
   2193 
   2194     def call(self, inputs):
   2195         """Polar decoding and rate-recovery for uplink 5G Polar codes.
   2196 
   2197         Args:
   2198             inputs (tf.float32): Tensor of shape `[...,n]` containing the
   2199                 channel logits/llr values.
   2200 
   2201         Returns:
   2202             `tf.float32`: Tensor of shape `[...,k]` containing
   2203                 hard-decided estimates of all ``k`` information bits.
   2204 
   2205         Raises:
   2206             ValueError: If ``inputs`` is not of shape `[..., n]`
   2207                 or `dtype` is not `output_dtype`.
   2208 
   2209             InvalidArgumentError: When rank(``inputs``)<2.
   2210         """
   2211 
   2212         tf.debugging.assert_type(inputs, self._output_dtype,
   2213                                  "Invalid input dtype.")
   2214         # internal calculations still in tf.float32
   2215         inputs = tf.cast(inputs, tf.float32)
   2216 
   2217         # Reshape inputs to [-1, n]
   2218         tf.debugging.assert_greater(tf.rank(inputs), 1)
   2219         input_shape = inputs.shape
   2220         new_shape = [-1, self._n_target]
   2221         llr_ch = tf.reshape(inputs, new_shape)
   2222 
   2223         # Note: logits are not inverted here; this is done in the decoder itself
   2224 
   2225         # 1.) Undo channel interleaving
   2226         if self._bil:
   2227             llr_deint = tf.gather(llr_ch, self.ind_ch_int_inv, axis=1)
   2228         else:
   2229             llr_deint = llr_ch
   2230 
   2231         # 2.) Remove puncturing, shortening, repetition (see Sec. 5.4.1.2)
   2232         # a) Puncturing: set LLRs to 0
   2233         # b) Shortening: set LLRs to infinity
   2234         # c) Repetition: combine LLRs
   2235         if self._n_target >= self._n_polar:
   2236             # Repetition coding
   2237             # Add the last n_rep positions to the first llr positions
   2238             n_rep = self._n_target - self._n_polar
   2239             llr_1 = llr_deint[:,:n_rep]
   2240             llr_2 = llr_deint[:,n_rep:self._n_polar]
   2241             llr_3 = llr_deint[:,self._n_polar:]
   2242             llr_dematched = tf.concat([llr_1+llr_3, llr_2], 1)
   2243         else:
   2244             if self._k_polar/self._n_target <= 7/16:
   2245                 # Puncturing
   2246                 # Append n_polar - n_target "zero" llrs to first positions
   2247                 llr_zero = tf.zeros([tf.shape(llr_deint)[0],
   2248                                      self._n_polar-self._n_target])
   2249                 llr_dematched = tf.concat([llr_zero, llr_deint], 1)
   2250             else:
   2251                 # Shortening
   2252                 # Append n_polar - n_target "-infinity" llrs to last positions
   2253                 # Remark: we still operate with logits here, thus the neg. sign
   2254                 llr_infty = -self._llr_max * tf.ones([tf.shape(llr_deint)[0],
   2255                                                 self._n_polar-self._n_target])
   2256                 llr_dematched = tf.concat([llr_deint, llr_infty], 1)
   2257 
   2258         # 3.) Remove subblock interleaving
   2259         llr_dec = tf.gather(llr_dematched, self.ind_sub_int_inv, axis=1)
   2260 
   2261         # 4.) Run main decoder
   2262         u_hat_crc = self._polar_dec(llr_dec)
   2263 
   2264         # 5.) Shortening should be implicitly recovered by decoder
   2265 
   2266         # 6.) Remove input bit interleaving for downlink channels only
   2267         if self._iil:
   2268             u_hat_crc = tf.gather(u_hat_crc, self.ind_iil_inv, axis=1)
   2269 
   2270         # 7.) Evaluate or remove CRC (and PC)
   2271         if self._return_crc_status:
   2272             # for compatibility with SC/BP, a dedicated CRC decoder is
   2273             # used here (instead of accessing the interal SCL)
   2274             u_hat, crc_status = self._dec_crc(u_hat_crc)
   2275         else: # just remove CRC bits
   2276             u_hat = u_hat_crc[:,:-self._k_crc]
   2277 
   2278         # And reconstruct input shape
   2279         output_shape = input_shape.as_list()
   2280         output_shape[-1] = self._k_target
   2281         output_shape[0] = -1 # First dim can be dynamic (None)
   2282         u_hat_reshape = tf.reshape(u_hat, output_shape)
   2283         # and cast to output dtype
   2284         u_hat_reshape = tf.cast(u_hat_reshape, dtype=self._output_dtype)
   2285 
   2286         if self._return_crc_status:
   2287             # reconstruct CRC shape
   2288             output_shape.pop() # remove last dimension
   2289             crc_status = tf.reshape(crc_status, output_shape)
   2290             crc_status = tf.cast(crc_status, dtype=self._output_dtype)
   2291             return u_hat_reshape, crc_status
   2292 
   2293         else:
   2294             return u_hat_reshape