anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

filter.py (31796B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers implementing filters"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 from tensorflow.keras.initializers import RandomNormal
     10 from abc import ABC, abstractmethod
     11 import matplotlib.pyplot as plt
     12 import numpy as np
     13 from . import convolve, Window, HannWindow, HammingWindow, BlackmanWindow, empirical_aclr
     14 
     15 
     16 class Filter(ABC, Layer):
     17     # pylint: disable=line-too-long
     18     r"""Filter(span_in_symbols, samples_per_symbol, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs)
     19 
     20     This is an abtract class for defining a filter of ``length`` K which can be
     21     applied to an input ``x`` of length N.
     22 
     23     The filter length K is equal to the filter span in symbols (``span_in_symbols``)
     24     multiplied by the oversampling factor (``samples_per_symbol``).
     25     If this product is even, a value of one will be added.
     26 
     27     The filter is applied through discrete convolution.
     28 
     29     An optional windowing function ``window`` can be applied to the filter.
     30 
     31     The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`.
     32     Otherwise, the dtype of the output is `tf.complex`.
     33 
     34     Three padding modes are available for applying the filter:
     35 
     36     *   "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter.
     37         The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to
     38         compute the convolution at the borders.
     39     *   "same": Returns an output of the same length as the input ``x``. The convolution is computed such
     40         that the coefficients of the input ``x`` are centered on the coefficient of the filter with index
     41         (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders.
     42     *   "valid": Returns the convolution only at points where ``x`` and the filter completely overlap.
     43         The length of the output is N - K + 1.
     44 
     45     Parameters
     46     ----------
     47     span_in_symbols: int
     48         Filter span as measured by the number of symbols.
     49 
     50     samples_per_symbol: int
     51         Number of samples per symbol, i.e., the oversampling factor.
     52 
     53     window: Window or string (["hann", "hamming", "blackman"])
     54         Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients.
     55         Alternatively, a string indicating the window name can be provided. In this case,
     56         the chosen window will be instantiated with the default parameters. Custom windows
     57         must be provided as instance.
     58 
     59     normalize: bool
     60         If `True`, the filter is normalized to have unit power.
     61         Defaults to `True`.
     62 
     63     trainable: bool
     64         If `True`, the filter coefficients are trainable.
     65         Defaults to `False`.
     66 
     67     dtype: tf.DType
     68         The `dtype` of the filter coefficients.
     69         Defaults to `tf.float32`.
     70 
     71     Input
     72     -----
     73     x : [..., N], tf.complex or tf.float
     74         The input to which the filter is applied.
     75         The filter is applied along the last dimension.
     76 
     77     padding : string (["full", "valid", "same"])
     78         Padding mode for convolving ``x`` and the filter.
     79         Must be one of "full", "valid", or "same". Case insensitive.
     80         Defaults to "full".
     81 
     82     conjugate : bool
     83         If `True`, the complex conjugate of the filter is applied.
     84         Defaults to `False`.
     85 
     86     Output
     87     ------
     88     y : [...,M], tf.complex or tf.float
     89         Filtered input.
     90         It is `tf.float` only if both ``x`` and the filter are `tf.float`.
     91         It is `tf.complex` otherwise.
     92         The length M depends on the ``padding``.
     93     """
     94     def __init__(self,
     95                  span_in_symbols,
     96                  samples_per_symbol,
     97                  window=None,
     98                  normalize=True,
     99                  trainable=False,
    100                  dtype=tf.float32,
    101                  **kwargs):
    102         super().__init__(dtype=dtype, **kwargs)
    103 
    104         assert span_in_symbols>0, "span_in_symbols must be positive"
    105         self._span_in_symbols = span_in_symbols
    106 
    107         assert samples_per_symbol>0, "samples_per_symbol must be positive"
    108         self._samples_per_symbol = samples_per_symbol
    109 
    110         self.window = window
    111 
    112         assert isinstance(normalize, bool), "normalize must be bool"
    113         self._normalize = normalize
    114 
    115         assert isinstance(trainable, bool), "trainable must be bool"
    116         self._trainable = trainable
    117 
    118         assert self.length==self._coefficients_source.shape[-1], \
    119             "The number of coefficients must match the filter length."
    120 
    121         dtype = tf.as_dtype(self._dtype)
    122         if dtype.is_floating:
    123             self._coefficients = tf.Variable(self._coefficients_source,
    124                                                 trainable=self.trainable)
    125         elif dtype.is_complex:
    126             c = self._coefficients_source
    127             self._coefficients = [  tf.Variable(tf.math.real(c),
    128                                                 trainable=self.trainable),
    129                                     tf.Variable(tf.math.imag(c),
    130                                                 trainable=self.trainable)]
    131 
    132     @property
    133     def length(self):
    134         """The filter length in samples"""
    135         l = self._span_in_symbols*self._samples_per_symbol
    136         l = 2*(l//2)+1 # Force length to be the next odd number
    137         return l
    138 
    139     @property
    140     def window(self):
    141         """The window function that is applied to the filter coefficients. `None` if no window is applied."""
    142         return self._window
    143 
    144     @window.setter
    145     def window(self, value):
    146         if isinstance(value, str):
    147             if value=="hann":
    148                 self._window = HannWindow(self.length)
    149             elif value=="hamming":
    150                 self._window = HammingWindow(self.length)
    151             elif value=="blackman":
    152                 self._window = BlackmanWindow(self.length)
    153             else:
    154                 raise AssertionError("Invalid window type")
    155         elif isinstance(value, Window) or value is None:
    156             self._window = value
    157         else:
    158             raise AssertionError("Invalid window type")
    159 
    160     @property
    161     def normalize(self):
    162         """`True` if the filter is normalized to have unit power. `False` otherwise."""
    163         return self._normalize
    164 
    165     @property
    166     def trainable(self):
    167         """`True` if the filter coefficients are trainable. `False` otherwise."""
    168         return self._trainable
    169 
    170     @property
    171     @abstractmethod
    172     def _coefficients_source(self):
    173         """Internal property that returns the (unormalized) filter coefficients.
    174         Concrete classes that inherits from this one must implement this
    175         property."""
    176         pass
    177 
    178     @property
    179     def coefficients(self):
    180         """The filter coefficients (after normalization)"""
    181         h = self._coefficients
    182         dtype = tf.as_dtype(self.dtype)
    183 
    184         # Combine both real dimensions to complex if necessary
    185         if dtype.is_complex:
    186             h = tf.complex(h[0], h[1])
    187 
    188         # Apply window
    189         if self.window is not None:
    190             h = self._window(h)
    191 
    192         # Ensure unit L2-norm of the coefficients
    193         if self.normalize:
    194             energy = tf.reduce_sum(tf.square(tf.abs(h)))
    195             h = h / tf.cast(tf.sqrt(energy), h.dtype)
    196 
    197         return h
    198 
    199     @property
    200     def sampling_times(self):
    201         """Sampling times in multiples of the symbol duration"""
    202         n_min = -(self.length//2)
    203         n_max = n_min + self.length
    204         t = np.arange(n_min, n_max, dtype=np.float32)
    205         t /= self._samples_per_symbol
    206         return t
    207 
    208     def show(self, response="impulse", scale="lin"):
    209         r"""Plot the impulse or magnitude response
    210 
    211         Plots the impulse response (time domain) or magnitude response
    212         (frequency domain) of the filter.
    213 
    214         For the computation of the magnitude response, a minimum DFT size
    215         of 1024 is assumed which is obtained through zero padding of
    216         the filter coefficients in the time domain.
    217 
    218         Input
    219         -----
    220         response: str, one of ["impulse", "magnitude"]
    221             The desired response type.
    222             Defaults to "impulse"
    223 
    224         scale: str, one of ["lin", "db"]
    225             The y-scale of the magnitude response.
    226             Can be "lin" (i.e., linear) or "db" (, i.e., Decibel).
    227             Defaults to "lin".
    228         """
    229         assert response in ["impulse", "magnitude"], "Invalid response"
    230         if response=="impulse":
    231             dtype = tf.as_dtype(self.dtype)
    232             plt.figure(figsize=(12,6))
    233             plt.plot(self.sampling_times, np.real(self.coefficients))
    234             if dtype.is_complex:
    235                 plt.plot(self.sampling_times, np.imag(self.coefficients))
    236                 plt.legend(["Real part", "Imaginary part"])
    237             plt.title("Impulse response")
    238             plt.grid()
    239             plt.xlabel(r"Normalized time $(t/T)$")
    240             plt.ylabel(r"$h(t)$")
    241             plt.xlim(self.sampling_times[0], self.sampling_times[-1])
    242 
    243         else:
    244             assert scale in ["lin", "db"], "Invalid scale"
    245             fft_size = max(1024, self.coefficients.shape[-1])
    246             h = np.fft.fft(self.coefficients, fft_size)
    247             h = np.fft.fftshift(h)
    248             h = np.abs(h)
    249             plt.figure(figsize=(12,6))
    250             if scale=="db":
    251                 h = np.maximum(h, 1e-10)
    252                 h = 10*np.log10(h)
    253                 plt.ylabel(r"$|H(f)|$ (dB)")
    254             else:
    255                 plt.ylabel(r"$|H(f)|$")
    256             f = np.linspace(-self._samples_per_symbol/2,
    257                             self._samples_per_symbol/2, fft_size)
    258             plt.plot(f, h)
    259             plt.title("Magnitude response")
    260             plt.grid()
    261             plt.xlabel(r"Normalized frequency $(f/W)$")
    262             plt.xlim(f[0], f[-1])
    263 
    264     @property
    265     def aclr(self):
    266         """ACLR of the filter
    267 
    268         This ACLR corresponds to what one would obtain from using
    269         this filter as pulse shaping filter on an i.i.d. sequence of symbols.
    270         The in-band is assumed to range from [-0.5, 0.5] in normalized
    271         frequency.
    272         """
    273         fft_size = 1024
    274         n = fft_size - tf.shape(self.coefficients)[-1]
    275         z = tf.zeros([n], self.coefficients.dtype)
    276         c = tf.cast(tf.concat([self.coefficients, z], -1), tf.complex64)
    277         return empirical_aclr(c, self._samples_per_symbol)
    278 
    279     def call(self, x, padding='full', conjugate=False):
    280         h = self.coefficients
    281         dtype = tf.as_dtype(self.dtype)
    282         if conjugate and dtype.is_complex:
    283             h = tf.math.conj(h)
    284         y = convolve(x,h,padding)
    285         return y
    286 
    287 
    288 class RaisedCosineFilter(Filter):
    289     # pylint: disable=line-too-long
    290     r"""RaisedCosineFilter(span_in_symbols, samples_per_symbol, beta, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs)
    291 
    292     Layer for applying a raised-cosine filter of ``length`` K
    293     to an input ``x`` of length N.
    294 
    295     The raised-cosine filter is defined by
    296 
    297     .. math::
    298         h(t) =
    299         \begin{cases}
    300         \frac{\pi}{4T} \text{sinc}\left(\frac{1}{2\beta}\right), & \text { if }t = \pm \frac{T}{2\beta}\\
    301         \frac{1}{T}\text{sinc}\left(\frac{t}{T}\right)\frac{\cos\left(\frac{\pi\beta t}{T}\right)}{1-\left(\frac{2\beta t}{T}\right)^2}, & \text{otherwise}
    302         \end{cases}
    303 
    304     where :math:`\beta` is the roll-off factor and :math:`T` the symbol duration.
    305 
    306     The filter length K is equal to the filter span in symbols (``span_in_symbols``)
    307     multiplied by the oversampling factor (``samples_per_symbol``).
    308     If this product is even, a value of one will be added.
    309 
    310     The filter is applied through discrete convolution.
    311 
    312     An optional windowing function ``window`` can be applied to the filter.
    313 
    314     The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`.
    315     Otherwise, the dtype of the output is `tf.complex`.
    316 
    317     Three padding modes are available for applying the filter:
    318 
    319     *   "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter.
    320         The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to
    321         compute the convolution at the borders.
    322     *   "same": Returns an output of the same length as the input ``x``. The convolution is computed such
    323         that the coefficients of the input ``x`` are centered on the coefficient of the filter with index
    324         (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders.
    325     *   "valid": Returns the convolution only at points where ``x`` and the filter completely overlap.
    326         The length of the output is N - K + 1.
    327 
    328     Parameters
    329     ----------
    330     span_in_symbols: int
    331         Filter span as measured by the number of symbols.
    332 
    333     samples_per_symbol: int
    334         Number of samples per symbol, i.e., the oversampling factor.
    335 
    336     beta : float
    337         Roll-off factor.
    338         Must be in the range :math:`[0,1]`.
    339 
    340     window: Window or string (["hann", "hamming", "blackman"])
    341         Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients.
    342         Alternatively, a string indicating the window name can be provided. In this case,
    343         the chosen window will be instantiated with the default parameters. Custom windows
    344         must be provided as instance.
    345 
    346     normalize: bool
    347         If `True`, the filter is normalized to have unit power.
    348         Defaults to `True`.
    349 
    350     trainable: bool
    351         If `True`, the filter coefficients are trainable variables.
    352         Defaults to `False`.
    353 
    354     dtype: tf.DType
    355         The `dtype` of the filter coefficients.
    356         Defaults to `tf.float32`.
    357 
    358     Input
    359     -----
    360     x : [..., N], tf.complex or tf.float
    361         The input to which the filter is applied.
    362         The filter is applied along the last dimension.
    363 
    364     padding : string (["full", "valid", "same"])
    365         Padding mode for convolving ``x`` and the filter.
    366         Must be one of "full", "valid", or "same".
    367         Defaults to "full".
    368 
    369     conjugate : bool
    370         If `True`, the complex conjugate of the filter is applied.
    371         Defaults to `False`.
    372 
    373     Output
    374     ------
    375     y : [...,M], tf.complex or tf.float
    376         Filtered input.
    377         It is `tf.float` only if both ``x`` and the filter are `tf.float`.
    378         It is `tf.complex` otherwise.
    379         The length M depends on the ``padding``.
    380     """
    381     def __init__(self,
    382                  span_in_symbols,
    383                  samples_per_symbol,
    384                  beta,
    385                  window=None,
    386                  normalize=True,
    387                  trainable=False,
    388                  dtype=tf.float32,
    389                  **kwargs):
    390 
    391         assert 0 <= beta <= 1, "beta must be from the intervall [0,1]"
    392         self._beta = beta
    393 
    394         super().__init__(span_in_symbols,
    395                          samples_per_symbol,
    396                          window=window,
    397                          normalize=normalize,
    398                          trainable=trainable,
    399                          dtype=dtype,
    400                          **kwargs)
    401 
    402     @property
    403     def beta(self):
    404         """Roll-off factor"""
    405         return self._beta
    406 
    407     @property
    408     def _coefficients_source(self):
    409         h = self._raised_cosine(self.sampling_times,
    410                                 1.0,
    411                                 self.beta)
    412         h = tf.constant(h, self.dtype)
    413         return h
    414 
    415     def _raised_cosine(self, t, symbol_duration, beta):
    416         """Raised-cosine filter from Wikipedia
    417         https://en.wikipedia.org/wiki/Raised-cosine_filter"""
    418         h = np.zeros([len(t)], np.float32)
    419         for i, tt in enumerate(t):
    420             tt = np.abs(tt)
    421             if beta>0 and (tt-np.abs(symbol_duration/2/beta)==0):
    422                 h[i] = np.pi/4/symbol_duration*np.sinc(1/2/beta)
    423             else:
    424                 h[i] = 1./symbol_duration*np.sinc(tt/symbol_duration)\
    425                     * np.cos(np.pi*beta*tt/symbol_duration)\
    426                     /(1-(2*beta*tt/symbol_duration)**2)
    427         return h
    428 
    429 
    430 class RootRaisedCosineFilter(Filter):
    431     # pylint: disable=line-too-long
    432     r"""RootRaisedCosineFilter(span_in_symbols, samples_per_symbol, beta, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs)
    433 
    434     Layer for applying a root-raised-cosine filter of ``length`` K
    435     to an input ``x`` of length N.
    436 
    437     The root-raised-cosine filter is defined by
    438 
    439     .. math::
    440         h(t) =
    441         \begin{cases}
    442         \frac{1}{T} \left(1 + \beta\left(\frac{4}{\pi}-1\right) \right), & \text { if }t = 0\\
    443         \frac{\beta}{T\sqrt{2}} \left[ \left(1+\frac{2}{\pi}\right)\sin\left(\frac{\pi}{4\beta}\right) + \left(1-\frac{2}{\pi}\right)\cos\left(\frac{\pi}{4\beta}\right) \right], & \text { if }t = \pm\frac{T}{4\beta} \\
    444         \frac{1}{T} \frac{\sin\left(\pi\frac{t}{T}(1-\beta)\right) + 4\beta\frac{t}{T}\cos\left(\pi\frac{t}{T}(1+\beta)\right)}{\pi\frac{t}{T}\left(1-\left(4\beta\frac{t}{T}\right)^2\right)}, & \text { otherwise}
    445         \end{cases}
    446 
    447     where :math:`\beta` is the roll-off factor and :math:`T` the symbol duration.
    448 
    449     The filter length K is equal to the filter span in symbols (``span_in_symbols``)
    450     multiplied by the oversampling factor (``samples_per_symbol``).
    451     If this product is even, a value of one will be added.
    452 
    453     The filter is applied through discrete convolution.
    454 
    455     An optional windowing function ``window`` can be applied to the filter.
    456 
    457     The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`.
    458     Otherwise, the dtype of the output is `tf.complex`.
    459 
    460     Three padding modes are available for applying the filter:
    461 
    462     *   "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter.
    463         The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to
    464         compute the convolution at the borders.
    465     *   "same": Returns an output of the same length as the input ``x``. The convolution is computed such
    466         that the coefficients of the input ``x`` are centered on the coefficient of the filter with index
    467         (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders.
    468     *   "valid": Returns the convolution only at points where ``x`` and the filter completely overlap.
    469         The length of the output is N - K + 1.
    470 
    471     Parameters
    472     ----------
    473     span_in_symbols: int
    474         Filter span as measured by the number of symbols.
    475 
    476     samples_per_symbol: int
    477         Number of samples per symbol, i.e., the oversampling factor.
    478 
    479     beta : float
    480         Roll-off factor.
    481         Must be in the range :math:`[0,1]`.
    482 
    483     window: Window or string (["hann", "hamming", "blackman"])
    484         Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients.
    485         Alternatively, a string indicating the window name can be provided. In this case,
    486         the chosen window will be instantiated with the default parameters. Custom windows
    487         must be provided as instance.
    488 
    489     normalize: bool
    490         If `True`, the filter is normalized to have unit power.
    491         Defaults to `True`.
    492 
    493     trainable: bool
    494         If `True`, the filter coefficients are trainable variables.
    495         Defaults to `False`.
    496 
    497     dtype: tf.DType
    498         The `dtype` of the filter coefficients.
    499         Defaults to `tf.float32`.
    500 
    501     Input
    502     -----
    503     x : [..., N], tf.complex or tf.float
    504         The input to which the filter is applied.
    505         The filter is applied along the last dimension.
    506 
    507     padding : string (["full", "valid", "same"])
    508         Padding mode for convolving ``x`` and the filter.
    509         Must be one of "full", "valid", or "same". Case insensitive.
    510         Defaults to "full".
    511 
    512     conjugate : bool
    513         If `True`, the complex conjugate of the filter is applied.
    514         Defaults to `False`.
    515 
    516     Output
    517     ------
    518     y : [...,M], tf.complex or tf.float
    519         Filtered input.
    520         It is `tf.float` only if both ``x`` and the filter are `tf.float`.
    521         It is `tf.complex` otherwise.
    522         The length M depends on the ``padding``.
    523     """
    524     def __init__(self,
    525                  span_in_symbols,
    526                  samples_per_symbol,
    527                  beta,
    528                  window=None,
    529                  normalize=True,
    530                  trainable=False,
    531                  dtype=tf.float32,
    532                  **kwargs):
    533 
    534         assert 0 <= beta <= 1, "beta must be from the intervall [0,1]"
    535         self._beta = beta
    536 
    537         super().__init__(span_in_symbols,
    538                          samples_per_symbol,
    539                          window=window,
    540                          normalize=normalize,
    541                          trainable=trainable,
    542                          dtype=dtype,
    543                          **kwargs)
    544 
    545     @property
    546     def beta(self):
    547         """Roll-off factor"""
    548         return self._beta
    549 
    550     @property
    551     def _coefficients_source(self):
    552         h = self._root_raised_cosine(self.sampling_times,
    553                                      1.0,
    554                                      self.beta)
    555         h = tf.constant(h, self.dtype)
    556         return h
    557 
    558     def _root_raised_cosine(self, t, symbol_duration, beta):
    559         """Root-raised-cosine filter from Wikipedia
    560             https://en.wikipedia.org/wiki/Root-raised-cosine_filter"""
    561         h = np.zeros([len(t)], np.float32)
    562         for i, tt in enumerate(t):
    563             tt = np.abs(tt)
    564             if tt==0:
    565                 h[i] = 1/symbol_duration*(1+beta*(4/np.pi-1))
    566             elif beta>0 and (tt-np.abs(symbol_duration/4/beta)==0):
    567                 h[i] = beta/symbol_duration/np.sqrt(2)\
    568                     * ((1+2/np.pi)*np.sin(np.pi/4/beta) + \
    569                                             (1-2/np.pi)*np.cos(np.pi/4/beta))
    570             else:
    571                 h[i] = 1/symbol_duration\
    572                 / (np.pi*tt/symbol_duration*(1-(4*beta*tt/symbol_duration)**2))\
    573                 * (np.sin(np.pi*tt/symbol_duration*(1-beta)) + \
    574                 4*beta*tt/symbol_duration\
    575                 *np.cos(np.pi*tt/symbol_duration*(1+beta)))
    576         return h
    577 
    578 
    579 class SincFilter(Filter):
    580     # pylint: disable=line-too-long
    581     r"""SincFilter(span_in_symbols, samples_per_symbol, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs)
    582 
    583     Layer for applying a sinc filter of ``length`` K
    584     to an input ``x`` of length N.
    585 
    586     The sinc filter is defined by
    587 
    588     .. math::
    589         h(t) = \frac{1}{T}\text{sinc}\left(\frac{t}{T}\right)
    590 
    591     where :math:`T` the symbol duration.
    592 
    593     The filter length K is equal to the filter span in symbols (``span_in_symbols``)
    594     multiplied by the oversampling factor (``samples_per_symbol``).
    595     If this product is even, a value of one will be added.
    596 
    597     The filter is applied through discrete convolution.
    598 
    599     An optional windowing function ``window`` can be applied to the filter.
    600 
    601     The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`.
    602     Otherwise, the dtype of the output is `tf.complex`.
    603 
    604     Three padding modes are available for applying the filter:
    605 
    606     *   "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter.
    607         The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to
    608         compute the convolution at the borders.
    609     *   "same": Returns an output of the same length as the input ``x``. The convolution is computed such
    610         that the coefficients of the input ``x`` are centered on the coefficient of the filter with index
    611         (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders.
    612     *   "valid": Returns the convolution only at points where ``x`` and the filter completely overlap.
    613         The length of the output is N - K + 1.
    614 
    615     Parameters
    616     ----------
    617     span_in_symbols: int
    618         Filter span as measured by the number of symbols.
    619 
    620     samples_per_symbol: int
    621         Number of samples per symbol, i.e., the oversampling factor.
    622 
    623     window: Window or string (["hann", "hamming", "blackman"])
    624         Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients.
    625         Alternatively, a string indicating the window name can be provided. In this case,
    626         the chosen window will be instantiated with the default parameters. Custom windows
    627         must be provided as instance.
    628 
    629     normalize: bool
    630         If `True`, the filter is normalized to have unit power.
    631         Defaults to `True`.
    632 
    633     trainable: bool
    634         If `True`, the filter coefficients are trainable variables.
    635         Defaults to `False`.
    636 
    637     dtype: tf.DType
    638         The `dtype` of the filter coefficients.
    639         Defaults to `tf.float32`.
    640 
    641     Input
    642     -----
    643     x : [..., N], tf.complex or tf.float
    644         The input to which the filter is applied.
    645         The filter is applied along the last dimension.
    646 
    647     padding : string (["full", "valid", "same"])
    648         Padding mode for convolving ``x`` and the filter.
    649         Must be one of "full", "valid", or "same". Case insensitive.
    650         Defaults to "full".
    651 
    652     conjugate : bool
    653         If `True`, the complex conjugate of the filter is applied.
    654         Defaults to `False`.
    655 
    656     Output
    657     ------
    658     y : [...,M], tf.complex or tf.float
    659         Filtered input.
    660         It is `tf.float` only if both ``x`` and the filter are `tf.float`.
    661         It is `tf.complex` otherwise.
    662         The length M depends on the ``padding``.
    663     """
    664     def __init__(self,
    665                  span_in_symbols,
    666                  samples_per_symbol,
    667                  window=None,
    668                  normalize=True,
    669                  trainable=False,
    670                  dtype=tf.float32,
    671                  **kwargs):
    672         super().__init__(span_in_symbols,
    673                          samples_per_symbol,
    674                          window=window,
    675                          normalize=normalize,
    676                          trainable=trainable,
    677                          dtype=dtype,
    678                          **kwargs)
    679 
    680     @property
    681     def _coefficients_source(self):
    682         h = self._sinc(self.sampling_times,
    683                        1.0)
    684         h = tf.constant(h, self.dtype)
    685         return h
    686 
    687     def _sinc(self, t, symbol_duration):
    688         """Sinc filter"""
    689         return 1/symbol_duration*np.sinc(t/symbol_duration)
    690 
    691 
    692 class CustomFilter(Filter):
    693     # pylint: disable=line-too-long
    694     r"""CustomFilter(span_in_symbols=None, samples_per_symbol=None, coefficients=None, window=None, normalize=True, trainable=False, dtype=tf.float32, **kwargs)
    695 
    696     Layer for applying a custom filter of ``length`` K
    697     to an input ``x`` of length N.
    698 
    699     The filter length K is equal to the filter span in symbols (``span_in_symbols``)
    700     multiplied by the oversampling factor (``samples_per_symbol``).
    701     If this product is even, a value of one will be added.
    702 
    703     The filter is applied through discrete convolution.
    704 
    705     An optional windowing function ``window`` can be applied to the filter.
    706 
    707     The `dtype` of the output is `tf.float` if both ``x`` and the filter coefficients have dtype `tf.float`.
    708     Otherwise, the dtype of the output is `tf.complex`.
    709 
    710     Three padding modes are available for applying the filter:
    711 
    712     *   "full" (default): Returns the convolution at each point of overlap between ``x`` and the filter.
    713         The length of the output is N + K - 1. Zero-padding of the input ``x`` is performed to
    714         compute the convolution at the borders.
    715     *   "same": Returns an output of the same length as the input ``x``. The convolution is computed such
    716         that the coefficients of the input ``x`` are centered on the coefficient of the filter with index
    717         (K-1)/2. Zero-padding of the input signal is performed to compute the convolution at the borders.
    718     *   "valid": Returns the convolution only at points where ``x`` and the filter completely overlap.
    719         The length of the output is N - K + 1.
    720 
    721     Parameters
    722     ----------
    723     span_in_symbols: int
    724         Filter span as measured by the number of symbols.
    725         Only needs to be provided if ``coefficients`` is None.
    726 
    727     samples_per_symbol: int
    728         Number of samples per symbol, i.e., the oversampling factor.
    729         Must always be provided.
    730 
    731     coefficients: [K], tf.float or tf.complex
    732         Optional filter coefficients.
    733         If set to `None`, then a random filter of K is generated
    734         by sampling a Gaussian distribution. Defaults to `None`.
    735 
    736     window: Window or string (["hann", "hamming", "blackman"])
    737         Instance of :class:`~sionna.signal.Window` that is applied to the filter coefficients.
    738         Alternatively, a string indicating the window name can be provided. In this case,
    739         the chosen window will be instantiated with the default parameters. Custom windows
    740         must be provided as instance.
    741 
    742     normalize: bool
    743         If `True`, the filter is normalized to have unit power.
    744         Defaults to `True`.
    745 
    746     trainable: bool
    747         If `True`, the filter coefficients are trainable variables.
    748         Defaults to `False`.
    749 
    750     dtype: tf.DType
    751         The `dtype` of the filter coefficients.
    752         Defaults to `tf.float32`.
    753 
    754     Input
    755     -----
    756     x : [..., N], tf.complex or tf.float
    757         The input to which the filter is applied.
    758         The filter is applied along the last dimension.
    759 
    760     padding : string (["full", "valid", "same"])
    761         Padding mode for convolving ``x`` and the filter.
    762         Must be one of "full", "valid", or "same". Case insensitive.
    763         Defaults to "full".
    764 
    765     conjugate : bool
    766         If `True`, the complex conjugate of the filter is applied.
    767         Defaults to `False`.
    768 
    769     Output
    770     ------
    771     y : [...,M], tf.complex or tf.float
    772         Filtered input.
    773         It is `tf.float` only if both ``x`` and the filter are `tf.float`.
    774         It is `tf.complex` otherwise.
    775         The length M depends on the ``padding``.
    776     """
    777     def __init__(self,
    778                  span_in_symbols=None,
    779                  samples_per_symbol=None,
    780                  coefficients=None,
    781                  window=None,
    782                  normalize=True,
    783                  trainable=False,
    784                  dtype=tf.float32,
    785                  **kwargs):
    786 
    787         assert samples_per_symbol is not None and samples_per_symbol>0, \
    788         "samples_per_symbol must be positive"
    789         self._samples_per_symbol = samples_per_symbol
    790 
    791         if coefficients is None:
    792             assert span_in_symbols is not None and span_in_symbols>0, \
    793                 "span_in_symbols must be positive"
    794             self._span_in_symbols = span_in_symbols
    795 
    796         if coefficients is not None:
    797             l = coefficients.shape[-1]
    798             assert l%2==1, \
    799                 "The number of coefficients must be odd"
    800             self._span_in_symbols = l//self._samples_per_symbol
    801         else:
    802             if dtype.is_complex:
    803                 h = RandomNormal()([2, self.length], dtype.real_dtype)
    804                 coefficients = tf.complex(h[0], h[1])
    805             else:
    806                 coefficients = RandomNormal()([self.length], dtype)
    807 
    808         # Coefficients setter initialises coefficients properly
    809         self._h = tf.constant(coefficients, dtype)
    810 
    811         super().__init__(self._span_in_symbols,
    812                          self._samples_per_symbol,
    813                          window=window,
    814                          normalize=normalize,
    815                          trainable=trainable,
    816                          dtype=dtype,
    817                          **kwargs)
    818 
    819     @property
    820     def _coefficients_source(self):
    821         return self._h