anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

pusch_pilot_pattern.py (4253B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """PUSCH pilot pattern for the nr (5G) sub-package of the Sionna library.
      6 """
      7 import warnings
      8 from collections.abc import Sequence
      9 import tensorflow as tf
     10 import numpy as np
     11 from sionna.ofdm import PilotPattern
     12 from .pusch_config import PUSCHConfig
     13 
     14 class PUSCHPilotPattern(PilotPattern):
     15     # pylint: disable=line-too-long
     16     r"""Class defining a pilot pattern for NR PUSCH.
     17 
     18     This class defines a :class:`~sionna.ofdm.PilotPattern`
     19     that is used to configure an OFDM :class:`~sionna.ofdm.ResourceGrid`.
     20 
     21     For every transmitter, a separte :class:`~sionna.nr.PUSCHConfig`
     22     needs to be provided from which the pilot pattern will be created.
     23 
     24     Parameters
     25     ----------
     26     pusch_configs : instance or list of :class:`~sionna.nr.PUSCHConfig`
     27         PUSCH Configurations according to which the pilot pattern
     28         will created. One configuration is needed for each transmitter.
     29 
     30     dtype : tf.Dtype
     31         Defines the datatype for internal calculations and the output
     32         dtype. Defaults to `tf.complex64`.
     33     """
     34     def __init__(self,
     35                  pusch_configs,
     36                  dtype=tf.complex64):
     37 
     38         # Check correct type of pusch_configs
     39         if isinstance(pusch_configs, PUSCHConfig):
     40             pusch_configs = [pusch_configs]
     41         elif isinstance(pusch_configs, Sequence):
     42             for c in pusch_configs:
     43                 assert isinstance(c, PUSCHConfig), \
     44                     "Each element of pusch_configs must be a valide PUSCHConfig"
     45         else:
     46             raise ValueError("Invalid value for pusch_configs")
     47 
     48         # Check validity of provided pusch_configs
     49         num_tx = len(pusch_configs)
     50         num_streams_per_tx = pusch_configs[0].num_layers
     51         dmrs_grid = pusch_configs[0].dmrs_grid
     52         num_subcarriers = dmrs_grid[0].shape[0]
     53         num_ofdm_symbols = pusch_configs[0].l_d
     54         precoding = pusch_configs[0].precoding
     55         dmrs_ports = []
     56         num_pilots = np.sum(pusch_configs[0].dmrs_mask)
     57         for pusch_config in pusch_configs:
     58             assert pusch_config.num_layers==num_streams_per_tx, \
     59                 "All pusch_configs must have the same number of layers"
     60             assert pusch_config.dmrs_grid[0].shape[0]==num_subcarriers, \
     61                 "All pusch_configs must have the same number of subcarriers"
     62             assert pusch_config.l_d==num_ofdm_symbols, \
     63                 "All pusch_configs must have the same number of OFDM symbols"
     64             assert pusch_config.precoding==precoding, \
     65                 "All pusch_configs must have a the same precoding method"
     66             assert np.sum(pusch_config.dmrs_mask)==num_pilots, \
     67                 "All pusch_configs must have a the same number of masked REs"
     68             with warnings.catch_warnings():
     69                 warnings.simplefilter('always')
     70                 for port in pusch_config.dmrs.dmrs_port_set:
     71                     if port in dmrs_ports:
     72                         msg = f"DMRS port {port} used by multiple transmitters"
     73                         warnings.warn(msg)
     74             dmrs_ports += pusch_config.dmrs.dmrs_port_set
     75 
     76         # Create mask and pilots tensors
     77         mask = np.zeros([num_tx,
     78                          num_streams_per_tx,
     79                          num_ofdm_symbols,
     80                          num_subcarriers], bool)
     81         num_pilots = np.sum(pusch_configs[0].dmrs_mask)
     82         pilots = np.zeros([num_tx, num_streams_per_tx, num_pilots], complex)
     83         for i, pusch_config in enumerate(pusch_configs):
     84             for j in range(num_streams_per_tx):
     85                 ind0, ind1 = pusch_config.symbol_allocation
     86                 mask[i,j] = np.transpose(
     87                                 pusch_config.dmrs_mask[:, ind0:ind0+ind1])
     88                 dmrs_grid = np.transpose(
     89                                 pusch_config.dmrs_grid[j, :, ind0:ind0+ind1])
     90                 pilots[i,j] = dmrs_grid[np.where(mask[i,j])]
     91 
     92         # Init PilotPattern class
     93         super().__init__(mask, pilots,
     94                          trainable=False,
     95                          normalize=False,
     96                          dtype=dtype)