resource_grid.py (19504B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition and functions related to the resource grid""" 6 7 import tensorflow as tf 8 import numpy as np 9 from tensorflow.keras.layers import Layer 10 11 from .pilot_pattern import PilotPattern, EmptyPilotPattern, KroneckerPilotPattern # pylint: disable=line-too-long 12 from sionna.utils import flatten_last_dims, flatten_dims, split_dim 13 import matplotlib.pyplot as plt 14 from matplotlib import colors 15 16 17 class ResourceGrid(): 18 # pylint: disable=line-too-long 19 r"""Defines a `ResourceGrid` spanning multiple OFDM symbols and subcarriers. 20 21 Parameters 22 ---------- 23 num_ofdm_symbols : int 24 Number of OFDM symbols. 25 26 fft_size : int 27 FFT size (, i.e., the number of subcarriers). 28 29 subcarrier_spacing : float 30 The subcarrier spacing in Hz. 31 32 num_tx : int 33 Number of transmitters. 34 35 num_streams_per_tx : int 36 Number of streams per transmitter. 37 38 cyclic_prefix_length : int 39 Length of the cyclic prefix. 40 41 num_guard_carriers : int 42 List of two integers defining the number of guardcarriers at the 43 left and right side of the resource grid. 44 45 dc_null : bool 46 Indicates if the DC carrier is nulled or not. 47 48 pilot_pattern : One of [None, "kronecker", "empty", PilotPattern] 49 An instance of :class:`~sionna.ofdm.PilotPattern`, a string 50 shorthand for the :class:`~sionna.ofdm.KroneckerPilotPattern` 51 or :class:`~sionna.ofdm.EmptyPilotPattern`, or `None`. 52 Defaults to `None` which is equivalent to `"empty"`. 53 54 pilot_ofdm_symbol_indices : List, int 55 List of indices of OFDM symbols reserved for pilot transmissions. 56 Only needed if ``pilot_pattern="kronecker"``. Defaults to `None`. 57 58 dtype : tf.Dtype 59 Defines the datatype for internal calculations and the output 60 dtype. Defaults to `tf.complex64`. 61 """ 62 def __init__(self, 63 num_ofdm_symbols, 64 fft_size, 65 subcarrier_spacing, 66 num_tx=1, 67 num_streams_per_tx=1, 68 cyclic_prefix_length=0, 69 num_guard_carriers=(0,0), 70 dc_null=False, 71 pilot_pattern=None, 72 pilot_ofdm_symbol_indices=None, 73 dtype=tf.complex64): 74 super().__init__() 75 self._dtype = dtype 76 self._num_ofdm_symbols = num_ofdm_symbols 77 self._fft_size = fft_size 78 self._subcarrier_spacing = subcarrier_spacing 79 self._cyclic_prefix_length = int(cyclic_prefix_length) 80 self._num_tx = num_tx 81 self._num_streams_per_tx = num_streams_per_tx 82 self._num_guard_carriers = np.array(num_guard_carriers) 83 self._dc_null = dc_null 84 self._pilot_ofdm_symbol_indices = pilot_ofdm_symbol_indices 85 self.pilot_pattern = pilot_pattern 86 self._check_settings() 87 88 @property 89 def cyclic_prefix_length(self): 90 """Length of the cyclic prefix.""" 91 return self._cyclic_prefix_length 92 93 @property 94 def num_tx(self): 95 """Number of transmitters.""" 96 return self._num_tx 97 98 @property 99 def num_streams_per_tx(self): 100 """Number of streams per transmitter.""" 101 return self._num_streams_per_tx 102 103 @property 104 def num_ofdm_symbols(self): 105 """The number of OFDM symbols of the resource grid.""" 106 return self._num_ofdm_symbols 107 108 @property 109 def num_resource_elements(self): 110 """Number of resource elements.""" 111 return self._fft_size*self._num_ofdm_symbols 112 113 @property 114 def num_effective_subcarriers(self): 115 """Number of subcarriers used for data and pilot transmissions.""" 116 n = self._fft_size - self._dc_null - np.sum(self._num_guard_carriers) 117 return n 118 119 @property 120 def effective_subcarrier_ind(self): 121 """Returns the indices of the effective subcarriers.""" 122 num_gc = self._num_guard_carriers 123 sc_ind = range(num_gc[0], self.fft_size-num_gc[1]) 124 if self.dc_null: 125 sc_ind = np.delete(sc_ind, self.dc_ind-num_gc[0]) 126 return sc_ind 127 128 @property 129 def num_data_symbols(self): 130 """Number of resource elements used for data transmissions.""" 131 n = self.num_effective_subcarriers * self._num_ofdm_symbols - \ 132 self.num_pilot_symbols 133 return tf.cast(n, tf.int32) 134 135 @property 136 def num_pilot_symbols(self): 137 """Number of resource elements used for pilot symbols.""" 138 return self.pilot_pattern.num_pilot_symbols 139 140 @property 141 def num_zero_symbols(self): 142 """Number of empty resource elements.""" 143 n = (self._fft_size-self.num_effective_subcarriers) * \ 144 self._num_ofdm_symbols 145 return tf.cast(n, tf.int32) 146 147 @property 148 def num_guard_carriers(self): 149 """Number of left and right guard carriers.""" 150 return self._num_guard_carriers 151 152 @property 153 def dc_ind(self): 154 """Index of the DC subcarrier. 155 156 If ``fft_size`` is odd, the index is (``fft_size``-1)/2. 157 If ``fft_size`` is even, the index is ``fft_size``/2. 158 """ 159 return int(self._fft_size/2 - (self._fft_size%2==1)/2) 160 161 @property 162 def fft_size(self): 163 """The FFT size.""" 164 return self._fft_size 165 166 @property 167 def subcarrier_spacing(self): 168 """The subcarrier spacing [Hz].""" 169 return self._subcarrier_spacing 170 171 @property 172 def ofdm_symbol_duration(self): 173 """Duration of an OFDM symbol with cyclic prefix [s].""" 174 return (1. + self.cyclic_prefix_length/self.fft_size) \ 175 / self.subcarrier_spacing 176 177 @property 178 def bandwidth(self): 179 """The occupied bandwidth [Hz]: ``fft_size*subcarrier_spacing``.""" 180 return self.fft_size*self.subcarrier_spacing 181 182 @property 183 def num_time_samples(self): 184 """The number of time-domain samples occupied by the resource grid.""" 185 return (self.fft_size + self.cyclic_prefix_length) \ 186 * self._num_ofdm_symbols 187 188 @property 189 def dc_null(self): 190 """Indicates if the DC carriers is nulled or not.""" 191 return self._dc_null 192 193 @property 194 def pilot_pattern(self): 195 """The used PilotPattern.""" 196 return self._pilot_pattern 197 198 @pilot_pattern.setter 199 def pilot_pattern(self, value): 200 if value is None: 201 value = EmptyPilotPattern(self._num_tx, 202 self._num_streams_per_tx, 203 self._num_ofdm_symbols, 204 self.num_effective_subcarriers, 205 dtype=self._dtype) 206 elif isinstance(value, PilotPattern): 207 pass 208 elif isinstance(value, str): 209 assert value in ["kronecker", "empty"],\ 210 "Unknown pilot pattern" 211 if value=="empty": 212 value = EmptyPilotPattern(self._num_tx, 213 self._num_streams_per_tx, 214 self._num_ofdm_symbols, 215 self.num_effective_subcarriers, 216 dtype=self._dtype) 217 elif value=="kronecker": 218 assert self._pilot_ofdm_symbol_indices is not None,\ 219 "You must provide pilot_ofdm_symbol_indices." 220 value = KroneckerPilotPattern(self, 221 self._pilot_ofdm_symbol_indices, dtype=self._dtype) 222 else: 223 raise ValueError("Unsupported pilot_pattern") 224 self._pilot_pattern = value 225 226 def _check_settings(self): 227 """Validate that all properties define a valid resource grid""" 228 assert self._num_ofdm_symbols > 0, \ 229 "`num_ofdm_symbols` must be positive`." 230 assert self._fft_size > 0, \ 231 "`fft_size` must be positive`." 232 assert self._cyclic_prefix_length>=0, \ 233 "`cyclic_prefix_length must be nonnegative." 234 assert self._cyclic_prefix_length<=self._fft_size, \ 235 "`cyclic_prefix_length cannot be longer than `fft_size`." 236 assert self._num_tx > 0, \ 237 "`num_tx` must be positive`." 238 assert self._num_streams_per_tx > 0, \ 239 "`num_streams_per_tx` must be positive`." 240 assert len(self._num_guard_carriers)==2, \ 241 "`num_guard_carriers` must have two elements." 242 assert np.all(np.greater_equal(self._num_guard_carriers, 0)), \ 243 "`num_guard_carriers` must have nonnegative entries." 244 assert np.sum(self._num_guard_carriers)<=self._fft_size-self._dc_null,\ 245 "Total number of guardcarriers cannot be larger than `fft_size`." 246 assert self._dtype in [tf.complex64, tf.complex128], \ 247 "dtype must be tf.complex64 or tf.complex128" 248 return True 249 250 def build_type_grid(self): 251 """Returns a tensor indicating the type of each resource element. 252 253 Resource elements can be one of 254 255 - 0 : Data symbol 256 - 1 : Pilot symbol 257 - 2 : Guard carrier symbol 258 - 3 : DC carrier symbol 259 260 Output 261 ------ 262 : [num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.int32 263 Tensor indicating for each transmitter and stream the type of 264 the resource elements of the corresponding resource grid. 265 The type can be one of [0,1,2,3] as explained above. 266 """ 267 shape = [self._num_tx, self._num_streams_per_tx, self._num_ofdm_symbols] 268 gc_l = 2*tf.ones(shape+[self._num_guard_carriers[0]], tf.int32) 269 gc_r = 2*tf.ones(shape+[self._num_guard_carriers[1]], tf.int32) 270 dc = 3*tf.ones(shape + [tf.cast(self._dc_null, tf.int32)], tf.int32) 271 mask = self.pilot_pattern.mask 272 split_ind = self.dc_ind-self._num_guard_carriers[0] 273 rg_type = tf.concat([gc_l, # Left Guards 274 mask[...,:split_ind], # Data & pilots 275 dc, # DC 276 mask[...,split_ind:], # Data & pilots 277 gc_r], -1) # Right guards 278 return rg_type 279 280 def show(self, tx_ind=0, tx_stream_ind=0): 281 """Visualizes the resource grid for a specific transmitter and stream. 282 283 Input 284 ----- 285 tx_ind : int 286 Indicates the transmitter index. 287 288 tx_stream_ind : int 289 Indicates the index of the stream. 290 291 Output 292 ------ 293 : `matplotlib.figure` 294 A handle to a matplot figure object. 295 """ 296 fig = plt.figure() 297 data = self.build_type_grid()[tx_ind, tx_stream_ind] 298 cmap = colors.ListedColormap([[60/256,8/256,72/256], 299 [45/256,91/256,128/256], 300 [45/256,172/256,111/256], 301 [250/256,228/256,62/256]]) 302 bounds=[0,1,2,3,4] 303 norm = colors.BoundaryNorm(bounds, cmap.N) 304 img = plt.imshow(np.transpose(data), interpolation="nearest", 305 origin="lower", cmap=cmap, norm=norm, 306 aspect="auto") 307 cbar = plt.colorbar(img, ticks=[0.5, 1.5, 2.5,3.5], 308 orientation="vertical", shrink=0.8) 309 cbar.set_ticklabels(["Data", "Pilot", "Guard carrier", "DC carrier"]) 310 plt.title("OFDM Resource Grid") 311 plt.ylabel("Subcarrier Index") 312 plt.xlabel("OFDM Symbol") 313 plt.xticks(range(0, data.shape[0])) 314 315 return fig 316 317 class ResourceGridMapper(Layer): 318 # pylint: disable=line-too-long 319 r"""ResourceGridMapper(resource_grid, dtype=tf.complex64, **kwargs) 320 321 Maps a tensor of modulated data symbols to a ResourceGrid. 322 323 This layer takes as input a tensor of modulated data symbols 324 and maps them together with pilot symbols onto an 325 OFDM :class:`~sionna.ofdm.ResourceGrid`. The output can be 326 converted to a time-domain signal with the 327 :class:`~sionna.ofdm.Modulator` or further processed in the 328 frequency domain. 329 330 Parameters 331 ---------- 332 resource_grid : ResourceGrid 333 An instance of :class:`~sionna.ofdm.ResourceGrid`. 334 335 dtype : tf.Dtype 336 Datatype for internal calculations and the output dtype. 337 Defaults to `tf.complex64`. 338 339 Input 340 ----- 341 : [batch_size, num_tx, num_streams_per_tx, num_data_symbols], tf.complex 342 The modulated data symbols to be mapped onto the resource grid. 343 344 Output 345 ------ 346 : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex 347 The full OFDM resource grid in the frequency domain. 348 """ 349 def __init__(self, resource_grid, dtype=tf.complex64, **kwargs): 350 super().__init__(dtype=dtype, **kwargs) 351 self._resource_grid = resource_grid 352 353 def build(self, input_shape): # pylint: disable=unused-argument 354 """Precompute a tensor of shape 355 [num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size] 356 which is prefilled with pilots and stores indices 357 to scatter data symbols. 358 """ 359 self._rg_type = self._resource_grid.build_type_grid() 360 self._pilot_ind = tf.where(self._rg_type==1) 361 self._data_ind = tf.where(self._rg_type==0) 362 363 def call(self, inputs): 364 # Map pilots on empty resource grid 365 pilots = flatten_last_dims(self._resource_grid.pilot_pattern.pilots, 3) 366 template = tf.scatter_nd(self._pilot_ind, 367 pilots, 368 self._rg_type.shape) 369 template = tf.expand_dims(template, -1) 370 371 # Broadcast the resource grid template to batch_size 372 batch_size = tf.shape(inputs)[0] 373 new_shape = tf.concat([tf.shape(template)[:-1], [batch_size]], 0) 374 template = tf.broadcast_to(template, new_shape) 375 376 # Flatten the inputs and put batch_dim last for scatter update 377 inputs = tf.transpose(flatten_last_dims(inputs, 3)) 378 rg = tf.tensor_scatter_nd_update(template, self._data_ind, inputs) 379 rg = tf.transpose(rg, [4, 0, 1, 2, 3]) 380 381 return rg 382 383 class ResourceGridDemapper(Layer): 384 # pylint: disable=line-too-long 385 r"""ResourceGridDemapper(resource_grid, stream_management, dtype=tf.complex64, **kwargs) 386 387 Extracts data-carrying resource elements from a resource grid. 388 389 This layer takes as input an OFDM :class:`~sionna.ofdm.ResourceGrid` and 390 extracts the data-carrying resource elements. In other words, it implements 391 the reverse operation of :class:`~sionna.ofdm.ResourceGridMapper`. 392 393 Parameters 394 ---------- 395 resource_grid : ResourceGrid 396 An instance of :class:`~sionna.ofdm.ResourceGrid`. 397 398 stream_management : StreamManagement 399 An instance of :class:`~sionna.mimo.StreamManagement`. 400 401 dtype : tf.Dtype 402 Datatype for internal calculations and the output dtype. 403 Defaults to `tf.complex64`. 404 405 Input 406 ----- 407 : [batch_size, num_rx, num_streams_per_rx, num_ofdm_symbols, fft_size, data_dim] 408 The full OFDM resource grid in the frequency domain. 409 The last dimension `data_dim` is optional. If `data_dim` 410 is used, it refers to the dimensionality of the data that should be 411 demapped to individual streams. An example would be LLRs. 412 413 Output 414 ------ 415 : [batch_size, num_rx, num_streams_per_rx, num_data_symbols, data_dim] 416 The data that were mapped into the resource grid. 417 The last dimension `data_dim` is only returned if it was used for the 418 input. 419 """ 420 def __init__(self, 421 resource_grid, 422 stream_management, 423 dtype=tf.complex64, 424 **kwargs): 425 super().__init__(dtype=dtype, **kwargs) 426 self._stream_management = stream_management 427 self._resource_grid = resource_grid 428 429 # Precompute indices to extract data symbols 430 mask = resource_grid.pilot_pattern.mask 431 num_data_symbols = resource_grid.pilot_pattern.num_data_symbols 432 data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") 433 self._data_ind = data_ind[...,:num_data_symbols] 434 435 def call(self, y): # pylint: disable=arguments-renamed 436 437 # y has shape 438 # [batch_size, num_rx, num_streams_per_rx, num_ofdm_symbols,... 439 # ..., fft_size, data_dim] 440 441 # If data_dim is not provided, add a dummy dimension 442 if len(y.shape)==5: 443 y = tf.expand_dims(y, -1) 444 445 # Remove nulled subcarriers from y (guards, dc). New shape: 446 # [batch_size, num_rx, num_rx_ant, ... 447 # ..., num_ofdm_symbols, num_effective_subcarriers, data dim] 448 y = tf.gather(y, self._resource_grid.effective_subcarrier_ind, axis=-2) 449 450 # Transpose tensor to shape 451 # [num_rx, num_streams_per_rx, num_ofdm_symbols,... 452 # ..., num_effective_subcarriers, data_dim, batch_size] 453 y = tf.transpose(y, [1, 2, 3, 4, 5, 0]) 454 455 # Merge num_rx amd num_streams_per_rx 456 # [num_rx * num_streams_per_rx, num_ofdm_symbols,... 457 # ...,num_effective_subcarriers, data_dim, batch_size] 458 y = flatten_dims(y, 2, 0) 459 460 # Put first dimension into the right ordering 461 stream_ind = self._stream_management.stream_ind 462 y = tf.gather(y, stream_ind, axis=0) 463 464 # Reshape first dimensions to [num_tx, num_streams] so that 465 # we can compared to the way the streams were created. 466 # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,... 467 # ..., data_dim, batch_size] 468 num_streams = self._stream_management.num_streams_per_tx 469 num_tx = self._stream_management.num_tx 470 y = split_dim(y, [num_tx, num_streams], 0) 471 472 # Flatten resource grid dimensions 473 # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,... 474 # ..., data_dim, batch_size] 475 y = flatten_dims(y, 2, 2) 476 477 # Gather data symbols 478 # [num_tx, num_streams, num_data_symbols, data_dim, batch_size] 479 y = tf.gather(y, self._data_ind, batch_dims=2, axis=2) 480 481 # Put batch_dim first 482 # [batch_size, num_tx, num_streams, num_data_symbols] 483 y = tf.transpose(y, [4, 0, 1, 2, 3]) 484 485 # Squeeze data_dim 486 if y.shape[-1]==1: 487 y = tf.squeeze(y, -1) 488 489 return y 490 491 class RemoveNulledSubcarriers(Layer): 492 # pylint: disable=line-too-long 493 r"""RemoveNulledSubcarriers(resource_grid, **kwargs) 494 495 Removes nulled guard and/or DC subcarriers from a resource grid. 496 497 Parameters 498 ---------- 499 resource_grid : ResourceGrid 500 An instance of :class:`~sionna.ofdm.ResourceGrid`. 501 502 Input 503 ----- 504 : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex64 505 Full resource grid. 506 507 Output 508 ------ 509 : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex64 510 Resource grid without nulled subcarriers. 511 """ 512 def __init__(self, resource_grid, **kwargs): 513 self._sc_ind = resource_grid.effective_subcarrier_ind 514 super().__init__(**kwargs) 515 516 def call(self, inputs): 517 return tf.gather(inputs, self._sc_ind, axis=-1)