pusch_channel_estimation.py (7502B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """PUSCH Channel Estimation for the nr (5G) sub-package of the Sionna library. 6 """ 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from sionna.ofdm import LSChannelEstimator 10 from sionna.utils import expand_to_rank, split_dim 11 12 class PUSCHLSChannelEstimator(LSChannelEstimator, Layer): 13 # pylint: disable=line-too-long 14 r"""LSChannelEstimator(resource_grid, dmrs_length, dmrs_additional_position, num_cdm_groups_without_data, interpolation_type="nn", interpolator=None, dtype=tf.complex64, **kwargs) 15 16 Layer implementing least-squares (LS) channel estimation for NR PUSCH Transmissions. 17 18 After LS channel estimation at the pilot positions, the channel estimates 19 and error variances are interpolated accross the entire resource grid using 20 a specified interpolation function. 21 22 The implementation is similar to that of :class:`~sionna.ofdm.LSChannelEstimator`. 23 However, it additional takes into account the separation of streams in the same CDM group 24 as defined in :class:`~sionna.nr.PUSCHDMRSConfig`. This is done through 25 frequency and time averaging of adjacent LS channel estimates. 26 27 Parameters 28 ---------- 29 resource_grid : ResourceGrid 30 An instance of :class:`~sionna.ofdm.ResourceGrid` 31 32 dmrs_length : int, [1,2] 33 Length of DMRS symbols. See :class:`~sionna.nr.PUSCHDMRSConfig`. 34 35 dmrs_additional_position : int, [0,1,2,3] 36 Number of additional DMRS symbols. 37 See :class:`~sionna.nr.PUSCHDMRSConfig`. 38 39 num_cdm_groups_without_data : int, [1,2,3] 40 Number of CDM groups masked for data transmissions. 41 See :class:`~sionna.nr.PUSCHDMRSConfig`. 42 43 interpolation_type : One of ["nn", "lin", "lin_time_avg"], string 44 The interpolation method to be used. 45 It is ignored if ``interpolator`` is not `None`. 46 Available options are :class:`~sionna.ofdm.NearestNeighborInterpolator` (`"nn`") 47 or :class:`~sionna.ofdm.LinearInterpolator` without (`"lin"`) or with 48 averaging across OFDM symbols (`"lin_time_avg"`). 49 Defaults to "nn". 50 51 interpolator : BaseChannelInterpolator 52 An instance of :class:`~sionna.ofdm.BaseChannelInterpolator`, 53 such as :class:`~sionna.ofdm.LMMSEInterpolator`, 54 or `None`. In the latter case, the interpolator specified 55 by ``interpolation_type`` is used. 56 Otherwise, the ``interpolator`` is used and ``interpolation_type`` 57 is ignored. 58 Defaults to `None`. 59 60 dtype : tf.Dtype 61 Datatype for internal calculations and the output dtype. 62 Defaults to `tf.complex64`. 63 64 Input 65 ----- 66 (y, no) : 67 Tuple: 68 69 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols,fft_size], tf.complex 70 Observed resource grid 71 72 no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float 73 Variance of the AWGN 74 75 Output 76 ------ 77 h_ls : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols,fft_size], tf.complex 78 Channel estimates across the entire resource grid for all 79 transmitters and streams 80 81 err_var : Same shape as ``h_ls``, tf.float 82 Channel estimation error variance across the entire resource grid 83 for all transmitters and streams 84 """ 85 def __init__(self, 86 resource_grid, 87 dmrs_length, 88 dmrs_additional_position, 89 num_cdm_groups_without_data, 90 interpolation_type="nn", 91 interpolator=None, 92 dtype=tf.complex64, 93 **kwargs): 94 super().__init__(resource_grid, 95 interpolation_type, 96 interpolator, 97 dtype, **kwargs) 98 99 self._dmrs_length = dmrs_length 100 self._dmrs_additional_position = dmrs_additional_position 101 self._num_cdm_groups_without_data = num_cdm_groups_without_data 102 103 # Number of DMRS OFDM symbols 104 self._num_dmrs_syms = self._dmrs_length \ 105 * (self._dmrs_additional_position+1) 106 107 # Number of pilot symbols per DMRS OFDM symbol 108 # Some pilot symbols can be zero (for masking) 109 self._num_pilots_per_dmrs_sym = int( 110 self._pilot_pattern.pilots.shape[-1]/self._num_dmrs_syms) 111 112 def estimate_at_pilot_locations(self, y_pilots, no): 113 # y_pilots : [batch_size, num_rx, num_rx_ant, num_tx, num_streams, 114 # num_pilot_symbols], tf.complex 115 # The observed signals for the pilot-carrying resource elements. 116 117 # no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, 118 # tf.float 119 # The variance of the AWGN. 120 121 # Compute LS channel estimates 122 # Note: Some might be Inf because pilots=0, but we do not care 123 # as only the valid estimates will be considered during interpolation. 124 # We do a save division to replace Inf by 0. 125 # Broadcasting from pilots here is automatic since pilots have shape 126 # [num_tx, num_streams, num_pilot_symbols] 127 h_ls = tf.math.divide_no_nan(y_pilots, self._pilot_pattern.pilots) 128 h_ls_shape = tf.shape(h_ls) 129 130 # Compute error variance and broadcast to the same shape as h_ls 131 # Expand rank of no for broadcasting 132 no = expand_to_rank(no, tf.rank(h_ls), -1) 133 134 # Expand rank of pilots for broadcasting 135 pilots = expand_to_rank(self._pilot_pattern.pilots, tf.rank(h_ls), 0) 136 137 # Compute error variance, broadcastable to the shape of h_ls 138 err_var = tf.math.divide_no_nan(no, tf.abs(pilots)**2) 139 140 # In order to deal with CDM, we need to do (optional) time and 141 # frequency averaging of the LS estimates 142 h_hat = h_ls 143 144 # (Optional) Time-averaging across adjacent DMRS OFDM symbols 145 if self._dmrs_length==2: 146 # Reshape last dim to [num_dmrs_syms, num_pilots_per_dmrs_sym] 147 h_hat = split_dim(h_hat, [self._num_dmrs_syms, 148 self._num_pilots_per_dmrs_sym], 5) 149 150 # Average adjacent DMRS symbols in time domain 151 h_hat = (h_hat[...,0::2,:]+h_hat[...,1::2,:]) \ 152 / tf.cast(2, h_hat.dtype) 153 h_hat = tf.repeat(h_hat, 2, axis=-2) 154 h_hat = tf.reshape(h_hat, h_ls_shape) 155 156 # The error variance gets reduced by a factor of two 157 err_var /= tf.cast(2, err_var.dtype) 158 159 # Frequency-averaging between adjacent channel estimates 160 161 # Compute number of elements across which frequency averaging should 162 # be done. This includes the zeroed elements. 163 n = 2*self._num_cdm_groups_without_data 164 k = int(h_hat.shape[-1]/n) # Second dimension 165 166 # Reshape last dimension to [k, n] 167 h_hat = split_dim(h_hat, [k, n], 5) 168 cond = tf.abs(h_hat)>0 # Mask for irrelevant channel estimates 169 h_hat = tf.reduce_sum(h_hat, axis=-1, keepdims=True) \ 170 / tf.cast(2,h_hat.dtype) 171 h_hat = tf.repeat(h_hat, n, axis=-1) 172 h_hat = tf.where(cond, h_hat, 0) # Mask irrelevant channel estimates 173 h_hat = tf.reshape(h_hat, h_ls_shape) 174 175 # The error variance gets reduced by a factor of two 176 err_var /= tf.cast(2, err_var.dtype) 177 178 return h_hat, err_var