filter.py (31796B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Layers implementing filters""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from tensorflow.keras.initializers import RandomNormal 10 from abc import ABC, abstractmethod 11 import matplotlib.pyplot as plt 12 import numpy as np 13 from . import convolve, Window, HannWindow, HammingWindow, BlackmanWindow, empirical_aclr 14 15 16 class Filter(ABC, Layer): 17 # pylint: disable=line-too-long 18 r"""Filter(span_in_symbols, samples_per_symbol, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs) 19 20 This is an abtract class for defining a filter of ``length`` K which can be 21 applied to an input ``x`` of length N. 22 23 The filter length K is equal to the filter span in symbols (``span_in_symbols``) 24 multiplied by the oversampling factor (``samples_per_symbol``). 25 If this product is even, a value of one will be added. 26 27 The filter is applied through discrete convolution. 28 29 An optional windowing function ``window`` can be applied to the filter. 30 31 The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`. 32 Otherwise, the dtype of the output is `tf.complex`. 33 34 Three padding modes are available for applying the filter: 35 36 * "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter. 37 The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to 38 compute the convolution at the borders. 39 * "same": Returns an output of the same length as the input ``x``. The convolution is computed such 40 that the coefficients of the input ``x`` are centered on the coefficient of the filter with index 41 (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders. 42 * "valid": Returns the convolution only at points where ``x`` and the filter completely overlap. 43 The length of the output is N - K + 1. 44 45 Parameters 46 ---------- 47 span_in_symbols: int 48 Filter span as measured by the number of symbols. 49 50 samples_per_symbol: int 51 Number of samples per symbol, i.e., the oversampling factor. 52 53 window: Window or string (["hann", "hamming", "blackman"]) 54 Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients. 55 Alternatively, a string indicating the window name can be provided. In this case, 56 the chosen window will be instantiated with the default parameters. Custom windows 57 must be provided as instance. 58 59 normalize: bool 60 If `True`, the filter is normalized to have unit power. 61 Defaults to `True`. 62 63 trainable: bool 64 If `True`, the filter coefficients are trainable. 65 Defaults to `False`. 66 67 dtype: tf.DType 68 The `dtype` of the filter coefficients. 69 Defaults to `tf.float32`. 70 71 Input 72 ----- 73 x : [..., N], tf.complex or tf.float 74 The input to which the filter is applied. 75 The filter is applied along the last dimension. 76 77 padding : string (["full", "valid", "same"]) 78 Padding mode for convolving ``x`` and the filter. 79 Must be one of "full", "valid", or "same". Case insensitive. 80 Defaults to "full". 81 82 conjugate : bool 83 If `True`, the complex conjugate of the filter is applied. 84 Defaults to `False`. 85 86 Output 87 ------ 88 y : [...,M], tf.complex or tf.float 89 Filtered input. 90 It is `tf.float` only if both ``x`` and the filter are `tf.float`. 91 It is `tf.complex` otherwise. 92 The length M depends on the ``padding``. 93 """ 94 def __init__(self, 95 span_in_symbols, 96 samples_per_symbol, 97 window=None, 98 normalize=True, 99 trainable=False, 100 dtype=tf.float32, 101 **kwargs): 102 super().__init__(dtype=dtype, **kwargs) 103 104 assert span_in_symbols>0, "span_in_symbols must be positive" 105 self._span_in_symbols = span_in_symbols 106 107 assert samples_per_symbol>0, "samples_per_symbol must be positive" 108 self._samples_per_symbol = samples_per_symbol 109 110 self.window = window 111 112 assert isinstance(normalize, bool), "normalize must be bool" 113 self._normalize = normalize 114 115 assert isinstance(trainable, bool), "trainable must be bool" 116 self._trainable = trainable 117 118 assert self.length==self._coefficients_source.shape[-1], \ 119 "The number of coefficients must match the filter length." 120 121 dtype = tf.as_dtype(self._dtype) 122 if dtype.is_floating: 123 self._coefficients = tf.Variable(self._coefficients_source, 124 trainable=self.trainable) 125 elif dtype.is_complex: 126 c = self._coefficients_source 127 self._coefficients = [ tf.Variable(tf.math.real(c), 128 trainable=self.trainable), 129 tf.Variable(tf.math.imag(c), 130 trainable=self.trainable)] 131 132 @property 133 def length(self): 134 """The filter length in samples""" 135 l = self._span_in_symbols*self._samples_per_symbol 136 l = 2*(l//2)+1 # Force length to be the next odd number 137 return l 138 139 @property 140 def window(self): 141 """The window function that is applied to the filter coefficients. `None` if no window is applied.""" 142 return self._window 143 144 @window.setter 145 def window(self, value): 146 if isinstance(value, str): 147 if value=="hann": 148 self._window = HannWindow(self.length) 149 elif value=="hamming": 150 self._window = HammingWindow(self.length) 151 elif value=="blackman": 152 self._window = BlackmanWindow(self.length) 153 else: 154 raise AssertionError("Invalid window type") 155 elif isinstance(value, Window) or value is None: 156 self._window = value 157 else: 158 raise AssertionError("Invalid window type") 159 160 @property 161 def normalize(self): 162 """`True` if the filter is normalized to have unit power. `False` otherwise.""" 163 return self._normalize 164 165 @property 166 def trainable(self): 167 """`True` if the filter coefficients are trainable. `False` otherwise.""" 168 return self._trainable 169 170 @property 171 @abstractmethod 172 def _coefficients_source(self): 173 """Internal property that returns the (unormalized) filter coefficients. 174 Concrete classes that inherits from this one must implement this 175 property.""" 176 pass 177 178 @property 179 def coefficients(self): 180 """The filter coefficients (after normalization)""" 181 h = self._coefficients 182 dtype = tf.as_dtype(self.dtype) 183 184 # Combine both real dimensions to complex if necessary 185 if dtype.is_complex: 186 h = tf.complex(h[0], h[1]) 187 188 # Apply window 189 if self.window is not None: 190 h = self._window(h) 191 192 # Ensure unit L2-norm of the coefficients 193 if self.normalize: 194 energy = tf.reduce_sum(tf.square(tf.abs(h))) 195 h = h / tf.cast(tf.sqrt(energy), h.dtype) 196 197 return h 198 199 @property 200 def sampling_times(self): 201 """Sampling times in multiples of the symbol duration""" 202 n_min = -(self.length//2) 203 n_max = n_min + self.length 204 t = np.arange(n_min, n_max, dtype=np.float32) 205 t /= self._samples_per_symbol 206 return t 207 208 def show(self, response="impulse", scale="lin"): 209 r"""Plot the impulse or magnitude response 210 211 Plots the impulse response (time domain) or magnitude response 212 (frequency domain) of the filter. 213 214 For the computation of the magnitude response, a minimum DFT size 215 of 1024 is assumed which is obtained through zero padding of 216 the filter coefficients in the time domain. 217 218 Input 219 ----- 220 response: str, one of ["impulse", "magnitude"] 221 The desired response type. 222 Defaults to "impulse" 223 224 scale: str, one of ["lin", "db"] 225 The y-scale of the magnitude response. 226 Can be "lin" (i.e., linear) or "db" (, i.e., Decibel). 227 Defaults to "lin". 228 """ 229 assert response in ["impulse", "magnitude"], "Invalid response" 230 if response=="impulse": 231 dtype = tf.as_dtype(self.dtype) 232 plt.figure(figsize=(12,6)) 233 plt.plot(self.sampling_times, np.real(self.coefficients)) 234 if dtype.is_complex: 235 plt.plot(self.sampling_times, np.imag(self.coefficients)) 236 plt.legend(["Real part", "Imaginary part"]) 237 plt.title("Impulse response") 238 plt.grid() 239 plt.xlabel(r"Normalized time $(t/T)$") 240 plt.ylabel(r"$h(t)$") 241 plt.xlim(self.sampling_times[0], self.sampling_times[-1]) 242 243 else: 244 assert scale in ["lin", "db"], "Invalid scale" 245 fft_size = max(1024, self.coefficients.shape[-1]) 246 h = np.fft.fft(self.coefficients, fft_size) 247 h = np.fft.fftshift(h) 248 h = np.abs(h) 249 plt.figure(figsize=(12,6)) 250 if scale=="db": 251 h = np.maximum(h, 1e-10) 252 h = 10*np.log10(h) 253 plt.ylabel(r"$|H(f)|$ (dB)") 254 else: 255 plt.ylabel(r"$|H(f)|$") 256 f = np.linspace(-self._samples_per_symbol/2, 257 self._samples_per_symbol/2, fft_size) 258 plt.plot(f, h) 259 plt.title("Magnitude response") 260 plt.grid() 261 plt.xlabel(r"Normalized frequency $(f/W)$") 262 plt.xlim(f[0], f[-1]) 263 264 @property 265 def aclr(self): 266 """ACLR of the filter 267 268 This ACLR corresponds to what one would obtain from using 269 this filter as pulse shaping filter on an i.i.d. sequence of symbols. 270 The in-band is assumed to range from [-0.5, 0.5] in normalized 271 frequency. 272 """ 273 fft_size = 1024 274 n = fft_size - tf.shape(self.coefficients)[-1] 275 z = tf.zeros([n], self.coefficients.dtype) 276 c = tf.cast(tf.concat([self.coefficients, z], -1), tf.complex64) 277 return empirical_aclr(c, self._samples_per_symbol) 278 279 def call(self, x, padding='full', conjugate=False): 280 h = self.coefficients 281 dtype = tf.as_dtype(self.dtype) 282 if conjugate and dtype.is_complex: 283 h = tf.math.conj(h) 284 y = convolve(x,h,padding) 285 return y 286 287 288 class RaisedCosineFilter(Filter): 289 # pylint: disable=line-too-long 290 r"""RaisedCosineFilter(span_in_symbols, samples_per_symbol, beta, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs) 291 292 Layer for applying a raised-cosine filter of ``length`` K 293 to an input ``x`` of length N. 294 295 The raised-cosine filter is defined by 296 297 .. math:: 298 h(t) = 299 \begin{cases} 300 \frac{\pi}{4T} \text{sinc}\left(\frac{1}{2\beta}\right), & \text { if }t = \pm \frac{T}{2\beta}\\ 301 \frac{1}{T}\text{sinc}\left(\frac{t}{T}\right)\frac{\cos\left(\frac{\pi\beta t}{T}\right)}{1-\left(\frac{2\beta t}{T}\right)^2}, & \text{otherwise} 302 \end{cases} 303 304 where :math:`\beta` is the roll-off factor and :math:`T` the symbol duration. 305 306 The filter length K is equal to the filter span in symbols (``span_in_symbols``) 307 multiplied by the oversampling factor (``samples_per_symbol``). 308 If this product is even, a value of one will be added. 309 310 The filter is applied through discrete convolution. 311 312 An optional windowing function ``window`` can be applied to the filter. 313 314 The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`. 315 Otherwise, the dtype of the output is `tf.complex`. 316 317 Three padding modes are available for applying the filter: 318 319 * "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter. 320 The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to 321 compute the convolution at the borders. 322 * "same": Returns an output of the same length as the input ``x``. The convolution is computed such 323 that the coefficients of the input ``x`` are centered on the coefficient of the filter with index 324 (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders. 325 * "valid": Returns the convolution only at points where ``x`` and the filter completely overlap. 326 The length of the output is N - K + 1. 327 328 Parameters 329 ---------- 330 span_in_symbols: int 331 Filter span as measured by the number of symbols. 332 333 samples_per_symbol: int 334 Number of samples per symbol, i.e., the oversampling factor. 335 336 beta : float 337 Roll-off factor. 338 Must be in the range :math:`[0,1]`. 339 340 window: Window or string (["hann", "hamming", "blackman"]) 341 Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients. 342 Alternatively, a string indicating the window name can be provided. In this case, 343 the chosen window will be instantiated with the default parameters. Custom windows 344 must be provided as instance. 345 346 normalize: bool 347 If `True`, the filter is normalized to have unit power. 348 Defaults to `True`. 349 350 trainable: bool 351 If `True`, the filter coefficients are trainable variables. 352 Defaults to `False`. 353 354 dtype: tf.DType 355 The `dtype` of the filter coefficients. 356 Defaults to `tf.float32`. 357 358 Input 359 ----- 360 x : [..., N], tf.complex or tf.float 361 The input to which the filter is applied. 362 The filter is applied along the last dimension. 363 364 padding : string (["full", "valid", "same"]) 365 Padding mode for convolving ``x`` and the filter. 366 Must be one of "full", "valid", or "same". 367 Defaults to "full". 368 369 conjugate : bool 370 If `True`, the complex conjugate of the filter is applied. 371 Defaults to `False`. 372 373 Output 374 ------ 375 y : [...,M], tf.complex or tf.float 376 Filtered input. 377 It is `tf.float` only if both ``x`` and the filter are `tf.float`. 378 It is `tf.complex` otherwise. 379 The length M depends on the ``padding``. 380 """ 381 def __init__(self, 382 span_in_symbols, 383 samples_per_symbol, 384 beta, 385 window=None, 386 normalize=True, 387 trainable=False, 388 dtype=tf.float32, 389 **kwargs): 390 391 assert 0 <= beta <= 1, "beta must be from the intervall [0,1]" 392 self._beta = beta 393 394 super().__init__(span_in_symbols, 395 samples_per_symbol, 396 window=window, 397 normalize=normalize, 398 trainable=trainable, 399 dtype=dtype, 400 **kwargs) 401 402 @property 403 def beta(self): 404 """Roll-off factor""" 405 return self._beta 406 407 @property 408 def _coefficients_source(self): 409 h = self._raised_cosine(self.sampling_times, 410 1.0, 411 self.beta) 412 h = tf.constant(h, self.dtype) 413 return h 414 415 def _raised_cosine(self, t, symbol_duration, beta): 416 """Raised-cosine filter from Wikipedia 417 https://en.wikipedia.org/wiki/Raised-cosine_filter""" 418 h = np.zeros([len(t)], np.float32) 419 for i, tt in enumerate(t): 420 tt = np.abs(tt) 421 if beta>0 and (tt-np.abs(symbol_duration/2/beta)==0): 422 h[i] = np.pi/4/symbol_duration*np.sinc(1/2/beta) 423 else: 424 h[i] = 1./symbol_duration*np.sinc(tt/symbol_duration)\ 425 * np.cos(np.pi*beta*tt/symbol_duration)\ 426 /(1-(2*beta*tt/symbol_duration)**2) 427 return h 428 429 430 class RootRaisedCosineFilter(Filter): 431 # pylint: disable=line-too-long 432 r"""RootRaisedCosineFilter(span_in_symbols, samples_per_symbol, beta, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs) 433 434 Layer for applying a root-raised-cosine filter of ``length`` K 435 to an input ``x`` of length N. 436 437 The root-raised-cosine filter is defined by 438 439 .. math:: 440 h(t) = 441 \begin{cases} 442 \frac{1}{T} \left(1 + \beta\left(\frac{4}{\pi}-1\right) \right), & \text { if }t = 0\\ 443 \frac{\beta}{T\sqrt{2}} \left[ \left(1+\frac{2}{\pi}\right)\sin\left(\frac{\pi}{4\beta}\right) + \left(1-\frac{2}{\pi}\right)\cos\left(\frac{\pi}{4\beta}\right) \right], & \text { if }t = \pm\frac{T}{4\beta} \\ 444 \frac{1}{T} \frac{\sin\left(\pi\frac{t}{T}(1-\beta)\right) + 4\beta\frac{t}{T}\cos\left(\pi\frac{t}{T}(1+\beta)\right)}{\pi\frac{t}{T}\left(1-\left(4\beta\frac{t}{T}\right)^2\right)}, & \text { otherwise} 445 \end{cases} 446 447 where :math:`\beta` is the roll-off factor and :math:`T` the symbol duration. 448 449 The filter length K is equal to the filter span in symbols (``span_in_symbols``) 450 multiplied by the oversampling factor (``samples_per_symbol``). 451 If this product is even, a value of one will be added. 452 453 The filter is applied through discrete convolution. 454 455 An optional windowing function ``window`` can be applied to the filter. 456 457 The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`. 458 Otherwise, the dtype of the output is `tf.complex`. 459 460 Three padding modes are available for applying the filter: 461 462 * "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter. 463 The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to 464 compute the convolution at the borders. 465 * "same": Returns an output of the same length as the input ``x``. The convolution is computed such 466 that the coefficients of the input ``x`` are centered on the coefficient of the filter with index 467 (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders. 468 * "valid": Returns the convolution only at points where ``x`` and the filter completely overlap. 469 The length of the output is N - K + 1. 470 471 Parameters 472 ---------- 473 span_in_symbols: int 474 Filter span as measured by the number of symbols. 475 476 samples_per_symbol: int 477 Number of samples per symbol, i.e., the oversampling factor. 478 479 beta : float 480 Roll-off factor. 481 Must be in the range :math:`[0,1]`. 482 483 window: Window or string (["hann", "hamming", "blackman"]) 484 Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients. 485 Alternatively, a string indicating the window name can be provided. In this case, 486 the chosen window will be instantiated with the default parameters. Custom windows 487 must be provided as instance. 488 489 normalize: bool 490 If `True`, the filter is normalized to have unit power. 491 Defaults to `True`. 492 493 trainable: bool 494 If `True`, the filter coefficients are trainable variables. 495 Defaults to `False`. 496 497 dtype: tf.DType 498 The `dtype` of the filter coefficients. 499 Defaults to `tf.float32`. 500 501 Input 502 ----- 503 x : [..., N], tf.complex or tf.float 504 The input to which the filter is applied. 505 The filter is applied along the last dimension. 506 507 padding : string (["full", "valid", "same"]) 508 Padding mode for convolving ``x`` and the filter. 509 Must be one of "full", "valid", or "same". Case insensitive. 510 Defaults to "full". 511 512 conjugate : bool 513 If `True`, the complex conjugate of the filter is applied. 514 Defaults to `False`. 515 516 Output 517 ------ 518 y : [...,M], tf.complex or tf.float 519 Filtered input. 520 It is `tf.float` only if both ``x`` and the filter are `tf.float`. 521 It is `tf.complex` otherwise. 522 The length M depends on the ``padding``. 523 """ 524 def __init__(self, 525 span_in_symbols, 526 samples_per_symbol, 527 beta, 528 window=None, 529 normalize=True, 530 trainable=False, 531 dtype=tf.float32, 532 **kwargs): 533 534 assert 0 <= beta <= 1, "beta must be from the intervall [0,1]" 535 self._beta = beta 536 537 super().__init__(span_in_symbols, 538 samples_per_symbol, 539 window=window, 540 normalize=normalize, 541 trainable=trainable, 542 dtype=dtype, 543 **kwargs) 544 545 @property 546 def beta(self): 547 """Roll-off factor""" 548 return self._beta 549 550 @property 551 def _coefficients_source(self): 552 h = self._root_raised_cosine(self.sampling_times, 553 1.0, 554 self.beta) 555 h = tf.constant(h, self.dtype) 556 return h 557 558 def _root_raised_cosine(self, t, symbol_duration, beta): 559 """Root-raised-cosine filter from Wikipedia 560 https://en.wikipedia.org/wiki/Root-raised-cosine_filter""" 561 h = np.zeros([len(t)], np.float32) 562 for i, tt in enumerate(t): 563 tt = np.abs(tt) 564 if tt==0: 565 h[i] = 1/symbol_duration*(1+beta*(4/np.pi-1)) 566 elif beta>0 and (tt-np.abs(symbol_duration/4/beta)==0): 567 h[i] = beta/symbol_duration/np.sqrt(2)\ 568 * ((1+2/np.pi)*np.sin(np.pi/4/beta) + \ 569 (1-2/np.pi)*np.cos(np.pi/4/beta)) 570 else: 571 h[i] = 1/symbol_duration\ 572 / (np.pi*tt/symbol_duration*(1-(4*beta*tt/symbol_duration)**2))\ 573 * (np.sin(np.pi*tt/symbol_duration*(1-beta)) + \ 574 4*beta*tt/symbol_duration\ 575 *np.cos(np.pi*tt/symbol_duration*(1+beta))) 576 return h 577 578 579 class SincFilter(Filter): 580 # pylint: disable=line-too-long 581 r"""SincFilter(span_in_symbols, samples_per_symbol, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs) 582 583 Layer for applying a sinc filter of ``length`` K 584 to an input ``x`` of length N. 585 586 The sinc filter is defined by 587 588 .. math:: 589 h(t) = \frac{1}{T}\text{sinc}\left(\frac{t}{T}\right) 590 591 where :math:`T` the symbol duration. 592 593 The filter length K is equal to the filter span in symbols (``span_in_symbols``) 594 multiplied by the oversampling factor (``samples_per_symbol``). 595 If this product is even, a value of one will be added. 596 597 The filter is applied through discrete convolution. 598 599 An optional windowing function ``window`` can be applied to the filter. 600 601 The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`. 602 Otherwise, the dtype of the output is `tf.complex`. 603 604 Three padding modes are available for applying the filter: 605 606 * "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter. 607 The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to 608 compute the convolution at the borders. 609 * "same": Returns an output of the same length as the input ``x``. The convolution is computed such 610 that the coefficients of the input ``x`` are centered on the coefficient of the filter with index 611 (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders. 612 * "valid": Returns the convolution only at points where ``x`` and the filter completely overlap. 613 The length of the output is N - K + 1. 614 615 Parameters 616 ---------- 617 span_in_symbols: int 618 Filter span as measured by the number of symbols. 619 620 samples_per_symbol: int 621 Number of samples per symbol, i.e., the oversampling factor. 622 623 window: Window or string (["hann", "hamming", "blackman"]) 624 Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients. 625 Alternatively, a string indicating the window name can be provided. In this case, 626 the chosen window will be instantiated with the default parameters. Custom windows 627 must be provided as instance. 628 629 normalize: bool 630 If `True`, the filter is normalized to have unit power. 631 Defaults to `True`. 632 633 trainable: bool 634 If `True`, the filter coefficients are trainable variables. 635 Defaults to `False`. 636 637 dtype: tf.DType 638 The `dtype` of the filter coefficients. 639 Defaults to `tf.float32`. 640 641 Input 642 ----- 643 x : [..., N], tf.complex or tf.float 644 The input to which the filter is applied. 645 The filter is applied along the last dimension. 646 647 padding : string (["full", "valid", "same"]) 648 Padding mode for convolving ``x`` and the filter. 649 Must be one of "full", "valid", or "same". Case insensitive. 650 Defaults to "full". 651 652 conjugate : bool 653 If `True`, the complex conjugate of the filter is applied. 654 Defaults to `False`. 655 656 Output 657 ------ 658 y : [...,M], tf.complex or tf.float 659 Filtered input. 660 It is `tf.float` only if both ``x`` and the filter are `tf.float`. 661 It is `tf.complex` otherwise. 662 The length M depends on the ``padding``. 663 """ 664 def __init__(self, 665 span_in_symbols, 666 samples_per_symbol, 667 window=None, 668 normalize=True, 669 trainable=False, 670 dtype=tf.float32, 671 **kwargs): 672 super().__init__(span_in_symbols, 673 samples_per_symbol, 674 window=window, 675 normalize=normalize, 676 trainable=trainable, 677 dtype=dtype, 678 **kwargs) 679 680 @property 681 def _coefficients_source(self): 682 h = self._sinc(self.sampling_times, 683 1.0) 684 h = tf.constant(h, self.dtype) 685 return h 686 687 def _sinc(self, t, symbol_duration): 688 """Sinc filter""" 689 return 1/symbol_duration*np.sinc(t/symbol_duration) 690 691 692 class CustomFilter(Filter): 693 # pylint: disable=line-too-long 694 r"""CustomFilter(span_in_symbols=None, samples_per_symbol=None, coefficients=None, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs) 695 696 Layer for applying a custom filter of ``length`` K 697 to an input ``x`` of length N. 698 699 The filter length K is equal to the filter span in symbols (``span_in_symbols``) 700 multiplied by the oversampling factor (``samples_per_symbol``). 701 If this product is even, a value of one will be added. 702 703 The filter is applied through discrete convolution. 704 705 An optional windowing function ``window`` can be applied to the filter. 706 707 The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`. 708 Otherwise, the dtype of the output is `tf.complex`. 709 710 Three padding modes are available for applying the filter: 711 712 * "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter. 713 The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to 714 compute the convolution at the borders. 715 * "same": Returns an output of the same length as the input ``x``. The convolution is computed such 716 that the coefficients of the input ``x`` are centered on the coefficient of the filter with index 717 (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders. 718 * "valid": Returns the convolution only at points where ``x`` and the filter completely overlap. 719 The length of the output is N - K + 1. 720 721 Parameters 722 ---------- 723 span_in_symbols: int 724 Filter span as measured by the number of symbols. 725 Only needs to be provided if ``coefficients`` is None. 726 727 samples_per_symbol: int 728 Number of samples per symbol, i.e., the oversampling factor. 729 Must always be provided. 730 731 coefficients: [K], tf.float or tf.complex 732 Optional filter coefficients. 733 If set to `None`, then a random filter of K is generated 734 by sampling a Gaussian distribution. Defaults to `None`. 735 736 window: Window or string (["hann", "hamming", "blackman"]) 737 Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients. 738 Alternatively, a string indicating the window name can be provided. In this case, 739 the chosen window will be instantiated with the default parameters. Custom windows 740 must be provided as instance. 741 742 normalize: bool 743 If `True`, the filter is normalized to have unit power. 744 Defaults to `True`. 745 746 trainable: bool 747 If `True`, the filter coefficients are trainable variables. 748 Defaults to `False`. 749 750 dtype: tf.DType 751 The `dtype` of the filter coefficients. 752 Defaults to `tf.float32`. 753 754 Input 755 ----- 756 x : [..., N], tf.complex or tf.float 757 The input to which the filter is applied. 758 The filter is applied along the last dimension. 759 760 padding : string (["full", "valid", "same"]) 761 Padding mode for convolving ``x`` and the filter. 762 Must be one of "full", "valid", or "same". Case insensitive. 763 Defaults to "full". 764 765 conjugate : bool 766 If `True`, the complex conjugate of the filter is applied. 767 Defaults to `False`. 768 769 Output 770 ------ 771 y : [...,M], tf.complex or tf.float 772 Filtered input. 773 It is `tf.float` only if both ``x`` and the filter are `tf.float`. 774 It is `tf.complex` otherwise. 775 The length M depends on the ``padding``. 776 """ 777 def __init__(self, 778 span_in_symbols=None, 779 samples_per_symbol=None, 780 coefficients=None, 781 window=None, 782 normalize=True, 783 trainable=False, 784 dtype=tf.float32, 785 **kwargs): 786 787 assert samples_per_symbol is not None and samples_per_symbol>0, \ 788 "samples_per_symbol must be positive" 789 self._samples_per_symbol = samples_per_symbol 790 791 if coefficients is None: 792 assert span_in_symbols is not None and span_in_symbols>0, \ 793 "span_in_symbols must be positive" 794 self._span_in_symbols = span_in_symbols 795 796 if coefficients is not None: 797 l = coefficients.shape[-1] 798 assert l%2==1, \ 799 "The number of coefficients must be odd" 800 self._span_in_symbols = l//self._samples_per_symbol 801 else: 802 if dtype.is_complex: 803 h = RandomNormal()([2, self.length], dtype.real_dtype) 804 coefficients = tf.complex(h[0], h[1]) 805 else: 806 coefficients = RandomNormal()([self.length], dtype) 807 808 # Coefficients setter initialises coefficients properly 809 self._h = tf.constant(coefficients, dtype) 810 811 super().__init__(self._span_in_symbols, 812 self._samples_per_symbol, 813 window=window, 814 normalize=normalize, 815 trainable=trainable, 816 dtype=dtype, 817 **kwargs) 818 819 @property 820 def _coefficients_source(self): 821 return self._h