window.py (15451B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers implementing windowing functions""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from abc import ABC, abstractmethod 10 from sionna.utils.tensors import expand_to_rank 11 import matplotlib.pyplot as plt 12 import numpy as np 13 14 class Window(ABC, Layer): 15 # pylint: disable=line-too-long 16 r"""Window(length, trainable=False, normalize=False, dtype=tf.float32, **kwargs) 17 18 This is an abtract class for defining and applying a window function of length ``length`` to an input ``x`` of the same length. 19 20 The window function is applied through element-wise multiplication. 21 22 The window function is real-valued, i.e., has `tf.float` as `dtype`. 23 The `dtype` of the output is the same as the `dtype` of the input ``x`` to which the window function is applied. 24 The window function and the input must have the same precision. 25 26 Parameters 27 ---------- 28 length: int 29 Window length (number of samples). 30 31 trainable: bool 32 If `True`, the window coefficients are trainable variables. 33 Defaults to `False`. 34 35 normalize: bool 36 If `True`, the window is normalized to have unit average power 37 per coefficient. 38 Defaults to `False`. 39 40 dtype: tf.DType 41 The `dtype` of the filter coefficients. 42 Must be either `tf.float32` or `tf.float64`. 43 Defaults to `tf.float32`. 44 45 Input 46 ----- 47 x : [..., N], tf.complex or tf.float 48 The input to which the window function is applied. 49 The window function is applied along the last dimension. 50 The length of the last dimension ``N`` must be the same as the ``length`` of the window function. 51 52 Output 53 ------ 54 y : [...,N], tf.complex or tf.float 55 Output of the windowing operation. 56 The output has the same shape and `dtype` as the input ``x``. 57 """ 58 59 def __init__(self, 60 length, 61 trainable=False, 62 normalize=False, 63 dtype=tf.float32, 64 **kwargs): 65 super().__init__(dtype=dtype, **kwargs) 66 67 assert length>0, "Length must be positive" 68 self._length = length 69 70 assert isinstance(trainable, bool), "trainable must be bool" 71 self._trainable = trainable 72 73 assert isinstance(normalize, bool), "normalize must be bool" 74 self._normalize = normalize 75 76 assert dtype.is_floating,\ 77 "`dtype` must be either `tf.float32` or `tf.float64`" 78 79 self._coefficients = tf.Variable(self._coefficients_source, 80 trainable=self.trainable, 81 dtype=tf.as_dtype(self.dtype)) 82 83 @property 84 @abstractmethod 85 def _coefficients_source(self): 86 """Internal property that returns the (unormalized) window coefficients. 87 Concrete classes that inherits from this one must implement this 88 property.""" 89 pass 90 91 @property 92 def coefficients(self): 93 """The window coefficients (after normalization)""" 94 w = self._coefficients 95 96 # Normalize if requested 97 if self.normalize: 98 energy = tf.reduce_mean(tf.square(w)) 99 w = w / tf.cast(tf.sqrt(energy), w.dtype) 100 101 return w 102 103 @property 104 def length(self): 105 "Window length in number of samples" 106 return self._length 107 108 @property 109 def trainable(self): 110 "`True` if the window coefficients are trainable. `False` otherwise." 111 return self._trainable 112 113 @property 114 def normalize(self): 115 """`True` if the window is normalized to have unit average power per coefficient. `False` 116 otherwise.""" 117 return self._normalize 118 119 def show(self, samples_per_symbol, domain="time", scale="lin"): 120 r"""Plot the window in time or frequency domain 121 122 For the computation of the Fourier transform, a minimum DFT size 123 of 1024 is assumed which is obtained through zero padding of 124 the window coefficients in the time domain. 125 126 Input 127 ----- 128 samples_per_symbol: int 129 Number of samples per symbol, i.e., the oversampling factor. 130 131 domain: str, one of ["time", "frequency"] 132 The desired domain. 133 Defaults to "time" 134 135 scale: str, one of ["lin", "db"] 136 The y-scale of the magnitude in the frequency domain. 137 Can be "lin" (i.e., linear) or "db" (, i.e., Decibel). 138 Defaults to "lin". 139 """ 140 assert domain in ["time", "frequency"], "Invalid domain" 141 # Sampling times 142 n_min = -(self.length//2) 143 n_max = n_min + self.length 144 sampling_times = np.arange(n_min, n_max, dtype=np.float32) 145 sampling_times /= samples_per_symbol 146 # 147 if domain=="time": 148 plt.figure(figsize=(12,6)) 149 plt.plot(sampling_times, np.real(self.coefficients.numpy())) 150 plt.title("Time domain") 151 plt.grid() 152 plt.xlabel(r"Normalized time $(t/T)$") 153 plt.ylabel(r"$w(t)$") 154 plt.xlim(sampling_times[0], sampling_times[-1]) 155 else: 156 assert scale in ["lin", "db"], "Invalid scale" 157 fft_size = max(1024, self.coefficients.shape[-1]) 158 h = np.fft.fft(self.coefficients.numpy(), fft_size) 159 h = np.fft.fftshift(h) 160 h = np.abs(h) 161 plt.figure(figsize=(12,6)) 162 if scale=="db": 163 h = np.maximum(h, 1e-10) 164 h = 10*np.log10(h) 165 plt.ylabel(r"$|W(f)|$ (dB)") 166 else: 167 plt.ylabel(r"$|W(f)|$") 168 f = np.linspace(-samples_per_symbol/2, 169 samples_per_symbol/2, fft_size) 170 plt.plot(f, h) 171 plt.title("Frequency domain") 172 plt.grid() 173 plt.xlabel(r"Normalized frequency $(f/W)$") 174 plt.xlim(f[0], f[-1]) 175 176 def call(self, x): 177 x_dtype = tf.as_dtype(x.dtype) 178 179 # Expand to the same rank as the input for broadcasting 180 w = self.coefficients 181 w = expand_to_rank(w, tf.rank(x), 0) 182 183 if x_dtype.is_floating: 184 y = x*w 185 else: # Complex 186 w = tf.complex(w, tf.zeros_like(w)) 187 y = w*x 188 189 return y 190 191 192 class CustomWindow(Window): 193 # pylint: disable=line-too-long 194 r"""CustomWindow(length, coefficients=None, trainable=False, normalize=False, dtype=tf.float32, **kwargs) 195 196 Layer for defining and applying a custom window function of length ``length`` to an input ``x`` of the same length. 197 198 The window function is applied through element-wise multiplication. 199 200 The window function is real-valued, i.e., has `tf.float` as `dtype`. 201 The `dtype` of the output is the same as the `dtype` of the input ``x`` to which the window function is applied. 202 The window function and the input must have the same precision. 203 204 The window coefficients can be set through the ``coefficients`` parameter. 205 If not provided, random window coefficients are generated by sampling a Gaussian distribution. 206 207 Parameters 208 ---------- 209 length: int 210 Window length (number of samples). 211 212 coefficients: [N], tf.float 213 Optional window coefficients. 214 If set to `None`, then a random window of length ``length`` is generated by sampling a Gaussian distribution. 215 Defaults to `None`. 216 217 trainable: bool 218 If `True`, the window coefficients are trainable variables. 219 Defaults to `False`. 220 221 normalize: bool 222 If `True`, the window is normalized to have unit average power 223 per coefficient. 224 Defaults to `False`. 225 226 dtype: tf.DType 227 The `dtype` of the filter coefficients. 228 Must be either `tf.float32` or `tf.float64`. 229 Defaults to `tf.float32`. 230 231 Input 232 ----- 233 x : [..., N], tf.complex or tf.float 234 The input to which the window function is applied. 235 The window function is applied along the last dimension. 236 The length of the last dimension ``N`` must be the same as the ``length`` of the window function. 237 238 Output 239 ------ 240 y : [...,N], tf.complex or tf.float 241 Output of the windowing operation. 242 The output has the same shape and `dtype` as the input ``x``. 243 """ 244 245 def __init__(self, 246 length, 247 coefficients=None, 248 trainable=False, 249 normalize=False, 250 dtype=tf.float32, 251 **kwargs): 252 253 if coefficients is not None: 254 assert len(coefficients) == length,\ 255 "specified `length` does not match the one of `coefficients`" 256 self._c = tf.constant(coefficients, dtype=dtype) 257 else: 258 self._c = tf.keras.initializers.RandomNormal()([length], dtype) 259 260 super().__init__(length, 261 trainable, 262 normalize, 263 dtype, 264 **kwargs) 265 266 @property 267 def _coefficients_source(self): 268 return self._c 269 270 271 class HannWindow(Window): 272 # pylint: disable=line-too-long 273 r"""HannWindow(length, trainable=False, normalize=False, dtype=tf.float32, **kwargs) 274 275 Layer for applying a Hann window function of length ``length`` to an input ``x`` of the same length. 276 277 The window function is applied through element-wise multiplication. 278 279 The window function is real-valued, i.e., has `tf.float` as `dtype`. 280 The `dtype` of the output is the same as the `dtype` of the input ``x`` to which the window function is applied. 281 The window function and the input must have the same precision. 282 283 The Hann window is defined by 284 285 .. math:: 286 w_n = \sin^2 \left( \frac{\pi n}{N} \right), 0 \leq n \leq N-1 287 288 where :math:`N` is the window length. 289 290 Parameters 291 ---------- 292 length: int 293 Window length (number of samples). 294 295 trainable: bool 296 If `True`, the window coefficients are trainable variables. 297 Defaults to `False`. 298 299 normalize: bool 300 If `True`, the window is normalized to have unit average power 301 per coefficient. 302 Defaults to `False`. 303 304 dtype: tf.DType 305 The `dtype` of the filter coefficients. 306 Must be either `tf.float32` or `tf.float64`. 307 Defaults to `tf.float32`. 308 309 Input 310 ----- 311 x : [..., N], tf.complex or tf.float 312 The input to which the window function is applied. 313 The window function is applied along the last dimension. 314 The length of the last dimension ``N`` must be the same as the ``length`` of the window function. 315 316 Output 317 ------ 318 y : [...,N], tf.complex or tf.float 319 Output of the windowing operation. 320 The output has the same shape and `dtype` as the input ``x``. 321 """ 322 323 @property 324 def _coefficients_source(self): 325 n = np.arange(self.length) 326 coefficients = np.square(np.sin(np.pi*n/self.length)) 327 return tf.constant(coefficients, self.dtype) 328 329 330 class HammingWindow(Window): 331 # pylint: disable=line-too-long 332 r"""HammingWindow(length, trainable=False, normalize=False, dtype=tf.float32, **kwargs) 333 334 Layer for applying a Hamming window function of length ``length`` to an input ``x`` of the same length. 335 336 The window function is applied through element-wise multiplication. 337 338 The window function is real-valued, i.e., has `tf.float` as `dtype`. 339 The `dtype` of the output is the same as the `dtype` of the input ``x`` to which the window function is applied. 340 The window function and the input must have the same precision. 341 342 The Hamming window is defined by 343 344 .. math:: 345 w_n = a_0 - (1-a_0) \cos \left( \frac{2 \pi n}{N} \right), 0 \leq n \leq N-1 346 347 where :math:`N` is the window length and :math:`a_0 = \frac{25}{46}`. 348 349 Parameters 350 ---------- 351 length: int 352 Window length (number of samples). 353 354 trainable: bool 355 If `True`, the window coefficients are trainable variables. 356 Defaults to `False`. 357 358 normalize: bool 359 If `True`, the window is normalized to have unit average power 360 per coefficient. 361 Defaults to `False`. 362 363 dtype: tf.DType 364 The `dtype` of the filter coefficients. 365 Must be either `tf.float32` or `tf.float64`. 366 Defaults to `tf.float32`. 367 368 Input 369 ----- 370 x : [..., N], tf.complex or tf.float 371 The input to which the window function is applied. 372 The window function is applied along the last dimension. 373 The length of the last dimension ``N`` must be the same as the ``length`` of the window function. 374 375 Output 376 ------ 377 y : [...,N], tf.complex or tf.float 378 Output of the windowing operation. 379 The output has the same shape and `dtype` as the input ``x``. 380 """ 381 382 @property 383 def _coefficients_source(self): 384 n = self.length 385 nn = np.arange(n) 386 a0 = 25./46. 387 a1 = 1. - a0 388 coefficients = a0 - a1*np.cos(2.*np.pi*nn/n) 389 return tf.constant(coefficients, self.dtype) 390 391 392 class BlackmanWindow(Window): 393 # pylint: disable=line-too-long 394 r"""BlackmanWindow(length, trainable=False, normalize=False, dtype=tf.float32, **kwargs) 395 396 Layer for applying a Blackman window function of length ``length`` to an input ``x`` of the same length. 397 398 The window function is applied through element-wise multiplication. 399 400 The window function is real-valued, i.e., has `tf.float` as `dtype`. 401 The `dtype` of the output is the same as the `dtype` of the input ``x`` to which the window function is applied. 402 The window function and the input must have the same precision. 403 404 The Blackman window is defined by 405 406 .. math:: 407 w_n = a_0 - a_1 \cos \left( \frac{2 \pi n}{N} \right) + a_2 \cos \left( \frac{4 \pi n}{N} \right), 0 \leq n \leq N-1 408 409 where :math:`N` is the window length, :math:`a_0 = \frac{7938}{18608}`, :math:`a_1 = \frac{9240}{18608}`, and :math:`a_2 = \frac{1430}{18608}`. 410 411 Parameters 412 ---------- 413 length: int 414 Window length (number of samples). 415 416 trainable: bool 417 If `True`, the window coefficients are trainable variables. 418 Defaults to `False`. 419 420 normalize: bool 421 If `True`, the window is normalized to have unit average power 422 per coefficient. 423 Defaults to `False`. 424 425 dtype: tf.DType 426 The `dtype` of the filter coefficients. 427 Must be either `tf.float32` or `tf.float64`. 428 Defaults to `tf.float32`. 429 430 Input 431 ----- 432 x : [..., N], tf.complex or tf.float 433 The input to which the window function is applied. 434 The window function is applied along the last dimension. 435 The length of the last dimension ``N`` must be the same as the ``length`` of the window function. 436 437 Output 438 ------ 439 y : [...,N], tf.complex or tf.float 440 Output of the windowing operation. 441 The output has the same shape and `dtype` as the input ``x``. 442 """ 443 444 @property 445 def _coefficients_source(self): 446 n = self.length 447 nn = np.arange(n) 448 a0 = 7938./18608. 449 a1 = 9240./18608. 450 a2 = 1430./18608. 451 coefficients = a0 - a1*np.cos(2.*np.pi*nn/n) + a2*np.cos(4.*np.pi*nn/n) 452 return tf.constant(coefficients, self.dtype)