equalization.py (20362B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition and functions related to OFDM channel equalization""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 import sionna 10 from sionna.utils import flatten_dims, split_dim, flatten_last_dims, expand_to_rank 11 from sionna.mimo import lmmse_equalizer, zf_equalizer, mf_equalizer 12 from sionna.ofdm import RemoveNulledSubcarriers 13 14 15 class OFDMEqualizer(Layer): 16 # pylint: disable=line-too-long 17 r"""OFDMEqualizer(equalizer, resource_grid, stream_management, dtype=tf.complex64, **kwargs) 18 19 Layer that wraps a MIMO equalizer for use with the OFDM waveform. 20 21 The parameter ``equalizer`` is a callable (e.g., a function) that 22 implements a MIMO equalization algorithm for arbitrary batch dimensions. 23 24 This class pre-processes the received resource grid ``y`` and channel 25 estimate ``h_hat``, and computes for each receiver the 26 noise-plus-interference covariance matrix according to the OFDM and stream 27 configuration provided by the ``resource_grid`` and 28 ``stream_management``, which also accounts for the channel 29 estimation error variance ``err_var``. These quantities serve as input 30 to the equalization algorithm that is implemented by the callable ``equalizer``. 31 This layer computes soft-symbol estimates together with effective noise 32 variances for all streams which can, e.g., be used by a 33 :class:`~sionna.mapping.Demapper` to obtain LLRs. 34 35 Note 36 ----- 37 The callable ``equalizer`` must take three inputs: 38 39 * **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals. 40 * **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices. 41 * **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices. 42 43 It must generate two outputs: 44 45 * **x_hat** ([...,num_streams_per_rx], tf.complex) -- 1+D tensor representing the estimated symbol vectors. 46 * **no_eff** (tf.float) -- Tensor of the same shape as ``x_hat`` containing the effective noise variance estimates. 47 48 Parameters 49 ---------- 50 equalizer : Callable 51 Callable object (e.g., a function) that implements a MIMO equalization 52 algorithm for arbitrary batch dimensions 53 54 resource_grid : ResourceGrid 55 Instance of :class:`~sionna.ofdm.ResourceGrid` 56 57 stream_management : StreamManagement 58 Instance of :class:`~sionna.mimo.StreamManagement` 59 60 dtype : tf.Dtype 61 Datatype for internal calculations and the output dtype. 62 Defaults to `tf.complex64`. 63 64 Input 65 ----- 66 (y, h_hat, err_var, no) : 67 Tuple: 68 69 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 70 Received OFDM resource grid after cyclic prefix removal and FFT 71 72 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 73 Channel estimates for all streams from all transmitters 74 75 err_var : [Broadcastable to shape of ``h_hat``], tf.float 76 Variance of the channel estimation error 77 78 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 79 Variance of the AWGN 80 81 Output 82 ------ 83 x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex 84 Estimated symbols 85 86 no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float 87 Effective noise variance for each estimated symbol 88 """ 89 def __init__(self, 90 equalizer, 91 resource_grid, 92 stream_management, 93 dtype=tf.complex64, 94 **kwargs): 95 super().__init__(dtype=dtype, **kwargs) 96 assert callable(equalizer) 97 assert isinstance(resource_grid, sionna.ofdm.ResourceGrid) 98 assert isinstance(stream_management, sionna.mimo.StreamManagement) 99 self._equalizer = equalizer 100 self._resource_grid = resource_grid 101 self._stream_management = stream_management 102 self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid) 103 104 # Precompute indices to extract data symbols 105 mask = resource_grid.pilot_pattern.mask 106 num_data_symbols = resource_grid.pilot_pattern.num_data_symbols 107 data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") 108 self._data_ind = data_ind[...,:num_data_symbols] 109 110 def call(self, inputs): 111 112 y, h_hat, err_var, no = inputs 113 # y has shape: 114 # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size] 115 116 # h_hat has shape: 117 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... 118 # ..., num_ofdm_symbols, num_effective_subcarriers] 119 120 # err_var has a shape that is broadcastable to h_hat 121 122 # no has shape [batch_size, num_rx, num_rx_ant] 123 # or just the first n dimensions of this 124 125 # Remove nulled subcarriers from y (guards, dc). New shape: 126 # [batch_size, num_rx, num_rx_ant, ... 127 # ..., num_ofdm_symbols, num_effective_subcarriers] 128 y_eff = self._removed_nulled_scs(y) 129 130 #################################################### 131 ### Prepare the observation y for MIMO detection ### 132 #################################################### 133 # Transpose y_eff to put num_rx_ant last. New shape: 134 # [batch_size, num_rx, num_ofdm_symbols,... 135 # ..., num_effective_subcarriers, num_rx_ant] 136 y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2]) 137 y_dt = tf.cast(y_dt, self._dtype) 138 139 ############################################## 140 ### Prepare the err_var for MIMO detection ### 141 ############################################## 142 # New shape is: 143 # [batch_size, num_rx, num_ofdm_symbols,... 144 # ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams] 145 err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat)) 146 err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4]) 147 err_var_dt = flatten_last_dims(err_var_dt, 2) 148 err_var_dt = tf.cast(err_var_dt, self._dtype) 149 150 ############################### 151 ### Construct MIMO channels ### 152 ############################### 153 154 # Reshape h_hat for the construction of desired/interfering channels: 155 # [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,... 156 # ..., num_ofdm_symbols, num_effective_subcarriers] 157 perm = [1, 3, 4, 0, 2, 5, 6] 158 h_dt = tf.transpose(h_hat, perm) 159 160 # Flatten first tthree dimensions: 161 # [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ... 162 # ..., num_ofdm_symbols, num_effective_subcarriers] 163 h_dt = flatten_dims(h_dt, 3, 0) 164 165 # Gather desired and undesired channels 166 ind_desired = self._stream_management.detection_desired_ind 167 ind_undesired = self._stream_management.detection_undesired_ind 168 h_dt_desired = tf.gather(h_dt, ind_desired, axis=0) 169 h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0) 170 171 # Split first dimension to separate RX and TX: 172 # [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ... 173 # ..., num_ofdm_symbols, num_effective_subcarriers] 174 h_dt_desired = split_dim(h_dt_desired, 175 [self._stream_management.num_rx, 176 self._stream_management.num_streams_per_rx], 177 0) 178 h_dt_undesired = split_dim(h_dt_undesired, 179 [self._stream_management.num_rx, -1], 0) 180 181 # Permutate dims to 182 # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 183 # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 184 perm = [2, 0, 4, 5, 3, 1] 185 h_dt_desired = tf.transpose(h_dt_desired, perm) 186 h_dt_desired = tf.cast(h_dt_desired, self._dtype) 187 h_dt_undesired = tf.transpose(h_dt_undesired, perm) 188 189 ################################## 190 ### Prepare the noise variance ### 191 ################################## 192 # no is first broadcast to [batch_size, num_rx, num_rx_ant] 193 # then the rank is expanded to that of y 194 # then it is transposed like y to the final shape 195 # [batch_size, num_rx, num_ofdm_symbols,... 196 # ..., num_effective_subcarriers, num_rx_ant] 197 no_dt = expand_to_rank(no, 3, -1) 198 no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3]) 199 no_dt = expand_to_rank(no_dt, tf.rank(y), -1) 200 no_dt = tf.transpose(no_dt, [0,1,3,4,2]) 201 no_dt = tf.cast(no_dt, self._dtype) 202 203 ################################################## 204 ### Compute the interference covariance matrix ### 205 ################################################## 206 # Covariance of undesired transmitters 207 s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True) 208 209 #Thermal noise 210 s_no = tf.linalg.diag(no_dt) 211 212 # Channel estimation errors 213 # As we have only error variance information for each element, 214 # we simply sum them across transmitters and build a 215 # diagonal covariance matrix from this 216 s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1)) 217 218 # Final covariance matrix 219 s = s_inf + s_no + s_csi 220 s = tf.cast(s, self._dtype) 221 222 ############################################################ 223 ### Compute symbol estimate and effective noise variance ### 224 ############################################################ 225 # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,... 226 # ..., num_stream_per_rx] 227 x_hat, no_eff = self._equalizer(y_dt, h_dt_desired, s) 228 229 ################################################ 230 ### Extract data symbols for all detected TX ### 231 ################################################ 232 # Transpose tensor to shape 233 # [num_rx, num_streams_per_rx, num_ofdm_symbols,... 234 # ..., num_effective_subcarriers, batch_size] 235 x_hat = tf.transpose(x_hat, [1, 4, 2, 3, 0]) 236 no_eff = tf.transpose(no_eff, [1, 4, 2, 3, 0]) 237 238 # Merge num_rx amd num_streams_per_rx 239 # [num_rx * num_streams_per_rx, num_ofdm_symbols,... 240 # ...,num_effective_subcarriers, batch_size] 241 x_hat = flatten_dims(x_hat, 2, 0) 242 no_eff = flatten_dims(no_eff, 2, 0) 243 244 # Put first dimension into the right ordering 245 stream_ind = self._stream_management.stream_ind 246 x_hat = tf.gather(x_hat, stream_ind, axis=0) 247 no_eff = tf.gather(no_eff, stream_ind, axis=0) 248 249 # Reshape first dimensions to [num_tx, num_streams] so that 250 # we can compared to the way the streams were created. 251 # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,... 252 # ..., batch_size] 253 num_streams = self._stream_management.num_streams_per_tx 254 num_tx = self._stream_management.num_tx 255 x_hat = split_dim(x_hat, [num_tx, num_streams], 0) 256 no_eff = split_dim(no_eff, [num_tx, num_streams], 0) 257 258 # Flatten resource grid dimensions 259 # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,... 260 # ..., batch_size] 261 x_hat = flatten_dims(x_hat, 2, 2) 262 no_eff = flatten_dims(no_eff, 2, 2) 263 264 # Broadcast no_eff to the shape of x_hat 265 no_eff = tf.broadcast_to(no_eff, tf.shape(x_hat)) 266 267 # Gather data symbols 268 # [num_tx, num_streams, num_data_symbols, batch_size] 269 x_hat = tf.gather(x_hat, self._data_ind, batch_dims=2, axis=2) 270 no_eff = tf.gather(no_eff, self._data_ind, batch_dims=2, axis=2) 271 272 # Put batch_dim first 273 # [batch_size, num_tx, num_streams, num_data_symbols] 274 x_hat = tf.transpose(x_hat, [3, 0, 1, 2]) 275 no_eff = tf.transpose(no_eff, [3, 0, 1, 2]) 276 277 return (x_hat, no_eff) 278 279 280 class LMMSEEqualizer(OFDMEqualizer): 281 # pylint: disable=line-too-long 282 """LMMSEEqualizer(resource_grid, stream_management, whiten_interference=True, dtype=tf.complex64, **kwargs) 283 284 LMMSE equalization for OFDM MIMO transmissions. 285 286 This layer computes linear minimum mean squared error (LMMSE) equalization 287 for OFDM MIMO transmissions. The OFDM and stream configuration are provided 288 by a :class:`~sionna.ofdm.ResourceGrid` and 289 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 290 detection algorithm is the :meth:`~sionna.mimo.lmmse_equalizer`. The layer 291 computes soft-symbol estimates together with effective noise variances 292 for all streams which can, e.g., be used by a 293 :class:`~sionna.mapping.Demapper` to obtain LLRs. 294 295 Parameters 296 ---------- 297 resource_grid : ResourceGrid 298 Instance of :class:`~sionna.ofdm.ResourceGrid` 299 300 stream_management : StreamManagement 301 Instance of :class:`~sionna.mimo.StreamManagement` 302 303 whiten_interference : bool 304 If `True` (default), the interference is first whitened before equalization. 305 In this case, an alternative expression for the receive filter is used which 306 can be numerically more stable. 307 308 dtype : tf.Dtype 309 Datatype for internal calculations and the output dtype. 310 Defaults to `tf.complex64`. 311 312 Input 313 ----- 314 (y, h_hat, err_var, no) : 315 Tuple: 316 317 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 318 Received OFDM resource grid after cyclic prefix removal and FFT 319 320 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 321 Channel estimates for all streams from all transmitters 322 323 err_var : [Broadcastable to shape of ``h_hat``], tf.float 324 Variance of the channel estimation error 325 326 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 327 Variance of the AWGN 328 329 Output 330 ------ 331 x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex 332 Estimated symbols 333 334 no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float 335 Effective noise variance for each estimated symbol 336 337 Note 338 ---- 339 If you want to use this layer in Graph mode with XLA, i.e., within 340 a function that is decorated with ``@tf.function(jit_compile=True)``, 341 you must set ``sionna.Config.xla_compat=true``. 342 See :py:attr:`~sionna.Config.xla_compat`. 343 """ 344 def __init__(self, 345 resource_grid, 346 stream_management, 347 whiten_interference=True, 348 dtype=tf.complex64, 349 **kwargs): 350 351 def equalizer(y, h, s): 352 return lmmse_equalizer(y, h, s, whiten_interference) 353 354 super().__init__(equalizer=equalizer, 355 resource_grid=resource_grid, 356 stream_management=stream_management, 357 dtype=dtype, **kwargs) 358 359 360 class ZFEqualizer(OFDMEqualizer): 361 # pylint: disable=line-too-long 362 """ZFEqualizer(resource_grid, stream_management, dtype=tf.complex64, **kwargs) 363 364 ZF equalization for OFDM MIMO transmissions. 365 366 This layer computes zero-forcing (ZF) equalization 367 for OFDM MIMO transmissions. The OFDM and stream configuration are provided 368 by a :class:`~sionna.ofdm.ResourceGrid` and 369 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 370 detection algorithm is the :meth:`~sionna.mimo.zf_equalizer`. The layer 371 computes soft-symbol estimates together with effective noise variances 372 for all streams which can, e.g., be used by a 373 :class:`~sionna.mapping.Demapper` to obtain LLRs. 374 375 Parameters 376 ---------- 377 resource_grid : ResourceGrid 378 An instance of :class:`~sionna.ofdm.ResourceGrid`. 379 380 stream_management : StreamManagement 381 An instance of :class:`~sionna.mimo.StreamManagement`. 382 383 dtype : tf.Dtype 384 Datatype for internal calculations and the output dtype. 385 Defaults to `tf.complex64`. 386 387 Input 388 ----- 389 (y, h_hat, err_var, no) : 390 Tuple: 391 392 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 393 Received OFDM resource grid after cyclic prefix removal and FFT 394 395 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 396 Channel estimates for all streams from all transmitters 397 398 err_var : [Broadcastable to shape of ``h_hat``], tf.float 399 Variance of the channel estimation error 400 401 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 402 Variance of the AWGN 403 404 Output 405 ------ 406 x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex 407 Estimated symbols 408 409 no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float 410 Effective noise variance for each estimated symbol 411 412 Note 413 ---- 414 If you want to use this layer in Graph mode with XLA, i.e., within 415 a function that is decorated with ``@tf.function(jit_compile=True)``, 416 you must set ``sionna.Config.xla_compat=true``. 417 See :py:attr:`~sionna.Config.xla_compat`. 418 """ 419 def __init__(self, 420 resource_grid, 421 stream_management, 422 dtype=tf.complex64, 423 **kwargs): 424 super().__init__(equalizer=zf_equalizer, 425 resource_grid=resource_grid, 426 stream_management=stream_management, 427 dtype=dtype, **kwargs) 428 429 430 class MFEqualizer(OFDMEqualizer): 431 # pylint: disable=line-too-long 432 """MFEqualizer(resource_grid, stream_management, dtype=tf.complex64, **kwargs) 433 434 MF equalization for OFDM MIMO transmissions. 435 436 This layer computes matched filter (MF) equalization 437 for OFDM MIMO transmissions. The OFDM and stream configuration are provided 438 by a :class:`~sionna.ofdm.ResourceGrid` and 439 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 440 detection algorithm is the :meth:`~sionna.mimo.mf_equalizer`. The layer 441 computes soft-symbol estimates together with effective noise variances 442 for all streams which can, e.g., be used by a 443 :class:`~sionna.mapping.Demapper` to obtain LLRs. 444 445 Parameters 446 ---------- 447 resource_grid : ResourceGrid 448 An instance of :class:`~sionna.ofdm.ResourceGrid`. 449 450 stream_management : StreamManagement 451 An instance of :class:`~sionna.mimo.StreamManagement`. 452 453 dtype : tf.Dtype 454 Datatype for internal calculations and the output dtype. 455 Defaults to `tf.complex64`. 456 457 Input 458 ----- 459 (y, h_hat, err_var, no) : 460 Tuple: 461 462 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 463 Received OFDM resource grid after cyclic prefix removal and FFT 464 465 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 466 Channel estimates for all streams from all transmitters 467 468 err_var : [Broadcastable to shape of ``h_hat``], tf.float 469 Variance of the channel estimation error 470 471 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 472 Variance of the AWGN 473 474 Output 475 ------ 476 x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex 477 Estimated symbols 478 479 no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float 480 Effective noise variance for each estimated symbol 481 482 Note 483 ---- 484 If you want to use this layer in Graph mode with XLA, i.e., within 485 a function that is decorated with ``@tf.function(jit_compile=True)``, 486 you must set ``sionna.Config.xla_compat=true``. 487 See :py:attr:`~sionna.Config.xla_compat`. 488 """ 489 def __init__(self, 490 resource_grid, 491 stream_management, 492 dtype=tf.complex64, 493 **kwargs): 494 super().__init__(equalizer=mf_equalizer, 495 resource_grid=resource_grid, 496 stream_management=stream_management, 497 dtype=dtype, **kwargs)