anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

pusch_channel_estimation.py (7502B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """PUSCH Channel Estimation for the nr (5G) sub-package of the Sionna library.
      6 """
      7 import tensorflow as tf
      8 from tensorflow.keras.layers import Layer
      9 from sionna.ofdm import LSChannelEstimator
     10 from sionna.utils import expand_to_rank, split_dim
     11 
     12 class PUSCHLSChannelEstimator(LSChannelEstimator, Layer):
     13     # pylint: disable=line-too-long
     14     r"""LSChannelEstimator(resource_grid, dmrs_length, dmrs_additional_position, num_cdm_groups_without_data, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs)
     15 
     16     Layer implementing least-squares (LS) channel estimation for NR PUSCH Transmissions.
     17 
     18     After LS channel estimation at the pilot positions, the channel estimates
     19     and error variances are interpolated accross the entire resource grid using
     20     a specified interpolation function.
     21 
     22     The implementation is similar to that of :class:`~sionna.ofdm.LSChannelEstimator`.
     23     However, it additional takes into account the separation of streams in the same CDM group
     24     as defined in :class:`~sionna.nr.PUSCHDMRSConfig`. This is done through
     25     frequency and time averaging of adjacent LS channel estimates.
     26 
     27     Parameters
     28     ----------
     29     resource_grid : ResourceGrid
     30         An instance of :class:`~sionna.ofdm.ResourceGrid`
     31 
     32     dmrs_length : int, [1,2]
     33         Length of DMRS symbols. See :class:`~sionna.nr.PUSCHDMRSConfig`.
     34 
     35     dmrs_additional_position : int, [0,1,2,3]
     36         Number of additional DMRS symbols.
     37         See :class:`~sionna.nr.PUSCHDMRSConfig`.
     38 
     39     num_cdm_groups_without_data : int, [1,2,3]
     40         Number of CDM groups masked for data transmissions.
     41         See :class:`~sionna.nr.PUSCHDMRSConfig`.
     42 
     43     interpolation_type : One of ["nn", "lin", "lin_time_avg"], string
     44         The interpolation method to be used.
     45         It is ignored if ``interpolator`` is not `None`.
     46         Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`")
     47         or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with
     48         averaging across OFDM symbols (`"lin_time_avg"`).
     49         Defaults to "nn".
     50 
     51     interpolator : BaseChannelInterpolator
     52         An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`,
     53         such as :class:`~sionna.ofdm.LMMSEInterpolator`,
     54         or `None`. In the latter case, the interpolator specified
     55         by ``interpolation_type`` is used.
     56         Otherwise, the ``interpolator`` is used and ``interpolation_type``
     57         is ignored.
     58         Defaults to `None`.
     59 
     60     dtype : tf.Dtype
     61         Datatype for internal calculations and the output dtype.
     62         Defaults to `tf.complex64`.
     63 
     64     Input
     65     -----
     66     (y, no) :
     67         Tuple:
     68 
     69     y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex
     70         Observed resource grid
     71 
     72     no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float
     73         Variance of the AWGN
     74 
     75     Output
     76     ------
     77     h_ls : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex
     78         Channel estimates across the entire resource grid for all
     79         transmitters and streams
     80 
     81     err_var : Same shape as ``h_ls``, tf.float
     82         Channel estimation error variance across the entire resource grid
     83         for all transmitters and streams
     84     """
     85     def __init__(self,
     86                  resource_grid,
     87                  dmrs_length,
     88                  dmrs_additional_position,
     89                  num_cdm_groups_without_data,
     90                  interpolation_type="nn",
     91                  interpolator=None,
     92                  dtype=tf.complex64,
     93                  **kwargs):
     94         super().__init__(resource_grid,
     95                          interpolation_type,
     96                          interpolator,
     97                          dtype, **kwargs)
     98 
     99         self._dmrs_length = dmrs_length
    100         self._dmrs_additional_position = dmrs_additional_position
    101         self._num_cdm_groups_without_data = num_cdm_groups_without_data
    102 
    103         # Number of DMRS OFDM symbols
    104         self._num_dmrs_syms = self._dmrs_length \
    105                               * (self._dmrs_additional_position+1)
    106 
    107         # Number of pilot symbols per DMRS OFDM symbol
    108         # Some pilot symbols can be zero (for masking)
    109         self._num_pilots_per_dmrs_sym = int(
    110                     self._pilot_pattern.pilots.shape[-1]/self._num_dmrs_syms)
    111 
    112     def estimate_at_pilot_locations(self, y_pilots, no):
    113         # y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams,
    114         #               num_pilot_symbols], tf.complex
    115         #     The observed signals for the pilot-carrying resource elements.
    116 
    117         # no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims,
    118         #   tf.float
    119         #     The variance of the AWGN.
    120 
    121         # Compute LS channel estimates
    122         # Note: Some might be Inf because pilots=0, but we do not care
    123         # as only the valid estimates will be considered during interpolation.
    124         # We do a save division to replace Inf by 0.
    125         # Broadcasting from pilots here is automatic since pilots have shape
    126         # [num_tx, num_streams, num_pilot_symbols]
    127         h_ls = tf.math.divide_no_nan(y_pilots, self._pilot_pattern.pilots)
    128         h_ls_shape = tf.shape(h_ls)
    129 
    130         # Compute error variance and broadcast to the same shape as h_ls
    131         # Expand rank of no for broadcasting
    132         no = expand_to_rank(no, tf.rank(h_ls), -1)
    133 
    134         # Expand rank of pilots for broadcasting
    135         pilots = expand_to_rank(self._pilot_pattern.pilots, tf.rank(h_ls), 0)
    136 
    137         # Compute error variance, broadcastable to the shape of h_ls
    138         err_var = tf.math.divide_no_nan(no, tf.abs(pilots)**2)
    139 
    140         # In order to deal with CDM, we need to do (optional) time and
    141         # frequency averaging of the LS estimates
    142         h_hat = h_ls
    143 
    144         # (Optional) Time-averaging across adjacent DMRS OFDM symbols
    145         if self._dmrs_length==2:
    146             # Reshape last dim to [num_dmrs_syms, num_pilots_per_dmrs_sym]
    147             h_hat = split_dim(h_hat, [self._num_dmrs_syms,
    148                                       self._num_pilots_per_dmrs_sym], 5)
    149 
    150             # Average adjacent DMRS symbols in time domain
    151             h_hat = (h_hat[...,0::2,:]+h_hat[...,1::2,:]) \
    152                      / tf.cast(2, h_hat.dtype)
    153             h_hat = tf.repeat(h_hat, 2, axis=-2)
    154             h_hat = tf.reshape(h_hat, h_ls_shape)
    155 
    156             # The error variance gets reduced by a factor of two
    157             err_var /= tf.cast(2, err_var.dtype)
    158 
    159         # Frequency-averaging between adjacent channel estimates
    160 
    161         # Compute number of elements across which frequency averaging should
    162         # be done. This includes the zeroed elements.
    163         n = 2*self._num_cdm_groups_without_data
    164         k = int(h_hat.shape[-1]/n) # Second dimension
    165 
    166         # Reshape last dimension to [k, n]
    167         h_hat = split_dim(h_hat, [k, n], 5)
    168         cond = tf.abs(h_hat)>0 # Mask for irrelevant channel estimates
    169         h_hat = tf.reduce_sum(h_hat, axis=-1, keepdims=True) \
    170                 / tf.cast(2,h_hat.dtype)
    171         h_hat = tf.repeat(h_hat, n, axis=-1)
    172         h_hat = tf.where(cond, h_hat, 0) # Mask irrelevant channel estimates
    173         h_hat = tf.reshape(h_hat, h_ls_shape)
    174 
    175         # The error variance gets reduced by a factor of two
    176         err_var /= tf.cast(2, err_var.dtype)
    177 
    178         return h_hat, err_var