anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

decoding.py (17046B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for decoding of linear codes."""
      6 
      7 import tensorflow as tf
      8 import numpy as np
      9 import scipy as sp # for sparse H matrix computations
     10 from tensorflow.keras.layers import Layer
     11 from sionna.fec.utils import pcm2gm, int_mod_2, make_systematic
     12 from sionna.utils import hard_decisions
     13 import itertools
     14 
     15 class OSDecoder(Layer):
     16     # pylint: disable=line-too-long
     17     r"""OSDecoder(enc_mat=None, t=0, is_pcm=False, encoder=None, dtype=tf.float32, **kwargs)
     18 
     19     Ordered statistics decoding (OSD) for binary, linear block codes.
     20 
     21     This layer implements the OSD algorithm as proposed in [Fossorier]_ and,
     22     thereby, approximates maximum likelihood decoding for a sufficiently large
     23     order :math:`t`. The algorithm works for arbitrary linear block codes, but
     24     has a high computational complexity for long codes.
     25 
     26     The algorithm consists of the following steps:
     27 
     28         1. Sort LLRs according to their reliability and apply the same column
     29         permutation to the generator matrix.
     30 
     31         2. Bring the permuted generator matrix into its systematic form
     32         (so-called *most-reliable basis*).
     33 
     34         3. Hard-decide and re-encode the :math:`k` most reliable bits and
     35         discard the remaining :math:`n-k` received positions.
     36 
     37         4. Generate all possible error patterns up to :math:`t` errors in the
     38         :math:`k` most reliable positions find the most likely codeword within
     39         these candidates.
     40 
     41     This implementation of the OSD algorithm uses the LLR-based distance metric
     42     from [Stimming_LLR_OSD]_ which simplifies the handling of higher-order
     43     modulation schemes.
     44 
     45     The class inherits from the Keras layer class and can be used as layer in a
     46     Keras model.
     47 
     48     Parameters
     49     ----------
     50     enc_mat : [k, n] or [n-k, n], ndarray
     51         Binary generator matrix of shape `[k, n]`. If ``is_pcm`` is
     52         True, ``enc_mat`` is interpreted as parity-check matrix of shape
     53         `[n-k, n]`.
     54 
     55     t : int
     56         Order of the OSD algorithm
     57 
     58     is_pcm: bool
     59         Defaults to False. If True, ``enc_mat`` is interpreted as parity-check
     60         matrix.
     61 
     62     encoder: Layer
     63         Keras layer that implements a FEC encoder.
     64         If not None, ``enc_mat`` will be ignored and the code as specified by he
     65         encoder is used to initialize OSD.
     66 
     67     dtype: tf.DType
     68         Defaults to `tf.float32`. Defines the datatype for the output dtype.
     69 
     70     Input
     71     -----
     72     llrs_ch: [...,n], tf.float32
     73         2+D tensor containing the channel logits/llr values.
     74 
     75     Output
     76     ------
     77         : [...,n], tf.float32
     78             2+D Tensor of same shape as ``llrs_ch`` containing
     79             binary hard-decisions of all codeword bits.
     80 
     81     Note
     82     ----
     83     OS decoding is of high complexity and is only feasible for small values of
     84     :math:`t` as :math:`{n \choose t}` patterns must be evaluated. The
     85     advantage of OSD is that it works for arbitrary linear block codes and
     86     provides an estimate of the expected ML performance for sufficiently large
     87     :math:`t`. However, for some code families, more efficient decoding
     88     algorithms with close to ML performance exist which can exploit certain
     89     code specific properties. Examples of such decoders are the
     90     :class:`~sionna.fec.conv.ViterbiDecoder` algorithm for  convolutional codes
     91     or the :class:`~sionna.fec.polar.decoding.PolarSCLDecoder` for Polar codes
     92     (for a sufficiently large list size).
     93 
     94     It is recommended to run the decoder in XLA mode as it
     95     significantly reduces the memory complexity.
     96     """
     97 
     98     def __init__(self,
     99                  enc_mat=None,
    100                  t=0,
    101                  is_pcm=False,
    102                  encoder=None,
    103                  dtype=tf.float32,
    104                  **kwargs):
    105 
    106         super().__init__(dtype=dtype, **kwargs)
    107 
    108         assert isinstance(is_pcm, bool), 'is_pcm must be bool.'
    109 
    110         self._llr_max = 100. # internal clipping value for llrs
    111 
    112         if enc_mat is not None:
    113             # check that gm is binary
    114             if isinstance(enc_mat, np.ndarray):
    115                 assert np.array_equal(enc_mat, enc_mat.astype(bool)), \
    116                     'PC matrix must be binary.'
    117             elif isinstance(enc_mat, sp.sparse.csr_matrix):
    118                 assert np.array_equal(enc_mat.data, enc_mat.data.astype(bool)),\
    119                     'PC matrix must be binary.'
    120             elif isinstance(enc_mat, sp.sparse.csc_matrix):
    121                 assert np.array_equal(enc_mat.data, enc_mat.data.astype(bool)),\
    122                     'PC matrix must be binary.'
    123             else:
    124                 raise TypeError("Unsupported dtype of pcm.")
    125 
    126         if dtype not in (tf.float16, tf.float32, tf.float64):
    127             raise ValueError(
    128                 'dtype must be {tf.float16, tf.float32, tf.float64}.')
    129 
    130         assert (int(t)==t), "t must be int."
    131         self._t = int(t)
    132 
    133         if encoder is not None:
    134             # test that encoder is already initialized (relevant for conv codes)
    135             if encoder.k is None:
    136                 raise AttributeError("It seems as if the encoder is not "\
    137                                      "initialized or has no attribute k.")
    138             # encode identity matrix to get k basis vectors of the code
    139             u = tf.expand_dims(tf.eye(encoder.k), axis=0)
    140             # encode and remove batch_dim
    141             self._gm = tf.cast(tf.squeeze(encoder(u), axis=0), self.dtype)
    142         else:
    143             assert (enc_mat is not None),\
    144                 "enc_mat cannot be None if no encoder is provided."
    145             if is_pcm:
    146                 gm = pcm2gm(enc_mat)
    147             else:
    148                 # check if gm is of full rank (raise error otherwise)
    149                 make_systematic(enc_mat)
    150                 gm = enc_mat
    151             self._gm = tf.constant(gm, dtype=self.dtype)
    152 
    153         self._k = self._gm.shape[0]
    154         self._n = self._gm.shape[1]
    155 
    156         # init error patterns
    157         num_patterns = self._num_error_patterns(self._n, self._t)
    158 
    159         # storage/computational complexity scales with n
    160         num_symbols = num_patterns * self._n
    161         if num_symbols>1e9: # number still to be optimized
    162             print(f"Note: Required memory complexity is large for the "\
    163                   f"given code parameters and t={t}. Please consider small " \
    164                   f"batch-sizes to keep the inference complexity small and " \
    165                   f"activate XLA mode if possible." )
    166         if num_symbols>1e11: # number still to be optimized
    167             raise ResourceWarning("Due to its high complexity, OSD is not " \
    168                                  "feasible for the selected parameters. " \
    169                                  "Please consider using a smaller value for t.")
    170 
    171         # pre-compute all error patterns
    172         self._err_patterns = []
    173         for t_i in range(1, t+1):
    174             self._err_patterns.append(self._gen_error_patterns(self._k, t_i))
    175 
    176     #########################################
    177     # Public methods and properties
    178     #########################################
    179 
    180     @property
    181     def gm(self):
    182         """Generator matrix of the code"""
    183         return self._gm
    184 
    185     @property
    186     def n(self):
    187         """Codeword length"""
    188         return self._n
    189 
    190     @property
    191     def k(self):
    192         """Number of information bits per codeword"""
    193         return self._k
    194 
    195     @property
    196     def t(self):
    197         """Order of the OSD algorithm"""
    198         return self._t
    199 
    200     #########################
    201     # Utility methods
    202     #########################
    203 
    204     def _num_error_patterns(self, n, t):
    205         r"""Returns number of possible error patterns for t errors in n
    206         positions, i.e., calculates :math:`{n \choose t}`.
    207 
    208         Input
    209         -----
    210         n: int
    211             length of vector.
    212 
    213         t: int
    214             number of errors.
    215         """
    216         return sp.special.comb(n, t, exact=True, repetition=False)
    217 
    218     def _gen_error_patterns(self, n, t):
    219         r"""Returns list of all possible error patterns for t errors in n
    220         positions.
    221 
    222         Input
    223         -----
    224         n: int
    225             Length of vector.
    226 
    227         t: int
    228             Number of errors.
    229 
    230         Output
    231         ------
    232         : [num_patterns, t], tf.int32
    233             Tensor of size `num_patterns`=:math:`{n \choose t}` containing the
    234             t error indices.
    235         """
    236 
    237         err_patterns = []
    238         for p in itertools.combinations(range(n), t):
    239             err_patterns.append(p)
    240 
    241         return tf.constant(err_patterns)
    242 
    243     def _get_dist(self, llr, c_hat):
    244         """Distance function used for ML candidate selection.
    245 
    246         Currently, the distance metric from Polar decoding [Stimming_LLR_OSD]_
    247         literature is implemented.
    248 
    249         Input
    250         -----
    251         llr: [bs, n], tf.float32
    252             Received llrs of the channel observations.
    253 
    254         c_hat: [bs, num_cand, n], tf.float32
    255             Candidate codewords for which the distance to ``llr`` shall be
    256             evaluated.
    257 
    258         Output
    259         ------
    260         : [bs, num_cand], tf.float32
    261             Distance between ``llr`` and ``c_hat`` for each of the `num_cand`
    262             codeword candidates.
    263 
    264         Reference
    265         ---------
    266         [Stimming_LLR_OSD] Alexios Balatsoukas-Stimming, Mani Bastani Parizi,
    267         Andreas Burg, "LLR-Based Successive Cancellation List Decoding
    268         of Polar Codes." IEEE Trans Signal Processing, 2015.
    269         """
    270 
    271         # broadcast llr to all codeword candidates
    272         llr = tf.expand_dims(llr, axis=1)
    273         llr_sign = llr * (-2.*c_hat + 1.) # apply BPSK mapping
    274 
    275         d = tf.math.log(1. + tf.exp(llr_sign))
    276         return tf.reduce_mean(d, axis=2)
    277 
    278     def _find_min_dist(self, llr_ch, ep, gm_mrb, c):
    279         r"""Find error pattern which leads to minimum distance.
    280 
    281         Input
    282         -----
    283         llr_ch: [bs, n], tf.float32
    284             Channel observations as llrs after mrb sorting.
    285 
    286         ep: [num_patterns, t], tf.int32
    287             Tensor of size `num_patterns`=:math:`{n \choose t}` containing the
    288             t error indices.
    289 
    290         gm_mrb: [bs, k, n] tf.float32
    291             Most reliable basis for each batch example.
    292 
    293         c: [bs, n], tf.float32
    294             Most reliable base codeword.
    295 
    296         Output
    297         ------
    298         : [bs], tf.float32
    299             Distance of the most likely codeword to ``llr_ch`` after testing all
    300             ``ep`` error patterns.
    301 
    302         : [bs, n], tf.float32
    303             The most likely codeword after testing against all ``ep`` error
    304             patterns.
    305         """
    306 
    307         # generate all test candidates for each possible error pattern
    308         e = tf.gather(gm_mrb, ep, axis=1)
    309         e = tf.reduce_sum(e, axis=2)
    310         e += tf.expand_dims(c, axis=1) # add to mrb codeword
    311         c_cand = int_mod_2(e) # apply modulo-2 operation
    312 
    313         # calculate distance for each candidate
    314         # where c_cand has shape [bs, num_patterns, n]
    315         d = self._get_dist(llr_ch, c_cand)
    316 
    317         # find candidate index with smallest metric
    318         idx = tf.argmin(d, axis=1)
    319         c_hat = tf.gather(c_cand, idx, batch_dims=1)
    320         d = tf.gather(d, idx, batch_dims=1)
    321         return d, c_hat
    322 
    323     def _find_mrb(self, gm):
    324         """Find most reliable basis for all generator matrices in batch.
    325 
    326         Input
    327         -----
    328         gm: [bs, k, n] tf.float32
    329             Generator matrix for each batch example.
    330 
    331         Output
    332         ------
    333         gm_mrb: [bs, k, n] tf.float32
    334             Most reliable basis in systematic form for each batch example.
    335 
    336         idx_sort: [bs, n] tf.int64
    337             Indices of column permutations applied during mrb calculation.
    338         """
    339 
    340         bs = tf.shape(gm)[0]
    341         s = gm.shape
    342         idx_pivot = tf.TensorArray(tf.int64, self._k, dynamic_size=False)
    343 
    344         #  bring gm in systematic form (by so-called pivot method)
    345         for idx_c in tf.range(self._k):
    346 
    347             # ensure shape to avoid XLA incompatibility with TF2.11 in tf.range
    348             gm = tf.ensure_shape(gm, s)
    349 
    350             # find pivot (i.e., first pos with index 1)
    351             idx_p = tf.argmax(gm[:, idx_c, :], axis=-1)
    352 
    353             # store pivot position
    354             idx_pivot = idx_pivot.write(idx_c, idx_p)
    355 
    356             # and eliminate the column in all other rows
    357             r = tf.gather(gm, idx_p, batch_dims=1, axis=-1)
    358 
    359             # ignore idx_c row itself by adding all-zero row
    360             rz = tf.zeros((bs, 1), dtype=self.dtype)
    361             r = tf.concat([r[:,:idx_c], rz , r[:,idx_c+1:]], axis=1)
    362 
    363             # mask is zero at all rows where pivot position of this row is zero
    364             mask = tf.tile(tf.expand_dims(r, axis=-1), (1, 1, self._n))
    365             gm_off = tf.expand_dims(gm[:,idx_c,:], axis=1)
    366 
    367             # update all row in parallel
    368             gm = int_mod_2(gm + mask * gm_off) # account for binary operations
    369 
    370         # pivot positions
    371         idx_pivot = tf.transpose(idx_pivot.stack())
    372 
    373         # find non-pivot positions (i.e., all indices that are not part of
    374         # idx_pivot)
    375 
    376         # solution 1: sets.difference() does not support XLA (unknown shapes)
    377         #idx_parity = tf.sets.difference(idx_range, idx_pivot)
    378         #idx_parity = tf.sparse.to_dense(idx_parity)
    379         #idx_pivot = tf.reshape(idx_pivot, (-1, self._n)) # ensure shape
    380 
    381         # solution 2: add large offset to pivot indices and sorting gives the
    382         # indices of interest
    383         idx_range = tf.tile(tf.expand_dims(
    384                                 tf.range(self._n, dtype=tf.int64), axis=0),
    385                             (bs, 1))
    386         # large value to be added to irrelevant indices
    387         updates = self._n * tf.ones((bs, self._k), tf.int64)
    388 
    389         # generate indices for tf.scatter_nd_add
    390         s = tf.shape(idx_pivot, tf.int64)
    391         ii, _ = tf.meshgrid(tf.range(s[0]), tf.range(s[1]), indexing='ij')
    392         idx_updates = tf.stack([ii, idx_pivot], axis=-1)
    393 
    394         # add large value to pivot positions
    395         idx = tf.tensor_scatter_nd_add(idx_range, idx_updates, updates)
    396 
    397         # sort and slice first n-k indices (equals parity positions)
    398         idx_parity = tf.cast(tf.argsort(idx)[:,:self._n-self._k], tf.int64)
    399 
    400         idx_sort = tf.concat([idx_pivot, idx_parity], axis=1)
    401 
    402         # permute gm according to indices idx_sort
    403         gm = tf.gather(gm, idx_sort, batch_dims=1, axis=-1)
    404 
    405         return gm, idx_sort
    406 
    407     #########################
    408     # Keras layer functions
    409     #########################
    410 
    411     def build(self, input_shape):
    412         """Nothing to build, but check for valid shapes."""
    413 
    414         assert input_shape[-1]==self._n, "Invalid input shape."
    415 
    416     def call(self, inputs):
    417         r"""Applies ordered statistic decoding to inputs.
    418 
    419         Remark: the decoder is implemented with llr definition
    420         llr = p(x=1)/p(x=0).
    421         """
    422 
    423         # flatten batch-dim
    424         input_shape = tf.shape(inputs)
    425         llr_ch = tf.reshape(inputs, (-1, self._n))
    426         llr_ch = tf.cast(llr_ch, self.dtype)
    427         bs = tf.shape(llr_ch)[0]
    428 
    429         # clip inputs
    430         llr_ch = tf.clip_by_value(llr_ch, -self._llr_max, self._llr_max)
    431 
    432         # step 1: sort LLRs
    433         idx_sort = tf.argsort(tf.abs(llr_ch), direction="DESCENDING")
    434 
    435         # permute gm per batch sample individually
    436         gm = tf.broadcast_to(tf.expand_dims(self._gm, axis=0),
    437                              (bs, self._k,self._n))
    438         gm_sort = tf.gather(gm, idx_sort, batch_dims=1, axis=-1)
    439 
    440         # step 2: Find most reliable basis (MRB)
    441         gm_mrb, idx_mrb = self._find_mrb(gm_sort)
    442 
    443         # apply corresponding mrb permutations
    444         idx_sort = tf.gather(idx_sort, idx_mrb, batch_dims=1)
    445         llr_sort = tf.gather(llr_ch, idx_sort, batch_dims=1)
    446 
    447         # find inverse permutation for final output
    448         idx_sort_inv = tf.argsort(idx_sort)
    449 
    450         # hard-decide k most reliable positions and encode
    451         u_hd = hard_decisions(llr_sort[:,0:self._k])
    452         u_hd = tf.expand_dims(u_hd, axis=1)
    453         c = tf.squeeze(tf.matmul(u_hd, gm_mrb), axis=1)
    454         c = int_mod_2(c)
    455 
    456         # and search for most likely pattern
    457         # _get_dist expects a list of candidates, thus expand_dims to [bs, 1, n]
    458         d_best = self._get_dist(llr_sort, tf.expand_dims(c, axis=1))
    459         d_best = tf.squeeze(d_best, axis=1)
    460         c_hat_best = c
    461 
    462         # known in advance - can be unrolled
    463         for ep in self._err_patterns:
    464             # compute distance for all candidate codewords
    465             d, c_hat = self._find_min_dist(llr_sort, ep, gm_mrb, c)
    466 
    467             # select most likely candidate
    468             ind = tf.expand_dims(d<d_best, axis=1)
    469             c_hat_best = tf.where(ind, c_hat, c_hat_best)
    470             d_best = tf.where(d<d_best, d, d_best)
    471 
    472         # undo permutations for final codeword
    473         c_hat_best = tf.gather(c_hat_best, idx_sort_inv, axis=1, batch_dims=1)
    474         # input shape
    475         c_hat = tf.reshape(c_hat_best, input_shape)
    476 
    477         return c_hat