anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

modulator.py (4423B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition for the OFDM Modulator"""
      6 
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 from tensorflow.signal import ifftshift
     10 from sionna.utils import flatten_last_dims
     11 from sionna.signal import ifft
     12 
     13 
     14 class OFDMModulator(Layer):
     15     # pylint: disable=line-too-long
     16     """
     17     OFDMModulator(cyclic_prefix_length=0, **kwargs)
     18 
     19     Computes the time-domain representation of an OFDM resource grid
     20     with (optional) cyclic prefix
     21 
     22     Parameters
     23     ----------
     24     cyclic_prefix_length : scalar or [num_ofdm_symbols], int
     25         Integer or vector of integers indicating the length of the
     26         cyclic prefix that is prepended to each OFDM symbol. None of its
     27         elements can be larger than the FFT size.
     28         Defaults to 0.
     29 
     30     Input
     31     -----
     32     : [...,num_ofdm_symbols,fft_size], tf.complex
     33         Resource grid in the frequency domain
     34 
     35     Output
     36     ------
     37     : [...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)], tf.complex
     38         Time-domain OFDM signal
     39     """
     40 
     41     def __init__(self, cyclic_prefix_length=0, **kwargs):
     42         super().__init__(**kwargs)
     43         self._cyclic_prefix_length = None
     44         self.cyclic_prefix_length = cyclic_prefix_length
     45 
     46     @property
     47     def cyclic_prefix_length(self):
     48         """
     49         scalar or [num_ofdm_symbols], int : Get/set the cyclic prefix length
     50         """
     51         return self._cyclic_prefix_length
     52 
     53     @cyclic_prefix_length.setter
     54     def cyclic_prefix_length(self, value):
     55         value = tf.cast(value, tf.int32)
     56         if not tf.reduce_all(value>=0):
     57             msg = "`cyclic_prefix_length` must be nonnegative."
     58             raise ValueError(msg)
     59         if not 0<= tf.rank(value)<=1:
     60             msg = "`cyclic_prefix_length` must be of rank 0 or 1"
     61             raise ValueError(msg)
     62         self._cyclic_prefix_length = value
     63 
     64     def build(self, input_shape):
     65         num_ofdm_symbols, fft_size = input_shape[-2:]
     66         if not tf.reduce_all(self.cyclic_prefix_length<=fft_size):
     67             msg = "`cyclic_prefix_length` cannot be larger than `fft_size`."
     68             raise ValueError(msg)
     69         if len(self.cyclic_prefix_length.shape)==1:
     70             if not self.cyclic_prefix_length.shape[0]==num_ofdm_symbols:
     71                 msg = "`cyclic_prefix_length` must be of size [num_ofdm_symbols]"
     72                 raise ValueError(msg)
     73 
     74             # Compute indices of CP symbols
     75             # These are offset by the number of the OFDM symbol
     76             # [num_ofdm_symbols, 1]
     77             offsets = tf.expand_dims(tf.range(1, num_ofdm_symbols+1)*fft_size,
     78                                      1)
     79             # [num_ofdm_symbols, None] (ragged tensor)
     80             cp_ind = tf.ragged.range(starts=-self.cyclic_prefix_length,
     81                                      limits=0) + offsets
     82 
     83             # Compute indices of symbols containing the actual sequence
     84             # [num_ofdm_symbols, fft_size]
     85             data_ind = tf.repeat(tf.expand_dims(tf.range(0, fft_size), 0),
     86                                  num_ofdm_symbols, 0) + offsets - fft_size
     87 
     88             # Concat CP and sequence indices
     89             # [num_ofdm_symbols, None]
     90             ind = tf.concat([cp_ind, data_ind], axis=-1)
     91 
     92             # Flatten in time domain
     93             # [num_ofdm_symbols *fft_size + sum(cyclic_prefix_length)]
     94             self._ind = ind.flat_values
     95 
     96     def call(self, inputs):
     97 
     98         # Shift DC subcarrier to first position
     99         x_freq = ifftshift(inputs, axes=-1)
    100 
    101         # Compute IFFT along the last dimension
    102         x_time = ifft(x_freq)
    103 
    104         if len(self.cyclic_prefix_length.shape)==1:
    105             # Individual CP length per OFDM symbol
    106 
    107             # Flatten last two dimensions
    108             x_time = flatten_last_dims(x_time, 2)
    109 
    110             # Gather full time-domain signal
    111             return tf.gather(x_time, self._ind, axis=-1)
    112 
    113         else:
    114             # Same CP length for all OFDM symbols
    115 
    116             # Obtain cyclic prefix
    117             cp = x_time[...,tf.shape(x_time)[-1]-self._cyclic_prefix_length:]
    118 
    119             # Prepend cyclic prefix
    120             x_time = tf.concat([cp, x_time], -1)
    121 
    122             # Serialize last two dimensions
    123             return  flatten_last_dims(x_time, 2)