pilot_pattern.py (13332B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition and functions related to pilot patterns""" 6 7 import tensorflow as tf 8 import numpy as np 9 import matplotlib.pyplot as plt 10 from matplotlib import colors 11 from sionna.utils import QAMSource 12 13 14 class PilotPattern(): 15 # pylint: disable=line-too-long 16 r"""Class defining a pilot pattern for an OFDM ResourceGrid. 17 18 This class defines a pilot pattern object that is used to configure 19 an OFDM :class:`~sionna.ofdm.ResourceGrid`. 20 21 Parameters 22 ---------- 23 mask : [num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], bool 24 Tensor indicating resource elements that are reserved for pilot transmissions. 25 26 pilots : [num_tx, num_streams_per_tx, num_pilots], tf.complex 27 The pilot symbols to be mapped onto the ``mask``. 28 29 trainable : bool 30 Indicates if ``pilots`` is a trainable `Variable`. 31 Defaults to `False`. 32 33 normalize : bool 34 Indicates if the ``pilots`` should be normalized to an average 35 energy of one across the last dimension. This can be useful to 36 ensure that trainable ``pilots`` have a finite energy. 37 Defaults to `False`. 38 39 dtype : tf.Dtype 40 Defines the datatype for internal calculations and the output 41 dtype. Defaults to `tf.complex64`. 42 """ 43 def __init__(self, mask, pilots, trainable=False, normalize=False, 44 dtype=tf.complex64): 45 super().__init__() 46 self._dtype = dtype 47 self._mask = tf.cast(mask, tf.int32) 48 self._pilots = tf.Variable(tf.cast(pilots, self._dtype), trainable) 49 self.normalize = normalize 50 self._check_settings() 51 52 @property 53 def num_tx(self): 54 """Number of transmitters""" 55 return self._mask.shape[0] 56 57 @property 58 def num_streams_per_tx(self): 59 """Number of streams per transmitter""" 60 return self._mask.shape[1] 61 62 @ property 63 def num_ofdm_symbols(self): 64 """Number of OFDM symbols""" 65 return self._mask.shape[2] 66 67 @ property 68 def num_effective_subcarriers(self): 69 """Number of effectvie subcarriers""" 70 return self._mask.shape[3] 71 72 @property 73 def num_pilot_symbols(self): 74 """Number of pilot symbols per transmit stream.""" 75 return tf.shape(self._pilots)[-1] 76 77 @property 78 def num_data_symbols(self): 79 """ Number of data symbols per transmit stream.""" 80 return tf.shape(self._mask)[-1]*tf.shape(self._mask)[-2] - \ 81 self.num_pilot_symbols 82 83 @property 84 def normalize(self): 85 """Returns or sets the flag indicating if the pilots 86 are normalized or not 87 """ 88 return self._normalize 89 90 @normalize.setter 91 def normalize(self, value): 92 self._normalize = tf.cast(value, tf.bool) 93 94 @property 95 def mask(self): 96 """Mask of the pilot pattern""" 97 return self._mask 98 99 @property 100 def pilots(self): 101 """Returns or sets the possibly normalized tensor of pilot symbols. 102 If pilots are normalized, the normalization will be applied 103 after new values for pilots have been set. If this is 104 not the desired behavior, turn normalization off. 105 """ 106 def norm_pilots(): 107 scale = tf.abs(self._pilots)**2 108 scale = 1/tf.sqrt(tf.reduce_mean(scale, axis=-1, keepdims=True)) 109 scale = tf.cast(scale, self._dtype) 110 return scale*self._pilots 111 112 return tf.cond(self.normalize, norm_pilots, lambda: self._pilots) 113 114 @pilots.setter 115 def pilots(self, value): 116 self._pilots.assign(value) 117 118 def _check_settings(self): 119 """Validate that all properties define a valid pilot pattern.""" 120 121 assert tf.rank(self._mask)==4, "`mask` must have four dimensions." 122 assert tf.rank(self._pilots)==3, "`pilots` must have three dimensions." 123 assert np.array_equal(self._mask.shape[:2], self._pilots.shape[:2]), \ 124 "The first two dimensions of `mask` and `pilots` must be equal." 125 126 num_pilots = tf.reduce_sum(self._mask, axis=(-2,-1)) 127 assert tf.reduce_min(num_pilots)==tf.reduce_max(num_pilots), \ 128 """The number of nonzero elements in the masks for all transmitters 129 and streams must be identical.""" 130 131 assert self.num_pilot_symbols==tf.reduce_max(num_pilots), \ 132 """The shape of the last dimension of `pilots` must equal 133 the number of non-zero entries within the last two 134 dimensions of `mask`.""" 135 136 return True 137 138 @property 139 def trainable(self): 140 """Returns if pilots are trainable or not""" 141 return self._pilots.trainable 142 143 144 def show(self, tx_ind=None, stream_ind=None, show_pilot_ind=False): 145 """Visualizes the pilot patterns for some transmitters and streams. 146 147 Input 148 ----- 149 tx_ind : list, int 150 Indicates the indices of transmitters to be included. 151 Defaults to `None`, i.e., all transmitters included. 152 153 stream_ind : list, int 154 Indicates the indices of streams to be included. 155 Defaults to `None`, i.e., all streams included. 156 157 show_pilot_ind : bool 158 Indicates if the indices of the pilot symbols should be shown. 159 160 Output 161 ------ 162 list : matplotlib.figure.Figure 163 List of matplot figure objects showing each the pilot pattern 164 from a specific transmitter and stream. 165 """ 166 mask = self.mask.numpy() 167 pilots = self.pilots.numpy() 168 169 if tx_ind is None: 170 tx_ind = range(0, self.num_tx) 171 elif not isinstance(tx_ind, list): 172 tx_ind = [tx_ind] 173 174 if stream_ind is None: 175 stream_ind = range(0, self.num_streams_per_tx) 176 elif not isinstance(stream_ind, list): 177 stream_ind = [stream_ind] 178 179 figs = [] 180 for i in tx_ind: 181 for j in stream_ind: 182 q = np.zeros_like(mask[0,0]) 183 q[np.where(mask[i,j])] = (np.abs(pilots[i,j])==0) + 1 184 legend = ["Data", "Pilots", "Masked"] 185 fig = plt.figure() 186 plt.title(f"TX {i} - Stream {j}") 187 plt.xlabel("OFDM Symbol") 188 plt.ylabel("Subcarrier Index") 189 plt.xticks(range(0, q.shape[1])) 190 cmap = plt.cm.tab20c 191 b = np.arange(0, 4) 192 norm = colors.BoundaryNorm(b, cmap.N) 193 im = plt.imshow(np.transpose(q), origin="lower", aspect="auto", norm=norm, cmap=cmap) 194 cbar = plt.colorbar(im) 195 cbar.set_ticks(b[:-1]+0.5) 196 cbar.set_ticklabels(legend) 197 198 if show_pilot_ind: 199 c = 0 200 for t in range(self.num_ofdm_symbols): 201 for k in range(self.num_effective_subcarriers): 202 if mask[i,j][t,k]: 203 if np.abs(pilots[i,j,c])>0: 204 plt.annotate(c, [t, k]) 205 c+=1 206 figs.append(fig) 207 208 return figs 209 210 class EmptyPilotPattern(PilotPattern): 211 """Creates an empty pilot pattern. 212 213 Generates a instance of :class:`~sionna.ofdm.PilotPattern` with 214 an empty ``mask`` and ``pilots``. 215 216 Parameters 217 ---------- 218 num_tx : int 219 Number of transmitters. 220 221 num_streams_per_tx : int 222 Number of streams per transmitter. 223 224 num_ofdm_symbols : int 225 Number of OFDM symbols. 226 227 num_effective_subcarriers : int 228 Number of effective subcarriers 229 that are available for the transmission of data and pilots. 230 Note that this number is generally smaller than the ``fft_size`` 231 due to nulled subcarriers. 232 233 dtype : tf.Dtype 234 Defines the datatype for internal calculations and the output 235 dtype. Defaults to `tf.complex64`. 236 """ 237 def __init__(self, 238 num_tx, 239 num_streams_per_tx, 240 num_ofdm_symbols, 241 num_effective_subcarriers, 242 dtype=tf.complex64): 243 244 assert num_tx > 0, \ 245 "`num_tx` must be positive`." 246 assert num_streams_per_tx > 0, \ 247 "`num_streams_per_tx` must be positive`." 248 assert num_ofdm_symbols > 0, \ 249 "`num_ofdm_symbols` must be positive`." 250 assert num_effective_subcarriers > 0, \ 251 "`num_effective_subcarriers` must be positive`." 252 253 shape = [num_tx, num_streams_per_tx, num_ofdm_symbols, 254 num_effective_subcarriers] 255 mask = tf.zeros(shape, tf.bool) 256 pilots = tf.zeros(shape[:2]+[0], dtype) 257 super().__init__(mask, pilots, trainable=False, normalize=False, 258 dtype=dtype) 259 260 class KroneckerPilotPattern(PilotPattern): 261 """Simple orthogonal pilot pattern with Kronecker structure. 262 263 This function generates an instance of :class:`~sionna.ofdm.PilotPattern` 264 that allocates non-overlapping pilot sequences for all transmitters and 265 streams on specified OFDM symbols. As the same pilot sequences are reused 266 across those OFDM symbols, the resulting pilot pattern has a frequency-time 267 Kronecker structure. This structure enables a very efficient implementation 268 of the LMMSE channel estimator. Each pilot sequence is constructed from 269 randomly drawn QPSK constellation points. 270 271 Parameters 272 ---------- 273 resource_grid : ResourceGrid 274 An instance of a :class:`~sionna.ofdm.ResourceGrid`. 275 276 pilot_ofdm_symbol_indices : list, int 277 List of integers defining the OFDM symbol indices that are reserved 278 for pilots. 279 280 normalize : bool 281 Indicates if the ``pilots`` should be normalized to an average 282 energy of one across the last dimension. 283 Defaults to `True`. 284 285 seed : int 286 Seed for the generation of the pilot sequence. Different seed values 287 lead to different sequences. Defaults to 0. 288 289 dtype : tf.Dtype 290 Defines the datatype for internal calculations and the output 291 dtype. Defaults to `tf.complex64`. 292 293 Note 294 ---- 295 It is required that the ``resource_grid``'s property 296 ``num_effective_subcarriers`` is an 297 integer multiple of ``num_tx * num_streams_per_tx``. This condition is 298 required to ensure that all transmitters and streams get 299 non-overlapping pilot sequences. For a large number of streams and/or 300 transmitters, the pilot pattern becomes very sparse in the frequency 301 domain. 302 303 Examples 304 -------- 305 >>> rg = ResourceGrid(num_ofdm_symbols=14, 306 ... fft_size=64, 307 ... subcarrier_spacing = 30e3, 308 ... num_tx=4, 309 ... num_streams_per_tx=2, 310 ... pilot_pattern = "kronecker", 311 ... pilot_ofdm_symbol_indices = [2, 11]) 312 >>> rg.pilot_pattern.show(); 313 314 .. image:: ../figures/kronecker_pilot_pattern.png 315 316 """ 317 def __init__(self, 318 resource_grid, 319 pilot_ofdm_symbol_indices, 320 normalize=True, 321 seed=0, 322 dtype=tf.complex64): 323 324 num_tx = resource_grid.num_tx 325 num_streams_per_tx = resource_grid.num_streams_per_tx 326 num_ofdm_symbols = resource_grid.num_ofdm_symbols 327 num_effective_subcarriers = resource_grid.num_effective_subcarriers 328 self._dtype = dtype 329 330 # Number of OFDM symbols carrying pilots 331 num_pilot_symbols = len(pilot_ofdm_symbol_indices) 332 333 # Compute the total number of required orthogonal sequences 334 num_seq = num_tx*num_streams_per_tx 335 336 # Compute the length of a pilot sequence 337 num_pilots = num_pilot_symbols*num_effective_subcarriers/num_seq 338 assert (num_pilots/num_pilot_symbols)%1==0, \ 339 """`num_effective_subcarriers` must be an integer multiple of 340 `num_tx`*`num_streams_per_tx`.""" 341 342 # Number of pilots per OFDM symbol 343 num_pilots_per_symbol = int(num_pilots/num_pilot_symbols) 344 345 # Prepare empty mask and pilots 346 shape = [num_tx, num_streams_per_tx, 347 num_ofdm_symbols,num_effective_subcarriers] 348 mask = np.zeros(shape, bool) 349 shape[2] = num_pilot_symbols 350 pilots = np.zeros(shape, np.complex64) 351 352 # Populate all selected OFDM symbols in the mask 353 mask[..., pilot_ofdm_symbol_indices, :] = True 354 355 # Populate the pilots with random QPSK symbols 356 qam_source = QAMSource(2, seed=seed, dtype=self._dtype) 357 for i in range(num_tx): 358 for j in range(num_streams_per_tx): 359 # Generate random QPSK symbols 360 p = qam_source([1,1,num_pilot_symbols,num_pilots_per_symbol]) 361 362 # Place pilots spaced by num_seq to avoid overlap 363 pilots[i,j,:,i*num_streams_per_tx+j::num_seq] = p 364 365 # Reshape the pilots tensor 366 pilots = np.reshape(pilots, [num_tx, num_streams_per_tx, -1]) 367 368 super().__init__(mask, pilots, trainable=False, 369 normalize=normalize, dtype=self._dtype)