anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

decoding.py (55273B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for channel decoding and utility functions."""
      6 
      7 import tensorflow as tf
      8 import numpy as np
      9 import scipy as sp # for sparse H matrix computations
     10 from tensorflow.keras.layers import Layer
     11 from sionna.fec.ldpc.encoding import LDPC5GEncoder
     12 from sionna.fec.utils import llr2mi
     13 import matplotlib.pyplot as plt
     14 
     15 class LDPCBPDecoder(Layer):
     16     # pylint: disable=line-too-long
     17     r"""LDPCBPDecoder(pcm, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, num_iter=20, stateful=False,output_dtype=tf.float32, **kwargs)
     18 
     19     Iterative belief propagation decoder for low-density parity-check (LDPC)
     20     codes and other `codes on graphs`.
     21 
     22     This class defines a generic belief propagation decoder for decoding
     23     with arbitrary parity-check matrices. It can be used to iteratively
     24     estimate/recover the transmitted codeword (or information bits) based on the
     25     LLR-values of the received noisy codeword observation.
     26 
     27     The decoder implements the flooding SPA algorithm [Ryan]_, i.e., all nodes
     28     are updated in a parallel fashion. Different check node update functions are
     29     available
     30 
     31     (1) `boxplus`
     32 
     33         .. math::
     34             y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right)
     35 
     36     (2) `boxplus-phi`
     37 
     38         .. math::
     39             y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}_(j) \setminus i} \phi \left( |x_{i' \to j}|\right) \right)
     40 
     41         with :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x}{2}) \right)`
     42 
     43     (3) `minsum`
     44 
     45         .. math::
     46             \qquad y_{j \to i} = \alpha_{j \to i} \cdot {min}_{i' \in \mathcal{N}_(j) \setminus i} \left(|x_{i' \to j}|\right)
     47 
     48     where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to
     49     variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*,
     50     respectively. Further, :math:`\mathcal{N}_(j)` denotes all indices of
     51     connected VNs to CN *j* and
     52 
     53     .. math::
     54         \alpha_{j \to i} = \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{sign}(x_{i' \to j})
     55 
     56     is the sign of the outgoing message. For further details we refer to
     57     [Ryan]_.
     58 
     59     Note that for full 5G 3GPP NR compatibility, the correct puncturing and
     60     shortening patterns must be applied (cf. [Richardson]_ for details), this
     61     can be done by :class:`~sionna.fec.ldpc.decoding.LDPC5GEncoder` and
     62     :class:`~sionna.fec.ldpc.decoding.LDPC5GDecoder`, respectively.
     63 
     64     If required, the decoder can be made trainable and is fully differentiable
     65     by following the concept of `weighted BP` [Nachmani]_ as shown in Fig. 1
     66     leading to
     67 
     68     .. math::
     69         y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} \left( \frac{\textcolor{red}{w_{i' \to j}} \cdot x_{i' \to j}}{2} \right) \right)
     70 
     71     where :math:`w_{i \to j}` denotes the trainable weight of message :math:`x_{i \to j}`.
     72     Please note that the training of some check node types may be not supported.
     73 
     74     ..  figure:: ../figures/weighted_bp.png
     75 
     76         Fig. 1: Weighted BP as proposed in [Nachmani]_.
     77 
     78     For numerical stability, the decoder applies LLR clipping of
     79     +/- 20 to the input LLRs.
     80 
     81     The class inherits from the Keras layer class and can be used as layer in a
     82     Keras model.
     83 
     84     Parameters
     85     ----------
     86         pcm: ndarray
     87             An ndarray of shape `[n-k, n]` defining the parity-check matrix
     88             consisting only of `0` or `1` entries. Can be also of type `scipy.
     89             sparse.csr_matrix` or `scipy.sparse.csc_matrix`.
     90 
     91         trainable: bool
     92             Defaults to False. If True, every outgoing variable node message is
     93             scaled with a trainable scalar.
     94 
     95         cn_type: str
     96             A string defaults to '"boxplus-phi"'. One of
     97             {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where
     98             '"boxplus"' implements the single-parity-check APP decoding rule.
     99             '"boxplus-phi"' implements the numerical more stable version of
    100             boxplus [Ryan]_.
    101             '"minsum"' implements the min-approximation of the CN
    102             update rule [Ryan]_.
    103 
    104         hard_out: bool
    105             Defaults to True. If True, the decoder provides hard-decided
    106             codeword bits instead of soft-values.
    107 
    108         track_exit: bool
    109             Defaults to False. If True, the decoder tracks EXIT
    110             characteristics. Note that this requires the all-zero
    111             CW as input.
    112 
    113         num_iter: int
    114             Defining the number of decoder iteration (no early stopping used at
    115             the moment!).
    116 
    117         stateful: bool
    118             Defaults to False. If True, the internal VN messages ``msg_vn``
    119             from the last decoding iteration are returned, and ``msg_vn`` or
    120             `None` needs to be given as a second input when calling the decoder.
    121             This is required for iterative demapping and decoding.
    122 
    123         output_dtype: tf.DType
    124             Defaults to tf.float32. Defines the output datatype of the layer
    125             (internal precision remains tf.float32).
    126 
    127     Input
    128     -----
    129     llrs_ch or (llrs_ch, msg_vn):
    130         Tensor or Tuple (only required if ``stateful`` is True):
    131 
    132     llrs_ch: [...,n], tf.float32
    133         2+D tensor containing the channel logits/llr values.
    134 
    135     msg_vn: None or RaggedTensor, tf.float32
    136         Ragged tensor of VN messages.
    137         Required only if ``stateful`` is True.
    138 
    139     Output
    140     ------
    141         : [...,n], tf.float32
    142             2+D Tensor of same shape as ``inputs`` containing
    143             bit-wise soft-estimates (or hard-decided bit-values) of all
    144             codeword bits.
    145 
    146         : RaggedTensor, tf.float32:
    147             Tensor of VN messages.
    148             Returned only if ``stateful`` is set to True.
    149 
    150     Attributes
    151     ----------
    152         pcm: ndarray
    153             An ndarray of shape `[n-k, n]` defining the parity-check matrix
    154             consisting only of `0` or `1` entries. Can be also of type `scipy.
    155             sparse.csr_matrix` or `scipy.sparse.csc_matrix`.
    156 
    157         num_cns: int
    158             Defining the number of check nodes.
    159 
    160         num_vns: int
    161             Defining the number of variable nodes.
    162 
    163         num_edges: int
    164             Defining the total number of edges.
    165 
    166         trainable: bool
    167             If True, the decoder uses trainable weights.
    168 
    169         _atanh_clip_value: float
    170             Defining the internal clipping value before the atanh is applied
    171             (relates to the CN update).
    172 
    173         _cn_type: str
    174             Defining the CN update function type.
    175 
    176         _cn_update:
    177             A function defining the CN update.
    178 
    179         _hard_out: bool
    180             If True, the decoder outputs hard-decided bits.
    181 
    182         _cn_con: ndarray
    183             An ndarray of shape `[num_edges]` defining all edges from check
    184             node perspective.
    185 
    186         _vn_con: ndarray
    187             An ndarray of shape `[num_edges]` defining all edges from variable
    188             node perspective.
    189 
    190         _vn_mask_tf: tf.float32
    191             A ragged Tensor of shape `[num_vns, None]` defining the incoming
    192             message indices per VN. The second dimension is ragged and depends
    193             on the node degree.
    194 
    195         _cn_mask_tf: tf.float32
    196             A ragged Tensor of shape `[num_cns, None]` defining the incoming
    197             message indices per CN. The second dimension is ragged and depends
    198             on the node degree.
    199 
    200         _ind_cn: ndarray
    201             An ndarray of shape `[num_edges]` defining the permutation index to
    202             rearrange messages from variable into check node perspective.
    203 
    204         _ind_cn_inv: ndarray
    205             An ndarray of shape `[num_edges]` defining the permutation index to
    206             rearrange messages from check into variable node perspective.
    207 
    208         _vn_row_splits: ndarray
    209             An ndarray of shape `[num_vns+1]` defining the row split positions
    210             of a 1D vector consisting of all edges messages. Used to build a
    211             ragged Tensor of incoming VN messages.
    212 
    213         _cn_row_splits: ndarray
    214             An ndarray of shape `[num_cns+1]` defining the row split positions
    215             of a 1D vector consisting of all edges messages. Used to build a
    216             ragged Tensor of incoming CN messages.
    217 
    218         _edge_weights: tf.float32
    219             A Tensor of shape `[num_edges]` defining a (trainable) weight per
    220             outgoing VN message.
    221 
    222     Raises:
    223         ValueError
    224             If the shape of ``pcm`` is invalid or contains other values than
    225             `0` or `1` or dtype is not `tf.float32`.
    226 
    227         ValueError
    228             If ``num_iter`` is not an integer greater (or equal) `0`.
    229 
    230         ValueError
    231             If ``output_dtype`` is not
    232             {tf.float16, tf.float32, tf.float64}.
    233 
    234         ValueError
    235             If ``inputs`` is not of shape `[batch_size, n]`.
    236 
    237         InvalidArgumentError
    238             When rank(``inputs``)<2.
    239     Note
    240     ----
    241         As decoding input logits
    242         :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are
    243         assumed for compatibility with the learning framework, but internally
    244         log-likelihood ratios (LLRs) with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used.
    245 
    246         The decoder is not (particularly) optimized for quasi-cyclic (QC) LDPC
    247         codes and, thus, supports arbitrary parity-check matrices.
    248 
    249         The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to
    250         account for arbitrary node degrees. To avoid a performance degradation
    251         caused by a severe indexing overhead, the batch-dimension is shifted to
    252         the last dimension during decoding.
    253 
    254         If the decoder is made trainable [Nachmani]_, for performance
    255         improvements only variable to check node messages are scaled as the VN
    256         operation is linear and, thus, would not increase the expressive power
    257         of the weights.
    258 
    259     """
    260 
    261     def __init__(self,
    262                  pcm,
    263                  trainable=False,
    264                  cn_type='boxplus-phi',
    265                  hard_out=True,
    266                  track_exit=False,
    267                  num_iter=20,
    268                  stateful=False,
    269                  output_dtype=tf.float32,
    270                  **kwargs):
    271 
    272         super().__init__(dtype=output_dtype, **kwargs)
    273 
    274         assert isinstance(trainable, bool), 'trainable must be bool.'
    275         assert isinstance(hard_out, bool), 'hard_out must be bool.'
    276         assert isinstance(track_exit, bool), 'track_exit must be bool.'
    277         assert isinstance(cn_type, str) , 'cn_type must be str.'
    278         assert isinstance(num_iter, int), 'num_iter must be int.'
    279         assert num_iter>=0, 'num_iter cannot be negative.'
    280         assert isinstance(stateful, bool), 'stateful must be bool.'
    281         assert isinstance(output_dtype, tf.DType), \
    282                                 'output_dtype must be tf.Dtype.'
    283 
    284         if isinstance(pcm, np.ndarray):
    285             assert np.array_equal(pcm, pcm.astype(bool)), 'PC matrix \
    286                 must be binary.'
    287         elif isinstance(pcm, sp.sparse.csr_matrix):
    288             assert np.array_equal(pcm.data, pcm.data.astype(bool)), \
    289                 'PC matrix must be binary.'
    290         elif isinstance(pcm, sp.sparse.csc_matrix):
    291             assert np.array_equal(pcm.data, pcm.data.astype(bool)), \
    292                 'PC matrix must be binary.'
    293         else:
    294             raise TypeError("Unsupported dtype of pcm.")
    295 
    296         if output_dtype not in (tf.float16, tf.float32, tf.float64):
    297             raise ValueError(
    298                 'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
    299 
    300         if output_dtype is not tf.float32:
    301             print('Note: decoder uses tf.float32 for internal calculations.')
    302 
    303         # init decoder parameters
    304         self._pcm = pcm
    305         self._trainable = trainable
    306         self._cn_type = cn_type
    307         self._hard_out = hard_out
    308         self._track_exit = track_exit
    309         self._num_iter = tf.constant(num_iter, dtype=tf.int32)
    310         self._stateful = stateful
    311         self._output_dtype = output_dtype
    312 
    313         # clipping value for the atanh function is applied (tf.float32 is used)
    314         self._atanh_clip_value = 1 - 1e-7
    315         # internal value for llr clipping
    316         self._llr_max = tf.constant(20., tf.float32)
    317 
    318         # init code parameters
    319         self._num_cns = pcm.shape[0] # total number of check nodes
    320         self._num_vns = pcm.shape[1] # total number of variable nodes
    321 
    322         # make pcm sparse first if ndarray is provided
    323         if isinstance(pcm, np.ndarray):
    324             pcm = sp.sparse.csr_matrix(pcm)
    325 
    326         # find all edges from variable and check node perspective
    327         self._cn_con, self._vn_con, _ = sp.sparse.find(pcm)
    328 
    329         # sort indices explicitly, as scipy.sparse.find changed from column to
    330         # row sorting in scipy>=1.11
    331         idx = np.argsort(self._vn_con)
    332         self._cn_con = self._cn_con[idx]
    333         self._vn_con = self._vn_con[idx]
    334 
    335         # number of edges equals number of non-zero elements in the
    336         # parity-check matrix
    337         self._num_edges = len(self._vn_con)
    338 
    339         # permutation index to rearrange messages into check node perspective
    340         self._ind_cn = np.argsort(self._cn_con)
    341 
    342         # inverse permutation index to rearrange messages back into variable
    343         # node perspective
    344         self._ind_cn_inv = np.argsort(self._ind_cn)
    345 
    346         # generate row masks (array of integers defining the row split pos.)
    347         self._vn_row_splits = self._gen_node_mask_row(self._vn_con)
    348         self._cn_row_splits = self._gen_node_mask_row(
    349                                                     self._cn_con[self._ind_cn])
    350         # pre-load the CN function for performance reasons
    351         if self._cn_type=='boxplus':
    352             # check node update using the tanh function
    353             self._cn_update = self._cn_update_tanh
    354         elif self._cn_type=='boxplus-phi':
    355             # check node update using the "_phi" function
    356             self._cn_update = self._cn_update_phi
    357         elif self._cn_type=='minsum':
    358             # check node update using the min-sum approximation
    359             self._cn_update = self._cn_update_minsum
    360         else:
    361             raise ValueError('Unknown node type.')
    362 
    363         # init trainable weights if needed
    364         self._has_weights = False  # indicates if trainable weights exist
    365         if self._trainable:
    366             self._has_weights = True
    367             self._edge_weights = tf.Variable(tf.ones(self._num_edges),
    368                                              trainable=self._trainable,
    369                                              dtype=tf.float32)
    370 
    371         # track mutual information during decoding
    372         self._ie_c = 0
    373         self._ie_v = 0
    374 
    375     #########################################
    376     # Public methods and properties
    377     #########################################
    378 
    379     @property
    380     def pcm(self):
    381         """Parity-check matrix of LDPC code."""
    382         return self._pcm
    383 
    384     @property
    385     def num_cns(self):
    386         """Number of check nodes."""
    387         return self._num_cns
    388 
    389     @property
    390     def num_vns(self):
    391         """Number of variable nodes."""
    392         return self._num_vns
    393 
    394     @property
    395     def num_edges(self):
    396         """Number of edges in decoding graph."""
    397         return self._num_edges
    398 
    399     @property
    400     def has_weights(self):
    401         """Indicates if decoder has trainable weights."""
    402         return self._has_weights
    403 
    404     @property
    405     def edge_weights(self):
    406         """Trainable weights of the BP decoder."""
    407         if not self._has_weights:
    408             return []
    409         else:
    410             return self._edge_weights
    411 
    412     @property
    413     def output_dtype(self):
    414         """Output dtype of decoder."""
    415         return self._output_dtype
    416 
    417     @property
    418     def ie_c(self):
    419         "Extrinsic mutual information at check node."
    420         return self._ie_c
    421 
    422     @property
    423     def ie_v(self):
    424         "Extrinsic mutual information at variable node."
    425         return self._ie_v
    426 
    427     @property
    428     def num_iter(self):
    429         "Number of decoding iterations."
    430         return self._num_iter
    431 
    432     @num_iter.setter
    433     def num_iter(self, num_iter):
    434         "Number of decoding iterations."
    435         assert isinstance(num_iter, int), 'num_iter must be int.'
    436         assert num_iter>=0, 'num_iter cannot be negative.'
    437         self._num_iter = tf.constant(num_iter, dtype=tf.int32)
    438 
    439     @property
    440     def llr_max(self):
    441         """Max LLR value used for internal calculations and rate-matching."""
    442         return self._llr_max
    443 
    444     @llr_max.setter
    445     def llr_max(self, value):
    446         """Max LLR value used for internal calculations and rate-matching."""
    447         assert value>=0, 'llr_max cannot be negative.'
    448         self._llr_max = tf.cast(value, dtype=tf.float32)
    449 
    450     def show_weights(self, size=7):
    451         """Show histogram of trainable weights.
    452 
    453         Input
    454         -----
    455             size: float
    456                 Figure size of the matplotlib figure.
    457 
    458         """
    459         # only plot if weights exist
    460         if self._has_weights:
    461             weights = self._edge_weights.numpy()
    462 
    463             plt.figure(figsize=(size,size))
    464             plt.hist(weights, density=True, bins=20, align='mid')
    465             plt.xlabel('weight value')
    466             plt.ylabel('density')
    467             plt.grid(True, which='both', axis='both')
    468             plt.title('Weight Distribution')
    469         else:
    470             print("No weights to show.")
    471 
    472     #########################
    473     # Utility methods
    474     #########################
    475 
    476     def _gen_node_mask(self, con):
    477         """ Generates internal node masks indicating which msg index belongs
    478         to which node index.
    479         """
    480         ind = np.argsort(con)
    481         con = con[ind]
    482 
    483         node_mask = []
    484 
    485         cur_node = 0
    486         cur_mask = []
    487         for i in range(self._num_edges):
    488             if con[i] == cur_node:
    489                 cur_mask.append(ind[i])
    490             else:
    491                 node_mask.append(cur_mask)
    492                 cur_mask = [ind[i]]
    493                 cur_node += 1
    494         node_mask.append(cur_mask)
    495         return node_mask
    496 
    497     def _gen_node_mask_row(self, con):
    498         """ Defining the row split positions of a 1D vector consisting of all
    499         edges messages.
    500 
    501         Used to build a ragged Tensor of incoming node messages.
    502         """
    503         node_mask = [0] # the first element indicates the first node index (=0)
    504 
    505         cur_node = 0
    506         for i in range(self._num_edges):
    507             if con[i] != cur_node:
    508                 node_mask.append(i)
    509                 cur_node += 1
    510         node_mask.append(self._num_edges) # last element must be the number of
    511         # elements (delimiter)
    512         return node_mask
    513 
    514     def _vn_update(self, msg, llr_ch):
    515         """ Variable node update function.
    516 
    517         This function implements the (extrinsic) variable node update
    518         function. It takes the sum over all incoming messages ``msg`` excluding
    519         the intrinsic (= outgoing) message itself.
    520 
    521         Additionally, the channel LLR ``llr_ch`` is added to each message.
    522         """
    523         # aggregate all incoming messages per node
    524         x = tf.reduce_sum(msg, axis=1)
    525         x = tf.add(x, llr_ch)
    526 
    527         # TF2.9 does not support XLA for the addition of ragged tensors
    528         # the following code provides a workaround that supports XLA
    529 
    530         # subtract extrinsic message from node value
    531         # x = tf.expand_dims(x, axis=1)
    532         # x = tf.add(-msg, x)
    533         x = tf.ragged.map_flat_values(lambda x, y, row_ind :
    534                                       x + tf.gather(y, row_ind),
    535                                       -1.*msg,
    536                                       x,
    537                                       msg.value_rowids())
    538         return x
    539 
    540     def _where_ragged(self, msg):
    541         """Helper to replace 0 elements from ragged tensor (called with
    542         map_flat_values)."""
    543         return tf.where(tf.equal(msg, 0), tf.ones_like(msg) * 1e-12, msg)
    544 
    545     def _where_ragged_inv(self, msg):
    546         """Helper to replace small elements from ragged tensor (called with
    547         map_flat_values) with exact `0`."""
    548         msg_mod =  tf.where(tf.less(tf.abs(msg), 1e-7),
    549                             tf.zeros_like(msg),
    550                             msg)
    551         return msg_mod
    552 
    553     def _cn_update_tanh(self, msg):
    554         """Check node update function implementing the exact boxplus operation.
    555 
    556         This function implements the (extrinsic) check node update
    557         function. It calculates the boxplus function over all incoming messages
    558         "msg" excluding the intrinsic (=outgoing) message itself.
    559         The exact boxplus function is implemented by using the tanh function.
    560 
    561         The input is expected to be a ragged Tensor of shape
    562         `[num_cns, None, batch_size]`.
    563 
    564         Note that for numerical stability clipping is applied.
    565         """
    566 
    567         msg = msg / 2
    568         # tanh is not overloaded for ragged tensors
    569         msg = tf.ragged.map_flat_values(tf.tanh, msg) # tanh is not overloaded
    570 
    571         # for ragged tensors; map to flat tensor first
    572         msg = tf.ragged.map_flat_values(self._where_ragged, msg)
    573 
    574         msg_prod = tf.reduce_prod(msg, axis=1)
    575 
    576         # TF2.9 does not support XLA for the multiplication of ragged tensors
    577         # the following code provides a workaround that supports XLA
    578 
    579         # ^-1 to avoid division
    580         # Note this is (potentially) numerically unstable
    581         # msg = msg**-1 * tf.expand_dims(msg_prod, axis=1) # remove own edge
    582 
    583         msg = tf.ragged.map_flat_values(lambda x, y, row_ind :
    584                                         x * tf.gather(y, row_ind),
    585                                         msg**-1,
    586                                         msg_prod,
    587                                         msg.value_rowids())
    588 
    589         # Overwrite small (numerical zeros) message values with exact zero
    590         # these are introduced by the previous "_where_ragged" operation
    591         # this is required to keep the product stable (cf. _phi_update for log
    592         # sum implementation)
    593         msg = tf.ragged.map_flat_values(self._where_ragged_inv, msg)
    594 
    595         msg = tf.clip_by_value(msg,
    596                                clip_value_min=-self._atanh_clip_value,
    597                                clip_value_max=self._atanh_clip_value)
    598 
    599         # atanh is not overloaded for ragged tensors
    600         msg = 2 * tf.ragged.map_flat_values(tf.atanh, msg)
    601         return msg
    602 
    603     def _phi(self, x):
    604         """Helper function for the check node update.
    605 
    606         This function implements the (element-wise) `"_phi"` function as defined
    607         in [Ryan]_.
    608         """
    609         # the clipping values are optimized for tf.float32
    610         x = tf.clip_by_value(x, clip_value_min=8.5e-8, clip_value_max=16.635532)
    611         return tf.math.log(tf.math.exp(x)+1) - tf.math.log(tf.math.exp(x)-1)
    612 
    613     def _cn_update_phi(self, msg):
    614         """Check node update function implementing the exact boxplus operation.
    615 
    616         This function implements the (extrinsic) check node update function
    617         based on the numerically more stable `"_phi"` function (cf. [Ryan]_).
    618         It calculates the boxplus function over all incoming messages ``msg``
    619         excluding the intrinsic (=outgoing) message itself.
    620         The exact boxplus function is implemented by using the `"_phi"` function
    621         as in [Ryan]_.
    622 
    623         The input is expected to be a ragged Tensor of shape
    624         `[num_cns, None, batch_size]`.
    625 
    626         Note that for numerical stability clipping is applied.
    627         """
    628 
    629         sign_val = tf.sign(msg)
    630 
    631         # TF2.14 does not support XLA for tf.where and ragged tensors in
    632         # CPU mode. The following code provides a workaround that supports XLA
    633         # sign_val = tf.where(tf.equal(sign_val, 0),
    634         #                    tf.ones_like(sign_val),
    635         #                    sign_val)
    636         sign_val = tf.ragged.map_flat_values(lambda x :
    637                                              tf.where(tf.equal(x, 0),
    638                                              tf.ones_like(x),x),
    639                                              sign_val)
    640 
    641         sign_node = tf.reduce_prod(sign_val, axis=1)
    642 
    643         # TF2.9 does not support XLA for the multiplication of ragged tensors
    644         # the following code provides a workaround that supports XLA
    645 
    646         # sign_val = sign_val * tf.expand_dims(sign_node, axis=1)
    647         sign_val = tf.ragged.map_flat_values(lambda x, y, row_ind :
    648                                              x * tf.gather(y, row_ind),
    649                                              sign_val,
    650                                              sign_node,
    651                                              sign_val.value_rowids())
    652 
    653         msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign
    654 
    655         # apply _phi element-wise (does not support ragged Tensors)
    656         msg = tf.ragged.map_flat_values(self._phi, msg)
    657         msg_sum = tf.reduce_sum(msg, axis=1)
    658 
    659         # TF2.9 does not support XLA for the addition of ragged tensors
    660         # the following code provides a workaround that supports XLA
    661 
    662         # msg = tf.add( -msg, tf.expand_dims(msg_sum, axis=1)) # remove own edge
    663         msg = tf.ragged.map_flat_values(lambda x, y, row_ind :
    664                                         x + tf.gather(y, row_ind),
    665                                         -1.*msg,
    666                                         msg_sum,
    667                                         msg.value_rowids())
    668 
    669         # apply _phi element-wise (does not support ragged Tensors)
    670         msg = self._stop_ragged_gradient(sign_val) * tf.ragged.map_flat_values(
    671                                                             self._phi, msg)
    672         return msg
    673 
    674     def _stop_ragged_gradient(self, rt):
    675         """Helper function as TF 2.5 does not support ragged gradient
    676         stopping"""
    677         return rt.with_flat_values(tf.stop_gradient(rt.flat_values))
    678 
    679     def _sign_val_minsum(self, msg):
    680         """Helper to replace find sign-value during min-sum decoding.
    681         Must be called with `map_flat_values`."""
    682 
    683         sign_val = tf.sign(msg)
    684         sign_val = tf.where(tf.equal(sign_val, 0),
    685                             tf.ones_like(sign_val),
    686                             sign_val)
    687         return sign_val
    688 
    689     def _cn_update_minsum(self, msg):
    690         """ Check node update function implementing the min-sum approximation.
    691 
    692         This function approximates the (extrinsic) check node update
    693         function based on the min-sum approximation (cf. [Ryan]_).
    694         It calculates the "extrinsic" min function over all incoming messages
    695         ``msg`` excluding the intrinsic (=outgoing) message itself.
    696 
    697         The input is expected to be a ragged Tensor of shape
    698         `[num_vns, None, batch_size]`.
    699         """
    700 
    701         # a constant used to overwrite the first min
    702         LARGE_VAL = 10000. # pylint: disable=invalid-name
    703 
    704         # clip values for numerical stability
    705         msg = tf.clip_by_value(msg,
    706                                clip_value_min=-self._llr_max,
    707                                clip_value_max=self._llr_max)
    708 
    709         # calculate sign of outgoing msg and the node
    710         sign_val = tf.ragged.map_flat_values(self._sign_val_minsum, msg)
    711         sign_node = tf.reduce_prod(sign_val, axis=1)
    712 
    713         # TF2.9 does not support XLA for the multiplication of ragged tensors
    714         # the following code provides a workaround that supports XLA
    715 
    716         # sign_val = self._stop_ragged_gradient(sign_val) \
    717         #             * tf.expand_dims(sign_node, axis=1)
    718         sign_val = tf.ragged.map_flat_values(
    719                                         lambda x, y, row_ind:
    720                                         tf.multiply(x, tf.gather(y, row_ind)),
    721                                         self._stop_ragged_gradient(sign_val),
    722                                         sign_node,
    723                                         sign_val.value_rowids())
    724 
    725         # remove sign from messages
    726         msg = tf.ragged.map_flat_values(tf.abs, msg)
    727 
    728         # Calculate the extrinsic minimum per CN, i.e., for each message of
    729         # index i, find the smallest and the second smallest value.
    730         # However, in some cases the second smallest value may equal the
    731         # smallest value (multiplicity of mins).
    732         # Please note that this needs to be applied to raggedTensors, e.g.,
    733         # tf.top_k() is currently not supported and all ops must support graph
    734         # and XLA mode.
    735 
    736         # find min_value per node
    737         min_val = tf.reduce_min(msg, axis=1, keepdims=True)
    738 
    739         # TF2.9 does not support XLA for the subtraction of ragged tensors
    740         # the following code provides a workaround that supports XLA
    741 
    742         # and subtract min; the new array contains zero at the min positions
    743         # benefits from broadcasting; all other values are positive
    744         msg_min1 = tf.ragged.map_flat_values(lambda x, y, row_ind:
    745                                              x - tf.gather(y, row_ind),
    746                                              msg,
    747                                              tf.squeeze(min_val, axis=1),
    748                                              msg.value_rowids())
    749 
    750         # replace 0 (=min positions) with large value to ignore it for further
    751         # min calculations
    752         msg = tf.ragged.map_flat_values(
    753                             lambda x: tf.where(tf.equal(x, 0), LARGE_VAL, x),
    754                             msg_min1)
    755 
    756         # find the second smallest element (we add min_val as this has been
    757         # subtracted before)
    758         min_val_2 = tf.reduce_min(msg, axis=1, keepdims=True) + min_val
    759 
    760         # Detect duplicated minima (i.e., min_val occurs at two incoming
    761         # messages). As the LLRs per node are <LLR_MAX and we have
    762         # replace at least 1 position (position with message "min_val") by
    763         # LARGE_VAL, it holds for the sum < LARGE_VAL + node_degree*LLR_MAX.
    764         # If the sum > 2*LARGE_VAL, the multiplicity of the min is at least 2.
    765         node_sum = tf.reduce_sum(msg, axis=1, keepdims=True) - (2*LARGE_VAL-1.)
    766         # indicator that duplicated min was detected (per node)
    767         double_min = 0.5*(1-tf.sign(node_sum))
    768 
    769         # if a duplicate min occurred, both edges must have min_val, otherwise
    770         # the second smallest value is taken
    771         min_val_e = (1-double_min) * min_val + (double_min) * min_val_2
    772 
    773         # replace all values with min_val except the position where the min
    774         # occurred (=extrinsic min).
    775 
    776         # no XLA support for TF 2.15
    777         # msg_e = tf.where(msg==LARGE_VAL, min_val_e, min_val)
    778 
    779         min_1 = tf.squeeze(tf.gather(min_val, msg.value_rowids()), axis=1)
    780         min_e = tf.squeeze(tf.gather(min_val_e, msg.value_rowids()), axis=1)
    781         msg_e = tf.ragged.map_flat_values(
    782                         lambda x: tf.where(x==LARGE_VAL, min_e, min_1),
    783                         msg)
    784 
    785         # it seems like tf.where does not set the shape of tf.ragged properly
    786         # we need to ensure the shape manually
    787         msg_e = tf.ragged.map_flat_values(
    788                         lambda x: tf.ensure_shape(x, msg.flat_values.shape),
    789                         msg_e)
    790 
    791         # TF2.9 does not support XLA for the multiplication of ragged tensors
    792         # the following code provides a workaround that supports XLA
    793 
    794         # and apply sign
    795         #msg = sign_val * msg_e
    796         msg = tf.ragged.map_flat_values(tf.multiply,
    797                                         sign_val,
    798                                         msg_e)
    799 
    800         return msg
    801 
    802     def _mult_weights(self, x):
    803         """Multiply messages with trainable weights for weighted BP."""
    804         # transpose for simpler broadcasting of training variables
    805         x = tf.transpose(x, (1, 0))
    806         x = tf.math.multiply(x, self._edge_weights)
    807         x = tf.transpose(x, (1, 0))
    808         return x
    809 
    810     #########################
    811     # Keras layer functions
    812     #########################
    813 
    814     def build(self, input_shape):
    815         # Raise AssertionError if shape of x is invalid
    816         if self._stateful:
    817             assert(len(input_shape)==2), \
    818                 "For stateful decoding, a tuple of two inputs is expected."
    819             input_shape = input_shape[0]
    820 
    821         assert (input_shape[-1]==self._num_vns), \
    822                             'Last dimension must be of length n.'
    823         assert (len(input_shape)>=2), 'The inputs must have at least rank 2.'
    824 
    825     def call(self, inputs):
    826         """Iterative BP decoding function.
    827 
    828         This function performs ``num_iter`` belief propagation decoding
    829         iterations and returns the estimated codeword.
    830 
    831         Args:
    832         llr_ch or (llr_ch, msg_vn):
    833 
    834             llr_ch (tf.float32): Tensor of shape `[...,n]` containing the
    835                 channel logits/llr values.
    836 
    837             msg_vn (tf.float32) : Ragged tensor containing the VN
    838                 messages, or None. Required if ``stateful`` is set to True.
    839 
    840         Returns:
    841             `tf.float32`: Tensor of shape `[...,n]` containing
    842             bit-wise soft-estimates (or hard-decided bit-values) of all
    843             codeword bits.
    844 
    845         Raises:
    846             ValueError: If ``inputs`` is not of shape `[batch_size, n]`.
    847 
    848             InvalidArgumentError: When rank(``inputs``)<2.
    849         """
    850 
    851         # Extract inputs
    852         if self._stateful:
    853             llr_ch, msg_vn = inputs
    854         else:
    855             llr_ch = inputs
    856 
    857         tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.')
    858 
    859         # internal calculations still in tf.float32
    860         llr_ch = tf.cast(llr_ch, tf.float32)
    861 
    862         # clip llrs for numerical stability
    863         llr_ch = tf.clip_by_value(llr_ch,
    864                                   clip_value_min=-self._llr_max,
    865                                   clip_value_max=self._llr_max)
    866 
    867         # last dim must be of length n
    868         tf.debugging.assert_equal(tf.shape(llr_ch)[-1],
    869                                   self._num_vns,
    870                                   'Last dimension must be of length n.')
    871 
    872         llr_ch_shape = llr_ch.get_shape().as_list()
    873         new_shape = [-1, self._num_vns]
    874         llr_ch_reshaped = tf.reshape(llr_ch, new_shape)
    875 
    876         # must be done during call, as XLA fails otherwise due to ragged
    877         # indices placed on the CPU device.
    878         # create permutation index from cn perspective
    879         self._cn_mask_tf = tf.ragged.constant(self._gen_node_mask(self._cn_con),
    880                                               row_splits_dtype=tf.int32)
    881 
    882         # batch dimension is last dimension due to ragged tensor representation
    883         llr_ch = tf.transpose(llr_ch_reshaped, (1,0))
    884 
    885         llr_ch = -1. * llr_ch # logits are converted into "true" llrs
    886 
    887         # init internal decoder state if not explicitly
    888         # provided (e.g., required to restore decoder state for iterative
    889         # detection and decoding)
    890         # load internal state from previous iteration
    891         # required for iterative det./dec.
    892         if not self._stateful or msg_vn is None:
    893             msg_shape = tf.stack([tf.constant(self._num_edges),
    894                                    tf.shape(llr_ch)[1]],
    895                                    axis=0)
    896             msg_vn = tf.zeros(msg_shape, dtype=tf.float32)
    897         else:
    898             msg_vn = msg_vn.flat_values
    899 
    900         # track exit decoding trajectory; requires all-zero cw?
    901         if self._track_exit:
    902             self._ie_c = tf.zeros(self._num_iter+1)
    903             self._ie_v = tf.zeros(self._num_iter+1)
    904 
    905         # perform one decoding iteration
    906         # Remark: msg_vn cannot be ragged as input for tf.while_loop as
    907         # otherwise XLA will not be supported (with TF 2.5)
    908         def dec_iter(llr_ch, msg_vn, it):
    909             it += 1
    910 
    911             msg_vn = tf.RaggedTensor.from_row_splits(
    912                         values=msg_vn,
    913                         row_splits=tf.constant(self._vn_row_splits, tf.int32))
    914             # variable node update
    915             msg_vn = self._vn_update(msg_vn, llr_ch)
    916 
    917             # track exit decoding trajectory; requires all-zero cw
    918             if self._track_exit:
    919                 # neg values as different llr def is expected
    920                 mi = llr2mi(-1. * msg_vn.flat_values)
    921                 self._ie_v = tf.tensor_scatter_nd_add(self._ie_v,
    922                                                      tf.reshape(it, (1, 1)),
    923                                                      tf.reshape(mi, (1)))
    924 
    925             # scale outgoing vn messages (weighted BP); only if activated
    926             if self._has_weights:
    927                 msg_vn = tf.ragged.map_flat_values(self._mult_weights,
    928                                                    msg_vn)
    929             # permute edges into CN perspective
    930             msg_cn = tf.gather(msg_vn.flat_values, self._cn_mask_tf, axis=None)
    931 
    932             # check node update using the pre-defined function
    933             msg_cn = self._cn_update(msg_cn)
    934 
    935             # track exit decoding trajectory; requires all-zero cw?
    936             if self._track_exit:
    937                 # neg values as different llr def is expected
    938                 mi = llr2mi(-1.*msg_cn.flat_values)
    939                 # update pos i+1 such that first iter is stored as 0
    940                 self._ie_c = tf.tensor_scatter_nd_add(self._ie_c,
    941                                                      tf.reshape(it, (1, 1)),
    942                                                      tf.reshape(mi, (1)))
    943 
    944             # re-permute edges to variable node perspective
    945             msg_vn = tf.gather(msg_cn.flat_values, self._ind_cn_inv, axis=None)
    946             return llr_ch, msg_vn, it
    947 
    948         # stopping condition (required for tf.while_loop)
    949         def dec_stop(llr_ch, msg_vn, it): # pylint: disable=W0613
    950             return tf.less(it, self._num_iter)
    951 
    952         # start decoding iterations
    953         it = tf.constant(0)
    954         # maximum_iterations required for XLA
    955         _, msg_vn, _ = tf.while_loop(dec_stop,
    956                                      dec_iter,
    957                                      (llr_ch, msg_vn, it),
    958                                      parallel_iterations=1,
    959                                      maximum_iterations=self._num_iter)
    960 
    961 
    962         # raggedTensor for final marginalization
    963         msg_vn = tf.RaggedTensor.from_row_splits(
    964                         values=msg_vn,
    965                         row_splits=tf.constant(self._vn_row_splits, tf.int32))
    966 
    967         # marginalize and remove ragged Tensor
    968         x_hat = tf.add(llr_ch, tf.reduce_sum(msg_vn, axis=1))
    969 
    970         # restore batch dimension to first dimension
    971         x_hat = tf.transpose(x_hat, (1,0))
    972 
    973         x_hat = -1. * x_hat # convert llrs back into logits
    974 
    975         if self._hard_out: # hard decide decoder output if required
    976             x_hat = tf.cast(tf.less(0.0, x_hat), self._output_dtype)
    977 
    978         # Reshape c_short so that it matches the original input dimensions
    979         output_shape = llr_ch_shape
    980         output_shape[0] = -1 # overwrite batch dim (can be None in Keras)
    981 
    982         x_reshaped = tf.reshape(x_hat, output_shape)
    983 
    984         # cast output to output_dtype
    985         x_out = tf.cast(x_reshaped, self._output_dtype)
    986 
    987         if not self._stateful:
    988             return x_out
    989         else:
    990             return x_out, msg_vn
    991 
    992 class LDPC5GDecoder(LDPCBPDecoder):
    993     # pylint: disable=line-too-long
    994     r"""LDPC5GDecoder(encoder, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, return_infobits=True, prune_pcm=True, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs)
    995 
    996     (Iterative) belief propagation decoder for 5G NR LDPC codes.
    997 
    998     Inherits from :class:`~sionna.fec.ldpc.decoding.LDPCBPDecoder` and provides
    999     a wrapper for 5G compatibility, i.e., automatically handles puncturing and
   1000     shortening according to [3GPPTS38212_LDPC]_.
   1001 
   1002     Note that for full 5G 3GPP NR compatibility, the correct puncturing and
   1003     shortening patterns must be applied and, thus, the encoder object is
   1004     required as input.
   1005 
   1006     If required the decoder can be made trainable and is differentiable
   1007     (the training of some check node types may be not supported) following the
   1008     concept of "weighted BP" [Nachmani]_.
   1009 
   1010     For numerical stability, the decoder applies LLR clipping of
   1011     +/- 20 to the input LLRs.
   1012 
   1013     The class inherits from the Keras layer class and can be used as layer in a
   1014     Keras model.
   1015 
   1016     Parameters
   1017     ----------
   1018         encoder: LDPC5GEncoder
   1019             An instance of :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder`
   1020             containing the correct code parameters.
   1021 
   1022         trainable: bool
   1023             Defaults to False. If True, every outgoing variable node message is
   1024             scaled with a trainable scalar.
   1025 
   1026         cn_type: str
   1027             A string defaults to '"boxplus-phi"'. One of
   1028             {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where
   1029             '"boxplus"' implements the single-parity-check APP decoding rule.
   1030             '"boxplus-phi"' implements the numerical more stable version of
   1031             boxplus [Ryan]_.
   1032             '"minsum"' implements the min-approximation of the CN
   1033             update rule [Ryan]_.
   1034 
   1035         hard_out: bool
   1036             Defaults to True. If True, the decoder provides hard-decided
   1037             codeword bits instead of soft-values.
   1038 
   1039         track_exit: bool
   1040             Defaults to False. If True, the decoder tracks EXIT characteristics.
   1041             Note that this requires the all-zero CW as input.
   1042 
   1043         return_infobits: bool
   1044             Defaults to True. If True, only the `k` info bits (soft or
   1045             hard-decided) are returned. Otherwise all `n` positions are
   1046             returned.
   1047 
   1048         prune_pcm: bool
   1049             Defaults to True. If True, all punctured degree-1 VNs and
   1050             connected check nodes are removed from the decoding graph (see
   1051             [Cammerer]_ for details). Besides numerical differences, this should
   1052             yield the same decoding result but improved the decoding throughput
   1053             and reduces the memory footprint.
   1054 
   1055         num_iter: int
   1056             Defining the number of decoder iteration (no early stopping used at
   1057             the moment!).
   1058 
   1059         stateful: bool
   1060             Defaults to False. If True, the internal VN messages ``msg_vn``
   1061             from the last decoding iteration are returned, and ``msg_vn`` or
   1062             `None` needs to be given as a second input when calling the decoder.
   1063             This is required for iterative demapping and decoding.
   1064 
   1065         output_dtype: tf.DType
   1066             Defaults to tf.float32. Defines the output datatype of the layer
   1067             (internal precision remains tf.float32).
   1068 
   1069     Input
   1070     -----
   1071     llrs_ch or (llrs_ch, msg_vn):
   1072         Tensor or Tuple (only required if ``stateful`` is True):
   1073 
   1074     llrs_ch: [...,n], tf.float32
   1075         2+D tensor containing the channel logits/llr values.
   1076 
   1077     msg_vn: None or RaggedTensor, tf.float32
   1078         Ragged tensor of VN messages.
   1079         Required only if ``stateful`` is True.
   1080 
   1081     Output
   1082     ------
   1083         : [...,n] or [...,k], tf.float32
   1084             2+D Tensor of same shape as ``inputs`` containing
   1085             bit-wise soft-estimates (or hard-decided bit-values) of all
   1086             codeword bits. If ``return_infobits`` is True, only the `k`
   1087             information bits are returned.
   1088 
   1089         : RaggedTensor, tf.float32:
   1090             Tensor of VN messages.
   1091             Returned only if ``stateful`` is set to True.
   1092     Raises
   1093     ------
   1094         ValueError
   1095             If the shape of ``pcm`` is invalid or contains other
   1096             values than `0` or `1`.
   1097 
   1098         AssertionError
   1099             If ``trainable`` is not `bool`.
   1100 
   1101         AssertionError
   1102             If ``track_exit`` is not `bool`.
   1103 
   1104         AssertionError
   1105             If ``hard_out`` is not `bool`.
   1106 
   1107         AssertionError
   1108             If ``return_infobits`` is not `bool`.
   1109 
   1110         AssertionError
   1111             If ``encoder`` is not an instance of
   1112             :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder`.
   1113 
   1114         ValueError
   1115             If ``output_dtype`` is not {tf.float16, tf.float32, tf.
   1116             float64}.
   1117 
   1118         ValueError
   1119             If ``inputs`` is not of shape `[batch_size, n]`.
   1120 
   1121         ValueError
   1122             If ``num_iter`` is not an integer greater (or equal) `0`.
   1123 
   1124         InvalidArgumentError
   1125             When rank(``inputs``)<2.
   1126 
   1127     Note
   1128     ----
   1129         As decoding input logits
   1130         :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for
   1131         compatibility with the learning framework, but
   1132         internally llrs with definition
   1133         :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used.
   1134 
   1135         The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC
   1136         codes and, thus, supports arbitrary parity-check matrices.
   1137 
   1138         The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to
   1139         account for arbitrary node degrees. To avoid a performance degradation
   1140         caused by a severe indexing overhead, the batch-dimension is shifted to
   1141         the last dimension during decoding.
   1142 
   1143         If the decoder is made trainable [Nachmani]_, for performance
   1144         improvements only variable to check node messages are scaled as the VN
   1145         operation is linear and, thus, would not increase the expressive power
   1146         of the weights.
   1147     """
   1148 
   1149     def __init__(self,
   1150                  encoder,
   1151                  trainable=False,
   1152                  cn_type='boxplus-phi',
   1153                  hard_out=True,
   1154                  track_exit=False,
   1155                  return_infobits=True,
   1156                  prune_pcm=True,
   1157                  num_iter=20,
   1158                  stateful=False,
   1159                  output_dtype=tf.float32,
   1160                  **kwargs):
   1161 
   1162         # needs the 5G Encoder to access all 5G parameters
   1163         assert isinstance(encoder, LDPC5GEncoder), 'encoder must \
   1164                           be of class LDPC5GEncoder.'
   1165         self._encoder = encoder
   1166         pcm = encoder.pcm
   1167 
   1168         assert isinstance(return_infobits, bool), 'return_info must be bool.'
   1169         self._return_infobits = return_infobits
   1170 
   1171         assert isinstance(output_dtype, tf.DType), \
   1172                                 'output_dtype must be tf.DType.'
   1173         if output_dtype not in (tf.float16, tf.float32, tf.float64):
   1174             raise ValueError(
   1175                 'output_dtype must be {tf.float16, tf.float32, tf.float64}.')
   1176         self._output_dtype = output_dtype
   1177 
   1178         assert isinstance(stateful, bool), 'stateful must be bool.'
   1179         self._stateful = stateful
   1180 
   1181         assert isinstance(prune_pcm, bool), 'prune_pcm must be bool.'
   1182         # prune punctured degree-1 VNs and connected CNs. A punctured
   1183         # VN-1 node will always "send" llr=0 to the connected CN. Thus, this
   1184         # CN will only send 0 messages to all other VNs, i.e., does not
   1185         # contribute to the decoding process.
   1186         self._prune_pcm = prune_pcm
   1187         if prune_pcm:
   1188             # find index of first position with only degree-1 VN
   1189             dv = np.sum(pcm, axis=0) # VN degree
   1190             last_pos = encoder._n_ldpc
   1191             for idx in range(encoder._n_ldpc-1, 0, -1):
   1192                 if dv[0, idx]==1:
   1193                     last_pos = idx
   1194                 else:
   1195                     break
   1196             # number of filler bits
   1197             k_filler = self.encoder.k_ldpc - self.encoder.k
   1198             # number of punctured bits
   1199             nb_punc_bits = ((self.encoder.n_ldpc - k_filler)
   1200                                      - self.encoder.n - 2*self.encoder.z)
   1201             # effective codeword length after pruning of vn-1 nodes
   1202             self._n_pruned = np.max((last_pos, encoder._n_ldpc - nb_punc_bits))
   1203             self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned
   1204             # remove last CNs and VNs from pcm
   1205             pcm = pcm[:-self._nb_pruned_nodes, :-self._nb_pruned_nodes]
   1206 
   1207             #check for consistency
   1208             assert(self._nb_pruned_nodes>=0), "Internal error: number of \
   1209                         pruned nodes must be positive."
   1210         else:
   1211             self._nb_pruned_nodes = 0
   1212             # no pruning; same length as before
   1213             self._n_pruned = encoder._n_ldpc
   1214 
   1215         super().__init__(pcm,
   1216                          trainable,
   1217                          cn_type,
   1218                          hard_out,
   1219                          track_exit,
   1220                          num_iter=num_iter,
   1221                          stateful=stateful,
   1222                          output_dtype=output_dtype,
   1223                          **kwargs)
   1224 
   1225     #########################################
   1226     # Public methods and properties
   1227     #########################################
   1228 
   1229     @property
   1230     def encoder(self):
   1231         """LDPC Encoder used for rate-matching/recovery."""
   1232         return self._encoder
   1233 
   1234     #########################
   1235     # Keras layer functions
   1236     #########################
   1237 
   1238     def build(self, input_shape):
   1239         """Build model."""
   1240         if self._stateful:
   1241             assert(len(input_shape)==2), \
   1242                 "For stateful decoding, a tuple of two inputs is expected."
   1243             input_shape = input_shape[0]
   1244 
   1245         # check input dimensions for consistency
   1246         assert (input_shape[-1]==self.encoder.n), \
   1247                                 'Last dimension must be of length n.'
   1248         assert (len(input_shape)>=2), 'The inputs must have at least rank 2.'
   1249 
   1250         self._old_shape_5g = input_shape
   1251 
   1252     def call(self, inputs):
   1253         """Iterative BP decoding function.
   1254 
   1255         This function performs ``num_iter`` belief propagation decoding
   1256         iterations and returns the estimated codeword.
   1257 
   1258         Args:
   1259             inputs (tf.float32): Tensor of shape `[...,n]` containing the
   1260                 channel logits/llr values.
   1261 
   1262         Returns:
   1263             `tf.float32`: Tensor of shape `[...,n]` or `[...,k]`
   1264             (``return_infobits`` is True) containing bit-wise soft-estimates
   1265             (or hard-decided bit-values) of all codeword bits (or info
   1266             bits, respectively).
   1267 
   1268         Raises:
   1269             ValueError: If ``inputs`` is not of shape `[batch_size, n]`.
   1270 
   1271             ValueError: If ``num_iter`` is not an integer greater (or equal)
   1272                 `0`.
   1273 
   1274             InvalidArgumentError: When rank(``inputs``)<2.
   1275         """
   1276 
   1277         # Extract inputs
   1278         if self._stateful:
   1279             llr_ch, msg_vn = inputs
   1280         else:
   1281             llr_ch = inputs
   1282 
   1283         tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.')
   1284 
   1285         llr_ch_shape = llr_ch.get_shape().as_list()
   1286         new_shape = [-1, llr_ch_shape[-1]]
   1287         llr_ch_reshaped = tf.reshape(llr_ch, new_shape)
   1288         batch_size = tf.shape(llr_ch_reshaped)[0]
   1289 
   1290         # invert if rate-matching output interleaver was applied as defined in
   1291         # Sec. 5.4.2.2 in 38.212
   1292         if self._encoder.num_bits_per_symbol is not None:
   1293             llr_ch_reshaped = tf.gather(llr_ch_reshaped,
   1294                                         self._encoder.out_int_inv,
   1295                                         axis=-1)
   1296 
   1297 
   1298         # undo puncturing of the first 2*Z bit positions
   1299         llr_5g = tf.concat(
   1300             [tf.zeros([batch_size, 2*self.encoder.z], self._output_dtype),
   1301                           llr_ch_reshaped],
   1302                           1)
   1303 
   1304         # undo puncturing of the last positions
   1305         # total length must be n_ldpc, while llr_ch has length n
   1306         # first 2*z positions are already added
   1307         # -> add n_ldpc - n - 2Z punctured positions
   1308         k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits
   1309         nb_punc_bits = ((self.encoder.n_ldpc - k_filler)
   1310                                      - self.encoder.n - 2*self.encoder.z)
   1311 
   1312 
   1313         llr_5g = tf.concat([llr_5g,
   1314                    tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes],
   1315                             self._output_dtype)],
   1316                             1)
   1317 
   1318         # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max)
   1319         # the first k positions are the systematic bits
   1320         x1 = tf.slice(llr_5g, [0,0], [batch_size, self.encoder.k])
   1321 
   1322         # parity part
   1323         nb_par_bits = (self.encoder.n_ldpc - k_filler
   1324                        - self.encoder.k - self._nb_pruned_nodes)
   1325         x2 = tf.slice(llr_5g,
   1326                       [0, self.encoder.k],
   1327                       [batch_size, nb_par_bits])
   1328 
   1329         # negative sign due to logit definition
   1330         z = -tf.cast(self._llr_max, self._output_dtype) \
   1331             * tf.ones([batch_size, k_filler], self._output_dtype)
   1332 
   1333         llr_5g = tf.concat([x1, z, x2], 1)
   1334 
   1335         # and execute the decoder
   1336         if not self._stateful:
   1337             x_hat = super().call(llr_5g)
   1338         else:
   1339             x_hat,msg_vn = super().call([llr_5g, msg_vn]) # pylint: disable=used-before-assignment
   1340 
   1341         if self._return_infobits: # return only info bits
   1342             # reconstruct u_hat # code is systematic
   1343             u_hat = tf.slice(x_hat, [0,0], [batch_size, self.encoder.k])
   1344             # Reshape u_hat so that it matches the original input dimensions
   1345             output_shape = llr_ch_shape[0:-1] + [self.encoder.k]
   1346             # overwrite first dimension as this could be None (Keras)
   1347             output_shape[0] = -1
   1348             u_reshaped = tf.reshape(u_hat, output_shape)
   1349 
   1350             # enable other output datatypes than tf.float32
   1351             u_out = tf.cast(u_reshaped, self._output_dtype)
   1352 
   1353             if not self._stateful:
   1354                 return u_out
   1355             else:
   1356                 return u_out, msg_vn
   1357 
   1358         else: # return all codeword bits
   1359             # the transmitted CW bits are not the same as used during decoding
   1360             # cf. last parts of 5G encoding function
   1361 
   1362             # remove last dim
   1363             x = tf.reshape(x_hat, [batch_size, self._n_pruned])
   1364 
   1365             # remove filler bits at pos (k, k_ldpc)
   1366             x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k])
   1367 
   1368             x_no_filler2 = tf.slice(x,
   1369                                     [0, self.encoder.k_ldpc],
   1370                                     [batch_size,
   1371                                     self._n_pruned-self.encoder.k_ldpc])
   1372 
   1373             x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1)
   1374 
   1375             # shorten the first 2*Z positions and end after n bits
   1376             x_short = tf.slice(x_no_filler,
   1377                                [0, 2*self.encoder.z],
   1378                                [batch_size, self.encoder.n])
   1379 
   1380             # if used, apply rate-matching output interleaver again as
   1381             # Sec. 5.4.2.2 in 38.212
   1382             if self._encoder.num_bits_per_symbol is not None:
   1383                 x_short = tf.gather(x_short, self._encoder.out_int, axis=-1)
   1384 
   1385             # Reshape x_short so that it matches the original input dimensions
   1386             # overwrite first dimension as this could be None (Keras)
   1387             llr_ch_shape[0] = -1
   1388             x_short= tf.reshape(x_short, llr_ch_shape)
   1389 
   1390             # enable other output datatypes than tf.float32
   1391             x_out = tf.cast(x_short, self._output_dtype)
   1392 
   1393             if not self._stateful:
   1394                 return x_out
   1395             else:
   1396                 return x_out, msg_vn