anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

downsampling.py (2145B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers implementing downsampling"""
      6 
      7 from tensorflow.keras.layers import Layer
      8 from tensorflow.experimental.numpy import swapaxes
      9 
     10 class Downsampling(Layer):
     11     # pylint: disable=line-too-long
     12     """Downsampling(samples_per_symbol, offset=0, num_symbols=None, axis=-1, **kwargs)
     13 
     14     Downsamples a tensor along a specified axis by retaining one out of
     15     ``samples_per_symbol`` elements.
     16 
     17     Parameters
     18     ----------
     19     samples_per_symbol: int
     20         The downsampling factor. If ``samples_per_symbol`` is equal to `n`, then the
     21         downsampled axis will be `n`-times shorter.
     22 
     23     offset: int
     24         Defines the index of the first element to be retained.
     25         Defaults to zero.
     26 
     27     num_symbols: int
     28         Defines the total number of symbols to be retained after
     29         downsampling.
     30         Defaults to None (i.e., the maximum possible number).
     31 
     32     axis: int
     33         The dimension to be downsampled. Must not be the first dimension.
     34 
     35     Input
     36     -----
     37     x : [...,n,...], tf.DType
     38         The tensor to be downsampled. `n` is the size of the `axis` dimension.
     39 
     40     Output
     41     ------
     42     y : [...,k,...], same dtype as ``x``
     43         The downsampled tensor, where ``k``
     44         is min((``n``-``offset``)//``samples_per_symbol``, ``num_symbols``).
     45     """
     46     def __init__(self,
     47                  samples_per_symbol,
     48                  offset=0,
     49                  num_symbols=None,
     50                  axis=-1, **kwargs):
     51         super().__init__(**kwargs)
     52         self._samples_per_symbol = samples_per_symbol
     53         self._offset = offset
     54         self._num_symbols = num_symbols
     55         self._axis = axis
     56 
     57     def call(self, inputs):
     58         # Put selected axis last
     59         x = swapaxes(inputs, self._axis, -1)
     60 
     61         # Downsample
     62         x = x[...,self._offset::self._samples_per_symbol]
     63 
     64         if self._num_symbols is not None:
     65             x = x[...,:self._num_symbols]
     66 
     67         # Put last axis to original position
     68         x = swapaxes(x, -1, self._axis)
     69 
     70         return x