anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

encoding.py (26847B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for LDPC channel encoding and utility functions."""
      6 
      7 import tensorflow as tf
      8 import numpy as np
      9 import scipy as sp
     10 from tensorflow.keras.layers import Layer
     11 from importlib_resources import files, as_file
     12 from . import codes # pylint: disable=relative-beyond-top-level
     13 import numbers # to check if n, k are numbers
     14 
     15 from sionna.fec.linear import AllZeroEncoder as AllZeroEncoder_new
     16 
     17 class LDPC5GEncoder(Layer):
     18     # pylint: disable=line-too-long
     19     """LDPC5GEncoder(k, n, num_bits_per_symbol=None, dtype=tf.float32, **kwargs)
     20 
     21     5G NR LDPC Encoder following the 3GPP NR Initiative [3GPPTS38212_LDPC]_
     22     including rate-matching.
     23 
     24     The class inherits from the Keras layer class and can be used as layer in a
     25     Keras model.
     26 
     27     Parameters
     28     ----------
     29         k: int
     30             Defining the number of information bit per codeword.
     31 
     32         n: int
     33             Defining the desired codeword length.
     34 
     35         num_bits_per_symbol: int or None
     36             Defining the number of bits per QAM symbol. If this parameter is
     37             explicitly provided, the codeword will be interleaved after
     38             rate-matching as specified in Sec. 5.4.2.2 in [3GPPTS38212_LDPC]_.
     39 
     40         dtype: tf.DType
     41             Defaults to `tf.float32`. Defines the output datatype of the layer
     42             (internal precision remains `tf.uint8`).
     43 
     44     Input
     45     -----
     46         inputs: [...,k], tf.float32
     47             2+D tensor containing the information bits to be
     48             encoded.
     49 
     50     Output
     51     ------
     52         : [...,n], tf.float32
     53             2+D tensor of same shape as inputs besides last dimension has
     54             changed to `n` containing the encoded codeword bits.
     55 
     56     Attributes
     57     ----------
     58         k: int
     59             Defining the number of information bit per codeword.
     60 
     61         n: int
     62             Defining the desired codeword length.
     63 
     64         coderate: float
     65             Defining the coderate r= ``k`` / ``n``.
     66 
     67         n_ldpc: int
     68             An integer defining the total codeword length (before
     69             punturing) of the lifted parity-check matrix.
     70 
     71         k_ldpc: int
     72             An integer defining the total information bit length
     73             (before zero removal) of the lifted parity-check matrix. Gap to
     74             ``k`` must be filled with so-called filler bits.
     75 
     76         num_bits_per_symbol: int or None.
     77             Defining the number of bits per QAM symbol. If this parameter is
     78             explicitly provided, the codeword will be interleaved after
     79             rate-matching as specified in Sec. 5.4.2.2 in [3GPPTS38212_LDPC]_.
     80 
     81         out_int: [n], ndarray of int
     82             Defining the rate-matching output interleaver sequence.
     83 
     84         out_int_inv: [n], ndarray of int
     85             Defining the inverse rate-matching output interleaver sequence.
     86 
     87         _check_input: bool
     88             A boolean that indicates whether the input vector
     89             during call of the layer should be checked for consistency (i.e.,
     90             binary).
     91 
     92         _bg: str
     93             Denoting the selected basegraph (either `bg1` or `bg2`).
     94 
     95         _z: int
     96             Denoting the lifting factor.
     97 
     98         _i_ls: int
     99             Defining which version of the basegraph to load.
    100             Can take values between 0 and 7.
    101 
    102         _k_b: int
    103             Defining the number of `information bit columns` in the
    104             basegraph. Determined by the code design procedure in
    105             [3GPPTS38212_LDPC]_.
    106 
    107         _bm: ndarray
    108             An ndarray defining the basegraph.
    109 
    110         _pcm: sp.sparse.csr_matrix
    111             A sparse matrix of shape `[k_ldpc-n_ldpc, n_ldpc]`
    112             containing the sparse parity-check matrix.
    113 
    114     Raises
    115     ------
    116         AssertionError
    117             If ``k`` is not `int`.
    118 
    119         AssertionError
    120             If ``n`` is not `int`.
    121 
    122         ValueError
    123             If ``code_length`` is not supported.
    124 
    125         ValueError
    126             If `dtype` is not supported.
    127 
    128         ValueError
    129             If ``inputs`` contains other values than `0` or `1`.
    130 
    131         InvalidArgumentError
    132             When rank(``inputs``)<2.
    133 
    134         InvalidArgumentError
    135             When shape of last dim is not ``k``.
    136 
    137     Note
    138     ----
    139         As specified in [3GPPTS38212_LDPC]_, the encoder also performs
    140         puncturing and shortening. Thus, the corresponding decoder needs to
    141         `invert` these operations, i.e., must be compatible with the 5G
    142         encoding scheme.
    143     """
    144 
    145     def __init__(self,
    146                  k,
    147                  n,
    148                  num_bits_per_symbol=None,
    149                  dtype=tf.float32,
    150                  **kwargs):
    151 
    152         super().__init__(dtype=dtype, **kwargs)
    153 
    154         assert isinstance(k, numbers.Number), "k must be a number."
    155         assert isinstance(n, numbers.Number), "n must be a number."
    156         k = int(k) # k or n can be float (e.g. as result of n=k*r)
    157         n = int(n) # k or n can be float (e.g. as result of n=k*r)
    158 
    159         if dtype is not tf.float32:
    160             print("Note: decoder uses tf.float32 for internal calculations.")
    161 
    162         if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8,
    163             tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32):
    164             raise ValueError("Unsupported dtype.")
    165         self._dtype = dtype
    166 
    167         if k>8448:
    168             raise ValueError("Unsupported code length (k too large).")
    169         if k<12:
    170             raise ValueError("Unsupported code length (k too small).")
    171 
    172         if n>(316*384):
    173             raise ValueError("Unsupported code length (n too large).")
    174         if n<0:
    175             raise ValueError("Unsupported code length (n negative).")
    176 
    177         # init encoder parameters
    178         self._k = k # number of input bits (= input shape)
    179         self._n = n # the desired length (= output shape)
    180         self._coderate = k / n
    181         self._check_input = True # check input for consistency (i.e., binary)
    182 
    183         # allow actual code rates slightly larger than 948/1024
    184         # to account for the quantization procedure in 38.214 5.1.3.1
    185         if self._coderate>(948/1024): # as specified in 38.212 5.4.2.1
    186             print(f"Warning: effective coderate r>948/1024 for n={n}, k={k}.")
    187         if self._coderate>(0.95): # as specified in 38.212 5.4.2.1
    188             raise ValueError(f"Unsupported coderate (r>0.95) for n={n}, k={k}.")
    189         if self._coderate<(1/5):
    190             # outer rep. coding currently not supported
    191             raise ValueError("Unsupported coderate (r<1/5).")
    192 
    193         # construct the basegraph according to 38.212
    194         self._bg = self._sel_basegraph(self._k, self._coderate)
    195         self._z, self._i_ls, self._k_b = self._sel_lifting(self._k, self._bg)
    196         self._bm = self._load_basegraph(self._i_ls, self._bg)
    197 
    198         # total number of codeword bits
    199         self._n_ldpc = self._bm.shape[1] * self._z
    200         # if K_real < K _target puncturing must be applied earlier
    201         self._k_ldpc = self._k_b * self._z
    202 
    203         # construct explicit graph via lifting
    204         pcm = self._lift_basegraph(self._bm, self._z)
    205 
    206         pcm_a, pcm_b_inv, pcm_c1, pcm_c2 = self._gen_submat(self._bm,
    207                                                             self._k_b,
    208                                                             self._z,
    209                                                             self._bg)
    210 
    211         # init sub-matrices for fast encoding ("RU"-method)
    212         # note: dtype is tf.float32;
    213         self._pcm = pcm # store the sparse parity-check matrix (for decoding)
    214 
    215         # store indices for fast gathering (instead of explicit matmul)
    216         self._pcm_a_ind = self._mat_to_ind(pcm_a)
    217         self._pcm_b_inv_ind = self._mat_to_ind(pcm_b_inv)
    218         self._pcm_c1_ind = self._mat_to_ind(pcm_c1)
    219         self._pcm_c2_ind = self._mat_to_ind(pcm_c2)
    220 
    221         self._num_bits_per_symbol = num_bits_per_symbol
    222         if num_bits_per_symbol is not None:
    223             self._out_int, self._out_int_inv  = self.generate_out_int(self._n,
    224                                                     self._num_bits_per_symbol)
    225 
    226     #########################################
    227     # Public methods and properties
    228     #########################################
    229 
    230     @property
    231     def k(self):
    232         """Number of input information bits."""
    233         return self._k
    234 
    235     @property
    236     def n(self):
    237         "Number of output codeword bits."
    238         return self._n
    239 
    240     @property
    241     def coderate(self):
    242         """Coderate of the LDPC code after rate-matching."""
    243         return self._coderate
    244 
    245     @property
    246     def k_ldpc(self):
    247         """Number of LDPC information bits after rate-matching."""
    248         return self._k_ldpc
    249 
    250     @property
    251     def n_ldpc(self):
    252         """Number of LDPC codeword bits before rate-matching."""
    253         return self._n_ldpc
    254 
    255     @property
    256     def pcm(self):
    257         """Parity-check matrix for given code parameters."""
    258         return self._pcm
    259 
    260     @property
    261     def z(self):
    262         """Lifting factor of the basegraph."""
    263         return self._z
    264 
    265     @property
    266     def num_bits_per_symbol(self):
    267         """Modulation order used for the rate-matching output interleaver."""
    268         return self._num_bits_per_symbol
    269 
    270     @property
    271     def out_int(self):
    272         """Output interleaver sequence as defined in 5.4.2.2."""
    273         return self._out_int
    274     @property
    275     def out_int_inv(self):
    276         """Inverse output interleaver sequence as defined in 5.4.2.2."""
    277         return self._out_int_inv
    278 
    279     #########################
    280     # Utility methods
    281     #########################
    282 
    283     def generate_out_int(self, n, num_bits_per_symbol):
    284         """"Generates LDPC output interleaver sequence as defined in
    285         Sec 5.4.2.2 in [3GPPTS38212_LDPC]_.
    286 
    287         Parameters
    288         ----------
    289         n: int
    290             Desired output sequence length.
    291 
    292         num_bits_per_symbol: int
    293             Number of symbols per QAM symbol, i.e., the modulation order.
    294 
    295         Output
    296         ------
    297         (perm_seq, perm_seq_inv):
    298             Tuple:
    299 
    300         perm_seq: ndarray of length n
    301             Containing the permuted indices.
    302 
    303         perm_seq_inv: ndarray of length n
    304             Containing the inverse permuted indices.
    305 
    306         Note
    307         ----
    308         The interleaver pattern depends on the modulation order and helps to
    309         reduce dependencies in bit-interleaved coded modulation (BICM) schemes.
    310         """
    311         # allow float inputs, but verify that they represent integer
    312         assert(n%1==0), "n must be int."
    313         assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int."
    314         n = int(n)
    315         assert(n>0), "n must be a positive integer."
    316         assert(num_bits_per_symbol>0), \
    317                     "num_bits_per_symbol must be a positive integer."
    318         num_bits_per_symbol = int(num_bits_per_symbol)
    319 
    320         assert(n%num_bits_per_symbol==0),\
    321             "n must be a multiple of num_bits_per_symbol."
    322 
    323         # pattern as defined in Sec 5.4.2.2
    324         perm_seq = np.zeros(n, dtype=int)
    325         for j in range(int(n/num_bits_per_symbol)):
    326             for i in range(num_bits_per_symbol):
    327                 perm_seq[i + j*num_bits_per_symbol] \
    328                     = int(i * int(n/num_bits_per_symbol) + j)
    329 
    330         perm_seq_inv = np.argsort(perm_seq)
    331 
    332         return perm_seq, perm_seq_inv
    333 
    334     def _sel_basegraph(self, k, r):
    335         """Select basegraph according to [3GPPTS38212_LDPC]_."""
    336 
    337         if k <= 292:
    338             bg = "bg2"
    339         elif k <= 3824 and r <= 0.67:
    340             bg = "bg2"
    341         elif r <= 0.25:
    342             bg = "bg2"
    343         else:
    344             bg = "bg1"
    345 
    346         # add for consistency
    347         if bg=="bg1" and k>8448:
    348             raise ValueError("K is not supported by BG1 (too large).")
    349 
    350         if bg=="bg2" and k>3840:
    351             raise ValueError(
    352                 f"K is not supported by BG2 (too large) k ={k}.")
    353 
    354         if bg=="bg1" and r<1/3:
    355             raise ValueError("Only coderate>1/3 supported for BG1. \
    356             Remark: Repetition coding is currently not supported.")
    357 
    358         if bg=="bg2" and r<1/5:
    359             raise ValueError("Only coderate>1/5 supported for BG2. \
    360             Remark: Repetition coding is currently not supported.")
    361 
    362         return bg
    363 
    364     def _load_basegraph(self, i_ls, bg):
    365         """Helper to load basegraph from csv files.
    366 
    367         ``i_ls`` is sub_index of the basegraph and fixed during lifting
    368         selection.
    369         """
    370 
    371         if i_ls > 7:
    372             raise ValueError("i_ls too large.")
    373 
    374         if i_ls < 0:
    375             raise ValueError("i_ls cannot be negative.")
    376 
    377         # csv files are taken from 38.212 and dimension is explicitly given
    378         if bg=="bg1":
    379             bm = np.zeros([46, 68]) - 1 # init matrix with -1 (None positions)
    380         elif bg=="bg2":
    381             bm = np.zeros([42, 52]) - 1 # init matrix with -1 (None positions)
    382         else:
    383             raise ValueError("Basegraph not supported.")
    384 
    385         # and load the basegraph from csv format in folder "codes"
    386         source = files(codes).joinpath(f"5G_{bg}.csv")
    387         with as_file(source) as codes.csv:
    388             bg_csv = np.genfromtxt(codes.csv, delimiter=";")
    389 
    390         # reconstruct BG for given i_ls
    391         r_ind = 0
    392         for r in np.arange(2, bg_csv.shape[0]):
    393             # check for next row index
    394             if not np.isnan(bg_csv[r, 0]):
    395                 r_ind = int(bg_csv[r, 0])
    396             c_ind = int(bg_csv[r, 1]) # second column in csv is column index
    397             value = bg_csv[r, i_ls + 2] # i_ls entries start at offset 2
    398             bm[r_ind, c_ind] = value
    399 
    400         return bm
    401 
    402     def _lift_basegraph(self, bm, z):
    403         """Lift basegraph with lifting factor ``z`` and shifted identities as
    404         defined by the entries of ``bm``."""
    405 
    406         num_nonzero = np.sum(bm>=0) # num of non-neg elements in bm
    407 
    408         # init all non-zero row/column indices
    409         r_idx = np.zeros(z*num_nonzero)
    410         c_idx = np.zeros(z*num_nonzero)
    411         data = np.ones(z*num_nonzero)
    412 
    413         # row/column indices of identity matrix for lifting
    414         im = np.arange(z)
    415 
    416         idx = 0
    417         for r in range(bm.shape[0]):
    418             for c in range(bm.shape[1]):
    419                 if bm[r,c]==-1: # -1 is used as all-zero matrix placeholder
    420                     pass #do nothing (sparse)
    421                 else:
    422                     # roll matrix by bm[r,c]
    423                     c_roll = np.mod(im+bm[r,c], z)
    424                     # append rolled identity matrix to pcm
    425                     r_idx[idx*z:(idx+1)*z] = r*z + im
    426                     c_idx[idx*z:(idx+1)*z] = c*z + c_roll
    427                     idx += 1
    428 
    429         # generate lifted sparse matrix from indices
    430         pcm = sp.sparse.csr_matrix((data,(r_idx, c_idx)),
    431                                    shape=(z*bm.shape[0], z*bm.shape[1]))
    432         return pcm
    433 
    434     def _sel_lifting(self, k, bg):
    435         """Select lifting as defined in Sec. 5.2.2 in [3GPPTS38212_LDPC]_.
    436 
    437         We assume B < K_cb, thus B'= B and C = 1, i.e., no
    438         additional CRC is appended. Thus, K' = B'/C = B and B is our K.
    439 
    440         Z is the lifting factor.
    441         i_ls is the set index ranging from 0...7 (specifying the exact bg
    442         selection).
    443         k_b is the number of information bit columns in the basegraph.
    444         """
    445         # lifting set according to 38.212 Tab 5.3.2-1
    446         s_val = [[2, 4, 8, 16, 32, 64, 128, 256],
    447                 [3, 6, 12, 24, 48, 96, 192, 384],
    448                 [5, 10, 20, 40, 80, 160, 320],
    449                 [7, 14, 28, 56, 112, 224],
    450                 [9, 18, 36, 72, 144, 288],
    451                 [11, 22, 44, 88, 176, 352],
    452                 [13, 26, 52, 104, 208],
    453                 [15, 30, 60, 120, 240]]
    454 
    455         if bg == "bg1":
    456             k_b = 22
    457         else:
    458             if k > 640:
    459                 k_b = 10
    460             elif k > 560:
    461                 k_b = 9
    462             elif k > 192:
    463                 k_b = 8
    464             else:
    465                 k_b = 6
    466 
    467         # find the min of Z from Tab. 5.3.2-1 s.t. k_b*Z>=K'
    468         min_val = 100000
    469         z = 0
    470         i_ls = 0
    471         i = -1
    472         for s in s_val:
    473             i += 1
    474             for s1 in s:
    475                 x = k_b *s1
    476                 if  x >= k:
    477                     # valid solution
    478                     if x < min_val:
    479                         min_val = x
    480                         z = s1
    481                         i_ls = i
    482 
    483         # and set K=22*Z for bg1 and K=10Z for bg2
    484         if bg == "bg1":
    485             k_b = 22
    486         else:
    487             k_b = 10
    488 
    489         return z, i_ls, k_b
    490 
    491     def _gen_submat(self, bm, k_b, z, bg):
    492         """Split the basegraph into multiple sub-matrices such that efficient
    493         encoding is possible.
    494         """
    495         g = 4 # code property (always fixed for 5G)
    496         mb = bm.shape[0] # number of CN rows in basegraph (BG property)
    497 
    498         bm_a = bm[0:g, 0:k_b]
    499         bm_b = bm[0:g, k_b:(k_b+g)]
    500         bm_c1 = bm[g:mb, 0:k_b]
    501         bm_c2 = bm[g:mb, k_b:(k_b+g)]
    502 
    503         # H could be sliced immediately (but easier to implement if based on B)
    504         hm_a = self._lift_basegraph(bm_a, z)
    505 
    506         # not required for encoding, but helpful for debugging
    507         #hm_b = self._lift_basegraph(bm_b, z)
    508 
    509         hm_c1 = self._lift_basegraph(bm_c1, z)
    510         hm_c2 = self._lift_basegraph(bm_c2, z)
    511 
    512         hm_b_inv = self._find_hm_b_inv(bm_b, z, bg)
    513 
    514         return hm_a, hm_b_inv, hm_c1, hm_c2
    515 
    516     def _find_hm_b_inv(self, bm_b, z, bg):
    517         """ For encoding we need to find the inverse of `hm_b` such that
    518         `hm_b^-1 * hm_b = I`.
    519 
    520         Could be done sparse
    521         For BG1 the structure of hm_b is given as (for all values of i_ls)
    522         hm_b =
    523         [P_A I 0 0
    524          P_B I I 0
    525          0 0 I I
    526          P_A 0 0 I]
    527         where P_B and P_A are Shifted identities.
    528 
    529         The inverse can be found by solving a linear system of equations
    530         hm_b_inv =
    531         [P_B^-1, P_B^-1, P_B^-1, P_B^-1,
    532          I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
    533          P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1, I+P_A*P_B^-1,
    534          P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1].
    535 
    536 
    537         For bg2 the structure of hm_b is given as (for all values of i_ls)
    538         hm_b =
    539         [P_A I 0 0
    540          0 I I 0
    541          P_B 0 I I
    542          P_A 0 0 I]
    543         where P_B and P_A are Shifted identities
    544 
    545         The inverse can be found by solving a linear system of equations
    546         hm_b_inv =
    547         [P_B^-1, P_B^-1, P_B^-1, P_B^-1,
    548          I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
    549          I+P_A*P_B^-1, I+P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
    550          P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1]
    551 
    552         Note: the inverse of B is simply a shifted identity matrix with
    553         negative shift direction.
    554         """
    555 
    556         # permutation indices
    557         pm_a= int(bm_b[0,0])
    558         if bg=="bg1":
    559             pm_b_inv = int(-bm_b[1, 0])
    560         else: # structure of B is slightly different for bg2
    561             pm_b_inv = int(-bm_b[2, 0])
    562 
    563         hm_b_inv = np.zeros([4*z, 4*z])
    564 
    565         im = np.eye(z)
    566 
    567         am = np.roll(im, pm_a, axis=1)
    568         b_inv = np.roll(im, pm_b_inv, axis=1)
    569         ab_inv = np.matmul(am, b_inv)
    570 
    571         # row 0
    572         hm_b_inv[0:z, 0:z] = b_inv
    573         hm_b_inv[0:z, z:2*z] = b_inv
    574         hm_b_inv[0:z, 2*z:3*z] = b_inv
    575         hm_b_inv[0:z, 3*z:4*z] = b_inv
    576 
    577         # row 1
    578         hm_b_inv[z:2*z, 0:z] = im + ab_inv
    579         hm_b_inv[z:2*z, z:2*z] = ab_inv
    580         hm_b_inv[z:2*z, 2*z:3*z] = ab_inv
    581         hm_b_inv[z:2*z, 3*z:4*z] = ab_inv
    582 
    583         # row 2
    584         if bg=="bg1":
    585             hm_b_inv[2*z:3*z, 0:z] = ab_inv
    586             hm_b_inv[2*z:3*z, z:2*z] = ab_inv
    587             hm_b_inv[2*z:3*z, 2*z:3*z] = im + ab_inv
    588             hm_b_inv[2*z:3*z, 3*z:4*z] = im + ab_inv
    589         else: # for bg2 the structure is slightly different
    590             hm_b_inv[2*z:3*z, 0:z] = im + ab_inv
    591             hm_b_inv[2*z:3*z, z:2*z] = im + ab_inv
    592             hm_b_inv[2*z:3*z, 2*z:3*z] = ab_inv
    593             hm_b_inv[2*z:3*z, 3*z:4*z] = ab_inv
    594 
    595         # row 3
    596         hm_b_inv[3*z:4*z, 0:z] = ab_inv
    597         hm_b_inv[3*z:4*z, z:2*z] = ab_inv
    598         hm_b_inv[3*z:4*z, 2*z:3*z] = ab_inv
    599         hm_b_inv[3*z:4*z, 3*z:4*z] = im + ab_inv
    600 
    601         # return results as sparse matrix
    602         return sp.sparse.csr_matrix(hm_b_inv)
    603 
    604     def _mat_to_ind(self, mat):
    605         """Helper to transform matrix into index representation for
    606         tf.gather. An index pointing to the `last_ind+1` is used for non-existing edges due to irregular degrees."""
    607         m = mat.shape[0]
    608         n = mat.shape[1]
    609 
    610         # transpose mat for sorted column format
    611         c_idx, r_idx, _ = sp.sparse.find(mat.transpose())
    612 
    613         # sort indices explicitly, as scipy.sparse.find changed from column to
    614         # row sorting in scipy>=1.11
    615         idx = np.argsort(r_idx)
    616         c_idx = c_idx[idx]
    617         r_idx = r_idx[idx]
    618 
    619         # find max number of no-zero entries
    620         n_max = np.max(mat.getnnz(axis=1))
    621 
    622         # init index array with n (pointer to last_ind+1, will be a default
    623         # value)
    624         gat_idx = np.zeros([m, n_max]) + n
    625 
    626         r_val = -1
    627         c_val = 0
    628         for idx in range(len(c_idx)):
    629             # check if same row or if a new row starts
    630             if r_idx[idx] != r_val:
    631                 r_val = r_idx[idx]
    632                 c_val = 0
    633             gat_idx[r_val, c_val] = c_idx[idx]
    634             c_val += 1
    635 
    636         gat_idx = tf.cast(tf.constant(gat_idx), tf.int32)
    637         return gat_idx
    638 
    639     def _matmul_gather(self, mat, vec):
    640         """Implements a fast sparse matmul via gather function."""
    641 
    642         # add 0 entry for gather-reduce_sum operation
    643         # (otherwise ragged Tensors are required)
    644         bs = tf.shape(vec)[0]
    645         vec = tf.concat([vec, tf.zeros([bs, 1], dtype=self.dtype)], 1)
    646 
    647         retval = tf.gather(vec, mat, batch_dims=0, axis=1)
    648         retval = tf.reduce_sum(retval, axis=-1)
    649 
    650         return retval
    651 
    652     def _encode_fast(self, s):
    653         """Main encoding function based on gathering function."""
    654         p_a = self._matmul_gather(self._pcm_a_ind, s)
    655         p_a = self._matmul_gather(self._pcm_b_inv_ind, p_a)
    656 
    657         # calc second part of parity bits p_b
    658         # second parities are given by C_1*s' + C_2*p_a' + p_b' = 0
    659         p_b_1 = self._matmul_gather(self._pcm_c1_ind, s)
    660         p_b_2 = self._matmul_gather(self._pcm_c2_ind, p_a)
    661         p_b = p_b_1 + p_b_2
    662 
    663         c = tf.concat([s, p_a, p_b], 1)
    664 
    665         # faster implementation of mod-2 operation c = tf.math.mod(c, 2)
    666         c_uint8 = tf.cast(c, tf.uint8)
    667         c_bin = tf.bitwise.bitwise_and(c_uint8, tf.constant(1, tf.uint8))
    668         c = tf.cast(c_bin, self.dtype)
    669 
    670         c = tf.expand_dims(c, axis=-1) # returns nx1 vector
    671         return c
    672 
    673     #########################
    674     # Keras layer functions
    675     #########################
    676 
    677     def build(self, input_shape):
    678         """"Build layer."""
    679         # check if k and input shape match
    680         assert (input_shape[-1]==self._k), "Last dimension must be of length k."
    681         assert (len(input_shape)>=2), "Rank of input must be at least 2."
    682 
    683     def call(self, inputs):
    684         """5G LDPC encoding function including rate-matching.
    685 
    686         This function returns the encoded codewords as specified by the 3GPP NR Initiative [3GPPTS38212_LDPC]_ including puncturing and shortening.
    687 
    688         Args:
    689             inputs (tf.float32): Tensor of shape `[...,k]` containing the
    690                 information bits to be encoded.
    691 
    692         Returns:
    693             `tf.float32`: Tensor of shape `[...,n]`.
    694 
    695         Raises:
    696             ValueError: If ``inputs`` contains other values than `0` or `1`.
    697 
    698             InvalidArgumentError: When rank(``inputs``)<2.
    699 
    700             InvalidArgumentError: When shape of last dim is not ``k``.
    701         """
    702 
    703         tf.debugging.assert_type(inputs, self.dtype, "Invalid input dtype.")
    704 
    705         # Reshape inputs to [...,k]
    706         input_shape = inputs.get_shape().as_list()
    707         new_shape = [-1, input_shape[-1]]
    708         u = tf.reshape(inputs, new_shape)
    709 
    710         # assert if u is non binary
    711         if self._check_input:
    712             tf.debugging.assert_equal(
    713                 tf.reduce_min(
    714                     tf.cast(
    715                         tf.logical_or(
    716                             tf.equal(u, tf.constant(0, self.dtype)),
    717                             tf.equal(u, tf.constant(1, self.dtype)),
    718                             ),
    719                         self.dtype)),
    720                 tf.constant(1, self.dtype),
    721                 "Input must be binary.")
    722             # input datatype consistency should be only evaluated once
    723             self._check_input = False
    724 
    725         batch_size = tf.shape(u)[0]
    726 
    727         # add "filler" bits to last positions to match info bit length k_ldpc
    728         u_fill = tf.concat([u,
    729                     tf.zeros([batch_size, self._k_ldpc-self._k], self.dtype)],
    730                             1)
    731 
    732         # use optimized encoding based on tf.gather
    733         c = self._encode_fast(u_fill)
    734 
    735         c = tf.reshape(c, [batch_size, self._n_ldpc]) # remove last dim
    736 
    737         # remove filler bits at pos (k, k_ldpc)
    738         c_no_filler1 = tf.slice(c, [0, 0], [batch_size, self._k])
    739         c_no_filler2 = tf.slice(c,
    740                                [0, self._k_ldpc],
    741                                [batch_size, self._n_ldpc-self._k_ldpc])
    742 
    743         c_no_filler = tf.concat([c_no_filler1, c_no_filler2], 1)
    744 
    745         # shorten the first 2*Z positions and end after n bits
    746         # (remaining parity bits can be used for IR-HARQ)
    747         c_short = tf.slice(c_no_filler, [0, 2*self._z], [batch_size, self.n])
    748         # incremental redundancy could be generated by accessing the last bits
    749 
    750         # if num_bits_per_symbol is provided, apply output interleaver as
    751         # specified in Sec. 5.4.2.2 in 38.212
    752         if self._num_bits_per_symbol is not None:
    753             c_short = tf.gather(c_short, self._out_int, axis=-1)
    754 
    755         # Reshape c_short so that it matches the original input dimensions
    756         output_shape = input_shape[0:-1] + [self.n]
    757         output_shape[0] = -1
    758         c_reshaped = tf.reshape(c_short, output_shape)
    759 
    760         return tf.cast(c_reshaped, self._dtype)
    761 
    762 
    763 ###########################################################
    764 # Deprecated aliases that will not be included in the next
    765 # major release
    766 ###########################################################
    767 
    768 def AllZeroEncoder(k,
    769                    n,
    770                    dtype=tf.float32,
    771                    **kwargs):
    772     print("Warning: The alias fec.ldpc.AllZeroEncoder will not be included in "\
    773           "Sionna 1.0. Please use sionna.fec.linear.AllZeroEncoder instead.")
    774     return AllZeroEncoder_new(k=k,
    775                               n=n,
    776                               dtype=dtype,
    777                               **kwargs)