anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

resource_grid.py (19504B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition and functions related to the resource grid"""
      6 
      7 import tensorflow as tf
      8 import numpy as np
      9 from tensorflow.keras.layers import Layer
     10 
     11 from .pilot_pattern import PilotPattern, EmptyPilotPattern, KroneckerPilotPattern # pylint: disable=line-too-long
     12 from sionna.utils import flatten_last_dims, flatten_dims, split_dim
     13 import matplotlib.pyplot as plt
     14 from matplotlib import colors
     15 
     16 
     17 class ResourceGrid():
     18     # pylint: disable=line-too-long
     19     r"""Defines a `ResourceGrid` spanning multiple OFDM symbols and subcarriers.
     20 
     21     Parameters
     22     ----------
     23         num_ofdm_symbols : int
     24             Number of OFDM symbols.
     25 
     26         fft_size : int
     27             FFT size (, i.e., the number of subcarriers).
     28 
     29         subcarrier_spacing : float
     30             The subcarrier spacing in Hz.
     31 
     32         num_tx : int
     33             Number of transmitters.
     34 
     35         num_streams_per_tx : int
     36             Number of streams per transmitter.
     37 
     38         cyclic_prefix_length : int
     39             Length of the cyclic prefix.
     40 
     41         num_guard_carriers : int
     42             List of two integers defining the number of guardcarriers at the
     43             left and right side of the resource grid.
     44 
     45         dc_null : bool
     46             Indicates if the DC carrier is nulled or not.
     47 
     48         pilot_pattern : One of [None, "kronecker", "empty", PilotPattern]
     49             An instance of :class:`~sionna.ofdm.PilotPattern`, a string
     50             shorthand for the :class:`~sionna.ofdm.KroneckerPilotPattern`
     51             or :class:`~sionna.ofdm.EmptyPilotPattern`, or `None`.
     52             Defaults to `None` which is equivalent to `"empty"`.
     53 
     54         pilot_ofdm_symbol_indices : List, int
     55             List of indices of OFDM symbols reserved for pilot transmissions.
     56             Only needed if ``pilot_pattern="kronecker"``. Defaults to `None`.
     57 
     58         dtype : tf.Dtype
     59             Defines the datatype for internal calculations and the output
     60             dtype. Defaults to `tf.complex64`.
     61     """
     62     def __init__(self,
     63                  num_ofdm_symbols,
     64                  fft_size,
     65                  subcarrier_spacing,
     66                  num_tx=1,
     67                  num_streams_per_tx=1,
     68                  cyclic_prefix_length=0,
     69                  num_guard_carriers=(0,0),
     70                  dc_null=False,
     71                  pilot_pattern=None,
     72                  pilot_ofdm_symbol_indices=None,
     73                  dtype=tf.complex64):
     74         super().__init__()
     75         self._dtype = dtype
     76         self._num_ofdm_symbols = num_ofdm_symbols
     77         self._fft_size = fft_size
     78         self._subcarrier_spacing = subcarrier_spacing
     79         self._cyclic_prefix_length = int(cyclic_prefix_length)
     80         self._num_tx = num_tx
     81         self._num_streams_per_tx = num_streams_per_tx
     82         self._num_guard_carriers = np.array(num_guard_carriers)
     83         self._dc_null = dc_null
     84         self._pilot_ofdm_symbol_indices = pilot_ofdm_symbol_indices
     85         self.pilot_pattern = pilot_pattern
     86         self._check_settings()
     87 
     88     @property
     89     def cyclic_prefix_length(self):
     90         """Length of the cyclic prefix."""
     91         return self._cyclic_prefix_length
     92 
     93     @property
     94     def num_tx(self):
     95         """Number of transmitters."""
     96         return self._num_tx
     97 
     98     @property
     99     def num_streams_per_tx(self):
    100         """Number of streams  per transmitter."""
    101         return self._num_streams_per_tx
    102 
    103     @property
    104     def num_ofdm_symbols(self):
    105         """The number of OFDM symbols of the resource grid."""
    106         return self._num_ofdm_symbols
    107 
    108     @property
    109     def num_resource_elements(self):
    110         """Number of resource elements."""
    111         return self._fft_size*self._num_ofdm_symbols
    112 
    113     @property
    114     def num_effective_subcarriers(self):
    115         """Number of subcarriers used for data and pilot transmissions."""
    116         n = self._fft_size - self._dc_null - np.sum(self._num_guard_carriers)
    117         return n
    118 
    119     @property
    120     def effective_subcarrier_ind(self):
    121         """Returns the indices of the effective subcarriers."""
    122         num_gc = self._num_guard_carriers
    123         sc_ind = range(num_gc[0], self.fft_size-num_gc[1])
    124         if self.dc_null:
    125             sc_ind = np.delete(sc_ind, self.dc_ind-num_gc[0])
    126         return sc_ind
    127 
    128     @property
    129     def num_data_symbols(self):
    130         """Number of resource elements used for data transmissions."""
    131         n = self.num_effective_subcarriers * self._num_ofdm_symbols - \
    132                self.num_pilot_symbols
    133         return tf.cast(n, tf.int32)
    134 
    135     @property
    136     def num_pilot_symbols(self):
    137         """Number of resource elements used for pilot symbols."""
    138         return self.pilot_pattern.num_pilot_symbols
    139 
    140     @property
    141     def num_zero_symbols(self):
    142         """Number of empty resource elements."""
    143         n = (self._fft_size-self.num_effective_subcarriers) * \
    144                self._num_ofdm_symbols
    145         return tf.cast(n, tf.int32)
    146 
    147     @property
    148     def num_guard_carriers(self):
    149         """Number of left and right guard carriers."""
    150         return self._num_guard_carriers
    151 
    152     @property
    153     def dc_ind(self):
    154         """Index of the DC subcarrier.
    155 
    156         If ``fft_size`` is odd, the index is (``fft_size``-1)/2.
    157         If ``fft_size`` is even, the index is ``fft_size``/2.
    158         """
    159         return int(self._fft_size/2 - (self._fft_size%2==1)/2)
    160 
    161     @property
    162     def fft_size(self):
    163         """The FFT size."""
    164         return self._fft_size
    165 
    166     @property
    167     def subcarrier_spacing(self):
    168         """The subcarrier spacing [Hz]."""
    169         return self._subcarrier_spacing
    170 
    171     @property
    172     def ofdm_symbol_duration(self):
    173         """Duration of an OFDM symbol with cyclic prefix [s]."""
    174         return (1. + self.cyclic_prefix_length/self.fft_size) \
    175                 / self.subcarrier_spacing
    176 
    177     @property
    178     def bandwidth(self):
    179         """The occupied bandwidth [Hz]: ``fft_size*subcarrier_spacing``."""
    180         return self.fft_size*self.subcarrier_spacing
    181 
    182     @property
    183     def num_time_samples(self):
    184         """The number of time-domain samples occupied by the resource grid."""
    185         return (self.fft_size + self.cyclic_prefix_length) \
    186                 * self._num_ofdm_symbols
    187 
    188     @property
    189     def dc_null(self):
    190         """Indicates if the DC carriers is nulled or not."""
    191         return self._dc_null
    192 
    193     @property
    194     def pilot_pattern(self):
    195         """The used PilotPattern."""
    196         return self._pilot_pattern
    197 
    198     @pilot_pattern.setter
    199     def pilot_pattern(self, value):
    200         if value is None:
    201             value = EmptyPilotPattern(self._num_tx,
    202                                       self._num_streams_per_tx,
    203                                       self._num_ofdm_symbols,
    204                                       self.num_effective_subcarriers,
    205                                       dtype=self._dtype)
    206         elif isinstance(value, PilotPattern):
    207             pass
    208         elif isinstance(value, str):
    209             assert value in ["kronecker", "empty"],\
    210                 "Unknown pilot pattern"
    211             if value=="empty":
    212                 value = EmptyPilotPattern(self._num_tx,
    213                                       self._num_streams_per_tx,
    214                                       self._num_ofdm_symbols,
    215                                       self.num_effective_subcarriers,
    216                                       dtype=self._dtype)
    217             elif value=="kronecker":
    218                 assert self._pilot_ofdm_symbol_indices is not None,\
    219                     "You must provide pilot_ofdm_symbol_indices."
    220                 value = KroneckerPilotPattern(self,
    221                         self._pilot_ofdm_symbol_indices, dtype=self._dtype)
    222         else:
    223             raise ValueError("Unsupported pilot_pattern")
    224         self._pilot_pattern = value
    225 
    226     def _check_settings(self):
    227         """Validate that all properties define a valid resource grid"""
    228         assert self._num_ofdm_symbols > 0, \
    229             "`num_ofdm_symbols` must be positive`."
    230         assert self._fft_size > 0, \
    231             "`fft_size` must be positive`."
    232         assert self._cyclic_prefix_length>=0, \
    233             "`cyclic_prefix_length must be nonnegative."
    234         assert self._cyclic_prefix_length<=self._fft_size, \
    235             "`cyclic_prefix_length cannot be longer than `fft_size`."
    236         assert self._num_tx > 0, \
    237             "`num_tx` must be positive`."
    238         assert self._num_streams_per_tx > 0, \
    239             "`num_streams_per_tx` must be positive`."
    240         assert len(self._num_guard_carriers)==2, \
    241             "`num_guard_carriers` must have two elements."
    242         assert np.all(np.greater_equal(self._num_guard_carriers, 0)), \
    243             "`num_guard_carriers` must have nonnegative entries."
    244         assert np.sum(self._num_guard_carriers)<=self._fft_size-self._dc_null,\
    245             "Total number of guardcarriers cannot be larger than `fft_size`."
    246         assert self._dtype in [tf.complex64, tf.complex128], \
    247             "dtype must be tf.complex64 or tf.complex128"
    248         return True
    249 
    250     def build_type_grid(self):
    251         """Returns a tensor indicating the type of each resource element.
    252 
    253         Resource elements can be one of
    254 
    255         - 0 : Data symbol
    256         - 1 : Pilot symbol
    257         - 2 : Guard carrier symbol
    258         - 3 : DC carrier symbol
    259 
    260         Output
    261         ------
    262         : [num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.int32
    263             Tensor indicating for each transmitter and stream the type of
    264             the resource elements of the corresponding resource grid.
    265             The type can be one of [0,1,2,3] as explained above.
    266         """
    267         shape = [self._num_tx, self._num_streams_per_tx, self._num_ofdm_symbols]
    268         gc_l = 2*tf.ones(shape+[self._num_guard_carriers[0]], tf.int32)
    269         gc_r = 2*tf.ones(shape+[self._num_guard_carriers[1]], tf.int32)
    270         dc   = 3*tf.ones(shape + [tf.cast(self._dc_null, tf.int32)], tf.int32)
    271         mask = self.pilot_pattern.mask
    272         split_ind = self.dc_ind-self._num_guard_carriers[0]
    273         rg_type = tf.concat([gc_l,                 # Left Guards
    274                              mask[...,:split_ind], # Data & pilots
    275                              dc,                   # DC
    276                              mask[...,split_ind:], # Data & pilots
    277                              gc_r], -1)            # Right guards
    278         return rg_type
    279 
    280     def show(self, tx_ind=0, tx_stream_ind=0):
    281         """Visualizes the resource grid for a specific transmitter and stream.
    282 
    283         Input
    284         -----
    285         tx_ind : int
    286             Indicates the transmitter index.
    287 
    288         tx_stream_ind : int
    289             Indicates the index of the stream.
    290 
    291         Output
    292         ------
    293         : `matplotlib.figure`
    294             A handle to a matplot figure object.
    295         """
    296         fig = plt.figure()
    297         data = self.build_type_grid()[tx_ind, tx_stream_ind]
    298         cmap = colors.ListedColormap([[60/256,8/256,72/256],
    299                               [45/256,91/256,128/256],
    300                               [45/256,172/256,111/256],
    301                               [250/256,228/256,62/256]])
    302         bounds=[0,1,2,3,4]
    303         norm = colors.BoundaryNorm(bounds, cmap.N)
    304         img = plt.imshow(np.transpose(data), interpolation="nearest",
    305                          origin="lower", cmap=cmap, norm=norm,
    306                          aspect="auto")
    307         cbar = plt.colorbar(img, ticks=[0.5, 1.5, 2.5,3.5],
    308                             orientation="vertical", shrink=0.8)
    309         cbar.set_ticklabels(["Data", "Pilot", "Guard carrier", "DC carrier"])
    310         plt.title("OFDM Resource Grid")
    311         plt.ylabel("Subcarrier Index")
    312         plt.xlabel("OFDM Symbol")
    313         plt.xticks(range(0, data.shape[0]))
    314 
    315         return fig
    316 
    317 class ResourceGridMapper(Layer):
    318     # pylint: disable=line-too-long
    319     r"""ResourceGridMapper(resource_grid, dtype=tf.complex64, **kwargs)
    320 
    321     Maps a tensor of modulated data symbols to a ResourceGrid.
    322 
    323     This layer takes as input a tensor of modulated data symbols
    324     and maps them together with pilot symbols onto an
    325     OFDM :class:`~sionna.ofdm.ResourceGrid`. The output can be
    326     converted to a time-domain signal with the
    327     :class:`~sionna.ofdm.Modulator` or further processed in the
    328     frequency domain.
    329 
    330     Parameters
    331     ----------
    332     resource_grid : ResourceGrid
    333         An instance of :class:`~sionna.ofdm.ResourceGrid`.
    334 
    335     dtype : tf.Dtype
    336         Datatype for internal calculations and the output dtype.
    337         Defaults to `tf.complex64`.
    338 
    339     Input
    340     -----
    341     : [batch_size, num_tx, num_streams_per_tx, num_data_symbols], tf.complex
    342         The modulated data symbols to be mapped onto the resource grid.
    343 
    344     Output
    345     ------
    346     : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
    347         The full OFDM resource grid in the frequency domain.
    348     """
    349     def __init__(self, resource_grid, dtype=tf.complex64, **kwargs):
    350         super().__init__(dtype=dtype, **kwargs)
    351         self._resource_grid = resource_grid
    352 
    353     def build(self, input_shape): # pylint: disable=unused-argument
    354         """Precompute a tensor of shape
    355         [num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size]
    356         which is prefilled with pilots and stores indices
    357         to scatter data symbols.
    358         """
    359         self._rg_type = self._resource_grid.build_type_grid()
    360         self._pilot_ind = tf.where(self._rg_type==1)
    361         self._data_ind = tf.where(self._rg_type==0)
    362 
    363     def call(self, inputs):
    364         # Map pilots on empty resource grid
    365         pilots = flatten_last_dims(self._resource_grid.pilot_pattern.pilots, 3)
    366         template = tf.scatter_nd(self._pilot_ind,
    367                                  pilots,
    368                                  self._rg_type.shape)
    369         template = tf.expand_dims(template, -1)
    370 
    371         # Broadcast the resource grid template to batch_size
    372         batch_size = tf.shape(inputs)[0]
    373         new_shape = tf.concat([tf.shape(template)[:-1], [batch_size]], 0)
    374         template = tf.broadcast_to(template, new_shape)
    375 
    376         # Flatten the inputs and put batch_dim last for scatter update
    377         inputs = tf.transpose(flatten_last_dims(inputs, 3))
    378         rg = tf.tensor_scatter_nd_update(template, self._data_ind, inputs)
    379         rg = tf.transpose(rg, [4, 0, 1, 2, 3])
    380 
    381         return rg
    382 
    383 class ResourceGridDemapper(Layer):
    384     # pylint: disable=line-too-long
    385     r"""ResourceGridDemapper(resource_grid, stream_management, dtype=tf.complex64, **kwargs)
    386 
    387     Extracts data-carrying resource elements from a resource grid.
    388 
    389     This layer takes as input an OFDM :class:`~sionna.ofdm.ResourceGrid` and
    390     extracts the data-carrying resource elements. In other words, it implements
    391     the reverse operation of :class:`~sionna.ofdm.ResourceGridMapper`.
    392 
    393     Parameters
    394     ----------
    395     resource_grid : ResourceGrid
    396         An instance of :class:`~sionna.ofdm.ResourceGrid`.
    397 
    398     stream_management : StreamManagement
    399         An instance of :class:`~sionna.mimo.StreamManagement`.
    400 
    401     dtype : tf.Dtype
    402         Datatype for internal calculations and the output dtype.
    403         Defaults to `tf.complex64`.
    404 
    405     Input
    406     -----
    407     : [batch_size, num_rx, num_streams_per_rx, num_ofdm_symbols, fft_size, data_dim]
    408         The full OFDM resource grid in the frequency domain.
    409         The last dimension `data_dim` is optional. If `data_dim`
    410         is used, it refers to the dimensionality of the data that should be
    411         demapped to individual streams. An example would be LLRs.
    412 
    413     Output
    414     ------
    415     : [batch_size, num_rx, num_streams_per_rx, num_data_symbols, data_dim]
    416         The data that were mapped into the resource grid.
    417         The last dimension `data_dim` is only returned if it was used for the
    418         input.
    419     """
    420     def __init__(self,
    421                  resource_grid,
    422                  stream_management,
    423                  dtype=tf.complex64,
    424                  **kwargs):
    425         super().__init__(dtype=dtype, **kwargs)
    426         self._stream_management = stream_management
    427         self._resource_grid = resource_grid
    428 
    429         # Precompute indices to extract data symbols
    430         mask = resource_grid.pilot_pattern.mask
    431         num_data_symbols = resource_grid.pilot_pattern.num_data_symbols
    432         data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING")
    433         self._data_ind = data_ind[...,:num_data_symbols]
    434 
    435     def call(self, y): # pylint: disable=arguments-renamed
    436 
    437         # y has shape
    438         # [batch_size, num_rx, num_streams_per_rx, num_ofdm_symbols,...
    439         # ..., fft_size, data_dim]
    440 
    441         # If data_dim is not provided, add a dummy dimension
    442         if len(y.shape)==5:
    443             y = tf.expand_dims(y, -1)
    444 
    445         # Remove nulled subcarriers from y (guards, dc). New shape:
    446         # [batch_size, num_rx, num_rx_ant, ...
    447         #  ..., num_ofdm_symbols, num_effective_subcarriers, data dim]
    448         y = tf.gather(y, self._resource_grid.effective_subcarrier_ind, axis=-2)
    449 
    450         # Transpose tensor to shape
    451         # [num_rx, num_streams_per_rx, num_ofdm_symbols,...
    452         #  ..., num_effective_subcarriers, data_dim, batch_size]
    453         y = tf.transpose(y, [1, 2, 3, 4, 5, 0])
    454 
    455         # Merge num_rx amd num_streams_per_rx
    456         # [num_rx * num_streams_per_rx, num_ofdm_symbols,...
    457         #  ...,num_effective_subcarriers, data_dim, batch_size]
    458         y = flatten_dims(y, 2, 0)
    459 
    460         # Put first dimension into the right ordering
    461         stream_ind = self._stream_management.stream_ind
    462         y = tf.gather(y, stream_ind, axis=0)
    463 
    464         # Reshape first dimensions to [num_tx, num_streams] so that
    465         # we can compared to the way the streams were created.
    466         # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,...
    467         #  ..., data_dim, batch_size]
    468         num_streams = self._stream_management.num_streams_per_tx
    469         num_tx = self._stream_management.num_tx
    470         y = split_dim(y, [num_tx, num_streams], 0)
    471 
    472         # Flatten resource grid dimensions
    473         # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,...
    474         #  ..., data_dim, batch_size]
    475         y = flatten_dims(y, 2, 2)
    476 
    477         # Gather data symbols
    478         # [num_tx, num_streams, num_data_symbols, data_dim, batch_size]
    479         y = tf.gather(y, self._data_ind, batch_dims=2, axis=2)
    480 
    481         # Put batch_dim first
    482         # [batch_size, num_tx, num_streams, num_data_symbols]
    483         y = tf.transpose(y, [4, 0, 1, 2, 3])
    484 
    485         # Squeeze data_dim
    486         if y.shape[-1]==1:
    487             y = tf.squeeze(y, -1)
    488 
    489         return y
    490 
    491 class RemoveNulledSubcarriers(Layer):
    492     # pylint: disable=line-too-long
    493     r"""RemoveNulledSubcarriers(resource_grid, **kwargs)
    494 
    495     Removes nulled guard and/or DC subcarriers from a resource grid.
    496 
    497     Parameters
    498     ----------
    499     resource_grid : ResourceGrid
    500         An instance of :class:`~sionna.ofdm.ResourceGrid`.
    501 
    502     Input
    503     -----
    504     : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex64
    505         Full resource grid.
    506 
    507     Output
    508     ------
    509     : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex64
    510         Resource grid without nulled subcarriers.
    511     """
    512     def __init__(self, resource_grid, **kwargs):
    513         self._sc_ind = resource_grid.effective_subcarrier_ind
    514         super().__init__(**kwargs)
    515 
    516     def call(self, inputs):
    517         return tf.gather(inputs, self._sc_ind, axis=-1)