anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

demodulator.py (7889B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition for the OFDM Demodulator"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 from tensorflow.signal import fftshift
     10 from sionna.constants import PI
     11 from sionna.utils import expand_to_rank
     12 from sionna.signal import fft
     13 import numpy as np
     14 
     15 class OFDMDemodulator(Layer):
     16     # pylint: disable=line-too-long
     17     r"""
     18     OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs)
     19 
     20     Computes the frequency-domain representation of an OFDM waveform
     21     with cyclic prefix removal.
     22 
     23     The demodulator assumes that the input sequence is generated by the
     24     :class:`~sionna.channel.TimeChannel`. For a single pair of antennas,
     25     the received signal sequence is given as:
     26 
     27     .. math::
     28 
     29         y_b = \sum_{\ell =L_\text{min}}^{L_\text{max}} \bar{h}_\ell x_{b-\ell} + w_b, \quad b \in[L_\text{min}, N_B+L_\text{max}-1]
     30 
     31     where :math:`\bar{h}_\ell` are the discrete-time channel taps,
     32     :math:`x_{b}` is the the transmitted signal,
     33     and :math:`w_\ell` Gaussian noise.
     34 
     35     Starting from the first symbol, the demodulator cuts the input
     36     sequence into pieces of size ``cyclic_prefix_length + fft_size``,
     37     and throws away any trailing symbols. For each piece, the cyclic
     38     prefix is removed and the ``fft_size``-point discrete Fourier
     39     transform is computed. It is also possible that every OFDM symbol
     40     has a cyclic prefix of different length.
     41 
     42     Since the input sequence starts at time :math:`L_\text{min}`,
     43     the FFT-window has a timing offset of :math:`L_\text{min}` symbols,
     44     which leads to a subcarrier-dependent phase shift of
     45     :math:`e^{\frac{j2\pi k L_\text{min}}{N}}`, where :math:`k`
     46     is the subcarrier index, :math:`N` is the FFT size,
     47     and :math:`L_\text{min} \le 0` is the largest negative time lag of
     48     the discrete-time channel impulse response. This phase shift
     49     is removed in this layer, by explicitly multiplying
     50     each subcarrier by  :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`.
     51     This is a very important step to enable channel estimation with
     52     sparse pilot patterns that needs to interpolate the channel frequency
     53     response accross subcarriers. It also ensures that the
     54     channel frequency response `seen` by the time-domain channel
     55     is close to the :class:`~sionna.channel.OFDMChannel`.
     56 
     57     Parameters
     58     ----------
     59     fft_size : int
     60         FFT size (, i.e., the number of subcarriers).
     61 
     62     l_min : int
     63         The largest negative time lag of the discrete-time channel
     64         impulse response. It should be the same value as that used by the
     65         `cir_to_time_channel` function.
     66 
     67     cyclic_prefix_length : scalar or [num_ofdm_symbols], int
     68         Integer or vector of integers indicating the length of the
     69         cyclic prefix that is prepended to each OFDM symbol. None of its
     70         elements can be larger than the FFT size.
     71         Defaults to 0.
     72 
     73     Input
     74     -----
     75     :[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)+n] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)+n], tf.complex
     76         Tensor containing the time-domain signal along the last dimension.
     77         `n` is a nonnegative integer.
     78 
     79     Output
     80     ------
     81     :[...,num_ofdm_symbols,fft_size], tf.complex
     82         Tensor containing the OFDM resource grid along the last
     83         two dimension.
     84     """
     85 
     86     def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs):
     87         super().__init__(**kwargs)
     88         self._fft_size = None
     89         self._l_min = None
     90         self._cyclic_prefix_length = None
     91         self.fft_size = fft_size
     92         self.l_min = l_min
     93         self.cyclic_prefix_length = cyclic_prefix_length
     94 
     95     @property
     96     def fft_size(self):
     97         return self._fft_size
     98 
     99     @fft_size.setter
    100     def fft_size(self, value):
    101         assert value>0, "`fft_size` must be positive."
    102         self._fft_size = int(value)
    103 
    104     @property
    105     def l_min(self):
    106         return self._l_min
    107 
    108     @l_min.setter
    109     def l_min(self, value):
    110         assert value<=0, "l_min must be nonpositive."
    111         self._l_min = int(value)
    112 
    113     @property
    114     def cyclic_prefix_length(self):
    115         return self._cyclic_prefix_length
    116 
    117     @cyclic_prefix_length.setter
    118     def cyclic_prefix_length(self, value):
    119         value = tf.cast(value, tf.int32)
    120         if not tf.reduce_all(value>=0):
    121             msg = "`cyclic_prefix_length` must be nonnegative."
    122             raise ValueError(msg)
    123         if not 0<= tf.rank(value)<=1:
    124             msg = "`cyclic_prefix_length` must be of rank 0 or 1"
    125             raise ValueError(msg)
    126         self._cyclic_prefix_length = value
    127 
    128     def build(self, input_shape): # pylint: disable=unused-argument
    129         # Compute phase correction terms to to channel
    130         tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \
    131               / tf.cast(self.fft_size, tf.float32) \
    132               * tf.range(self.fft_size, dtype=tf.float32)
    133         self._phase_compensation = tf.exp(tf.complex(0., tmp))
    134 
    135         if len(self.cyclic_prefix_length.shape)==0:
    136             # Compute number of elements that will be truncated
    137             self._rest = np.mod(input_shape[-1],
    138                                     self.fft_size + self.cyclic_prefix_length)
    139 
    140             # Compute number of full OFDM symbols to be demodulated
    141             self._num_ofdm_symbols = np.floor_divide(
    142                                     input_shape[-1]-self._rest,
    143                                     self.fft_size + self.cyclic_prefix_length)
    144         else:
    145             # Deal with individual cp lengths for OFDM symbols
    146             # Compute the relevant indices to gather for
    147             # every OFDM symbol from the time domain input
    148             num_ofdm_symbols = self.cyclic_prefix_length.shape[0]
    149             row_lengths = self.cyclic_prefix_length + self.fft_size
    150             offsets = tf.math.cumsum(tf.concat([[0], row_lengths],
    151                                                axis=0)[:-1])
    152             offsets = tf.expand_dims(offsets, 1)
    153             ind = tf.repeat(tf.range(start=0,
    154                                      limit=self.fft_size)[tf.newaxis,:],
    155                             repeats=num_ofdm_symbols, axis=0)
    156             ind += self.cyclic_prefix_length[:, tf.newaxis]
    157             ind += offsets
    158             # [num_ofdm_symbols, fft_size]
    159             self._ind = ind
    160 
    161     def call(self, inputs):
    162         """Demodulate OFDM waveform onto a resource grid.
    163 
    164         Args:
    165             inputs (tf.complex64):
    166                 `[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)]`.
    167 
    168         Returns:
    169             `tf.complex64` : The demodulated inputs of shape
    170             `[...,num_ofdm_symbols, fft_size]`.
    171         """
    172         if len(self.cyclic_prefix_length.shape)==0:
    173             # Same CP length for all OFDM symbols
    174             # Cut last samples that do not fit into an OFDM symbol
    175             inputs = inputs if self._rest==0 else inputs[...,:-self._rest]
    176 
    177             # Reshape input to separate OFDM symbols
    178             new_shape = tf.concat(
    179                             [tf.shape(inputs)[:-1],
    180                             [self._num_ofdm_symbols],
    181                             [self.fft_size + self.cyclic_prefix_length]], 0)
    182             x = tf.reshape(inputs, new_shape)
    183 
    184             # Remove cyclic prefix
    185             x = x[...,self.cyclic_prefix_length:]
    186 
    187         else:
    188             # Individual CP length for OFDM symbols
    189             x = tf.gather(inputs, self._ind, axis=-1)
    190 
    191         # Compute FFT
    192         x = fft(x)
    193 
    194         # Apply phase shift compensation to all subcarriers
    195         rot = tf.cast(self._phase_compensation, x.dtype)
    196         rot = expand_to_rank(rot, tf.rank(x), 0)
    197         x = x * rot
    198 
    199         # Shift DC subcarrier to the middle
    200         x = fftshift(x, axes=-1)
    201 
    202         return x