anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

pilot_pattern.py (13332B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Class definition and functions related to pilot patterns"""
      6 
      7 import tensorflow as tf
      8 import numpy as np
      9 import matplotlib.pyplot as plt
     10 from matplotlib import colors
     11 from sionna.utils import QAMSource
     12 
     13 
     14 class PilotPattern():
     15     # pylint: disable=line-too-long
     16     r"""Class defining a pilot pattern for an OFDM ResourceGrid.
     17 
     18     This class defines a pilot pattern object that is used to configure
     19     an OFDM :class:`~sionna.ofdm.ResourceGrid`.
     20 
     21     Parameters
     22     ----------
     23     mask : [num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], bool
     24         Tensor indicating resource elements that are reserved for pilot transmissions.
     25 
     26     pilots : [num_tx, num_streams_per_tx, num_pilots], tf.complex
     27         The pilot symbols to be mapped onto the ``mask``.
     28 
     29     trainable : bool
     30         Indicates if ``pilots`` is a trainable `Variable`.
     31         Defaults to `False`.
     32 
     33     normalize : bool
     34         Indicates if the ``pilots`` should be normalized to an average
     35         energy of one across the last dimension. This can be useful to
     36         ensure that trainable ``pilots`` have a finite energy.
     37         Defaults to `False`.
     38 
     39     dtype : tf.Dtype
     40         Defines the datatype for internal calculations and the output
     41         dtype. Defaults to `tf.complex64`.
     42     """
     43     def __init__(self, mask, pilots, trainable=False, normalize=False,
     44                  dtype=tf.complex64):
     45         super().__init__()
     46         self._dtype = dtype
     47         self._mask = tf.cast(mask, tf.int32)
     48         self._pilots = tf.Variable(tf.cast(pilots, self._dtype), trainable)
     49         self.normalize = normalize
     50         self._check_settings()
     51 
     52     @property
     53     def num_tx(self):
     54         """Number of transmitters"""
     55         return self._mask.shape[0]
     56 
     57     @property
     58     def num_streams_per_tx(self):
     59         """Number of streams per transmitter"""
     60         return self._mask.shape[1]
     61 
     62     @ property
     63     def num_ofdm_symbols(self):
     64         """Number of OFDM symbols"""
     65         return self._mask.shape[2]
     66 
     67     @ property
     68     def num_effective_subcarriers(self):
     69         """Number of effectvie subcarriers"""
     70         return self._mask.shape[3]
     71 
     72     @property
     73     def num_pilot_symbols(self):
     74         """Number of pilot symbols per transmit stream."""
     75         return tf.shape(self._pilots)[-1]
     76 
     77     @property
     78     def num_data_symbols(self):
     79         """ Number of data symbols per transmit stream."""
     80         return tf.shape(self._mask)[-1]*tf.shape(self._mask)[-2] - \
     81                self.num_pilot_symbols
     82 
     83     @property
     84     def normalize(self):
     85         """Returns or sets the flag indicating if the pilots
     86            are normalized or not
     87         """
     88         return self._normalize
     89 
     90     @normalize.setter
     91     def normalize(self, value):
     92         self._normalize = tf.cast(value, tf.bool)
     93 
     94     @property
     95     def mask(self):
     96         """Mask of the pilot pattern"""
     97         return self._mask
     98 
     99     @property
    100     def pilots(self):
    101         """Returns or sets the possibly normalized tensor of pilot symbols.
    102            If pilots are normalized, the normalization will be applied
    103            after new values for pilots have been set. If this is
    104            not the desired behavior, turn normalization off.
    105         """
    106         def norm_pilots():
    107             scale = tf.abs(self._pilots)**2
    108             scale = 1/tf.sqrt(tf.reduce_mean(scale, axis=-1, keepdims=True))
    109             scale = tf.cast(scale, self._dtype)
    110             return scale*self._pilots
    111 
    112         return tf.cond(self.normalize, norm_pilots, lambda: self._pilots)
    113 
    114     @pilots.setter
    115     def pilots(self, value):
    116         self._pilots.assign(value)
    117 
    118     def _check_settings(self):
    119         """Validate that all properties define a valid pilot pattern."""
    120 
    121         assert tf.rank(self._mask)==4, "`mask` must have four dimensions."
    122         assert tf.rank(self._pilots)==3, "`pilots` must have three dimensions."
    123         assert np.array_equal(self._mask.shape[:2], self._pilots.shape[:2]), \
    124             "The first two dimensions of `mask` and `pilots` must be equal."
    125 
    126         num_pilots = tf.reduce_sum(self._mask, axis=(-2,-1))
    127         assert tf.reduce_min(num_pilots)==tf.reduce_max(num_pilots), \
    128             """The number of nonzero elements in the masks for all transmitters
    129             and streams must be identical."""
    130 
    131         assert self.num_pilot_symbols==tf.reduce_max(num_pilots), \
    132             """The shape of the last dimension of `pilots` must equal
    133             the number of non-zero entries within the last two
    134             dimensions of `mask`."""
    135 
    136         return True
    137 
    138     @property
    139     def trainable(self):
    140         """Returns if pilots are trainable or not"""
    141         return self._pilots.trainable
    142 
    143 
    144     def show(self, tx_ind=None, stream_ind=None, show_pilot_ind=False):
    145         """Visualizes the pilot patterns for some transmitters and streams.
    146 
    147         Input
    148         -----
    149         tx_ind : list, int
    150             Indicates the indices of transmitters to be included.
    151             Defaults to `None`, i.e., all transmitters included.
    152 
    153         stream_ind : list, int
    154             Indicates the indices of streams to be included.
    155             Defaults to `None`, i.e., all streams included.
    156 
    157         show_pilot_ind : bool
    158             Indicates if the indices of the pilot symbols should be shown.
    159 
    160         Output
    161         ------
    162         list : matplotlib.figure.Figure
    163             List of matplot figure objects showing each the pilot pattern
    164             from a specific transmitter and stream.
    165         """
    166         mask = self.mask.numpy()
    167         pilots = self.pilots.numpy()
    168 
    169         if tx_ind is None:
    170             tx_ind = range(0, self.num_tx)
    171         elif not isinstance(tx_ind, list):
    172             tx_ind = [tx_ind]
    173 
    174         if stream_ind is None:
    175             stream_ind = range(0, self.num_streams_per_tx)
    176         elif not isinstance(stream_ind, list):
    177             stream_ind = [stream_ind]
    178 
    179         figs = []
    180         for i in tx_ind:
    181             for j in stream_ind:
    182                 q = np.zeros_like(mask[0,0])
    183                 q[np.where(mask[i,j])] = (np.abs(pilots[i,j])==0) + 1
    184                 legend = ["Data", "Pilots", "Masked"]
    185                 fig = plt.figure()
    186                 plt.title(f"TX {i} - Stream {j}")
    187                 plt.xlabel("OFDM Symbol")
    188                 plt.ylabel("Subcarrier Index")
    189                 plt.xticks(range(0, q.shape[1]))
    190                 cmap = plt.cm.tab20c
    191                 b = np.arange(0, 4)
    192                 norm = colors.BoundaryNorm(b, cmap.N)
    193                 im = plt.imshow(np.transpose(q), origin="lower", aspect="auto", norm=norm, cmap=cmap)
    194                 cbar = plt.colorbar(im)
    195                 cbar.set_ticks(b[:-1]+0.5)
    196                 cbar.set_ticklabels(legend)
    197 
    198                 if show_pilot_ind:
    199                     c = 0
    200                     for t in range(self.num_ofdm_symbols):
    201                         for k in range(self.num_effective_subcarriers):
    202                             if mask[i,j][t,k]:
    203                                 if np.abs(pilots[i,j,c])>0:
    204                                     plt.annotate(c, [t, k])
    205                                 c+=1
    206                 figs.append(fig)
    207 
    208         return figs
    209 
    210 class EmptyPilotPattern(PilotPattern):
    211     """Creates an empty pilot pattern.
    212 
    213     Generates a instance of :class:`~sionna.ofdm.PilotPattern` with
    214     an empty ``mask`` and ``pilots``.
    215 
    216     Parameters
    217     ----------
    218     num_tx : int
    219         Number of transmitters.
    220 
    221     num_streams_per_tx : int
    222         Number of streams per transmitter.
    223 
    224     num_ofdm_symbols : int
    225         Number of OFDM symbols.
    226 
    227     num_effective_subcarriers : int
    228         Number of effective subcarriers
    229         that are available for the transmission of data and pilots.
    230         Note that this number is generally smaller than the ``fft_size``
    231         due to nulled subcarriers.
    232 
    233     dtype : tf.Dtype
    234         Defines the datatype for internal calculations and the output
    235         dtype. Defaults to `tf.complex64`.
    236     """
    237     def __init__(self,
    238                  num_tx,
    239                  num_streams_per_tx,
    240                  num_ofdm_symbols,
    241                  num_effective_subcarriers,
    242                  dtype=tf.complex64):
    243 
    244         assert num_tx > 0, \
    245             "`num_tx` must be positive`."
    246         assert num_streams_per_tx > 0, \
    247             "`num_streams_per_tx` must be positive`."
    248         assert num_ofdm_symbols > 0, \
    249             "`num_ofdm_symbols` must be positive`."
    250         assert num_effective_subcarriers > 0, \
    251             "`num_effective_subcarriers` must be positive`."
    252 
    253         shape = [num_tx, num_streams_per_tx, num_ofdm_symbols,
    254                       num_effective_subcarriers]
    255         mask = tf.zeros(shape, tf.bool)
    256         pilots = tf.zeros(shape[:2]+[0], dtype)
    257         super().__init__(mask, pilots, trainable=False, normalize=False,
    258                          dtype=dtype)
    259 
    260 class KroneckerPilotPattern(PilotPattern):
    261     """Simple orthogonal pilot pattern with Kronecker structure.
    262 
    263     This function generates an instance of :class:`~sionna.ofdm.PilotPattern`
    264     that allocates non-overlapping pilot sequences for all transmitters and
    265     streams on specified OFDM symbols. As the same pilot sequences are reused
    266     across those OFDM symbols, the resulting pilot pattern has a frequency-time
    267     Kronecker structure. This structure enables a very efficient implementation
    268     of the LMMSE channel estimator. Each pilot sequence is constructed from
    269     randomly drawn QPSK constellation points.
    270 
    271     Parameters
    272     ----------
    273     resource_grid : ResourceGrid
    274         An instance of a :class:`~sionna.ofdm.ResourceGrid`.
    275 
    276     pilot_ofdm_symbol_indices : list, int
    277         List of integers defining the OFDM symbol indices that are reserved
    278         for pilots.
    279 
    280     normalize : bool
    281         Indicates if the ``pilots`` should be normalized to an average
    282         energy of one across the last dimension.
    283         Defaults to `True`.
    284 
    285     seed : int
    286         Seed for the generation of the pilot sequence. Different seed values
    287         lead to different sequences. Defaults to 0.
    288 
    289     dtype : tf.Dtype
    290         Defines the datatype for internal calculations and the output
    291         dtype. Defaults to `tf.complex64`.
    292 
    293     Note
    294     ----
    295     It is required that the ``resource_grid``'s property
    296     ``num_effective_subcarriers`` is an
    297     integer multiple of ``num_tx * num_streams_per_tx``. This condition is
    298     required to ensure that all transmitters and streams get
    299     non-overlapping pilot sequences. For a large number of streams and/or
    300     transmitters, the pilot pattern becomes very sparse in the frequency
    301     domain.
    302 
    303     Examples
    304     --------
    305     >>> rg = ResourceGrid(num_ofdm_symbols=14,
    306     ...                   fft_size=64,
    307     ...                   subcarrier_spacing = 30e3,
    308     ...                   num_tx=4,
    309     ...                   num_streams_per_tx=2,
    310     ...                   pilot_pattern = "kronecker",
    311     ...                   pilot_ofdm_symbol_indices = [2, 11])
    312     >>> rg.pilot_pattern.show();
    313 
    314     .. image:: ../figures/kronecker_pilot_pattern.png
    315 
    316     """
    317     def __init__(self,
    318                  resource_grid,
    319                  pilot_ofdm_symbol_indices,
    320                  normalize=True,
    321                  seed=0,
    322                  dtype=tf.complex64):
    323 
    324         num_tx = resource_grid.num_tx
    325         num_streams_per_tx = resource_grid.num_streams_per_tx
    326         num_ofdm_symbols = resource_grid.num_ofdm_symbols
    327         num_effective_subcarriers = resource_grid.num_effective_subcarriers
    328         self._dtype = dtype
    329 
    330         # Number of OFDM symbols carrying pilots
    331         num_pilot_symbols = len(pilot_ofdm_symbol_indices)
    332 
    333         # Compute the total number of required orthogonal sequences
    334         num_seq = num_tx*num_streams_per_tx
    335 
    336         # Compute the length of a pilot sequence
    337         num_pilots = num_pilot_symbols*num_effective_subcarriers/num_seq
    338         assert (num_pilots/num_pilot_symbols)%1==0, \
    339             """`num_effective_subcarriers` must be an integer multiple of
    340             `num_tx`*`num_streams_per_tx`."""
    341 
    342         # Number of pilots per OFDM symbol
    343         num_pilots_per_symbol = int(num_pilots/num_pilot_symbols)
    344 
    345         # Prepare empty mask and pilots
    346         shape = [num_tx, num_streams_per_tx,
    347                  num_ofdm_symbols,num_effective_subcarriers]
    348         mask = np.zeros(shape, bool)
    349         shape[2] = num_pilot_symbols
    350         pilots = np.zeros(shape, np.complex64)
    351 
    352         # Populate all selected OFDM symbols in the mask
    353         mask[..., pilot_ofdm_symbol_indices, :] = True
    354 
    355         # Populate the pilots with random QPSK symbols
    356         qam_source = QAMSource(2, seed=seed, dtype=self._dtype)
    357         for i in range(num_tx):
    358             for j in range(num_streams_per_tx):
    359                 # Generate random QPSK symbols
    360                 p = qam_source([1,1,num_pilot_symbols,num_pilots_per_symbol])
    361 
    362                 # Place pilots spaced by num_seq to avoid overlap
    363                 pilots[i,j,:,i*num_streams_per_tx+j::num_seq] = p
    364 
    365         # Reshape the pilots tensor
    366         pilots = np.reshape(pilots, [num_tx, num_streams_per_tx, -1])
    367 
    368         super().__init__(mask, pilots, trainable=False,
    369                          normalize=normalize, dtype=self._dtype)