utils.py (21186B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Utility functions and layers for the MIMO package.""" 6 7 import numpy as np 8 import tensorflow as tf 9 from tensorflow.keras.layers import Layer 10 from abc import ABC, abstractmethod 11 from sionna.utils import matrix_sqrt_inv, expand_to_rank, insert_dims 12 13 def complex2real_vector(z): 14 # pylint: disable=line-too-long 15 r"""Transforms a complex-valued vector into its real-valued equivalent. 16 17 Transforms the last dimension of a complex-valued tensor into 18 its real-valued equivalent by stacking the real and imaginary 19 parts on top of each other. 20 21 For a vector :math:`\mathbf{z}\in \mathbb{C}^M` with real and imaginary 22 parts :math:`\mathbf{x}\in \mathbb{R}^M` and 23 :math:`\mathbf{y}\in \mathbb{R}^M`, respectively, this function returns 24 the vector :math:`\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in\mathbb{R}^{2M}`. 25 26 Input 27 ----- 28 : [...,M], tf.complex 29 30 Output 31 ------ 32 : [...,2M], tf.complex.real_dtype 33 """ 34 x = tf.math.real(z) 35 y = tf.math.imag(z) 36 return tf.concat([x, y], axis=-1) 37 38 def real2complex_vector(z): 39 # pylint: disable=line-too-long 40 r"""Transforms a real-valued vector into its complex-valued equivalent. 41 42 Transforms the last dimension of a real-valued tensor into 43 its complex-valued equivalent by interpreting the first half 44 as the real and the second half as the imaginary part. 45 46 For a vector :math:`\mathbf{z}=\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in \mathbb{R}^{2M}` 47 with :math:`\mathbf{x}\in \mathbb{R}^M` and :math:`\mathbf{y}\in \mathbb{R}^M`, 48 this function returns 49 the vector :math:`\mathbf{x}+j\mathbf{y}\in\mathbb{C}^M`. 50 51 Input 52 ----- 53 : [...,2M], tf.float 54 55 Output 56 ------ 57 : [...,M], tf.complex 58 """ 59 x, y = tf.split(z, 2, -1) 60 return tf.complex(x, y) 61 62 def complex2real_matrix(z): 63 # pylint: disable=line-too-long 64 r"""Transforms a complex-valued matrix into its real-valued equivalent. 65 66 Transforms the last two dimensions of a complex-valued tensor into 67 their real-valued matrix equivalent representation. 68 69 For a matrix :math:`\mathbf{Z}\in \mathbb{C}^{M\times K}` with real and imaginary 70 parts :math:`\mathbf{X}\in \mathbb{R}^{M\times K}` and 71 :math:`\mathbf{Y}\in \mathbb{R}^{M\times K}`, respectively, this function returns 72 the matrix :math:`\tilde{\mathbf{Z}}\in \mathbb{R}^{2M\times 2K}`, given as 73 74 .. math:: 75 76 \tilde{\mathbf{Z}} = \begin{pmatrix} 77 \mathbf{X} & -\mathbf{Y}\\ 78 \mathbf{Y} & \mathbf{X} 79 \end{pmatrix}. 80 81 Input 82 ----- 83 : [...,M,K], tf.complex 84 85 Output 86 ------ 87 : [...,2M, 2K], tf.complex.real_dtype 88 """ 89 x = tf.math.real(z) 90 y = tf.math.imag(z) 91 row1 = tf.concat([x, -y], axis=-1) 92 row2 = tf.concat([y, x], axis=-1) 93 return tf.concat([row1, row2], axis=-2) 94 95 def real2complex_matrix(z): 96 # pylint: disable=line-too-long 97 r"""Transforms a real-valued matrix into its complex-valued equivalent. 98 99 Transforms the last two dimensions of a real-valued tensor into 100 their complex-valued matrix equivalent representation. 101 102 For a matrix :math:`\tilde{\mathbf{Z}}\in \mathbb{R}^{2M\times 2K}`, 103 satisfying 104 105 .. math:: 106 107 \tilde{\mathbf{Z}} = \begin{pmatrix} 108 \mathbf{X} & -\mathbf{Y}\\ 109 \mathbf{Y} & \mathbf{X} 110 \end{pmatrix} 111 112 with :math:`\mathbf{X}\in \mathbb{R}^{M\times K}` and 113 :math:`\mathbf{Y}\in \mathbb{R}^{M\times K}`, this function returns 114 the matrix :math:`\mathbf{Z}=\mathbf{X}+j\mathbf{Y}\in\mathbb{C}^{M\times K}`. 115 116 Input 117 ----- 118 : [...,2M,2K], tf.float 119 120 Output 121 ------ 122 : [...,M, 2], tf.complex 123 """ 124 m = tf.shape(z)[-2]//2 125 k = tf.shape(z)[-1]//2 126 x = z[...,:m,:k] 127 y = z[...,m:,:k] 128 return tf.complex(x, y) 129 130 def complex2real_covariance(r): 131 # pylint: disable=line-too-long 132 r"""Transforms a complex-valued covariance matrix to its real-valued equivalent. 133 134 Assume a proper complex random variable :math:`\mathbf{z}\in\mathbb{C}^M` [ProperRV]_ 135 with covariance matrix :math:`\mathbf{R}= \in\mathbb{C}^{M\times M}` 136 and real and imaginary parts :math:`\mathbf{x}\in \mathbb{R}^M` and 137 :math:`\mathbf{y}\in \mathbb{R}^M`, respectively. 138 This function transforms the given :math:`\mathbf{R}` into the covariance matrix of the real-valued equivalent 139 vector :math:`\tilde{\mathbf{z}}=\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in\mathbb{R}^{2M}`, which 140 is computed as [CovProperRV]_ 141 142 .. math:: 143 144 \mathbb{E}\left[\tilde{\mathbf{z}}\tilde{\mathbf{z}}^{\mathsf{H}} \right] = 145 \begin{pmatrix} 146 \frac12\Re\{\mathbf{R}\} & -\frac12\Im\{\mathbf{R}\}\\ 147 \frac12\Im\{\mathbf{R}\} & \frac12\Re\{\mathbf{R}\} 148 \end{pmatrix}. 149 150 Input 151 ----- 152 : [...,M,M], tf.complex 153 154 Output 155 ------ 156 : [...,2M, 2M], tf.complex.real_dtype 157 """ 158 q = complex2real_matrix(r) 159 scale = tf.cast(2, q.dtype) 160 return q/scale 161 162 def real2complex_covariance(q): 163 # pylint: disable=line-too-long 164 r"""Transforms a real-valued covariance matrix to its complex-valued equivalent. 165 166 Assume a proper complex random variable :math:`\mathbf{z}\in\mathbb{C}^M` [ProperRV]_ 167 with covariance matrix :math:`\mathbf{R}= \in\mathbb{C}^{M\times M}` 168 and real and imaginary parts :math:`\mathbf{x}\in \mathbb{R}^M` and 169 :math:`\mathbf{y}\in \mathbb{R}^M`, respectively. 170 This function transforms the given covariance matrix of the real-valued equivalent 171 vector :math:`\tilde{\mathbf{z}}=\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in\mathbb{R}^{2M}`, which 172 is given as [CovProperRV]_ 173 174 .. math:: 175 176 \mathbb{E}\left[\tilde{\mathbf{z}}\tilde{\mathbf{z}}^{\mathsf{H}} \right] = 177 \begin{pmatrix} 178 \frac12\Re\{\mathbf{R}\} & -\frac12\Im\{\mathbf{R}\}\\ 179 \frac12\Im\{\mathbf{R}\} & \frac12\Re\{\mathbf{R}\} 180 \end{pmatrix}, 181 182 into is complex-valued equivalent :math:`\mathbf{R}`. 183 184 Input 185 ----- 186 : [...,2M,2M], tf.float 187 188 Output 189 ------ 190 : [...,M, M], tf.complex 191 """ 192 r = real2complex_matrix(q) 193 scale = tf.cast(2, r.dtype) 194 return r*scale 195 196 def complex2real_channel(y, h, s): 197 # pylint: disable=line-too-long 198 r"""Transforms a complex-valued MIMO channel into its real-valued equivalent. 199 200 Assume the canonical MIMO channel model 201 202 .. math:: 203 204 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 205 206 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 207 :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, 208 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 209 and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector with covariance 210 matrix :math:`\mathbf{S}\in\mathbb{C}^{M\times M}`. 211 212 This function returns the real-valued equivalent representations of 213 :math:`\mathbf{y}`, :math:`\mathbf{H}`, and :math:`\mathbf{S}`, 214 which are used by a wide variety of MIMO detection algorithms (Section VII) [YH2015]_. 215 These are obtained by applying :meth:`~sionna.mimo.complex2real_vector` to :math:`\mathbf{y}`, 216 :meth:`~sionna.mimo.complex2real_matrix` to :math:`\mathbf{H}`, 217 and :meth:`~sionna.mimo.complex2real_covariance` to :math:`\mathbf{S}`. 218 219 Input 220 ----- 221 y : [...,M], tf.complex 222 1+D tensor containing the received signals. 223 224 h : [...,M,K], tf.complex 225 2+D tensor containing the channel matrices. 226 227 s : [...,M,M], tf.complex 228 2+D tensor containing the noise covariance matrices. 229 230 Output 231 ------ 232 : [...,2M], tf.complex.real_dtype 233 1+D tensor containing the real-valued equivalent received signals. 234 235 : [...,2M,2K], tf.complex.real_dtype 236 2+D tensor containing the real-valued equivalent channel matrices. 237 238 : [...,2M,2M], tf.complex.real_dtype 239 2+D tensor containing the real-valued equivalent noise covariance matrices. 240 """ 241 yr = complex2real_vector(y) 242 hr = complex2real_matrix(h) 243 sr = complex2real_covariance(s) 244 return yr, hr, sr 245 246 def real2complex_channel(y, h, s): 247 # pylint: disable=line-too-long 248 r"""Transforms a real-valued MIMO channel into its complex-valued equivalent. 249 250 Assume the canonical MIMO channel model 251 252 .. math:: 253 254 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 255 256 where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector, 257 :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols, 258 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix, 259 and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector with covariance 260 matrix :math:`\mathbf{S}\in\mathbb{C}^{M\times M}`. 261 262 This function transforms the real-valued equivalent representations of 263 :math:`\mathbf{y}`, :math:`\mathbf{H}`, and :math:`\mathbf{S}`, as, e.g., 264 obtained with the function :meth:`~sionna.mimo.complex2real_channel`, 265 back to their complex-valued equivalents (Section VII) [YH2015]_. 266 267 Input 268 ----- 269 y : [...,2M], tf.float 270 1+D tensor containing the real-valued received signals. 271 272 h : [...,2M,2K], tf.float 273 2+D tensor containing the real-valued channel matrices. 274 275 s : [...,2M,2M], tf.float 276 2+D tensor containing the real-valued noise covariance matrices. 277 278 Output 279 ------ 280 : [...,M], tf.complex 281 1+D tensor containing the complex-valued equivalent received signals. 282 283 : [...,M,K], tf.complex 284 2+D tensor containing the complex-valued equivalent channel matrices. 285 286 : [...,M,M], tf.complex 287 2+D tensor containing the complex-valued equivalent noise covariance matrices. 288 """ 289 yc = real2complex_vector(y) 290 hc = real2complex_matrix(h) 291 sc = real2complex_covariance(s) 292 return yc, hc, sc 293 294 def whiten_channel(y, h, s, return_s=True): 295 # pylint: disable=line-too-long 296 r"""Whitens a canonical MIMO channel. 297 298 Assume the canonical MIMO channel model 299 300 .. math:: 301 302 \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n} 303 304 where :math:`\mathbf{y}\in\mathbb{C}^M(\mathbb{R}^M)` is the received signal vector, 305 :math:`\mathbf{x}\in\mathbb{C}^K(\mathbb{R}^K)` is the vector of transmitted symbols, 306 :math:`\mathbf{H}\in\mathbb{C}^{M\times K}(\mathbb{R}^{M\times K})` is the known channel matrix, 307 and :math:`\mathbf{n}\in\mathbb{C}^M(\mathbb{R}^M)` is a noise vector with covariance 308 matrix :math:`\mathbf{S}\in\mathbb{C}^{M\times M}(\mathbb{R}^{M\times M})`. 309 310 This function whitens this channel by multiplying :math:`\mathbf{y}` and 311 :math:`\mathbf{H}` from the left by :math:`\mathbf{S}^{-\frac{1}{2}}`. 312 Optionally, the whitened noise covariance matrix :math:`\mathbf{I}_M` 313 can be returned. 314 315 Input 316 ----- 317 y : [...,M], tf.float or tf.complex 318 1+D tensor containing the received signals. 319 320 h : [...,M,K], tf.float or tf.complex 321 2+D tensor containing the channel matrices. 322 323 s : [...,M,M], tf.float or complex 324 2+D tensor containing the noise covariance matrices. 325 326 return_s : bool 327 If `True`, the whitened covariance matrix is returned. 328 Defaults to `True`. 329 330 Output 331 ------ 332 : [...,M], tf.float or tf.complex 333 1+D tensor containing the whitened received signals. 334 335 : [...,M,K], tf.float or tf.complex 336 2+D tensor containing the whitened channel matrices. 337 338 : [...,M,M], tf.float or tf.complex 339 2+D tensor containing the whitened noise covariance matrices. 340 Only returned if ``return_s`` is `True`. 341 """ 342 # Compute whitening matrix 343 s_inv_1_2 = matrix_sqrt_inv(s) 344 s_inv_1_2 = expand_to_rank(s_inv_1_2, tf.rank(h), 0) 345 346 # Whiten obervation and channel matrix 347 yw = tf.expand_dims(y, -1) 348 yw = tf.matmul(s_inv_1_2, yw) 349 yw = tf.squeeze(yw, axis=-1) 350 351 hw = tf.matmul(s_inv_1_2, h) 352 353 if return_s: 354 # Ideal interference covariance matrix after whitening 355 sw = tf.eye(tf.shape(s)[-2], dtype=s.dtype) 356 sw = expand_to_rank(sw, tf.rank(s), 0) 357 return yw, hw, sw 358 else: 359 return yw, hw 360 361 362 class List2LLR(ABC): 363 # pylint: disable=line-too-long 364 r"""List2LLR() 365 366 Abstract class defining a callable to compute LLRs from a list of 367 candidate vectors (or paths) provided by a MIMO detector. 368 369 The following channel model is assumed 370 371 .. math:: 372 \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}} 373 374 where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S` are the channel outputs, 375 :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is an upper-triangular matrix, 376 :math:`\bar{\mathbf{x}}\in\mathbb{C}^S` is the transmitted vector whose entries 377 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 378 and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S` is white noise 379 with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and 380 :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`. 381 382 It is assumed that a MIMO detector such as :class:`~sionna.mimo.KBestDetector` 383 produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S` 384 and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2` 385 for :math:`k=1,\dots,K`. This layer can also be used with the real-valued representation of the channel. 386 387 Input 388 ----- 389 (y, r, dists, path_inds, path_syms) : 390 Tuple: 391 392 y : [...,M], tf.complex or tf.float 393 Channel outputs of the whitened channel 394 395 r : [...,num_streams, num_streams], same dtype as ``y`` 396 Upper triangular channel matrix of the whitened channel 397 398 dists : [...,num_paths], tf.float 399 Distance metric for each path (or candidate) 400 401 path_inds : [...,num_paths,num_streams], tf.int32 402 Symbol indices for every stream of every path (or candidate) 403 404 path_syms : [...,num_path,num_streams], same dtype as ``y`` 405 Constellation symbol for every stream of every path (or candidate) 406 407 Output 408 ------ 409 llr : [...num_streams,num_bits_per_symbol], tf.float 410 LLRs for all bits of every stream 411 412 Note 413 ---- 414 An implementation of this class does not need to make use of all of 415 the provided inputs which enable various different implementations. 416 """ 417 @abstractmethod 418 def __call__(self, inputs): 419 raise NotImplementedError 420 421 class List2LLRSimple(Layer, List2LLR): 422 # pylint: disable=line-too-long 423 r"""List2LLRSimple(num_bits_per_symbol, llr_clip_val=20.0, **kwargs) 424 425 Computes LLRs from a list of candidate vectors (or paths) provided by a MIMO detector. 426 427 The following channel model is assumed: 428 429 .. math:: 430 \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}} 431 432 where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S` are the channel outputs, 433 :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is an upper-triangular matrix, 434 :math:`\bar{\mathbf{x}}\in\mathbb{C}^S` is the transmitted vector whose entries 435 are uniformly and independently drawn from the constellation :math:`\mathcal{C}`, 436 and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S` is white noise 437 with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and 438 :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`. 439 440 It is assumed that a MIMO detector such as :class:`~sionna.mimo.KBestDetector` 441 produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S` 442 and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2` 443 for :math:`k=1,\dots,K`. This layer can also be used with the real-valued representation of the channel. 444 445 The LLR for the :math:`i\text{th}` bit of the :math:`k\text{th}` stream is computed as 446 447 .. math:: 448 \begin{align} 449 LLR(k,i) &= \log\left(\frac{\Pr(b_{k,i}=1|\bar{\mathbf{y}},\mathbf{R})}{\Pr(b_{k,i}=0|\bar{\mathbf{y}},\mathbf{R})}\right)\\ 450 &\approx \min_{j \in \mathcal{C}_{k,i,0}}d_j - \min_{j \in \mathcal{C}_{k,i,1}}d_j 451 \end{align} 452 453 where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the set of indices 454 in the list of candidates for which the :math:`i\text{th}` bit of the :math:`k\text{th}` 455 stream is equal to 1 and 0, respectively. The LLRs are clipped to :math:`\pm LLR_\text{clip}` 456 which can be configured through the parameter ``llr_clip_val``. 457 458 If :math:`\mathcal{C}_{k,i,0}` is empty, :math:`LLR(k,i)=LLR_\text{clip}`; 459 if :math:`\mathcal{C}_{k,i,1}` is empty, :math:`LLR(k,i)=-LLR_\text{clip}`. 460 461 Parameters 462 ---------- 463 num_bits_per_symbol : int 464 Number of bits per constellation symbol 465 466 llr_clip_val : float 467 The absolute values of LLRs are clipped to this value. 468 Defaults to 20.0. Can also be a trainable variable. 469 470 Input 471 ----- 472 (y, r, dists, path_inds, path_syms) : 473 Tuple: 474 475 y : [...,M], tf.complex or tf.float 476 Channel outputs of the whitened channel 477 478 r : [...,num_streams, num_streams], same dtype as ``y`` 479 Upper triangular channel matrix of the whitened channel 480 481 dists : [...,num_paths], tf.float 482 Distance metric for each path (or candidate) 483 484 path_inds : [...,num_paths,num_streams], tf.int32 485 Symbol indices for every stream of every path (or candidate) 486 487 path_syms : [...,num_path,num_streams], same dtype as ``y`` 488 Constellation symbol for every stream of every path (or candidate) 489 490 Output 491 ------ 492 llr : [...num_streams,num_bits_per_symbol], tf.float 493 LLRs for all bits of every stream 494 """ 495 def __init__(self, 496 num_bits_per_symbol, 497 llr_clip_val=20.0, 498 **kwargs): 499 super().__init__(**kwargs) 500 501 # Array composed of binary representations of all symbols indices 502 num_points = 2**num_bits_per_symbol 503 a = np.zeros([num_points, num_bits_per_symbol]) 504 for i in range(num_points): 505 a[i, :] = np.array(list(np.binary_repr(i, num_bits_per_symbol)), 506 dtype=np.int32) 507 508 # Compute symbol indices for which the bits are 0 or 1, e.g.,: 509 # The ith column of c0 provides all symbol indices for which 510 # the ith bit is 0. 511 c0 = np.zeros([int(num_points/2), num_bits_per_symbol]) 512 c1 = np.zeros([int(num_points/2), num_bits_per_symbol]) 513 for i in range(num_bits_per_symbol): 514 c0[:,i] = np.where(a[:,i]==0)[0] 515 c1[:,i] = np.where(a[:,i]==1)[0] 516 517 # Convert to tensor and add dummy dimensions needed for broadcasting 518 self._c0 = expand_to_rank(tf.constant(c0, tf.int32), 5, 0) 519 self._c1 = expand_to_rank(tf.constant(c1, tf.int32), 5, 0) 520 521 # Assign this absolute value to all LLRs without counter-hypothesis 522 self.llr_clip_val = llr_clip_val 523 524 @property 525 def llr_clip_val(self): 526 return self._llr_clip_val 527 528 @llr_clip_val.setter 529 def llr_clip_val(self, value): 530 self._llr_clip_val = value 531 532 def __call__(self, inputs): 533 534 # dists : [batch_size, num_paths] 535 # path_inds : [batch_size, num_paths, num_streams] 536 dists, path_inds = inputs[2:4] 537 538 # Scaled by 0.5 to account for the reduced noise power in each complex 539 # dimension if real channel representation is used. 540 if inputs[0].dtype.is_floating: 541 dists = dists/2.0 542 543 # Compute for every symbol in every path which bits are 0 or 1 544 # b0/b1: [batch_size, num_path, num_streams, num_bits_per_symbol] 545 # The reduce_any op is forced to run in XLA mode to be able to 546 # work with very large tensors. There seems to an int32 indexing issue 547 # for all TF reduce CUDA kernels. 548 path_inds = insert_dims(path_inds, 2, axis=-1) 549 b0 = tf.equal(path_inds, self._c0) 550 b1 = tf.equal(path_inds, self._c1) 551 b0 = tf.function(tf.reduce_any, jit_compile=True)(b0, axis=-2) 552 b1 = tf.function(tf.reduce_any, jit_compile=True)(b1, axis=-2) 553 554 # Compute distances for all bits in all paths, set distance to inf 555 # if the bit does not have the correct value 556 dists = expand_to_rank(dists, tf.rank(b0), axis=-1) 557 d0 = tf.where(b0, dists, tf.constant(np.inf, dists.dtype)) 558 d1 = tf.where(b1, dists, tf.constant(np.inf, dists.dtype)) 559 560 # Compute minimum distance for each bit in each stream 561 # l0/l1: [batch_size, num_streams, num_bits_per_symbol] 562 l0 = tf.reduce_min(d0, axis=1) 563 l1 = tf.reduce_min(d1, axis=1) 564 565 # Compute LLRs 566 llr = l0-l1 567 568 # Clip LLRs 569 llr = tf.clip_by_value(llr, -self.llr_clip_val, self.llr_clip_val) 570 571 return llr 572