upsampling.py (1774B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers implementing upsampling""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from tensorflow.experimental.numpy import swapaxes 10 from sionna.utils.tensors import flatten_last_dims 11 12 class Upsampling(Layer): 13 """Upsampling(samples_per_symbol, axis=-1, **kwargs) 14 15 Upsamples a tensor along a specified axis by inserting zeros 16 between samples. 17 18 Parameters 19 ---------- 20 samples_per_symbol: int 21 The upsampling factor. If ``samples_per_symbol`` is equal to `n`, 22 then the upsampled axis will be `n`-times longer. 23 24 axis: int 25 The dimension to be up-sampled. Must not be the first dimension. 26 27 Input 28 ----- 29 x : [...,n,...], tf.DType 30 The tensor to be upsampled. `n` is the size of the `axis` dimension. 31 32 Output 33 ------ 34 y : [...,n*samples_per_symbol,...], same dtype as ``x`` 35 The upsampled tensor. 36 """ 37 def __init__(self, samples_per_symbol, axis=-1, **kwargs): 38 super().__init__(**kwargs) 39 self._samples_per_symbol = samples_per_symbol 40 self._axis = axis 41 42 def build(self, input_shape): 43 paddings = [] 44 for _ in range(len(input_shape)): 45 paddings.append([0, 0]) 46 paddings.append([0, self._samples_per_symbol-1]) 47 self._paddings = paddings 48 49 def call(self, inputs): 50 x = swapaxes(inputs, self._axis, -1) 51 x = tf.expand_dims(x, -1) 52 x = tf.pad(x, 53 self._paddings, 54 constant_values=tf.cast(0, dtype=x.dtype)) 55 x = flatten_last_dims(x, 2) 56 x = swapaxes(x, -1, self._axis) 57 return x