demodulator.py (7889B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition for the OFDM Demodulator""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from tensorflow.signal import fftshift 10 from sionna.constants import PI 11 from sionna.utils import expand_to_rank 12 from sionna.signal import fft 13 import numpy as np 14 15 class OFDMDemodulator(Layer): 16 # pylint: disable=line-too-long 17 r""" 18 OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs) 19 20 Computes the frequency-domain representation of an OFDM waveform 21 with cyclic prefix removal. 22 23 The demodulator assumes that the input sequence is generated by the 24 :class:`~sionna.channel.TimeChannel`. For a single pair of antennas, 25 the received signal sequence is given as: 26 27 .. math:: 28 29 y_b = \sum_{\ell =L_\text{min}}^{L_\text{max}} \bar{h}_\ell x_{b-\ell} + w_b, \quad b \in[L_\text{min}, N_B+L_\text{max}-1] 30 31 where :math:`\bar{h}_\ell` are the discrete-time channel taps, 32 :math:`x_{b}` is the the transmitted signal, 33 and :math:`w_\ell` Gaussian noise. 34 35 Starting from the first symbol, the demodulator cuts the input 36 sequence into pieces of size ``cyclic_prefix_length + fft_size``, 37 and throws away any trailing symbols. For each piece, the cyclic 38 prefix is removed and the ``fft_size``-point discrete Fourier 39 transform is computed. It is also possible that every OFDM symbol 40 has a cyclic prefix of different length. 41 42 Since the input sequence starts at time :math:`L_\text{min}`, 43 the FFT-window has a timing offset of :math:`L_\text{min}` symbols, 44 which leads to a subcarrier-dependent phase shift of 45 :math:`e^{\frac{j2\pi k L_\text{min}}{N}}`, where :math:`k` 46 is the subcarrier index, :math:`N` is the FFT size, 47 and :math:`L_\text{min} \le 0` is the largest negative time lag of 48 the discrete-time channel impulse response. This phase shift 49 is removed in this layer, by explicitly multiplying 50 each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`. 51 This is a very important step to enable channel estimation with 52 sparse pilot patterns that needs to interpolate the channel frequency 53 response accross subcarriers. It also ensures that the 54 channel frequency response `seen` by the time-domain channel 55 is close to the :class:`~sionna.channel.OFDMChannel`. 56 57 Parameters 58 ---------- 59 fft_size : int 60 FFT size (, i.e., the number of subcarriers). 61 62 l_min : int 63 The largest negative time lag of the discrete-time channel 64 impulse response. It should be the same value as that used by the 65 `cir_to_time_channel` function. 66 67 cyclic_prefix_length : scalar or [num_ofdm_symbols], int 68 Integer or vector of integers indicating the length of the 69 cyclic prefix that is prepended to each OFDM symbol. None of its 70 elements can be larger than the FFT size. 71 Defaults to 0. 72 73 Input 74 ----- 75 :[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)+n] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)+n], tf.complex 76 Tensor containing the time-domain signal along the last dimension. 77 `n` is a nonnegative integer. 78 79 Output 80 ------ 81 :[...,num_ofdm_symbols,fft_size], tf.complex 82 Tensor containing the OFDM resource grid along the last 83 two dimension. 84 """ 85 86 def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs): 87 super().__init__(**kwargs) 88 self._fft_size = None 89 self._l_min = None 90 self._cyclic_prefix_length = None 91 self.fft_size = fft_size 92 self.l_min = l_min 93 self.cyclic_prefix_length = cyclic_prefix_length 94 95 @property 96 def fft_size(self): 97 return self._fft_size 98 99 @fft_size.setter 100 def fft_size(self, value): 101 assert value>0, "`fft_size` must be positive." 102 self._fft_size = int(value) 103 104 @property 105 def l_min(self): 106 return self._l_min 107 108 @l_min.setter 109 def l_min(self, value): 110 assert value<=0, "l_min must be nonpositive." 111 self._l_min = int(value) 112 113 @property 114 def cyclic_prefix_length(self): 115 return self._cyclic_prefix_length 116 117 @cyclic_prefix_length.setter 118 def cyclic_prefix_length(self, value): 119 value = tf.cast(value, tf.int32) 120 if not tf.reduce_all(value>=0): 121 msg = "`cyclic_prefix_length` must be nonnegative." 122 raise ValueError(msg) 123 if not 0<= tf.rank(value)<=1: 124 msg = "`cyclic_prefix_length` must be of rank 0 or 1" 125 raise ValueError(msg) 126 self._cyclic_prefix_length = value 127 128 def build(self, input_shape): # pylint: disable=unused-argument 129 # Compute phase correction terms to to channel 130 tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \ 131 / tf.cast(self.fft_size, tf.float32) \ 132 * tf.range(self.fft_size, dtype=tf.float32) 133 self._phase_compensation = tf.exp(tf.complex(0., tmp)) 134 135 if len(self.cyclic_prefix_length.shape)==0: 136 # Compute number of elements that will be truncated 137 self._rest = np.mod(input_shape[-1], 138 self.fft_size + self.cyclic_prefix_length) 139 140 # Compute number of full OFDM symbols to be demodulated 141 self._num_ofdm_symbols = np.floor_divide( 142 input_shape[-1]-self._rest, 143 self.fft_size + self.cyclic_prefix_length) 144 else: 145 # Deal with individual cp lengths for OFDM symbols 146 # Compute the relevant indices to gather for 147 # every OFDM symbol from the time domain input 148 num_ofdm_symbols = self.cyclic_prefix_length.shape[0] 149 row_lengths = self.cyclic_prefix_length + self.fft_size 150 offsets = tf.math.cumsum(tf.concat([[0], row_lengths], 151 axis=0)[:-1]) 152 offsets = tf.expand_dims(offsets, 1) 153 ind = tf.repeat(tf.range(start=0, 154 limit=self.fft_size)[tf.newaxis,:], 155 repeats=num_ofdm_symbols, axis=0) 156 ind += self.cyclic_prefix_length[:, tf.newaxis] 157 ind += offsets 158 # [num_ofdm_symbols, fft_size] 159 self._ind = ind 160 161 def call(self, inputs): 162 """Demodulate OFDM waveform onto a resource grid. 163 164 Args: 165 inputs (tf.complex64): 166 `[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)]`. 167 168 Returns: 169 `tf.complex64` : The demodulated inputs of shape 170 `[...,num_ofdm_symbols, fft_size]`. 171 """ 172 if len(self.cyclic_prefix_length.shape)==0: 173 # Same CP length for all OFDM symbols 174 # Cut last samples that do not fit into an OFDM symbol 175 inputs = inputs if self._rest==0 else inputs[...,:-self._rest] 176 177 # Reshape input to separate OFDM symbols 178 new_shape = tf.concat( 179 [tf.shape(inputs)[:-1], 180 [self._num_ofdm_symbols], 181 [self.fft_size + self.cyclic_prefix_length]], 0) 182 x = tf.reshape(inputs, new_shape) 183 184 # Remove cyclic prefix 185 x = x[...,self.cyclic_prefix_length:] 186 187 else: 188 # Individual CP length for OFDM symbols 189 x = tf.gather(inputs, self._ind, axis=-1) 190 191 # Compute FFT 192 x = fft(x) 193 194 # Apply phase shift compensation to all subcarriers 195 rot = tf.cast(self._phase_compensation, x.dtype) 196 rot = expand_to_rank(rot, tf.rank(x), 0) 197 x = x * rot 198 199 # Shift DC subcarrier to the middle 200 x = fftshift(x, axes=-1) 201 202 return x