anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

interleaving.py (33144B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for interleaving and utility functions"""
      6 
      7 import numpy as np
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 from importlib_resources import files, as_file
     11 from sionna import config
     12 from sionna.fec.turbo import coeffs
     13 
     14 class RowColumnInterleaver(Layer):
     15      # pylint: disable=line-too-long
     16     r"""RowColumnInterleaver(row_depth, axis=-1, inverse=False, dtype=tf.float32, **kwargs)
     17 
     18     Interleaves a sequence of inputs via row/column swapping.
     19 
     20     The class inherits from the Keras layer class and can be used as layer in a
     21     Keras model.
     22 
     23     Parameters
     24     ----------
     25         row_depth: int
     26             The row depth, i.e., how many values per row can be stored.
     27 
     28         axis: int
     29             The dimension that should be interleaved. First dimension
     30             (`axis=0`) is not allowed.
     31 
     32         inverse: bool
     33             A boolean defaults to False. If True, the inverse permutation is
     34             performed.
     35 
     36         dtype: tf.DType
     37             Defaults to `tf.float32`. Defines the datatype for internal
     38             calculations and the output dtype.
     39 
     40     Input
     41     -----
     42         inputs: tf.DType
     43             2+D tensor of arbitrary shape and arbitrary dtype. Must have at
     44             least rank two.
     45 
     46     Output
     47     ------
     48          : tf.DType
     49             2+D tensor of same shape and dtype as ``inputs``.
     50 
     51     Raises
     52     ------
     53          AssertionError
     54             If ``axis`` is not an integer.
     55 
     56          AssertionError
     57             If ``row_depth`` is not an integer.
     58 
     59          AssertionError
     60             If ``axis`` > number of input dimensions.
     61 
     62     Note
     63     ----
     64         If the sequence length is not a multiple of ``row_depth``, additional
     65         filler bits are used for the last row that will be removed internally.
     66         However, for the last positions the interleaving distance may be
     67         slightly degraded.
     68 
     69         To permute the batch dimension, expand_dims at `axis=0`, interleave and
     70         remove new dimension.
     71     """
     72 
     73     def __init__(self,
     74                  row_depth,
     75                  axis=-1,
     76                  inverse=False,
     77                  dtype=tf.float32,
     78                  **kwargs):
     79 
     80         super().__init__(dtype=dtype, **kwargs)
     81 
     82         # store perm_seq
     83         self._perm_seq = None # initalized during build
     84         self._perm_seq_inv = None # initalized during build
     85 
     86         assert isinstance(axis, int), "axis must be int."
     87         self._axis = axis
     88 
     89         assert isinstance(row_depth, int), "row_depth must be int."
     90         self._row_depth = row_depth
     91 
     92         assert isinstance(inverse, bool), "inverse must be bool."
     93         self._inverse = inverse
     94 
     95         # cannot be changed, only required for associated interleaver
     96         self._keep_state = True
     97 
     98     #########################################
     99     # Public methods and properties
    100     #########################################
    101 
    102     @property
    103     def axis(self):
    104         """Axis to be permuted."""
    105         return self._axis
    106 
    107     @property
    108     def row_depth(self):
    109         """Row depth of the row-column interleaver."""
    110         return self._row_depth
    111 
    112     @property
    113     def perm_seq(self):
    114         """Permutation sequence."""
    115         return self._perm_seq
    116 
    117     @property
    118     def perm_seq_inv(self):
    119         """Inverse permutation sequence."""
    120         return self._perm_seq_inv
    121 
    122     @property
    123     def keep_state(self):
    124         """Row-column interleaver always uses same internal state."""
    125         return True
    126 
    127     def call_inverse(self, inputs):
    128         """Implements deinterleaver function corresponding to call().
    129 
    130         Input
    131         -----
    132             inputs: tf.DType
    133                 2+D tensor of arbitrary shape and arbitrary dtype. Must have at
    134                 least rank two.
    135 
    136         Output
    137         ------
    138             : tf.DType
    139                 2+D tensor of same shape and dtype as ``inputs``.
    140         """
    141         input_shape = inputs.shape
    142 
    143         x = tf.gather(inputs, self._perm_seq_inv, axis=self._axis)
    144 
    145         x = tf.ensure_shape(x, input_shape)
    146         return x
    147 
    148     #########################
    149     # Utility methods
    150     #########################
    151 
    152     def _generate_perm_rc(self, n_seq, r_depth):
    153         """Generates a row/column permutation to initialize an rc-interleaver.
    154 
    155         If required last positions use "filler" positions.
    156 
    157         Args:
    158             N_seq (int): An integer defining the sequence length to interleave.
    159 
    160             r_depth (int): An integer defining the depth of the interleaver.
    161         """
    162 
    163         # round to next multiple of r_depth
    164         n = tf.cast((tf.math.ceil(n_seq/r_depth)*r_depth), tf.int32)
    165         nb_rows = tf.cast(n/r_depth, tf.int64)
    166 
    167         ind = tf.range(n, dtype=tf.int32)
    168 
    169         # rearange in row/colum format
    170         ind_rc = tf.reshape(ind, [nb_rows,-1])
    171 
    172         # and interleave via row/column swapping
    173         ind_cr = tf.transpose(ind_rc, (1,0))
    174 
    175         # read out indices in column/row ordering
    176         perm_seq_filler= tf.reshape(ind_cr, [-1])
    177 
    178         # remove filler positions
    179         mask = tf.math.less(perm_seq_filler, n_seq)
    180         perm_seq = tf.boolean_mask(perm_seq_filler, mask)
    181         perm_seq_inv= tf.argsort(perm_seq)
    182         return perm_seq, perm_seq_inv
    183 
    184     #########################
    185     # Keras layer functions
    186     #########################
    187 
    188     def build(self, input_shape):
    189         assert self._axis < len(input_shape), "Axis does match input shape"
    190         # init rand sequences during build
    191         assert input_shape[self._axis] is not None, "Unknown shape at req. dim"
    192         p, pi = self._generate_perm_rc(input_shape[self._axis], self._row_depth)
    193         self._perm_seq = p
    194         self._perm_seq_inv = pi
    195 
    196     def call(self, inputs):
    197         """interleaving function
    198 
    199         This function returns the permuted version of inputs.
    200 
    201         Args:
    202             inputs (tf.float32): Tensor of arbitrary shape. Must have at least
    203                 rank two.
    204 
    205         Returns:
    206             `tf.float32`: Tensor of same shape as the input.
    207 
    208         """
    209 
    210         input_shape = inputs.shape
    211 
    212         # re-init if shape has changed, update perm_seq
    213         if inputs.shape[self._axis] != self._perm_seq.shape[0]:
    214             self.build(inputs.shape)
    215 
    216         if self._inverse:
    217             x = tf.gather(inputs, self._perm_seq_inv, axis=self._axis)
    218         else:
    219             x = tf.gather(inputs, self._perm_seq, axis=self._axis)
    220 
    221         x = tf.ensure_shape(x, input_shape)
    222         return x
    223 
    224 
    225 class RandomInterleaver(Layer):
    226     # pylint: disable=line-too-long
    227     """RandomInterleaver(seed=None, keep_batch_constant=True, inverse=False, keep_state=True, axis=-1, dtype=tf.float32, **kwargs)
    228 
    229     Random interleaver permuting a sequence of input symbols.
    230 
    231     The class inherits from the Keras layer class and can be used as layer in a
    232     Keras model.
    233 
    234     Parameters
    235     ----------
    236         seed: int
    237             Integer defining the random seed used if option ``keep_state`` is
    238             True.
    239 
    240         keep_batch_constant: bool
    241             Defaults to True. If set to True each sample in the batch uses the
    242             same permutation. Otherwise, unique permutations per batch sample
    243             are generate (slower).
    244 
    245         inverse: bool
    246             A boolean defaults to False. If True, the inverse permutation is
    247             performed.
    248 
    249         keep_state: bool
    250             A boolean defaults to True. If True, the permutation is fixed for
    251             multiple calls (defined by ``seed`` attribute).
    252 
    253         axis: int
    254             Defaults to `-1`. The dimension that should be interleaved.
    255             First dimension (`axis=0`) is not allowed.
    256 
    257         dtype: tf.DType
    258             Defaults to `tf.float32`. Defines the datatype for internal
    259             calculations and the output dtype.
    260 
    261     Input
    262     -----
    263         (x, seed):
    264             Either Tuple ``(x, seed)`` or ``x`` only (no tuple) if the internal
    265             seed should be used:
    266 
    267         x: tf.DType
    268             2+D tensor of arbitrary shape and dtype.
    269         seed: int
    270             An integer defining the state of the random number
    271             generator. If explicitly given, the global internal seed is
    272             replaced by this seed. Can be used to realize random
    273             interleaver/deinterleaver pairs (call with same random seed).
    274 
    275     Output
    276     ------
    277         : tf.DType
    278             2+D tensor of same shape and dtype as the input ``x``.
    279 
    280     Raises
    281     ------
    282         AssertionError
    283             If ``axis`` is not `int`.
    284 
    285         AssertionError
    286             If ``seed`` is not `None` or `int`.
    287 
    288         AssertionError
    289             If ``axis`` > number of input dimensions.
    290 
    291         AssertionError
    292             If ``inverse`` is not bool.
    293 
    294         AssertionError
    295             If ``keep_state`` is not bool.
    296 
    297         AssertionError
    298             If ``keep_batch_constant`` is not bool.
    299 
    300         InvalidArgumentError
    301             When rank(``x``)<2.
    302 
    303     Note
    304     ----
    305         To permute the batch dimension, expand_dims at ``axis=0``, interleave
    306         and remove new dimension.
    307 
    308         The interleaver layer is stateless, i.e., the seed is either random
    309         during each call or must be explicitly provided during init/call.
    310         This simplifies XLA/graph execution.
    311 
    312         This is NOT the 5G interleaver sequence.
    313     """
    314 
    315     def __init__(self,
    316                 seed=None,
    317                 keep_batch_constant=True,
    318                 inverse=False,
    319                 keep_state=True,
    320                 axis=-1,
    321                 dtype=tf.float32,
    322                 **kwargs):
    323 
    324         super().__init__(dtype=dtype, **kwargs)
    325 
    326         # verify and store attributes
    327         assert isinstance(keep_batch_constant, bool), \
    328             "keep_batch_constant must be bool."
    329         self._keep_batch_constant = keep_batch_constant
    330 
    331         assert isinstance(axis, int), "axis must be int."
    332         assert axis!=0, "Cannot permute batch_dim."
    333         self._axis=axis
    334 
    335         # a global seed is stored and used if called with keep_state=True
    336         if seed is not None:
    337             assert isinstance(seed, int), "seed must be int."
    338         else:
    339             # generate random seed if no value is provided
    340             seed = int(np.random.uniform(0, 2**31-1))
    341 
    342         # if keep_state==True this seed is used to generate scrambling sequences
    343         self._seed = (1337, seed)
    344 
    345         assert isinstance(inverse, bool), "inverse must be boolean"
    346         self._inverse = inverse
    347         assert isinstance(keep_state, bool), "keep_state must be boolean"
    348         self._keep_state = keep_state
    349 
    350         if self._keep_state is False and self._inverse is True:
    351             print("Note: keep_state=False and, thus, a new realization of " \
    352                 "the interleaver is generated during each call. Thus, " \
    353                 "the inverse interleaver does not correspond to a previous " \
    354                 "interleaver call.")
    355 
    356     #########################################
    357     # Public methods and properties
    358     #########################################
    359 
    360     @property
    361     def seed(self):
    362         """Seed to generate random sequence."""
    363         return self._seed[1] # only return the non-fixed seed
    364 
    365     @property
    366     def axis(self):
    367         """Axis to be permuted."""
    368         return self._axis
    369 
    370     @property
    371     def keep_state(self):
    372         """Generate new random seed per call."""
    373         return self._keep_state
    374 
    375 
    376     def find_s_min(self, seed, seq_length, s_min_stop=0):
    377         r"""Find :math:`S` parameter such that :math:`\pi(i)-\pi(j)>S` for all
    378         :math:`i-j<S`. This can be used to find optimized interleaver patterns.
    379 
    380         ``s_min_stop`` is an additional stopping condition, i.e., stop if
    381         current :math:`S` is already smaller than ``s_min_stop``.
    382 
    383         Please note that this is a Numpy utility function and usually not part
    384         of the graph.
    385 
    386         Input
    387         -----
    388             seed: int
    389                 seed to draw random permutation that shall be analyzed.
    390 
    391             seq_length: int
    392                 length of permutation sequence to be analyzed.
    393 
    394             s_min_stop: int
    395                 Defaults to 0. Enables early stop if already current s_min< ``s_min_stop`` .
    396         Output
    397         ------
    398             : float
    399                 The S-parameter for the given ``seed``.
    400         """
    401 
    402         assert isinstance(seed, int), "seed must be int."
    403         assert isinstance(seq_length, int), "seq_length must be int."
    404         assert isinstance(s_min_stop, int), "s_min_stop must be int."
    405 
    406         seed = (1337, seed)
    407         perm_seq = self._generate_perm_full(seed, seq_length, batch_size=1)
    408         perm_seq = tf.squeeze(perm_seq, axis=0).numpy()
    409         s_min = seq_length
    410         for i in range(len(perm_seq)): # search for all positions in perm_seq
    411             for j in range(-s_min,s_min,1): # search dist
    412                 if j==0: # ignore identity
    413                     continue
    414                 if i+j>=0 and i+j<seq_length:
    415                     d = np.abs(perm_seq[i] - perm_seq[i+j])
    416                     if d<=np.abs(j):
    417                         s_min = np.min([s_min, np.abs(j)])
    418                     if d<s_min and np.abs(j)<s_min:
    419                         s_min = np.min([s_min, d])
    420             # early stop
    421             if s_min<=s_min_stop:
    422                 break
    423         return int(s_min)
    424 
    425     def call_inverse(self, inputs):
    426         """Implements deinterleaver function corresponding to call().
    427 
    428         Input
    429         -----
    430             (x, seed):
    431                 Either Tuple ``(x, seed)`` or ``x`` only (no tuple) if the internal
    432                 seed should be used:
    433 
    434             x: tf.DType
    435                 2+D tensor of arbitrary shape and dtype.
    436             seed: int
    437                 An integer defining the state of the random number
    438                 generator. If explicitly given, the global internal seed is
    439                 replaced by this seed. Can be used to realize random
    440                 interleaver/deinterleaver pairs (call with same random seed).
    441 
    442         Output
    443         ------
    444             : tf.DType
    445                 2+D tensor of same shape and dtype as the input ``x``.
    446 
    447         Raises
    448         ------
    449             InvalidArgumentError
    450                 When rank(``x``)<2.
    451 
    452             ValueError
    453                 If ``keep_state`` is False and no explicit seed is provided.
    454 
    455         Note
    456         ----
    457             In case of inverse interleaving (e.g., at the receiver),
    458             ``keep_state`` should be True as otherwise a new permutation is
    459             generated and the output is not equal to the original sequence.
    460             Alternatively, an explicit seed must be provided as function
    461             argument.
    462         """
    463 
    464         if isinstance(inputs, (tuple, list)):
    465             if len(inputs)==1: # if user wants to call with call([x])
    466                 seed = None
    467                 x = inputs
    468             elif len(inputs)==2:
    469                 x, seed = inputs
    470             else:
    471                 raise TypeError("inputs cannot have more than 2 entries.")
    472         else:
    473             seed = None
    474             x = inputs
    475 
    476         input_shape = x.shape
    477         tf.debugging.assert_greater(tf.rank(x), 1)
    478 
    479         # use seed if explicit seed is provided
    480         if seed is not None:
    481             seed = (tf.constant(1337), tf.cast(seed, tf.int32))
    482         elif self._keep_state:
    483             # use sequence as defined by seed
    484             seed = self._seed
    485         else:
    486             # This mode is not supported for
    487             raise ValueError("Deinterleaving not possible for random " \
    488                 "seeds per call (keep_state=False) without explicitly " \
    489                 "providing the seed as inputs.")
    490         # select if each sample in batch needs own perm (computational complex!)
    491         if self._keep_batch_constant:
    492             batch_size = 1
    493         else:
    494             batch_size = tf.shape(x)[0]
    495 
    496         perm_seq = self._generate_perm_full(seed,
    497                                             tf.shape(x)[self._axis],
    498                                             batch_size,
    499                                             inverse=True) # activate inverse
    500 
    501         if self._keep_batch_constant:
    502             # broadcast single sequence over complete batch
    503             perm_seq = tf.squeeze(perm_seq, axis=0) # remove batch_dim
    504             x = tf.gather(x, perm_seq, batch_dims=0, axis=self._axis)
    505         else:
    506             x = tf.gather(x, perm_seq, batch_dims=1, axis=self._axis)
    507 
    508         # set explicitly for keras models
    509         x = tf.ensure_shape(x, input_shape)
    510         return x
    511 
    512     #########################
    513     # Utility methods
    514     #########################
    515 
    516     def _generate_perm_full(self, seed, seq_length, batch_size, inverse=False):
    517         """Generates a random permutation for the interleaver.
    518 
    519         Args:
    520             seed (int): A shape [2] Tensor, the seed to the random number
    521                 generator.
    522 
    523             seq_length (int): The length of the sequence to be permuted.
    524 
    525             batch_size (int): The batch size (=number of independent
    526                 permutations).
    527 
    528             inverse (bool): Defaults to False. If True, the inverse permutation
    529                 for the given seed is generated.
    530         """
    531         rand_seq = tf.random.stateless_uniform([batch_size, seq_length],
    532                                                 seed,
    533                                                 minval=0,
    534                                                 maxval=1,
    535                                                 dtype=tf.float32)
    536 
    537         perm_seq =  tf.argsort(rand_seq, axis=-1)
    538 
    539         if inverse:
    540             # cast to tf.float32 due to improved performance
    541             perm_seq = tf.cast(perm_seq, tf.float32)
    542             perm_seq = tf.argsort(perm_seq, axis=-1)
    543 
    544         return perm_seq
    545 
    546     #########################
    547     # Keras layer functions
    548     #########################
    549 
    550     def build(self, input_shape):
    551         """Build Keras layer and check consistency of dimensions."""
    552         if isinstance(input_shape, list):
    553             input_shape=input_shape[0]
    554 
    555         assert self._axis < len(input_shape), "Axis does not match input shape."
    556         assert len(input_shape) > 1, "At least two dims are required."
    557 
    558     def call(self, inputs):
    559         """Interleaving function.
    560 
    561         This function returns the permuted version of ``inputs``.
    562 
    563         Args:
    564             inputs (List): ``[x, seed]``, where
    565             ``x`` (tf.float32): Tensor of arbitrary shape. Must have at
    566                 least rank two.
    567             ``seed`` (int): An integer defining the state of the random number
    568                 generator. If explicitly given, the global internal seed is
    569                 replaced by this seed. Can be used the realize random
    570                 interleaver/deinterleaver pairs (call with same random seed).
    571 
    572 
    573         Returns:
    574             `tf.float32`: Tensor of same shape as the input.
    575 
    576         Raises:
    577             InvalidArgumentError
    578                 When rank(``x``)<2.
    579 
    580             AssertionError
    581                 If ``seed`` is not None or int.
    582 
    583         Note:
    584             In case of inverse interleaving (e.g., at the receiver),
    585             ``keep_state`` should be True as otherwise a new permutation is
    586             generated and the output is not equal to the original sequence.
    587             Alternatively, an explicit seed must be provided as function
    588             argument.
    589         """
    590 
    591         if isinstance(inputs, (tuple, list)):
    592             if len(inputs)==1: # if user wants to call with call([x])
    593                 seed = None
    594                 x = inputs
    595             elif len(inputs)==2:
    596                 x, seed = inputs
    597             else:
    598                 raise TypeError("inputs cannot have more than 2 entries.")
    599         else:
    600             seed = None
    601             x = inputs
    602 
    603         input_shape = x.shape
    604         tf.debugging.assert_greater(tf.rank(x), 1)
    605 
    606         # use seed if explicit seed is provided
    607         if seed is not None:
    608             seed = (tf.constant(1337), tf.cast(seed, tf.int32))
    609         # only generate a new random sequence if keep_state==False
    610         elif self._keep_state:
    611             # use sequence as defined by seed
    612             seed = self._seed
    613         else:
    614             # generate new seed for each call
    615             # Note: not necessarily random if XLA is active
    616             seed = config.tf_rng.uniform([2],
    617                                          minval=0,
    618                                          maxval=2**31-1,
    619                                          dtype=tf.int32)
    620         # select if each sample in batch needs own perm (computational complex!)
    621         if self._keep_batch_constant:
    622             batch_size = 1
    623         else:
    624             batch_size = tf.shape(x)[0]
    625 
    626         perm_seq = self._generate_perm_full(seed,
    627                                             tf.shape(x)[self._axis],
    628                                             batch_size,
    629                                             self._inverse)
    630 
    631         if self._keep_batch_constant:
    632             # broadcast single sequence over complete batch
    633             perm_seq = tf.squeeze(perm_seq, axis=0) # remove batch_dim
    634             x = tf.gather(x, perm_seq, batch_dims=0, axis=self._axis)
    635         else:
    636             x = tf.gather(x, perm_seq, batch_dims=1, axis=self._axis)
    637 
    638         # set explicitly for keras models
    639         x = tf.ensure_shape(x, input_shape)
    640         return x
    641 
    642 
    643 class Deinterleaver(Layer):
    644     """Deinterleaver(interleaver, dtype=None, **kwargs)
    645 
    646     Deinterleaver that reverts the interleaver for a given input sequence.
    647 
    648     The class inherits from the Keras layer class and can be used as layer in a
    649     Keras model.
    650 
    651     Parameters
    652     ----------
    653         interleaver: Interleaver
    654             Associated Interleaver which shall be deinterleaved by this layer.
    655             Can be either
    656             :class:`~sionna.fec.interleaving.RandomInterleaver` or
    657             :class:`~sionna.fec.interleaving.RowColumnInterleaver`.
    658 
    659         dtype: None or tf.DType
    660             Defaults to `None`. Defines the datatype for internal calculations
    661             and the output dtype. If no explicit dtype is provided the dtype
    662             from the associated interleaver is used.
    663 
    664     Input
    665     -----
    666         (x, seed):
    667             Either Tuple ``(x, seed)`` or ``x`` only (no tuple) if the internal
    668             seed should be used:
    669 
    670         x: tf.DType
    671             2+D tensor of arbitrary shape.
    672         seed: int
    673             An integer defining the state of the random number
    674             generator. If explicitly given, the global internal seed is
    675             replaced by this seed. Can be used to realize random
    676             interleaver/deinterleaver pairs (call with same random seed).
    677 
    678     Output
    679     ------
    680         : tf.DType
    681             2+D tensor of same shape and dtype as the input ``x``.
    682 
    683     Raises
    684     ------
    685         AssertionError
    686             If ``interleaver`` is not a valid instance of Interleaver.
    687 
    688     Note
    689     ----
    690         This layer provides a wrapper of the inverse interleaver function.
    691     """
    692 
    693     def __init__(self,
    694                  interleaver,
    695                  dtype=None,
    696                  **kwargs):
    697 
    698         if not isinstance(interleaver,
    699                           (RandomInterleaver,
    700                           RowColumnInterleaver,
    701                           Turbo3GPPInterleaver)):
    702             raise ValueError("interleaver is not a valid interleaver instance.")
    703         self._interleaver = interleaver
    704 
    705         # if dtype is None, use same dtype as associated interleaver
    706         if dtype is None:
    707             dtype = self._interleaver.dtype
    708 
    709         super().__init__(dtype=dtype, **kwargs)
    710 
    711         if self._interleaver._keep_state is False:
    712             print("Warning: deinterleaver requires interleaver to have " \
    713             "keep_state=True or to explicitly provide the seed as inputs.")
    714 
    715     #########################################
    716     # Public methods and properties
    717     #########################################
    718 
    719     @property
    720     def interleaver(self):
    721         """Associated interleaver instance."""
    722         return self._interleaver
    723 
    724     #########################
    725     # Utility methods
    726     #########################
    727 
    728     #########################
    729     # Keras layer functions
    730     #########################
    731 
    732     def build(self, input_shape):
    733         """build layer"""
    734         pass
    735 
    736     def call(self, inputs):
    737         """deinterleaving function.
    738 
    739         This function returns the permuted version of inputs.
    740 
    741         Args:
    742             inputs (tf.float32): Tensor of arbitrary shape. Must have at least
    743                 rank two.
    744 
    745         Returns:
    746             `tf.float32`: Tensor of same shape as the input.
    747         """
    748 
    749         x = self._interleaver.call_inverse(inputs)
    750 
    751         x = tf.cast(x, super().dtype) # cast output to correct dtype
    752         return x
    753 
    754 
    755 class Turbo3GPPInterleaver(Layer):
    756     # pylint: disable=line-too-long
    757     """Turbo3GPPInterleaver(inverse=False, axis=-1, dtype=tf.float32, **kwargs)
    758 
    759     Interleaver as used in the 3GPP Turbo codes [3GPPTS36212_I]_ and, thus,
    760     the maximum length is given as 6144 elements (only for the dimension as
    761     specific by ``axis``).
    762 
    763     The class inherits from the Keras layer class and can be used as layer in a
    764     Keras model.
    765 
    766     Parameters
    767     ----------
    768         inverse: bool
    769             A boolean defaults to False. If True, the inverse permutation is
    770             performed.
    771 
    772         axis: int
    773             Defaults to `-1`. The dimension that should be interleaved.
    774             First dimension (`axis=0`) is not allowed.
    775 
    776         dtype: tf.DType
    777             Defaults to `tf.float32`. Defines the datatype for internal
    778             calculations and the output dtype.
    779 
    780     Input
    781     -----
    782         x: tf.DType
    783             2+D tensor of arbitrary shape and dtype.
    784 
    785     Output
    786     ------
    787         : tf.DType
    788             2+D tensor of same shape and dtype as the input ``x``.
    789 
    790     Raises
    791     ------
    792         AssertionError
    793             If ``axis`` is not `int`.
    794 
    795         AssertionError
    796             If ``axis`` > number of input dimensions.
    797 
    798         AssertionError
    799             If ``inverse`` is not bool.
    800 
    801         InvalidArgumentError
    802             When rank(``x``)<2.
    803 
    804     Note
    805     ----
    806         Note that this implementation slightly deviates from the 3GPP
    807         standard [3GPPTS36212_I]_ in a sense that zero-padding is introduced
    808         for cases when the exact interleaver length is not supported by the
    809         standard.
    810     """
    811 
    812     def __init__(self,
    813                  inverse=False,
    814                  axis=-1,
    815                  dtype=tf.float32,
    816                  **kwargs):
    817 
    818         super().__init__(dtype=dtype, **kwargs)
    819 
    820         assert isinstance(axis, int), "axis must be int."
    821         assert axis!=0, "Cannot permute batch dimension."
    822         self._axis=axis
    823         self._keep_state = True # only required for deinterleaver
    824         self.frame_size = None
    825 
    826         assert isinstance(inverse, bool), "inverse must be boolean"
    827         self._inverse = inverse
    828 
    829         # load interleaver patterns as defined in the 3GPP standard
    830         self.coeffs_dict = {}
    831         source = files(coeffs).joinpath("turbo_coeffs.csv")
    832         with as_file(source) as coeffs.csv:
    833             csv_reader = np.genfromtxt(coeffs.csv, delimiter=",")
    834 
    835             for (line_count, row) in enumerate(csv_reader):
    836                 if line_count >0: #igonore first line (=header)
    837                     self.coeffs_dict[int(row[1])] = (int(row[2]), int(row[3]))
    838     #########################################
    839     # Public methods and properties
    840     #########################################
    841 
    842     @property
    843     def axis(self):
    844         """Axis to be permuted."""
    845         return self._axis
    846 
    847     def find_s_min(self, frame_size, s_min_stop=0):
    848         r"""Find :math:`S` parameter such that :math:`\pi(i)-\pi(j)>S` for all
    849         :math:`i-j<S`. This can be used to find optimized interleaver patterns.
    850 
    851         ``s_min_stop`` is an additional stopping condition, i.e., stop if
    852         current :math:`S` is already smaller than ``s_min_stop``.
    853 
    854         Please note that this is a Numpy utility function and usually not part
    855         of the graph.
    856 
    857         Input
    858         -----
    859         frame_size: int
    860             length of interleaver.
    861 
    862         s_min_stop: int
    863             Defaults to 0. Enables early stop if already current
    864             s_min<``s_min_stop``.
    865 
    866         Output
    867         ------
    868         : float
    869             The S-parameter for the given ``frame_size``.
    870         """
    871 
    872         assert isinstance(s_min_stop, int), "s_min_stop must be int."
    873         assert isinstance(frame_size, int), "frame_size must be int."
    874         assert(frame_size<6145), "Interleaver not defined for this frame_size."
    875 
    876         perm_seq = self._generate_perm_full(frame_size)
    877         perm_seq = perm_seq.numpy()
    878         s_min = frame_size
    879 
    880         for i in range(len(perm_seq)): # search for all positions in perm_seq
    881             for j in range(-s_min,s_min,1): # search dist
    882                 if j==0: # ignore identity
    883                     continue
    884                 if i+j>=0 and i+j<frame_size:
    885                     d = np.abs(perm_seq[i] - perm_seq[i+j])
    886                     if d<=np.abs(j):
    887                         s_min = np.min([s_min, np.abs(j)])
    888                     if d<s_min and np.abs(j)<s_min:
    889                         s_min = np.min([s_min, d])
    890             # early stop
    891             if s_min<=s_min_stop:
    892                 break
    893         return int(s_min)
    894 
    895     def call_inverse(self, inputs):
    896         """Implements deinterleaver function corresponding to call().
    897 
    898         Input
    899         -----
    900          x: tf.DType
    901             2+D tensor of arbitrary shape and dtype.
    902 
    903         Output
    904         ------
    905         : tf.DType
    906             2+D tensor of same shape and dtype as the input ``x``.
    907 
    908         Raises
    909         ------
    910         InvalidArgumentError
    911             When rank(``x``)<2.
    912         """
    913 
    914         if isinstance(inputs, (tuple, list)):
    915             if len(inputs)==1: # if user wants to call with call([x])
    916                 x = inputs
    917             else:
    918                 raise TypeError("inputs cannot have more than 1 entry.")
    919         else:
    920             x = inputs
    921 
    922         input_shape = x.shape
    923         frame_size = input_shape[self._axis]
    924 
    925         # activate inverse
    926         perm_seq = self._generate_perm_full(frame_size, inverse=True)
    927         x = tf.gather(x, perm_seq, batch_dims=0, axis=self._axis)
    928 
    929         # set explicitly for keras models
    930         x = tf.ensure_shape(x, input_shape)
    931         return x
    932 
    933     #########################
    934     # Utility methods
    935     #########################
    936 
    937     def _generate_perm_full(self, frame_size, inverse=False):
    938         """Generates a random permutation for the interleaver.
    939         Args:
    940             frame_size (int): The length of the sequence to be permuted.
    941 
    942             batch_size (int): The batch size (=number of independent
    943                 permutations).
    944 
    945             inverse (bool): Defaults to False. If True, the inverse permutation
    946                 for the given seed is generated.
    947         """
    948         k = frame_size
    949         if k not in self.coeffs_dict:
    950             geqk_sizes = sorted([x for x in self.coeffs_dict if x >= k])
    951             if len(geqk_sizes)==0:
    952                 print("Input frame size too large for 3GPP Turbo Interleaver.")
    953             else:
    954                 k = geqk_sizes[0]
    955         f1, f2 = self.coeffs_dict[k]
    956         perm_seq = [(f1 * i + f2* (i**2))%k for i in range(k)]
    957 
    958         if frame_size < k:
    959             perm_seq = [x for x in perm_seq if x < frame_size]
    960 
    961         perm_seq = tf.convert_to_tensor(perm_seq)
    962         if inverse:
    963             # cast to tf.float32 due to improved sorting performance
    964             perm_seq = tf.cast(perm_seq, tf.float32)
    965             perm_seq = tf.argsort(perm_seq, axis=-1)
    966 
    967         return perm_seq
    968 
    969     #########################
    970     # Keras layer functions
    971     #########################
    972 
    973     def build(self, input_shape):
    974         """Build Keras layer and check consistency of dimensions."""
    975         if isinstance(input_shape, list):
    976             input_shape=input_shape[0]
    977 
    978         assert self.axis < len(input_shape), "Axis does not match input shape."
    979         assert len(input_shape) > 1, "At least two dims are required."
    980 
    981         frame_size = input_shape[self._axis]
    982         assert(frame_size< 6145), \
    983             "3GPP Turbo Interleaver is defined for block lengths up to 6144."
    984 
    985     def call(self, inputs):
    986         """Interleaving function.
    987 
    988         This function returns the permuted version of ``inputs``.
    989         """
    990 
    991         if isinstance(inputs, (tuple, list)):
    992             if len(inputs)==1: # if user wants to call with call([x])
    993                 x = inputs
    994             else:
    995                 raise TypeError("inputs cannot have more than 1 entry.")
    996         else:
    997             x = inputs
    998 
    999         input_shape = x.shape
   1000         frame_size = input_shape[self._axis]
   1001 
   1002         perm_seq = self._generate_perm_full(frame_size, self._inverse)
   1003         x = tf.gather(x, perm_seq, batch_dims=0, axis=self._axis)
   1004 
   1005         # set explicitly for keras models
   1006         x = tf.ensure_shape(x, input_shape)
   1007         return x