downsampling.py (2145B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers implementing downsampling""" 6 7 from tensorflow.keras.layers import Layer 8 from tensorflow.experimental.numpy import swapaxes 9 10 class Downsampling(Layer): 11 # pylint: disable=line-too-long 12 """Downsampling(samples_per_symbol, offset=0, num_symbols=None, axis=-1, **kwargs) 13 14 Downsamples a tensor along a specified axis by retaining one out of 15 ``samples_per_symbol`` elements. 16 17 Parameters 18 ---------- 19 samples_per_symbol: int 20 The downsampling factor. If ``samples_per_symbol`` is equal to `n`, then the 21 downsampled axis will be `n`-times shorter. 22 23 offset: int 24 Defines the index of the first element to be retained. 25 Defaults to zero. 26 27 num_symbols: int 28 Defines the total number of symbols to be retained after 29 downsampling. 30 Defaults to None (i.e., the maximum possible number). 31 32 axis: int 33 The dimension to be downsampled. Must not be the first dimension. 34 35 Input 36 ----- 37 x : [...,n,...], tf.DType 38 The tensor to be downsampled. `n` is the size of the `axis` dimension. 39 40 Output 41 ------ 42 y : [...,k,...], same dtype as ``x`` 43 The downsampled tensor, where ``k`` 44 is min((``n``-``offset``)//``samples_per_symbol``, ``num_symbols``). 45 """ 46 def __init__(self, 47 samples_per_symbol, 48 offset=0, 49 num_symbols=None, 50 axis=-1, **kwargs): 51 super().__init__(**kwargs) 52 self._samples_per_symbol = samples_per_symbol 53 self._offset = offset 54 self._num_symbols = num_symbols 55 self._axis = axis 56 57 def call(self, inputs): 58 # Put selected axis last 59 x = swapaxes(inputs, self._axis, -1) 60 61 # Downsample 62 x = x[...,self._offset::self._samples_per_symbol] 63 64 if self._num_symbols is not None: 65 x = x[...,:self._num_symbols] 66 67 # Put last axis to original position 68 x = swapaxes(x, -1, self._axis) 69 70 return x