detection.py (55944B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Class definition and functions related to OFDM channel equalization""" 6 7 import tensorflow as tf 8 from tensorflow.keras.layers import Layer 9 from sionna.utils import flatten_dims, split_dim, flatten_last_dims, expand_to_rank 10 from sionna.ofdm import RemoveNulledSubcarriers 11 from sionna.mimo import MaximumLikelihoodDetectorWithPrior as MaximumLikelihoodDetectorWithPrior_ 12 from sionna.mimo import MaximumLikelihoodDetector as MaximumLikelihoodDetector_ 13 from sionna.mimo import LinearDetector as LinearDetector_ 14 from sionna.mimo import KBestDetector as KBestDetector_ 15 from sionna.mimo import EPDetector as EPDetector_ 16 from sionna.mimo import MMSEPICDetector as MMSEPICDetector_ 17 from sionna.mapping import Constellation 18 19 20 class OFDMDetector(Layer): 21 # pylint: disable=line-too-long 22 r"""OFDMDetector(detector, output, resource_grid, stream_management, dtype=tf.complex64, **kwargs) 23 24 Layer that wraps a MIMO detector for use with the OFDM waveform. 25 26 The parameter ``detector`` is a callable (e.g., a function) that 27 implements a MIMO detection algorithm for arbitrary batch dimensions. 28 29 This class pre-processes the received resource grid ``y`` and channel 30 estimate ``h_hat``, and computes for each receiver the 31 noise-plus-interference covariance matrix according to the OFDM and stream 32 configuration provided by the ``resource_grid`` and 33 ``stream_management``, which also accounts for the channel 34 estimation error variance ``err_var``. These quantities serve as input to the detection 35 algorithm that is implemented by ``detector``. 36 Both detection of symbols or bits with either soft- or hard-decisions are supported. 37 38 Note 39 ----- 40 The callable ``detector`` must take as input a tuple :math:`(\mathbf{y}, \mathbf{h}, \mathbf{s})` such that: 41 42 * **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals. 43 * **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices. 44 * **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices. 45 46 It must generate one of following outputs depending on the value of ``output``: 47 48 * **b_hat** ([..., num_streams_per_rx, num_bits_per_symbol], tf.float) -- LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 49 * **x_hat** ([..., num_streams_per_rx, num_points], tf.float) or ([..., num_streams_per_rx], tf.int) -- Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. Hard-decisions correspond to the symbol indices. 50 51 Parameters 52 ---------- 53 detector : Callable 54 Callable object (e.g., a function) that implements a MIMO detection 55 algorithm for arbitrary batch dimensions. Either one of the existing detectors, e.g., 56 :class:`~sionna.mimo.LinearDetector`, :class:`~sionna.mimo.MaximumLikelihoodDetector`, or 57 :class:`~sionna.mimo.KBestDetector` can be used, or a custom detector 58 callable provided that has the same input/output specification. 59 60 output : One of ["bit", "symbol"], str 61 Type of output, either bits or symbols 62 63 resource_grid : ResourceGrid 64 Instance of :class:`~sionna.ofdm.ResourceGrid` 65 66 stream_management : StreamManagement 67 Instance of :class:`~sionna.mimo.StreamManagement` 68 69 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 70 The dtype of `y`. Defaults to tf.complex64. 71 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 72 73 Input 74 ------ 75 (y, h_hat, err_var, no) : 76 Tuple: 77 78 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 79 Received OFDM resource grid after cyclic prefix removal and FFT 80 81 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 82 Channel estimates for all streams from all transmitters 83 84 err_var : [Broadcastable to shape of ``h_hat``], tf.float 85 Variance of the channel estimation error 86 87 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 88 Variance of the AWGN 89 90 Output 91 ------ 92 One of: 93 94 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 95 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"` 96 97 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 98 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 99 Hard-decisions correspond to the symbol indices. 100 """ 101 def __init__(self, 102 detector, 103 output, 104 resource_grid, 105 stream_management, 106 dtype=tf.complex64, 107 **kwargs): 108 super().__init__(dtype=dtype, **kwargs) 109 self._detector = detector 110 self._resource_grid = resource_grid 111 self._stream_management = stream_management 112 self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid) 113 self._output = output 114 115 # Precompute indices to extract data symbols 116 mask = resource_grid.pilot_pattern.mask 117 num_data_symbols = resource_grid.pilot_pattern.num_data_symbols 118 data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") 119 self._data_ind = data_ind[...,:num_data_symbols] 120 121 def _preprocess_inputs(self, y, h_hat, err_var, no): 122 """Pro-process the received signal and compute the 123 noise-plus-interference covariance matrix""" 124 125 # Remove nulled subcarriers from y (guards, dc). New shape: 126 # [batch_size, num_rx, num_rx_ant, ... 127 # ..., num_ofdm_symbols, num_effective_subcarriers] 128 y_eff = self._removed_nulled_scs(y) 129 130 #################################################### 131 ### Prepare the observation y for MIMO detection ### 132 #################################################### 133 # Transpose y_eff to put num_rx_ant last. New shape: 134 # [batch_size, num_rx, num_ofdm_symbols,... 135 # ..., num_effective_subcarriers, num_rx_ant] 136 y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2]) 137 y_dt = tf.cast(y_dt, self._dtype) 138 139 # Transpose y_eff to put num_rx_ant last. New shape: 140 # [batch_size, num_rx, num_ofdm_symbols,... 141 # ..., num_effective_subcarriers, num_rx_ant] 142 y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2]) 143 y_dt = tf.cast(y_dt, self._dtype) 144 145 ############################################## 146 ### Prepare the err_var for MIMO detection ### 147 ############################################## 148 # New shape is: 149 # [batch_size, num_rx, num_ofdm_symbols,... 150 # ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams] 151 err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat)) 152 err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4]) 153 err_var_dt = flatten_last_dims(err_var_dt, 2) 154 err_var_dt = tf.cast(err_var_dt, self._dtype) 155 156 ############################### 157 ### Construct MIMO channels ### 158 ############################### 159 160 # Reshape h_hat for the construction of desired/interfering channels: 161 # [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,... 162 # ..., num_ofdm_symbols, num_effective_subcarriers] 163 perm = [1, 3, 4, 0, 2, 5, 6] 164 h_dt = tf.transpose(h_hat, perm) 165 166 # Flatten first tthree dimensions: 167 # [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ... 168 # ..., num_ofdm_symbols, num_effective_subcarriers] 169 h_dt = flatten_dims(h_dt, 3, 0) 170 171 # Gather desired and undesired channels 172 ind_desired = self._stream_management.detection_desired_ind 173 ind_undesired = self._stream_management.detection_undesired_ind 174 h_dt_desired = tf.gather(h_dt, ind_desired, axis=0) 175 h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0) 176 177 # Split first dimension to separate RX and TX: 178 # [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ... 179 # ..., num_ofdm_symbols, num_effective_subcarriers] 180 h_dt_desired = split_dim(h_dt_desired, 181 [self._stream_management.num_rx, 182 self._stream_management.num_streams_per_rx], 183 0) 184 h_dt_undesired = split_dim(h_dt_undesired, 185 [self._stream_management.num_rx, -1], 0) 186 187 # Permutate dims to 188 # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 189 # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 190 perm = [2, 0, 4, 5, 3, 1] 191 h_dt_desired = tf.transpose(h_dt_desired, perm) 192 h_dt_desired = tf.cast(h_dt_desired, self._dtype) 193 h_dt_undesired = tf.transpose(h_dt_undesired, perm) 194 195 ################################## 196 ### Prepare the noise variance ### 197 ################################## 198 # no is first broadcast to [batch_size, num_rx, num_rx_ant] 199 # then the rank is expanded to that of y 200 # then it is transposed like y to the final shape 201 # [batch_size, num_rx, num_ofdm_symbols,... 202 # ..., num_effective_subcarriers, num_rx_ant] 203 no_dt = expand_to_rank(no, 3, -1) 204 no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3]) 205 no_dt = expand_to_rank(no_dt, tf.rank(y), -1) 206 no_dt = tf.transpose(no_dt, [0,1,3,4,2]) 207 no_dt = tf.cast(no_dt, self._dtype) 208 209 ################################################## 210 ### Compute the interference covariance matrix ### 211 ################################################## 212 # Covariance of undesired transmitters 213 s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True) 214 215 #Thermal noise 216 s_no = tf.linalg.diag(no_dt) 217 218 # Channel estimation errors 219 # As we have only error variance information for each element, 220 # we simply sum them across transmitters and build a 221 # diagonal covariance matrix from this 222 s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1)) 223 224 # Final covariance matrix 225 s = s_inf + s_no + s_csi 226 s = tf.cast(s, self._dtype) 227 228 return y_dt, h_dt_desired, s 229 230 def _extract_datasymbols(self, z): 231 """Extract data symbols for all detected TX""" 232 233 # If output is symbols with hard decision, the rank is 5 and not 6 as 234 # for other cases. The tensor rank is therefore expanded with one extra 235 # dimension, which is removed later. 236 rank_extanded = len(z.shape) < 6 237 z = expand_to_rank(z, 6, -1) 238 239 # Transpose tensor to shape 240 # [num_rx, num_streams_per_rx, num_ofdm_symbols, 241 # num_effective_subcarriers, num_bits_per_symbol or num_points, 242 # batch_size] 243 z = tf.transpose(z, [1, 4, 2, 3, 5, 0]) 244 245 # Merge num_rx amd num_streams_per_rx 246 # [num_rx * num_streams_per_rx, num_ofdm_symbols, 247 # num_effective_subcarriers, num_bits_per_symbol or num_points, 248 # batch_size] 249 z = flatten_dims(z, 2, 0) 250 251 # Put first dimension into the right ordering 252 stream_ind = self._stream_management.stream_ind 253 z = tf.gather(z, stream_ind, axis=0) 254 255 # Reshape first dimensions to [num_tx, num_streams] so that 256 # we can compare to the way the streams were created. 257 # [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, 258 # num_bits_per_symbol or num_points, batch_size] 259 num_streams = self._stream_management.num_streams_per_tx 260 num_tx = self._stream_management.num_tx 261 z = split_dim(z, [num_tx, num_streams], 0) 262 263 # Flatten resource grid dimensions 264 # [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarrier, 265 # num_bits_per_symbol or num_points, batch_size] 266 z = flatten_dims(z, 2, 2) 267 268 # Gather data symbols 269 # [num_tx, num_streams, num_data_symbols, 270 # num_bits_per_symbol or num_points, batch_size] 271 z = tf.gather(z, self._data_ind, batch_dims=2, axis=2) 272 273 # Put batch_dim first 274 # [batch_size, num_tx, num_streams, 275 # num_data_symbols, num_bits_per_symbol or num_points] 276 z = tf.transpose(z, [4, 0, 1, 2, 3]) 277 278 # Reshape LLRs to 279 # [batch_size, num_tx, num_streams, 280 # n = num_data_symbols*num_bits_per_symbol] 281 # if output is LLRs on bits 282 if self._output == 'bit': 283 z = flatten_dims(z, 2, 3) 284 # Remove dummy dimension if output is symbols with hard decision 285 if rank_extanded: 286 z = tf.squeeze(z, axis=-1) 287 288 return z 289 290 def call(self, inputs): 291 y, h_hat, err_var, no = inputs 292 # y has shape: 293 # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size] 294 295 # h_hat has shape: 296 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... 297 # ..., num_ofdm_symbols, num_effective_subcarriers] 298 299 # err_var has a shape that is broadcastable to h_hat 300 301 # no has shape [batch_size, num_rx, num_rx_ant] 302 # or just the first n dimensions of this 303 304 ################################ 305 ### Pre-process the inputs 306 ################################ 307 y_dt, h_dt_desired, s = self._preprocess_inputs(y, h_hat, err_var, no) 308 309 ################################# 310 ### Detection 311 ################################# 312 z = self._detector([y_dt, h_dt_desired, s]) 313 314 ############################################## 315 ### Extract data symbols for all detected TX 316 ############################################## 317 z = self._extract_datasymbols(z) 318 319 return z 320 321 322 class OFDMDetectorWithPrior(OFDMDetector): 323 # pylint: disable=line-too-long 324 r"""OFDMDetectorWithPrior(detector, output, resource_grid, stream_management, constellation_type, num_bits_per_symbol, constellation, dtype=tf.complex64, **kwargs) 325 326 Layer that wraps a MIMO detector that assumes prior knowledge of the bits or 327 constellation points is available, for use with the OFDM waveform. 328 329 The parameter ``detector`` is a callable (e.g., a function) that 330 implements a MIMO detection algorithm with prior for arbitrary batch 331 dimensions. 332 333 This class pre-processes the received resource grid ``y``, channel 334 estimate ``h_hat``, and the prior information ``prior``, and computes for each receiver the 335 noise-plus-interference covariance matrix according to the OFDM and stream 336 configuration provided by the ``resource_grid`` and 337 ``stream_management``, which also accounts for the channel 338 estimation error variance ``err_var``. These quantities serve as input to the detection 339 algorithm that is implemented by ``detector``. 340 Both detection of symbols or bits with either soft- or hard-decisions are supported. 341 342 Note 343 ----- 344 The callable ``detector`` must take as input a tuple :math:`(\mathbf{y}, \mathbf{h}, \mathbf{prior}, \mathbf{s})` such that: 345 346 * **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals. 347 * **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices. 348 * **prior** ([...,num_streams_per_rx,num_bits_per_symbol] or [...,num_streams_per_rx,num_points], tf.float) -- Prior for the transmitted signals. If ``output`` equals "bit", then LLRs for the transmitted bits are expected. If ``output`` equals "symbol", then logits for the transmitted constellation points are expected. 349 * **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices. 350 351 It must generate one of the following outputs depending on the value of ``output``: 352 353 * **b_hat** ([..., num_streams_per_rx, num_bits_per_symbol], tf.float) -- LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 354 * **x_hat** ([..., num_streams_per_rx, num_points], tf.float) or ([..., num_streams_per_rx], tf.int) -- Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. Hard-decisions correspond to the symbol indices. 355 356 Parameters 357 ---------- 358 detector : Callable 359 Callable object (e.g., a function) that implements a MIMO detection 360 algorithm with prior for arbitrary batch dimensions. Either the existing detector 361 :class:`~sionna.mimo.MaximumLikelihoodDetectorWithPrior` can be used, or a custom detector 362 callable provided that has the same input/output specification. 363 364 output : One of ["bit", "symbol"], str 365 Type of output, either bits or symbols 366 367 resource_grid : ResourceGrid 368 Instance of :class:`~sionna.ofdm.ResourceGrid` 369 370 stream_management : StreamManagement 371 Instance of :class:`~sionna.mimo.StreamManagement` 372 373 constellation_type : One of ["qam", "pam", "custom"], str 374 For "custom", an instance of :class:`~sionna.mapping.Constellation` 375 must be provided. 376 377 num_bits_per_symbol : int 378 Number of bits per constellation symbol, e.g., 4 for QAM16. 379 Only required for ``constellation_type`` in ["qam", "pam"]. 380 381 constellation : Constellation 382 Instance of :class:`~sionna.mapping.Constellation` or `None`. 383 In the latter case, ``constellation_type`` 384 and ``num_bits_per_symbol`` must be provided. 385 386 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 387 The dtype of `y`. Defaults to tf.complex64. 388 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 389 390 Input 391 ------ 392 (y, h_hat, prior, err_var, no) : 393 Tuple: 394 395 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 396 Received OFDM resource grid after cyclic prefix removal and FFT 397 398 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 399 Channel estimates for all streams from all transmitters 400 401 prior : [batch_size, num_tx, num_streams, num_data_symbols x num_bits_per_symbol] or [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float 402 Prior of the transmitted signals. 403 If ``output`` equals "bit", LLRs of the transmitted bits are expected. 404 If ``output`` equals "symbol", logits of the transmitted constellation points are expected. 405 406 err_var : [Broadcastable to shape of ``h_hat``], tf.float 407 Variance of the channel estimation error 408 409 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 410 Variance of the AWGN 411 412 Output 413 ------ 414 One of: 415 416 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 417 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 418 419 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 420 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 421 Hard-decisions correspond to the symbol indices. 422 """ 423 def __init__(self, 424 detector, 425 output, 426 resource_grid, 427 stream_management, 428 constellation_type=None, 429 num_bits_per_symbol=None, 430 constellation=None, 431 dtype=tf.complex64, 432 **kwargs): 433 super().__init__(detector=detector, 434 output=output, 435 resource_grid=resource_grid, 436 stream_management=stream_management, 437 dtype=dtype, 438 **kwargs) 439 440 # Constellation object 441 self._constellation = Constellation.create_or_check_constellation( 442 constellation_type, 443 num_bits_per_symbol, 444 constellation, 445 dtype=dtype) 446 447 # Precompute indices to map priors to a resource grid 448 rg_type = resource_grid.build_type_grid() 449 # The nulled subcarriers (nulled DC and guard carriers) are removed to 450 # get the correct indices of data-carrying resource elements. 451 remove_nulled_sc = RemoveNulledSubcarriers(resource_grid) 452 self._data_ind_scatter = tf.where(remove_nulled_sc(rg_type)==0) 453 454 # Overwrite the call() method of baseclass `BaseDetector` 455 def call(self, inputs): 456 y, h_hat, prior, err_var, no = inputs 457 # y has shape: 458 # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size] 459 460 # h_hat has shape: 461 # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... 462 # ..., num_ofdm_symbols, num_effective_subcarriers] 463 464 # prior has shape 465 # [batch_size, num_tx, num_streams,... 466 # ... num_data_symbols x num_bits_per_symbol] 467 # or [batch_size, num_tx, num_streams, num_data_symbols, num_points] 468 469 # err_var has a shape that is broadcastable to h_hat 470 471 # no has shape [batch_size, num_rx, num_rx_ant] 472 # or just the first n dimensions of this 473 474 ################################ 475 ### Pre-process the inputs 476 ################################ 477 y_dt, h_dt_desired, s = self._preprocess_inputs(y, h_hat, err_var, no) 478 479 ######################### 480 ### Prepare the prior ### 481 ######################### 482 # [batch_size, num_tx, num_streams_per_tx, num_data_symbols, 483 # ... num_bits_per_symbol/num_points] 484 if self._output == 'bit': 485 prior = split_dim( prior, 486 [ self._resource_grid.num_data_symbols, 487 self._constellation.num_bits_per_symbol], 488 3) 489 # Create a zero template for the prior 490 # [num_tx, num_streams_per_tx, num_ofdm_symbols,... 491 # ... num_effective_subcarriers, num_bits_per_symbol/num_points, 492 # ... batch_size] 493 template = tf.zeros([ self._resource_grid.num_tx, 494 self._resource_grid.num_streams_per_tx, 495 self._resource_grid.num_ofdm_symbols, 496 self._resource_grid.num_effective_subcarriers, 497 tf.shape(prior)[-1], 498 tf.shape(prior)[0]], 499 tf.as_dtype(self._dtype).real_dtype) 500 # [num_tx, num_streams_per_tx, num_data_symbols, 501 # ... num_bits_per_symbol/num_points, batch_size] 502 prior = tf.transpose(prior, [1, 2, 3, 4, 0]) 503 # [num_tx, num_streams_per_tx, num_ofdm_symbols,... 504 # ... num_effective_subcarriers, num_bits_per_symbol/num_points,... 505 # ... batch_size] 506 prior = flatten_dims(prior, 3, 0) 507 prior = tf.tensor_scatter_nd_update(template, self._data_ind_scatter, 508 prior) 509 # [batch_size, num_ofdm_symbols, num_effective_subcarriers,... 510 # num_tx*num_streams_per_tx, num_bits_per_symbol/num_points] 511 prior = tf.transpose(prior, [5, 2, 3, 0, 1, 4]) 512 prior = flatten_dims(prior, 2, 3) 513 # Add the receive antenna dimension for broadcasting 514 # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,... 515 # num_tx*num_streams_per_tx, num_bits_per_symbol/num_points] 516 prior = tf.tile(tf.expand_dims(prior, axis=1), 517 [1, tf.shape(y)[1], 1, 1, 1, 1]) 518 519 ################################# 520 ### Maximum-likelihood detection 521 ################################# 522 z = self._detector([y_dt, h_dt_desired, prior, s]) 523 524 ############################################## 525 ### Extract data symbols for all detected TX 526 ############################################## 527 z = self._extract_datasymbols(z) 528 529 return z 530 531 532 class MaximumLikelihoodDetector(OFDMDetector): 533 # pylint: disable=line-too-long 534 r"""MaximumLikelihoodDetector(output, demapping_method, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 535 536 Maximum-likelihood (ML) detection for OFDM MIMO transmissions. 537 538 This layer implements maximum-likelihood (ML) detection 539 for OFDM MIMO transmissions. Both ML detection of symbols or bits with either 540 soft- or hard-decisions are supported. The OFDM and stream configuration are provided 541 by a :class:`~sionna.ofdm.ResourceGrid` and 542 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 543 actual detector is an instance of :class:`~sionna.mimo.MaximumLikelihoodDetector`. 544 545 Parameters 546 ---------- 547 output : One of ["bit", "symbol"], str 548 Type of output, either bits or symbols. Whether soft- or 549 hard-decisions are returned can be configured with the 550 ``hard_out`` flag. 551 552 demapping_method : One of ["app", "maxlog"], str 553 Demapping method used 554 555 resource_grid : ResourceGrid 556 Instance of :class:`~sionna.ofdm.ResourceGrid` 557 558 stream_management : StreamManagement 559 Instance of :class:`~sionna.mimo.StreamManagement` 560 561 constellation_type : One of ["qam", "pam", "custom"], str 562 For "custom", an instance of :class:`~sionna.mapping.Constellation` 563 must be provided. 564 565 num_bits_per_symbol : int 566 Number of bits per constellation symbol, e.g., 4 for QAM16. 567 Only required for ``constellation_type`` in ["qam", "pam"]. 568 569 constellation : Constellation 570 Instance of :class:`~sionna.mapping.Constellation` or `None`. 571 In the latter case, ``constellation_type`` 572 and ``num_bits_per_symbol`` must be provided. 573 574 hard_out : bool 575 If `True`, the detector computes hard-decided bit values or 576 constellation point indices instead of soft-values. 577 Defaults to `False`. 578 579 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 580 The dtype of `y`. Defaults to tf.complex64. 581 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 582 583 Input 584 ------ 585 (y, h_hat, err_var, no) : 586 Tuple: 587 588 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 589 Received OFDM resource grid after cyclic prefix removal and FFT 590 591 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 592 Channel estimates for all streams from all transmitters 593 594 err_var : [Broadcastable to shape of ``h_hat``], tf.float 595 Variance of the channel estimation error 596 597 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 598 Variance of the AWGN noise 599 600 Output 601 ------ 602 One of: 603 604 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 605 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 606 607 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 608 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 609 Hard-decisions correspond to the symbol indices. 610 611 Note 612 ---- 613 If you want to use this layer in Graph mode with XLA, i.e., within 614 a function that is decorated with ``@tf.function(jit_compile=True)``, 615 you must set ``sionna.Config.xla_compat=true``. 616 See :py:attr:`~sionna.Config.xla_compat`. 617 """ 618 619 def __init__(self, 620 output, 621 demapping_method, 622 resource_grid, 623 stream_management, 624 constellation_type=None, 625 num_bits_per_symbol=None, 626 constellation=None, 627 hard_out=False, 628 dtype=tf.complex64, 629 **kwargs): 630 631 # Instantiate the maximum-likelihood detector 632 detector = MaximumLikelihoodDetector_(output=output, 633 demapping_method=demapping_method, 634 num_streams = stream_management.num_streams_per_rx, 635 constellation_type=constellation_type, 636 num_bits_per_symbol=num_bits_per_symbol, 637 constellation=constellation, 638 hard_out=hard_out, 639 dtype=dtype, 640 **kwargs) 641 642 super().__init__(detector=detector, 643 output=output, 644 resource_grid=resource_grid, 645 stream_management=stream_management, 646 dtype=dtype, 647 **kwargs) 648 649 650 class MaximumLikelihoodDetectorWithPrior(OFDMDetectorWithPrior): 651 # pylint: disable=line-too-long 652 r"""MaximumLikelihoodDetectorWithPrior(output, demapping_method, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 653 654 Maximum-likelihood (ML) detection for OFDM MIMO transmissions, assuming prior 655 knowledge of the bits or constellation points is available. 656 657 This layer implements maximum-likelihood (ML) detection 658 for OFDM MIMO transmissions assuming prior knowledge on the transmitted data is available. 659 Both ML detection of symbols or bits with either 660 soft- or hard-decisions are supported. The OFDM and stream configuration are provided 661 by a :class:`~sionna.ofdm.ResourceGrid` and 662 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 663 actual detector is an instance of :class:`~sionna.mimo.MaximumLikelihoodDetectorWithPrior`. 664 665 Parameters 666 ---------- 667 output : One of ["bit", "symbol"], str 668 Type of output, either bits or symbols. Whether soft- or 669 hard-decisions are returned can be configured with the 670 ``hard_out`` flag. 671 672 demapping_method : One of ["app", "maxlog"], str 673 Demapping method used 674 675 resource_grid : ResourceGrid 676 Instance of :class:`~sionna.ofdm.ResourceGrid` 677 678 stream_management : StreamManagement 679 Instance of :class:`~sionna.mimo.StreamManagement` 680 681 constellation_type : One of ["qam", "pam", "custom"], str 682 For "custom", an instance of :class:`~sionna.mapping.Constellation` 683 must be provided. 684 685 num_bits_per_symbol : int 686 Number of bits per constellation symbol, e.g., 4 for QAM16. 687 Only required for ``constellation_type`` in ["qam", "pam"]. 688 689 constellation : Constellation 690 Instance of :class:`~sionna.mapping.Constellation` or `None`. 691 In the latter case, ``constellation_type`` 692 and ``num_bits_per_symbol`` must be provided. 693 694 hard_out : bool 695 If `True`, the detector computes hard-decided bit values or 696 constellation point indices instead of soft-values. 697 Defaults to `False`. 698 699 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 700 The dtype of `y`. Defaults to tf.complex64. 701 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 702 703 Input 704 ------ 705 (y, h_hat, prior, err_var, no) : 706 Tuple: 707 708 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 709 Received OFDM resource grid after cyclic prefix removal and FFT 710 711 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 712 Channel estimates for all streams from all transmitters 713 714 prior : [batch_size, num_tx, num_streams, num_data_symbols x num_bits_per_symbol] or [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float 715 Prior of the transmitted signals. 716 If ``output`` equals "bit", LLRs of the transmitted bits are expected. 717 If ``output`` equals "symbol", logits of the transmitted constellation points are expected. 718 719 err_var : [Broadcastable to shape of ``h_hat``], tf.float 720 Variance of the channel estimation error 721 722 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 723 Variance of the AWGN noise 724 725 Output 726 ------ 727 One of: 728 729 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 730 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 731 732 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 733 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 734 Hard-decisions correspond to the symbol indices. 735 736 Note 737 ---- 738 If you want to use this layer in Graph mode with XLA, i.e., within 739 a function that is decorated with ``@tf.function(jit_compile=True)``, 740 you must set ``sionna.Config.xla_compat=true``. 741 See :py:attr:`~sionna.Config.xla_compat`. 742 """ 743 744 def __init__(self, 745 output, 746 demapping_method, 747 resource_grid, 748 stream_management, 749 constellation_type=None, 750 num_bits_per_symbol=None, 751 constellation=None, 752 hard_out=False, 753 dtype=tf.complex64, 754 **kwargs): 755 756 # Instantiate the maximum-likelihood detector 757 detector = MaximumLikelihoodDetectorWithPrior_(output=output, 758 demapping_method=demapping_method, 759 num_streams = stream_management.num_streams_per_rx, 760 constellation_type=constellation_type, 761 num_bits_per_symbol=num_bits_per_symbol, 762 constellation=constellation, 763 hard_out=hard_out, 764 dtype=dtype, 765 **kwargs) 766 767 super().__init__(detector=detector, 768 output=output, 769 resource_grid=resource_grid, 770 stream_management=stream_management, 771 constellation_type=constellation_type, 772 num_bits_per_symbol=num_bits_per_symbol, 773 constellation=constellation, 774 dtype=dtype, 775 **kwargs) 776 777 778 class LinearDetector(OFDMDetector): 779 # pylint: disable=line-too-long 780 r"""LinearDetector(equalizer, output, demapping_method, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 781 782 This layer wraps a MIMO linear equalizer and a :class:`~sionna.mapping.Demapper` 783 for use with the OFDM waveform. 784 785 Both detection of symbols or bits with either 786 soft- or hard-decisions are supported. The OFDM and stream configuration are provided 787 by a :class:`~sionna.ofdm.ResourceGrid` and 788 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 789 actual detector is an instance of :class:`~sionna.mimo.LinearDetector`. 790 791 Parameters 792 ---------- 793 equalizer : str, one of ["lmmse", "zf", "mf"], or an equalizer function 794 Equalizer to be used. Either one of the existing equalizers, e.g., 795 :func:`~sionna.mimo.lmmse_equalizer`, :func:`~sionna.mimo.zf_equalizer`, or 796 :func:`~sionna.mimo.mf_equalizer` can be used, or a custom equalizer 797 function provided that has the same input/output specification. 798 799 output : One of ["bit", "symbol"], str 800 Type of output, either bits or symbols. Whether soft- or 801 hard-decisions are returned can be configured with the 802 ``hard_out`` flag. 803 804 demapping_method : One of ["app", "maxlog"], str 805 Demapping method used 806 807 resource_grid : ResourceGrid 808 Instance of :class:`~sionna.ofdm.ResourceGrid` 809 810 stream_management : StreamManagement 811 Instance of :class:`~sionna.mimo.StreamManagement` 812 813 constellation_type : One of ["qam", "pam", "custom"], str 814 For "custom", an instance of :class:`~sionna.mapping.Constellation` 815 must be provided. 816 817 num_bits_per_symbol : int 818 Number of bits per constellation symbol, e.g., 4 for QAM16. 819 Only required for ``constellation_type`` in ["qam", "pam"]. 820 821 constellation : Constellation 822 Instance of :class:`~sionna.mapping.Constellation` or `None`. 823 In the latter case, ``constellation_type`` 824 and ``num_bits_per_symbol`` must be provided. 825 826 hard_out : bool 827 If `True`, the detector computes hard-decided bit values or 828 constellation point indices instead of soft-values. 829 Defaults to `False`. 830 831 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 832 The dtype of `y`. Defaults to tf.complex64. 833 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 834 835 Input 836 ------ 837 (y, h_hat, err_var, no) : 838 Tuple: 839 840 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 841 Received OFDM resource grid after cyclic prefix removal and FFT 842 843 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 844 Channel estimates for all streams from all transmitters 845 846 err_var : [Broadcastable to shape of ``h_hat``], tf.float 847 Variance of the channel estimation error 848 849 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 850 Variance of the AWGN 851 852 Output 853 ------ 854 One of: 855 856 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 857 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 858 859 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 860 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 861 Hard-decisions correspond to the symbol indices. 862 863 Note 864 ---- 865 If you want to use this layer in Graph mode with XLA, i.e., within 866 a function that is decorated with ``@tf.function(jit_compile=True)``, 867 you must set ``sionna.Config.xla_compat=true``. 868 See :py:attr:`~sionna.Config.xla_compat`. 869 """ 870 871 def __init__(self, 872 equalizer, 873 output, 874 demapping_method, 875 resource_grid, 876 stream_management, 877 constellation_type=None, 878 num_bits_per_symbol=None, 879 constellation=None, 880 hard_out=False, 881 dtype=tf.complex64, 882 **kwargs): 883 884 # Instantiate the linear detector 885 detector = LinearDetector_(equalizer=equalizer, 886 output=output, 887 demapping_method=demapping_method, 888 constellation_type=constellation_type, 889 num_bits_per_symbol=num_bits_per_symbol, 890 constellation=constellation, 891 hard_out=hard_out, 892 dtype=dtype, 893 **kwargs) 894 895 super().__init__(detector=detector, 896 output=output, 897 resource_grid=resource_grid, 898 stream_management=stream_management, 899 dtype=dtype, 900 **kwargs) 901 902 903 class KBestDetector(OFDMDetector): 904 # pylint: disable=line-too-long 905 r"""KBestDetector(output, num_streams, k, resource_grid, stream_management, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, use_real_rep=False, list2llr=None, dtype=tf.complex64, **kwargs) 906 907 This layer wraps the MIMO K-Best detector for use with the OFDM waveform. 908 909 Both detection of symbols or bits with either 910 soft- or hard-decisions are supported. The OFDM and stream configuration are provided 911 by a :class:`~sionna.ofdm.ResourceGrid` and 912 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 913 actual detector is an instance of :class:`~sionna.mimo.KBestDetector`. 914 915 Parameters 916 ---------- 917 output : One of ["bit", "symbol"], str 918 Type of output, either bits or symbols. Whether soft- or 919 hard-decisions are returned can be configured with the 920 ``hard_out`` flag. 921 922 num_streams : tf.int 923 Number of transmitted streams 924 925 k : tf.int 926 Number of paths to keep. Cannot be larger than the 927 number of constellation points to the power of the number of 928 streams. 929 930 resource_grid : ResourceGrid 931 Instance of :class:`~sionna.ofdm.ResourceGrid` 932 933 stream_management : StreamManagement 934 Instance of :class:`~sionna.mimo.StreamManagement` 935 936 constellation_type : One of ["qam", "pam", "custom"], str 937 For "custom", an instance of :class:`~sionna.mapping.Constellation` 938 must be provided. 939 940 num_bits_per_symbol : int 941 Number of bits per constellation symbol, e.g., 4 for QAM16. 942 Only required for ``constellation_type`` in ["qam", "pam"]. 943 944 constellation : Constellation 945 Instance of :class:`~sionna.mapping.Constellation` or `None`. 946 In the latter case, ``constellation_type`` 947 and ``num_bits_per_symbol`` must be provided. 948 949 hard_out : bool 950 If `True`, the detector computes hard-decided bit values or 951 constellation point indices instead of soft-values. 952 Defaults to `False`. 953 954 use_real_rep : bool 955 If `True`, the detector use the real-valued equivalent representation 956 of the channel. Note that this only works with a QAM constellation. 957 Defaults to `False`. 958 959 list2llr: `None` or instance of :class:`~sionna.mimo.List2LLR` 960 The function to be used to compute LLRs from a list of candidate solutions. 961 If `None`, the default solution :class:`~sionna.mimo.List2LLRSimple` 962 is used. 963 964 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 965 The dtype of `y`. Defaults to tf.complex64. 966 The output dtype is the corresponding real dtype (tf.float32 or tf.float64). 967 968 Input 969 ------ 970 (y, h_hat, err_var, no) : 971 Tuple: 972 973 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 974 Received OFDM resource grid after cyclic prefix removal and FFT 975 976 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 977 Channel estimates for all streams from all transmitters 978 979 err_var : [Broadcastable to shape of ``h_hat``], tf.float 980 Variance of the channel estimation error 981 982 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 983 Variance of the AWGN 984 985 Output 986 ------ 987 One of: 988 989 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 990 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 991 992 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 993 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 994 Hard-decisions correspond to the symbol indices. 995 996 Note 997 ---- 998 If you want to use this layer in Graph mode with XLA, i.e., within 999 a function that is decorated with ``@tf.function(jit_compile=True)``, 1000 you must set ``sionna.Config.xla_compat=true``. 1001 See :py:attr:`~sionna.Config.xla_compat`. 1002 """ 1003 1004 def __init__(self, 1005 output, 1006 num_streams, 1007 k, 1008 resource_grid, 1009 stream_management, 1010 constellation_type=None, 1011 num_bits_per_symbol=None, 1012 constellation=None, 1013 hard_out=False, 1014 use_real_rep=False, 1015 list2llr="default", 1016 dtype=tf.complex64, 1017 **kwargs): 1018 1019 # Instantiate the K-Best detector 1020 detector = KBestDetector_(output=output, 1021 num_streams=num_streams, 1022 k=k, 1023 constellation_type=constellation_type, 1024 num_bits_per_symbol=num_bits_per_symbol, 1025 constellation=constellation, 1026 hard_out=hard_out, 1027 use_real_rep=use_real_rep, 1028 list2llr=list2llr, 1029 dtype=dtype, 1030 **kwargs) 1031 1032 super().__init__(detector=detector, 1033 output=output, 1034 resource_grid=resource_grid, 1035 stream_management=stream_management, 1036 dtype=dtype, 1037 **kwargs) 1038 1039 1040 class EPDetector(OFDMDetector): 1041 # pylint: disable=line-too-long 1042 r"""EPDetector(output, resource_grid, stream_management, num_bits_per_symbol, hard_out=False, l=10, beta=0.9, dtype=tf.complex64, **kwargs) 1043 1044 This layer wraps the MIMO EP detector for use with the OFDM waveform. 1045 1046 Both detection of symbols or bits with either 1047 soft- or hard-decisions are supported. The OFDM and stream configuration are provided 1048 by a :class:`~sionna.ofdm.ResourceGrid` and 1049 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 1050 actual detector is an instance of :class:`~sionna.mimo.EPDetector`. 1051 1052 Parameters 1053 ---------- 1054 output : One of ["bit", "symbol"], str 1055 Type of output, either bits or symbols. Whether soft- or 1056 hard-decisions are returned can be configured with the 1057 ``hard_out`` flag. 1058 1059 resource_grid : ResourceGrid 1060 Instance of :class:`~sionna.ofdm.ResourceGrid` 1061 1062 stream_management : StreamManagement 1063 Instance of :class:`~sionna.mimo.StreamManagement` 1064 1065 num_bits_per_symbol : int 1066 Number of bits per constellation symbol, e.g., 4 for QAM16. 1067 Only required for ``constellation_type`` in ["qam", "pam"]. 1068 1069 hard_out : bool 1070 If `True`, the detector computes hard-decided bit values or 1071 constellation point indices instead of soft-values. 1072 Defaults to `False`. 1073 1074 l : int 1075 Number of iterations. Defaults to 10. 1076 1077 beta : float 1078 Parameter :math:`\beta\in[0,1]` for update smoothing. 1079 Defaults to 0.9. 1080 1081 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 1082 Precision used for internal computations. Defaults to ``tf.complex64``. 1083 Especially for large MIMO setups, the precision can make a significant 1084 performance difference. 1085 1086 Input 1087 ------ 1088 (y, h_hat, err_var, no) : 1089 Tuple: 1090 1091 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 1092 Received OFDM resource grid after cyclic prefix removal and FFT 1093 1094 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 1095 Channel estimates for all streams from all transmitters 1096 1097 err_var : [Broadcastable to shape of ``h_hat``], tf.float 1098 Variance of the channel estimation error 1099 1100 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 1101 Variance of the AWGN 1102 1103 Output 1104 ------ 1105 One of: 1106 1107 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 1108 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 1109 1110 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 1111 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 1112 Hard-decisions correspond to the symbol indices. 1113 1114 Note 1115 ---- 1116 For numerical stability, we do not recommend to use this function in Graph 1117 mode with XLA, i.e., within a function that is decorated with 1118 ``@tf.function(jit_compile=True)``. 1119 However, it is possible to do so by setting 1120 ``sionna.Config.xla_compat=true``. 1121 See :py:attr:`~sionna.Config.xla_compat`. 1122 """ 1123 def __init__(self, 1124 output, 1125 resource_grid, 1126 stream_management, 1127 num_bits_per_symbol=None, 1128 hard_out=False, 1129 l=10, 1130 beta=0.9, 1131 dtype=tf.complex64, 1132 **kwargs): 1133 1134 # Instantiate the EP detector 1135 detector = EPDetector_(output=output, 1136 num_bits_per_symbol=num_bits_per_symbol, 1137 hard_out=hard_out, 1138 l=l, 1139 beta=beta, 1140 dtype=dtype, 1141 **kwargs) 1142 1143 super().__init__(detector=detector, 1144 output=output, 1145 resource_grid=resource_grid, 1146 stream_management=stream_management, 1147 dtype=dtype, 1148 **kwargs) 1149 1150 class MMSEPICDetector(OFDMDetectorWithPrior): 1151 # pylint: disable=line-too-long 1152 r"""MMSEPICDetector(output, resource_grid, stream_management, demapping_method="maxlog", num_iter=1, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs) 1153 1154 This layer wraps the MIMO MMSE PIC detector for use with the OFDM waveform. 1155 1156 Both detection of symbols or bits with either 1157 soft- or hard-decisions are supported. The OFDM and stream configuration are provided 1158 by a :class:`~sionna.ofdm.ResourceGrid` and 1159 :class:`~sionna.mimo.StreamManagement` instance, respectively. The 1160 actual detector is an instance of :class:`~sionna.mimo.MMSEPICDetector`. 1161 1162 Parameters 1163 ---------- 1164 output : One of ["bit", "symbol"], str 1165 Type of output, either bits or symbols. Whether soft- or 1166 hard-decisions are returned can be configured with the 1167 ``hard_out`` flag. 1168 1169 resource_grid : ResourceGrid 1170 Instance of :class:`~sionna.ofdm.ResourceGrid` 1171 1172 stream_management : StreamManagement 1173 Instance of :class:`~sionna.mimo.StreamManagement` 1174 1175 demapping_method : One of ["app", "maxlog"], str 1176 The demapping method used. 1177 Defaults to "maxlog". 1178 1179 num_iter : int 1180 Number of MMSE PIC iterations. 1181 Defaults to 1. 1182 1183 constellation_type : One of ["qam", "pam", "custom"], str 1184 For "custom", an instance of :class:`~sionna.mapping.Constellation` 1185 must be provided. 1186 1187 num_bits_per_symbol : int 1188 The number of bits per constellation symbol, e.g., 4 for QAM16. 1189 Only required for ``constellation_type`` in ["qam", "pam"]. 1190 1191 constellation : Constellation 1192 An instance of :class:`~sionna.mapping.Constellation` or `None`. 1193 In the latter case, ``constellation_type`` 1194 and ``num_bits_per_symbol`` must be provided. 1195 1196 hard_out : bool 1197 If `True`, the detector computes hard-decided bit values or 1198 constellation point indices instead of soft-values. 1199 Defaults to `False`. 1200 1201 dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype) 1202 Precision used for internal computations. Defaults to ``tf.complex64``. 1203 Especially for large MIMO setups, the precision can make a significant 1204 performance difference. 1205 1206 Input 1207 ------ 1208 (y, h_hat, prior, err_var, no) : 1209 Tuple: 1210 1211 y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex 1212 Received OFDM resource grid after cyclic prefix removal and FFT 1213 1214 h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex 1215 Channel estimates for all streams from all transmitters 1216 1217 prior : [batch_size, num_tx, num_streams, num_data_symbols x num_bits_per_symbol] or [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float 1218 Prior of the transmitted signals. 1219 If ``output`` equals "bit", LLRs of the transmitted bits are expected. 1220 If ``output`` equals "symbol", logits of the transmitted constellation points are expected. 1221 1222 err_var : [Broadcastable to shape of ``h_hat``], tf.float 1223 Variance of the channel estimation error 1224 1225 no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float 1226 Variance of the AWGN 1227 1228 Output 1229 ------ 1230 One of: 1231 1232 : [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 1233 LLRs or hard-decisions for every bit of every stream, if ``output`` equals `"bit"`. 1234 1235 : [batch_size, num_tx, num_streams, num_data_symbols, num_points], tf.float or [batch_size, num_tx, num_streams, num_data_symbols], tf.int 1236 Logits or hard-decisions for constellation symbols for every stream, if ``output`` equals `"symbol"`. 1237 Hard-decisions correspond to the symbol indices. 1238 1239 Note 1240 ---- 1241 For numerical stability, we do not recommend to use this function in Graph 1242 mode with XLA, i.e., within a function that is decorated with 1243 ``@tf.function(jit_compile=True)``. 1244 However, it is possible to do so by setting 1245 ``sionna.Config.xla_compat=true``. 1246 See :py:attr:`~sionna.Config.xla_compat`. 1247 """ 1248 def __init__(self, 1249 output, 1250 resource_grid, 1251 stream_management, 1252 demapping_method="maxlog", 1253 num_iter=1, 1254 constellation_type=None, 1255 num_bits_per_symbol=None, 1256 constellation=None, 1257 hard_out=False, 1258 dtype=tf.complex64, 1259 **kwargs): 1260 1261 # Instantiate the EP detector 1262 detector = MMSEPICDetector_(output=output, 1263 demapping_method=demapping_method, 1264 num_iter=num_iter, 1265 constellation_type=constellation_type, 1266 num_bits_per_symbol=num_bits_per_symbol, 1267 constellation=constellation, 1268 hard_out=hard_out, 1269 dtype=dtype, 1270 **kwargs) 1271 1272 super().__init__(detector=detector, 1273 output=output, 1274 resource_grid=resource_grid, 1275 stream_management=stream_management, 1276 constellation_type=constellation_type, 1277 num_bits_per_symbol=num_bits_per_symbol, 1278 constellation=constellation, 1279 dtype=dtype, 1280 **kwargs)