anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

encoding.py (29176B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for Polar encoding including 5G compliant rate-matching and CRC
      6 concatenation."""
      7 
      8 from sionna.fec.crc import CRCEncoder
      9 from sionna.fec.polar.utils import generate_5g_ranking
     10 from numpy.core.numerictypes import issubdtype
     11 import tensorflow as tf
     12 import numpy as np
     13 from tensorflow.keras.layers import Layer
     14 import numbers
     15 
     16 class PolarEncoder(Layer):
     17     """PolarEncoder(frozen_pos, n, dtype=tf.float32)
     18 
     19     Polar encoder for given code parameters.
     20 
     21     This layer performs polar encoding for the given ``k`` information bits and
     22     the `frozen set` (i.e., indices of frozen positions) specified by
     23     ``frozen_pos``.
     24 
     25     The class inherits from the Keras layer class and can be used as layer in a
     26     Keras model.
     27 
     28     Parameters
     29     ----------
     30         frozen_pos: ndarray
     31             Array of `int` defining the `n-k` frozen indices, i.e., information
     32             bits are mapped onto the `k` complementary positions.
     33 
     34         n: int
     35             Defining the codeword length.
     36 
     37         dtype: tf.DType
     38             Defaults to `tf.float32`. Defines the output datatype of the layer
     39             (internal precision is `tf.uint8`).
     40 
     41     Input
     42     -----
     43         inputs: [...,k], tf.float32
     44             2+D tensor containing the information bits to be encoded.
     45 
     46     Output
     47     ------
     48         : [...,n], tf.float32
     49             2+D tensor containing the codeword bits.
     50 
     51     Raises
     52     ------
     53         AssertionError
     54             ``k`` and ``n`` must be positive integers and ``k`` must be smaller
     55             (or equal) than ``n``.
     56 
     57         AssertionError
     58             If ``n`` is not a power of 2.
     59 
     60         AssertionError
     61             If the number of elements in ``frozen_pos`` is great than ``n``.
     62 
     63         AssertionError
     64             If ``frozen_pos`` does not consists of `int`.
     65 
     66         ValueError
     67             If ``dtype`` is not supported.
     68 
     69         ValueError
     70             If ``inputs`` contains other values than `0` or `1`.
     71 
     72         TypeError
     73             If ``inputs`` is not `tf.float32`.
     74 
     75         InvalidArgumentError
     76             When rank(``inputs``)<2.
     77 
     78         InvalidArgumentError
     79             When shape of last dim is not ``k``.
     80 
     81     Note
     82     ----
     83         As commonly done, we assume frozen bits are set to `0`. Please note
     84         that - although its practical relevance is only little - setting frozen
     85         bits to `1` may result in `affine` codes instead of linear code as the
     86         `all-zero` codeword is not necessarily part of the code any more.
     87     """
     88 
     89     def __init__(self,
     90                  frozen_pos,
     91                  n,
     92                  dtype=tf.float32):
     93 
     94         if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8,
     95             tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32):
     96             raise ValueError("Unsupported dtype.")
     97 
     98         super().__init__(dtype=dtype)
     99 
    100         assert isinstance(n, numbers.Number), "n must be a number."
    101         n = int(n) # n can be float (e.g. as result of n=k*r)
    102         assert issubdtype(frozen_pos.dtype, int), "frozen_pos must \
    103                                                    consist of ints."
    104         assert len(frozen_pos)<=n, "Number of elements in frozen_pos cannot \
    105                                    be greater than n."
    106 
    107         assert np.log2(n)==int(np.log2(n)), "n must be a power of 2."
    108 
    109         self._k = n - len(frozen_pos)
    110         self._n = n
    111         self._frozen_pos = frozen_pos
    112 
    113         # generate info positions
    114         self._info_pos = np.setdiff1d(np.arange(self._n), frozen_pos)
    115         assert self._k==len(self._info_pos), "Internal error: invalid " \
    116                                               "info_pos generated."
    117 
    118         self._check_input = True # check input for bin. values during first call
    119 
    120         self._nb_stages = int(np.log2(self._n))
    121         self._ind_gather = self._gen_indices(self._n)
    122 
    123     #########################################
    124     # Public methods and properties
    125     #########################################
    126 
    127     @property
    128     def k(self):
    129         """Number of information bits."""
    130         return self._k
    131 
    132     @property
    133     def n(self):
    134         """Codeword length."""
    135         return self._n
    136 
    137     @property
    138     def frozen_pos(self):
    139         """Frozen positions for Polar decoding."""
    140         return self._frozen_pos
    141 
    142     @property
    143     def info_pos(self):
    144         """Information bit positions for Polar encoding."""
    145         return self._info_pos
    146 
    147     #########################
    148     # Utility methods
    149     #########################
    150 
    151     def _gen_indices(self, n):
    152         """Pre-calculate encoding indices stage-wise for tf.gather.
    153         """
    154 
    155         nb_stages = int(np.log2(n))
    156         # last position denotes empty placeholder (points to element n+1)
    157         ind_gather = np.ones([nb_stages, n+1]) * n
    158 
    159         for s in range(nb_stages):
    160             ind_range = np.arange(int(n/2))
    161             ind_dest = ind_range * 2 - np.mod(ind_range, 2**(s))
    162             ind_origin = ind_dest + 2**s
    163             ind_gather[s, ind_dest] = ind_origin # and update gather indices
    164 
    165         ind_gather = tf.constant(ind_gather, dtype=tf.int32)
    166 
    167         return ind_gather
    168 
    169     #########################
    170     # Keras layer functions
    171     #########################
    172 
    173     def build(self, input_shape):
    174         """build and check if ``k`` and ``input_shape`` match."""
    175         assert (input_shape[-1]==self._k), "Invalid input shape."
    176 
    177     def call(self, inputs):
    178         """Polar encoding function.
    179 
    180         This function returns the polar encoded codewords for the given
    181         information bits ``inputs``.
    182 
    183         Args:
    184             inputs (tf.float32): Tensor of shape `[...,k]` containing the
    185             information bits to be encoded.
    186 
    187         Returns:
    188             `tf.float32`: Tensor of shape `[...,n]`.
    189 
    190         Raises:
    191             ValueError: If ``inputs`` contains other values than `0` or `1`.
    192 
    193             TypeError: If ``inputs`` is not `tf.float32`.
    194 
    195             InvalidArgumentError: When rank(``inputs``)<2.
    196 
    197             InvalidArgumentError: When shape of last dim is not ``k``.
    198         """
    199 
    200         tf.debugging.assert_type(inputs, self.dtype,
    201                                  "Invalid input dtype.")
    202 
    203         # Reshape inputs to [...,k]
    204         tf.debugging.assert_greater(tf.rank(inputs), 1)
    205         input_shape = inputs.shape
    206         new_shape = [-1, input_shape[-1]]
    207         u = tf.reshape(inputs, new_shape)
    208 
    209         # last dim must be of length k
    210         tf.debugging.assert_equal(tf.shape(u)[-1],
    211                                   self._k,
    212                                   "Last dimension must be of length k.")
    213 
    214         # assert if binary=True and u is non binary
    215         if self._check_input:
    216             u_test = tf.cast(u, tf.float32) # only for internal check
    217             tf.debugging.assert_equal(tf.reduce_min(
    218                                         tf.cast(
    219                                             tf.logical_or(
    220                                                 tf.equal(u_test, 0.),
    221                                                 tf.equal(u_test, 1.)),
    222                                         tf.float32)),
    223                                       1.,
    224                                       "Input must be binary.")
    225             # input datatype consistency should be only evaluated once
    226             self._check_input = False
    227 
    228         # copy info bits to information set; other positions are frozen (=0)
    229 
    230         # return an all-zero tensor of shape [n,...]
    231         c = tf.zeros([self._n, tf.shape(u)[0]], self.dtype)
    232 
    233         # u has shape bs x k, we now want k x bs
    234         u_transpose = tf.transpose(u, (1,0)) # batch dim to last pos
    235 
    236         # index vector has at least two axis (= index_depth)
    237         info_pos_tf = tf.expand_dims(self.info_pos, axis=1)
    238 
    239         c = tf.tensor_scatter_nd_update(c, info_pos_tf, u_transpose)
    240         c = tf.transpose(c, (1,0))
    241         x_nan = tf.zeros([tf.shape(c)[0] ,1], self.dtype)
    242         x = tf.concat([c, x_nan], 1)
    243         x = tf.cast(x, tf.uint8)
    244 
    245         # loop over all stages
    246         for s in range(self._nb_stages):
    247             ind_helper = self._ind_gather[s,:]
    248             x_add = tf.gather(x, ind_helper, batch_dims=0, axis=1)
    249             #x = tf.math.logical_xor(x, x_add) # does not work well with XLA
    250             x = tf.bitwise.bitwise_xor(x, x_add)
    251 
    252         # remove last position
    253         c_out = x[:,0:self._n]
    254 
    255         # restore original shape
    256         input_shape_list = input_shape.as_list()
    257         output_shape = input_shape_list[0:-1] + [self._n]
    258         output_shape[0] = -1 # to support dynamic shapes
    259         c_reshaped = tf.reshape(c_out, output_shape)
    260 
    261         # cast to dtype for compatibility with other components
    262         return tf.cast(c_reshaped, self.dtype)
    263 
    264 class Polar5GEncoder(PolarEncoder):
    265     # pylint: disable=line-too-long
    266     """Polar5GEncoder(k, n, verbose=False, channel_type="uplink", dtype=tf.float32)
    267 
    268     5G compliant Polar encoder including rate-matching following [3GPPTS38212]_
    269     for the uplink scenario (`UCI`) and downlink scenario (`DCI`).
    270 
    271     This layer performs polar encoding for ``k`` information bits and
    272     rate-matching such that the codeword lengths is ``n``. This includes the CRC
    273     concatenation and the interleaving as defined in [3GPPTS38212]_.
    274 
    275     Note: `block segmentation` is currently not supported (`I_seq=False`).
    276 
    277     We follow the basic structure from Fig. 6 in [Bioglio_Design]_.
    278 
    279     ..  figure:: ../figures/PolarEncoding5G.png
    280 
    281         Fig. 1: Implemented 5G Polar encoding chain following Fig. 6 in
    282         [Bioglio_Design]_ for the uplink (`I_BIL` = `True`) and the downlink
    283         (`I_IL` = `True`) scenario without `block segmentation`.
    284 
    285     For further details, we refer to [3GPPTS38212]_, [Bioglio_Design]_ and
    286     [Hui_ChannelCoding]_.
    287 
    288     The class inherits from the Keras layer class and can be used as layer in a
    289     Keras model. Further, the class inherits from PolarEncoder.
    290 
    291     Parameters
    292     ----------
    293         k: int
    294             Defining the number of information bit per codeword.
    295 
    296         n: int
    297             Defining the codeword length.
    298 
    299         channel_type: str
    300             Defaults to "uplink". Can be "uplink" or "downlink".
    301 
    302         verbose: bool
    303             Defaults to False. If True, rate-matching parameters will be
    304             printed.
    305 
    306         dtype: tf.DType
    307             Defaults to tf.float32. Defines the output datatype of the layer
    308             (internal precision remains tf.uint8).
    309 
    310     Input
    311     -----
    312         inputs: [...,k], tf.float32
    313             2+D tensor containing the information bits to be encoded.
    314 
    315     Output
    316     ------
    317         : [...,n], tf.float32
    318             2+D tensor containing the codeword bits.
    319 
    320     Raises
    321     ------
    322         AssertionError
    323             ``k`` and ``n`` must be positive integers and ``k`` must be smaller
    324             (or equal) than ``n``.
    325 
    326         AssertionError
    327             If ``n`` and ``k`` are invalid code parameters (see [3GPPTS38212]_).
    328 
    329         AssertionError
    330             If ``verbose`` is not `bool`.
    331 
    332         ValueError
    333             If ``dtype`` is not supported.
    334 
    335     Note
    336     ----
    337         The encoder supports the `uplink` Polar coding (`UCI`) scheme from
    338         [3GPPTS38212]_ and the `downlink` Polar coding (`DCI`) [3GPPTS38212]_,
    339         respectively.
    340 
    341         For `12 <= k <= 19` the 3 additional parity bits as defined in
    342         [3GPPTS38212]_ are not implemented as it would also require a
    343         modified decoding procedure to materialize the potential gains.
    344 
    345         `Code segmentation` is currently not supported and, thus, ``n`` is
    346         limited to a maximum length of 1088 codeword bits.
    347 
    348         For the downlink scenario, the input length is limited to `k <= 140`
    349         information bits due to the limited input bit interleaver size
    350         [3GPPTS38212]_.
    351 
    352         For simplicity, the implementation does not exactly re-implement the
    353         `DCI` scheme from [3GPPTS38212]_. This implementation neglects the
    354         `all-one` initialization of the CRC shift register and the scrambling of the CRC parity bits with the `RNTI`.
    355     """
    356 
    357     def __init__(self,
    358                  k,
    359                  n,
    360                  channel_type="uplink",
    361                  verbose=False,
    362                  dtype=tf.float32,):
    363 
    364         if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8,
    365             tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32):
    366             raise ValueError("Unsupported dtype.")
    367 
    368         assert isinstance(k, numbers.Number), "k must be a number."
    369         assert isinstance(n, numbers.Number), "n must be a number."
    370         k = int(k) # k or n can be float (e.g. as result of n=k*r)
    371         n = int(n) # k or n can be float (e.g. as result of n=k*r)
    372         assert n>=k, "Invalid coderate (>1)."
    373         assert isinstance(verbose, bool), "verbose must be bool."
    374 
    375         assert channel_type in ("uplink","downlink"), \
    376                                          "Unsupported channel_type."
    377         self._channel_type = channel_type
    378 
    379         self._k_target = k
    380         self._n_target = n
    381         self._verbose = verbose
    382 
    383          # Initialize rate-matcher
    384         crc_degree, n_polar, frozen_pos, idx_rm, idx_input  = \
    385             self._init_rate_match(k, n)
    386 
    387         self._frozen_pos = frozen_pos # Required for decoder
    388         self._ind_rate_matching = idx_rm # Index for gather-based rate-matching
    389         self._ind_input_int = idx_input # Index for input interleaver
    390 
    391         # Initialize CRC encoder
    392         self._enc_crc = CRCEncoder(crc_degree, dtype=dtype)
    393 
    394         # Init super-class (PolarEncoder)
    395         super().__init__(frozen_pos, n_polar, dtype=dtype)
    396 
    397     #########################################
    398     # Public methods and properties
    399     #########################################
    400 
    401     @property
    402     def enc_crc(self):
    403         """CRC encoder layer used for CRC concatenation."""
    404         return self._enc_crc
    405 
    406     @property
    407     def k_target(self):
    408         """Number of information bits including rate-matching."""
    409         return self._k_target
    410 
    411     @property
    412     def n_target(self):
    413         """Codeword length including rate-matching."""
    414         return self._n_target
    415 
    416     @property
    417     def k_polar(self):
    418         """Number of information bits of the underlying Polar code."""
    419         return self._k
    420 
    421     @property
    422     def n_polar(self):
    423         """Codeword length of the underlying Polar code."""
    424         return self._n
    425 
    426     @property
    427     def k(self):
    428         """Number of information bits including rate-matching."""
    429         return self._k_target
    430 
    431     @property
    432     def n(self):
    433         """Codeword length including rate-matching."""
    434         return self._n_target
    435 
    436     def subblock_interleaving(self, u):
    437         """Input bit interleaving as defined in Sec 5.4.1.1 [3GPPTS38212]_.
    438 
    439         Input
    440         -----
    441             u: ndarray
    442                 1D array to be interleaved. Length of ``u`` must be a multiple
    443                 of 32.
    444 
    445         Output
    446         ------
    447             : ndarray
    448                 Interleaved version of ``u`` with same shape and dtype as ``u``.
    449 
    450         Raises
    451         ------
    452             AssertionError
    453                 If length of ``u`` is not a multiple of 32.
    454 
    455         """
    456 
    457         k = u.shape[-1]
    458         assert np.mod(k,32)==0, \
    459             "length for sub-block interleaving must be a multiple of 32."
    460         y = np.zeros_like(u)
    461 
    462         # Permutation according to Tab 5.4.1.1.1-1 in 38.212
    463         perm = np.array([0, 1, 2, 4, 3, 5, 6, 7, 8, 16, 9, 17, 10, 18, 11, 19,
    464                          12, 20, 13, 21, 14, 22, 15, 23, 24, 25, 26, 28, 27,
    465                          29, 30, 31])
    466 
    467         for n in range(k):
    468             i = int(np.floor(32*n/k))
    469             j = perm[i] * k/32 + np.mod(n, k/32)
    470             j = int(j)
    471             y[n] = u[j]
    472 
    473         return y
    474 
    475     def channel_interleaver(self, c):
    476         """Triangular interleaver following Sec. 5.4.1.3 in [3GPPTS38212]_.
    477 
    478         Input
    479         -----
    480             c: ndarray
    481                 1D array to be interleaved.
    482 
    483         Output
    484         ------
    485             : ndarray
    486                 Interleaved version of ``c`` with same shape and dtype as ``c``.
    487 
    488         """
    489 
    490         n = c.shape[-1] # Denoted as E in 38.212
    491         c_int = np.zeros_like(c)
    492 
    493         # Find smallest T s.t. T*(T+1)/2 >= n
    494         t = 0
    495         while t*(t+1)/2 < n:
    496             t +=1
    497 
    498         v = np.zeros([t, t])
    499         ind_k = 0
    500         for ind_i in range(t):
    501             for ind_j in range(t-ind_i):
    502                 if ind_k < n:
    503                     v[ind_i, ind_j] = c[ind_k]
    504                 else:
    505                     v[ind_i, ind_j] = np.nan # NULL
    506                 # Store nothing otherwise
    507                 ind_k += 1
    508         ind_k = 0
    509         for ind_j in range(t):
    510             for ind_i in range(t-ind_j):
    511                 if not np.isnan(v[ind_i, ind_j]):
    512                     c_int[ind_k] = v[ind_i, ind_j]
    513                     ind_k += 1
    514         return c_int
    515 
    516     def input_interleaver(self, c):
    517         """Input interleaver following Sec. 5.4.1.1 in [3GPPTS38212]_.
    518 
    519         Input
    520         -----
    521             c: ndarray
    522                 1D array to be interleaved.
    523 
    524         Output
    525         ------
    526             : ndarray
    527                 Interleaved version of ``c`` with same shape and dtype as ``c``.
    528 
    529         """
    530         # 38.212 Table 5.3.1.1-1
    531         p_il_max_table = [0, 2, 4, 7, 9, 14, 19, 20, 24, 25, 26, 28, 31, 34,
    532             42, 45, 49, 50, 51, 53, 54, 56, 58, 59, 61, 62, 65, 66, 67, 69,
    533             70, 71, 72, 76, 77, 81, 82, 83, 87, 88, 89, 91, 93, 95, 98, 101,
    534             104, 106, 108, 110, 111, 113, 115, 118, 119, 120, 122, 123, 126,
    535             127, 129, 132, 134, 138, 139, 140, 1, 3, 5, 8, 10, 15, 21, 27, 29,
    536             32, 35, 43, 46, 52, 55, 57, 60, 63, 68, 73, 78, 84, 90, 92, 94, 96,
    537             99, 102, 105, 107, 109, 112, 114, 116, 121, 124, 128, 130, 133,
    538             135, 141, 6, 11, 16, 22, 30, 33, 36, 44, 47, 64, 74, 79, 85, 97,
    539             100, 103, 117, 125, 131, 136, 142, 12, 17, 23, 37, 48, 75, 80, 86,
    540             137, 143, 13, 18, 38, 144, 39, 145, 40, 146, 41, 147, 148, 149,
    541             150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162,
    542             163]
    543         k_il_max = 164
    544         k = len(c)
    545         assert k<=k_il_max, "Input interleaver only defined for length of 164."
    546         c_apo = np.empty(k, 'int')
    547         i = 0
    548         for p_il_max in p_il_max_table:
    549             if p_il_max >= (k_il_max - k):
    550                 c_apo[i] = c[p_il_max - (k_il_max - k)]
    551                 i += 1
    552         return c_apo
    553 
    554     #########################
    555     # Utility methods
    556     #########################
    557 
    558     def _init_rate_match(self, k_target, n_target):
    559         """Implementing polar rate matching according to [3GPPTS38212]_.
    560 
    561         Please note that this part of the code only runs during the
    562         initialization and, thus, is not performance critical. For easier
    563         alignment and traceability with the standard document [3GPPTS38212]_
    564         the implementation prefers `for loop`-based indexing.
    565 
    566         The relation of terminology between [3GPPTS38212]_ and this code is
    567         given as:
    568         `A`...`k_target`
    569         `E`...`n_target`
    570         `K`...`k_polar`
    571         `N`...`n_polar`
    572         `L`...`k_crc`.
    573         """
    574 
    575         # Check input for consistency (see Sec. 6.3.1.2.1 for UL)
    576 
    577         # currently not relevant (segmentation not supported)
    578         # assert k_target<=1706, "Maximum supported codeword length for" \
    579         # "Polar  coding is 1706."
    580 
    581         assert n_target >= k_target, "n must be larger or equal k."
    582         assert n_target >= 18, \
    583                         "n<18 is not supported by the 5G Polar coding scheme."
    584         assert k_target <= 1013, \
    585             "k too large - no codeword segmentation supported at the moment."
    586         assert n_target <= 1088, \
    587             "n too large - no codeword segmentation supported at the moment."
    588 
    589         # Select CRC polynomials (see Sec. 6.3.1.2.1 for UL)
    590         if self._channel_type=="uplink":
    591 
    592             if 12<=k_target<=19:
    593                 crc_pol = "CRC6"
    594                 k_crc = 6
    595             elif k_target >=20:
    596                 crc_pol = "CRC11"
    597                 k_crc = 11
    598             else:
    599                 raise ValueError("k_target<12 is not supported in 5G NR for " \
    600                     "the uplink; please use 'channel coding of small block  " \
    601                     "lengths' scheme from Sec. 5.3.3 in 3GPP 38.212 instead.")
    602 
    603             # PC bit for k_target = 12-19 bits (see Sec. 6.3.1.3.1 for UL)
    604             n_pc = 0
    605             #n_pc_wm = 0
    606             if k_target<=19:
    607                 #n_pc = 3
    608                 n_pc = 0 # Currently deactivated
    609                 print("Warning: For 12<=k<=19 additional 3 parity-check bits " \
    610                     "are defined in 38.212. They are currently not " \
    611                     "implemented by this encoder and, thus, ignored.")
    612                 if n_target-k_target>175:
    613                     #n_pc_wm = 1 # not implemented
    614                     pass
    615 
    616         else: # downlink channel
    617             # for downlink CRC24 is used
    618             # remark: in PDCCH messages are limited to k=140
    619             # as the input interleaver does not support longer sequences
    620             assert k_target <= 140, \
    621                "k too large for downlink channel configuration."
    622             assert n_target >= 25, \
    623                 "n too small for downlink channel configuration with 24 bit " \
    624                 "CRC."
    625             assert n_target <= 576, \
    626             "n too large for downlink channel configuration."
    627             crc_pol = "CRC24C" # following 7.3.2
    628             k_crc = 24
    629             n_pc = 0
    630 
    631         # No input interleaving for uplink needed
    632 
    633         # Calculate Polar payload length (CRC bits are treated as info bits)
    634         k_polar = k_target + k_crc + n_pc
    635 
    636         assert k_polar <= n_target, "Device is not expected to be configured " \
    637                                     "with k_polar + k_crc + n_pc > n_target."
    638 
    639         # Select polar mother code length n_polar
    640         n_min = 5
    641         n_max = 10 # For uplink; otherwise 9
    642 
    643         # Select rate-matching scheme following Sec. 5.3.1
    644         if (n_target <= ((9/8) * 2**(np.ceil(np.log2(n_target))-1)) and
    645             k_polar/n_target < 9/16):
    646             n1 = np.ceil(np.log2(n_target))-1
    647         else:
    648             n1 = np.ceil(np.log2(n_target))
    649         n2 = np.ceil(np.log2(8*k_polar)) #Lower bound such that rate > 1/8
    650         n_polar = int(2**np.max((np.min([n1, n2, n_max]), n_min)))
    651 
    652         # Puncturing and shortening as defined in Sec. 5.4.1.1
    653         prefrozen_pos = [] # List containing the pre-frozen indices
    654         if n_target < n_polar:
    655             if k_polar/n_target <= 7/16:
    656                 # Puncturing
    657                 if self._verbose:
    658                     print("Using puncturing for rate-matching.")
    659                 n_int =  32 * np.ceil((n_polar-n_target) / 32)
    660                 int_pattern = self.subblock_interleaving(np.arange(n_int))
    661                 for i in range(n_polar-n_target):
    662                     # Freeze additional bits
    663                     prefrozen_pos.append(int(int_pattern[i]))
    664                 if n_target >= 3*n_polar/4:
    665                     t = int(np.ceil(3/4*n_polar - n_target/2) - 1)
    666                 else:
    667                     t = int(np.ceil(9/16*n_polar - n_target/4) - 1)
    668                 # Extra freezing
    669                 for i in range(t):
    670                     prefrozen_pos.append(i)
    671             else:
    672                 # Shortening ("through" sub-block interleaver)
    673                 if self._verbose:
    674                     print("Using shortening for rate-matching.")
    675                 n_int =  32 * np.ceil((n_polar) / 32)
    676                 int_pattern = self.subblock_interleaving(np.arange(n_int))
    677                 for i in range(n_target, n_polar):
    678                     prefrozen_pos.append(int_pattern[i])
    679 
    680         # Remove duplicates
    681         prefrozen_pos = np.unique(prefrozen_pos)
    682 
    683         # Find the remaining n_polar - k_polar - |frozen_set|
    684 
    685         # Load full channel ranking
    686         ch_ranking, _ = generate_5g_ranking(0, n_polar, sort=False)
    687 
    688         # Remove positions that are already frozen by `pre-freezing` stage
    689         info_cand = np.setdiff1d(ch_ranking, prefrozen_pos, assume_unique=True)
    690 
    691         # Identify k_polar most reliable positions from candidate positions
    692         info_pos = []
    693         for i in range(k_polar):
    694             info_pos.append(info_cand[-i-1])
    695 
    696         # Sort and create frozen positions for n_polar indices (no shortening)
    697         info_pos = np.sort(info_pos).astype(int)
    698         frozen_pos = np.setdiff1d(np.arange(n_polar),
    699                                   info_pos,
    700                                   assume_unique=True)
    701 
    702         # For downlink only: generate input bit interleaver
    703         if self._channel_type=="downlink":
    704             if self._verbose:
    705                 print("Using input bit interleaver for downlink.")
    706             ind_input_int = self.input_interleaver(np.arange(k_polar))
    707         else:
    708             ind_input_int = None
    709 
    710         # Generate tf.gather indices for sub-block interleaver
    711         ind_sub_int = self.subblock_interleaving(np.arange(n_polar))
    712 
    713         # Rate matching via circular buffer as defined in Sec. 5.4.1.2
    714         c_int = np.arange(n_polar)
    715         idx_c_matched = np.zeros([n_target])
    716         if n_target >= n_polar:
    717             # Repetition coding
    718             if self._verbose:
    719                 print("Using repetition coding for rate-matching")
    720             for ind in range(n_target):
    721                 idx_c_matched[ind] = c_int[np.mod(ind, n_polar)]
    722         else:
    723             if k_polar/n_target <= 7/16:
    724                 # Puncturing
    725                 for ind in range(n_target):
    726                     idx_c_matched[ind] = c_int[ind+n_polar-n_target]
    727             else:
    728                 # Shortening
    729                 for ind in range(n_target):
    730                     idx_c_matched[ind] = c_int[ind]
    731 
    732         # For uplink only: generate input bit interleaver
    733         if self._channel_type=="uplink":
    734             if self._verbose:
    735                 print("Using channel interleaver for uplink.")
    736             ind_channel_int = self.channel_interleaver(np.arange(n_target))
    737 
    738             # Combine indices for single tf.gather operation
    739             ind_t = idx_c_matched[ind_channel_int].astype(int)
    740             idx_rate_matched = ind_sub_int[ind_t]
    741         else: # no channel interleaver for downlink
    742             idx_rate_matched = ind_sub_int[idx_c_matched.astype(int)]
    743 
    744         if self._verbose:
    745             print("Code parameters after rate-matching: " \
    746                   f"k = {k_target}, n = {n_target}")
    747             print(f"Polar mother code: k_polar = {k_polar}, " \
    748                   f"n_polar = {n_polar}")
    749             print("Using", crc_pol)
    750             print("Frozen positions: ", frozen_pos)
    751             print("Channel type: " + self._channel_type)
    752 
    753         return crc_pol, n_polar, frozen_pos, idx_rate_matched, ind_input_int
    754 
    755     #########################
    756     # Keras layer functions
    757     #########################
    758 
    759     def build(self, input_shape):
    760         """Build and check if ``k`` and ``input_shape`` match."""
    761         assert (input_shape[-1]==self._k_target), "Invalid input shape."
    762 
    763     def call(self, inputs):
    764         """Polar encoding function including rate-matching and CRC encoding.
    765 
    766         This function returns the polar encoded codewords for the given
    767         information bits ``inputs`` following [3GPPTS38212]_ including
    768         rate-matching.
    769 
    770         Args:
    771             inputs (tf.float32): Tensor of shape `[...,k]` containing the
    772             information bits to be encoded.
    773 
    774         Returns:
    775             `tf.float32`: Tensor of shape `[...,n]`.
    776 
    777         Raises:
    778             TypeError: If ``inputs`` is not `tf.float32`.
    779 
    780             InvalidArgumentError: When rank(``inputs``)<2.
    781 
    782             InvalidArgumentError: When shape of last dim is not ``k``.
    783         """
    784 
    785         # Reshape inputs to [...,k]
    786         tf.debugging.assert_greater(tf.rank(inputs), 1)
    787         input_shape = inputs.shape
    788         new_shape = [-1, input_shape[-1]]
    789         u = tf.reshape(inputs, new_shape)
    790 
    791         # Consistency check (i.e., binary) of inputs will be done in super_class
    792 
    793         # CRC encode
    794         u_crc = self._enc_crc(u)
    795 
    796         # For downlink only: apply input bit interleaver
    797         if self._channel_type=="downlink":
    798             u_crc = tf.gather(u_crc, self._ind_input_int, axis=-1)
    799 
    800         # Encode bits (= channel allocation + Polar transform)
    801         c = super().call(u_crc)
    802 
    803         # Sub-block interleaving with 32 sub-blocks as in Sec. 5.4.1.1
    804         # Rate matching via circular buffer as defined in Sec. 5.4.1.2
    805         # For uplink only: channel interleaving (i_bil=True)
    806         c_matched = tf.gather(c, self._ind_rate_matching, axis=1)
    807 
    808         # Restore original shape
    809         input_shape_list = input_shape.as_list()
    810         output_shape = input_shape_list[0:-1] + [self._n_target]
    811         output_shape[0] = -1 # To support dynamic shapes
    812         c_reshaped = tf.reshape(c_matched, output_shape)
    813 
    814         return c_reshaped