anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

upsampling.py (1774B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers implementing upsampling"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 from tensorflow.experimental.numpy import swapaxes
     10 from sionna.utils.tensors import flatten_last_dims
     11 
     12 class Upsampling(Layer):
     13     """Upsampling(samples_per_symbol, axis=-1, **kwargs)
     14 
     15     Upsamples a tensor along a specified axis by inserting zeros
     16     between samples.
     17 
     18     Parameters
     19     ----------
     20     samples_per_symbol: int
     21         The upsampling factor. If ``samples_per_symbol`` is equal to `n`,
     22         then the upsampled axis will be `n`-times longer.
     23 
     24     axis: int
     25         The dimension to be up-sampled. Must not be the first dimension.
     26 
     27     Input
     28     -----
     29     x : [...,n,...], tf.DType
     30         The tensor to be upsampled. `n` is the size of the `axis` dimension.
     31 
     32     Output
     33     ------
     34     y : [...,n*samples_per_symbol,...], same dtype as ``x``
     35         The upsampled tensor.
     36     """
     37     def __init__(self, samples_per_symbol, axis=-1, **kwargs):
     38         super().__init__(**kwargs)
     39         self._samples_per_symbol = samples_per_symbol
     40         self._axis = axis
     41 
     42     def build(self, input_shape):
     43         paddings = []
     44         for _ in range(len(input_shape)):
     45             paddings.append([0, 0])
     46         paddings.append([0, self._samples_per_symbol-1])
     47         self._paddings = paddings
     48 
     49     def call(self, inputs):
     50         x = swapaxes(inputs, self._axis, -1)
     51         x = tf.expand_dims(x, -1)
     52         x = tf.pad(x,
     53                    self._paddings,
     54                    constant_values=tf.cast(0, dtype=x.dtype))
     55         x = flatten_last_dims(x, 2)
     56         x = swapaxes(x, -1, self._axis)
     57         return x