anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

utils.py (21186B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Utility functions and layers for the MIMO package."""
      6 
      7 import numpy as np
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 from abc import ABC, abstractmethod
     11 from sionna.utils import matrix_sqrt_inv, expand_to_rank, insert_dims
     12 
     13 def complex2real_vector(z):
     14     # pylint: disable=line-too-long
     15     r"""Transforms a complex-valued vector into its real-valued equivalent.
     16 
     17     Transforms the last dimension of a complex-valued tensor into
     18     its real-valued equivalent by stacking the real and imaginary
     19     parts on top of each other.
     20 
     21     For a vector :math:`\mathbf{z}\in \mathbb{C}^M` with real and imaginary
     22     parts :math:`\mathbf{x}\in \mathbb{R}^M` and
     23     :math:`\mathbf{y}\in \mathbb{R}^M`, respectively, this function returns
     24     the vector :math:`\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in\mathbb{R}^{2M}`.
     25 
     26     Input
     27     -----
     28     : [...,M], tf.complex
     29 
     30     Output
     31     ------
     32     : [...,2M], tf.complex.real_dtype
     33     """
     34     x = tf.math.real(z)
     35     y = tf.math.imag(z)
     36     return tf.concat([x, y], axis=-1)
     37 
     38 def real2complex_vector(z):
     39 # pylint: disable=line-too-long
     40     r"""Transforms a real-valued vector into its complex-valued equivalent.
     41 
     42     Transforms the last dimension of a real-valued tensor into
     43     its complex-valued equivalent by interpreting the first half
     44     as the real and the second half as the imaginary part.
     45 
     46     For a vector :math:`\mathbf{z}=\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in \mathbb{R}^{2M}`
     47     with :math:`\mathbf{x}\in \mathbb{R}^M` and :math:`\mathbf{y}\in \mathbb{R}^M`,
     48     this function returns
     49     the vector :math:`\mathbf{x}+j\mathbf{y}\in\mathbb{C}^M`.
     50 
     51     Input
     52     -----
     53     : [...,2M], tf.float
     54 
     55     Output
     56     ------
     57     : [...,M], tf.complex
     58     """
     59     x, y = tf.split(z, 2, -1)
     60     return tf.complex(x, y)
     61 
     62 def complex2real_matrix(z):
     63     # pylint: disable=line-too-long
     64     r"""Transforms a complex-valued matrix into its real-valued equivalent.
     65 
     66     Transforms the last two dimensions of a complex-valued tensor into
     67     their real-valued matrix equivalent representation.
     68 
     69     For a matrix :math:`\mathbf{Z}\in \mathbb{C}^{M\times K}` with real and imaginary
     70     parts :math:`\mathbf{X}\in \mathbb{R}^{M\times K}` and
     71     :math:`\mathbf{Y}\in \mathbb{R}^{M\times K}`, respectively, this function returns
     72     the matrix :math:`\tilde{\mathbf{Z}}\in \mathbb{R}^{2M\times 2K}`, given as
     73 
     74     .. math::
     75 
     76         \tilde{\mathbf{Z}} = \begin{pmatrix}
     77                                 \mathbf{X} & -\mathbf{Y}\\
     78                                 \mathbf{Y} & \mathbf{X}
     79                              \end{pmatrix}.
     80 
     81     Input
     82     -----
     83     : [...,M,K], tf.complex
     84 
     85     Output
     86     ------
     87     : [...,2M, 2K], tf.complex.real_dtype
     88     """
     89     x = tf.math.real(z)
     90     y = tf.math.imag(z)
     91     row1 = tf.concat([x, -y], axis=-1)
     92     row2 = tf.concat([y, x], axis=-1)
     93     return tf.concat([row1, row2], axis=-2)
     94 
     95 def real2complex_matrix(z):
     96     # pylint: disable=line-too-long
     97     r"""Transforms a real-valued matrix into its complex-valued equivalent.
     98 
     99     Transforms the last two dimensions of a real-valued tensor into
    100     their complex-valued matrix equivalent representation.
    101 
    102     For a matrix :math:`\tilde{\mathbf{Z}}\in \mathbb{R}^{2M\times 2K}`,
    103     satisfying
    104 
    105     .. math::
    106 
    107         \tilde{\mathbf{Z}} = \begin{pmatrix}
    108                                 \mathbf{X} & -\mathbf{Y}\\
    109                                 \mathbf{Y} & \mathbf{X}
    110                              \end{pmatrix}
    111 
    112     with :math:`\mathbf{X}\in \mathbb{R}^{M\times K}` and
    113     :math:`\mathbf{Y}\in \mathbb{R}^{M\times K}`, this function returns
    114     the matrix :math:`\mathbf{Z}=\mathbf{X}+j\mathbf{Y}\in\mathbb{C}^{M\times K}`.
    115 
    116     Input
    117     -----
    118     : [...,2M,2K], tf.float
    119 
    120     Output
    121     ------
    122     : [...,M, 2], tf.complex
    123     """
    124     m = tf.shape(z)[-2]//2
    125     k = tf.shape(z)[-1]//2
    126     x = z[...,:m,:k]
    127     y = z[...,m:,:k]
    128     return tf.complex(x, y)
    129 
    130 def complex2real_covariance(r):
    131     # pylint: disable=line-too-long
    132     r"""Transforms a complex-valued covariance matrix to its real-valued equivalent.
    133 
    134     Assume a proper complex random variable :math:`\mathbf{z}\in\mathbb{C}^M` [ProperRV]_
    135     with covariance matrix :math:`\mathbf{R}= \in\mathbb{C}^{M\times M}`
    136     and real and imaginary parts :math:`\mathbf{x}\in \mathbb{R}^M` and
    137     :math:`\mathbf{y}\in \mathbb{R}^M`, respectively.
    138     This function transforms the given :math:`\mathbf{R}` into the covariance matrix of the real-valued equivalent
    139     vector :math:`\tilde{\mathbf{z}}=\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in\mathbb{R}^{2M}`, which
    140     is computed as [CovProperRV]_
    141 
    142     .. math::
    143 
    144         \mathbb{E}\left[\tilde{\mathbf{z}}\tilde{\mathbf{z}}^{\mathsf{H}} \right] =
    145         \begin{pmatrix}
    146             \frac12\Re\{\mathbf{R}\} & -\frac12\Im\{\mathbf{R}\}\\
    147             \frac12\Im\{\mathbf{R}\} & \frac12\Re\{\mathbf{R}\}
    148         \end{pmatrix}.
    149 
    150     Input
    151     -----
    152     : [...,M,M], tf.complex
    153 
    154     Output
    155     ------
    156     : [...,2M, 2M], tf.complex.real_dtype
    157     """
    158     q = complex2real_matrix(r)
    159     scale = tf.cast(2, q.dtype)
    160     return q/scale
    161 
    162 def real2complex_covariance(q):
    163     # pylint: disable=line-too-long
    164     r"""Transforms a real-valued covariance matrix to its complex-valued equivalent.
    165 
    166     Assume a proper complex random variable :math:`\mathbf{z}\in\mathbb{C}^M` [ProperRV]_
    167     with covariance matrix :math:`\mathbf{R}= \in\mathbb{C}^{M\times M}`
    168     and real and imaginary parts :math:`\mathbf{x}\in \mathbb{R}^M` and
    169     :math:`\mathbf{y}\in \mathbb{R}^M`, respectively.
    170     This function transforms the given covariance matrix of the real-valued equivalent
    171     vector :math:`\tilde{\mathbf{z}}=\left[\mathbf{x}^{\mathsf{T}}, \mathbf{y}^{\mathsf{T}} \right ]^{\mathsf{T}}\in\mathbb{R}^{2M}`, which
    172     is given as [CovProperRV]_
    173 
    174     .. math::
    175 
    176         \mathbb{E}\left[\tilde{\mathbf{z}}\tilde{\mathbf{z}}^{\mathsf{H}} \right] =
    177         \begin{pmatrix}
    178             \frac12\Re\{\mathbf{R}\} & -\frac12\Im\{\mathbf{R}\}\\
    179             \frac12\Im\{\mathbf{R}\} & \frac12\Re\{\mathbf{R}\}
    180         \end{pmatrix},
    181 
    182     into is complex-valued equivalent :math:`\mathbf{R}`.
    183 
    184     Input
    185     -----
    186     : [...,2M,2M], tf.float
    187 
    188     Output
    189     ------
    190     : [...,M, M], tf.complex
    191     """
    192     r = real2complex_matrix(q)
    193     scale = tf.cast(2, r.dtype)
    194     return r*scale
    195 
    196 def complex2real_channel(y, h, s):
    197     # pylint: disable=line-too-long
    198     r"""Transforms a complex-valued MIMO channel into its real-valued equivalent.
    199 
    200     Assume the canonical MIMO channel model
    201 
    202     .. math::
    203 
    204         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    205 
    206     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    207     :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols,
    208     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
    209     and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector with covariance
    210     matrix :math:`\mathbf{S}\in\mathbb{C}^{M\times M}`.
    211 
    212     This function returns the real-valued equivalent representations of
    213     :math:`\mathbf{y}`, :math:`\mathbf{H}`, and :math:`\mathbf{S}`,
    214     which are used by a wide variety of MIMO detection algorithms (Section VII) [YH2015]_.
    215     These are obtained by applying :meth:`~sionna.mimo.complex2real_vector` to :math:`\mathbf{y}`,
    216     :meth:`~sionna.mimo.complex2real_matrix` to :math:`\mathbf{H}`,
    217     and :meth:`~sionna.mimo.complex2real_covariance` to :math:`\mathbf{S}`.
    218 
    219     Input
    220     -----
    221     y : [...,M], tf.complex
    222         1+D tensor containing the received signals.
    223 
    224     h : [...,M,K], tf.complex
    225         2+D tensor containing the channel matrices.
    226 
    227     s : [...,M,M], tf.complex
    228         2+D tensor containing the noise covariance matrices.
    229 
    230     Output
    231     ------
    232     : [...,2M], tf.complex.real_dtype
    233         1+D tensor containing the real-valued equivalent received signals.
    234 
    235     : [...,2M,2K], tf.complex.real_dtype
    236         2+D tensor containing the real-valued equivalent channel matrices.
    237 
    238     : [...,2M,2M], tf.complex.real_dtype
    239         2+D tensor containing the real-valued equivalent noise covariance matrices.
    240     """
    241     yr = complex2real_vector(y)
    242     hr = complex2real_matrix(h)
    243     sr = complex2real_covariance(s)
    244     return yr, hr, sr
    245 
    246 def real2complex_channel(y, h, s):
    247     # pylint: disable=line-too-long
    248     r"""Transforms a real-valued MIMO channel into its complex-valued equivalent.
    249 
    250     Assume the canonical MIMO channel model
    251 
    252     .. math::
    253 
    254         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    255 
    256     where :math:`\mathbf{y}\in\mathbb{C}^M` is the received signal vector,
    257     :math:`\mathbf{x}\in\mathbb{C}^K` is the vector of transmitted symbols,
    258     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}` is the known channel matrix,
    259     and :math:`\mathbf{n}\in\mathbb{C}^M` is a noise vector with covariance
    260     matrix :math:`\mathbf{S}\in\mathbb{C}^{M\times M}`.
    261 
    262     This function transforms the real-valued equivalent representations of
    263     :math:`\mathbf{y}`, :math:`\mathbf{H}`, and :math:`\mathbf{S}`, as, e.g.,
    264     obtained with the function :meth:`~sionna.mimo.complex2real_channel`,
    265     back to their complex-valued equivalents (Section VII) [YH2015]_.
    266 
    267     Input
    268     -----
    269     y : [...,2M], tf.float
    270         1+D tensor containing the real-valued received signals.
    271 
    272     h : [...,2M,2K], tf.float
    273         2+D tensor containing the real-valued channel matrices.
    274 
    275     s : [...,2M,2M], tf.float
    276         2+D tensor containing the real-valued noise covariance matrices.
    277 
    278     Output
    279     ------
    280     : [...,M], tf.complex
    281         1+D tensor containing the complex-valued equivalent received signals.
    282 
    283     : [...,M,K], tf.complex
    284         2+D tensor containing the complex-valued equivalent channel matrices.
    285 
    286     : [...,M,M], tf.complex
    287         2+D tensor containing the complex-valued equivalent noise covariance matrices.
    288     """
    289     yc = real2complex_vector(y)
    290     hc = real2complex_matrix(h)
    291     sc = real2complex_covariance(s)
    292     return yc, hc, sc
    293 
    294 def whiten_channel(y, h, s, return_s=True):
    295     # pylint: disable=line-too-long
    296     r"""Whitens a canonical MIMO channel.
    297 
    298     Assume the canonical MIMO channel model
    299 
    300     .. math::
    301 
    302         \mathbf{y} = \mathbf{H}\mathbf{x} + \mathbf{n}
    303 
    304     where :math:`\mathbf{y}\in\mathbb{C}^M(\mathbb{R}^M)` is the received signal vector,
    305     :math:`\mathbf{x}\in\mathbb{C}^K(\mathbb{R}^K)` is the vector of transmitted symbols,
    306     :math:`\mathbf{H}\in\mathbb{C}^{M\times K}(\mathbb{R}^{M\times K})` is the known channel matrix,
    307     and :math:`\mathbf{n}\in\mathbb{C}^M(\mathbb{R}^M)` is a noise vector with covariance
    308     matrix :math:`\mathbf{S}\in\mathbb{C}^{M\times M}(\mathbb{R}^{M\times M})`.
    309 
    310     This function whitens this channel by multiplying :math:`\mathbf{y}` and
    311     :math:`\mathbf{H}` from the left by :math:`\mathbf{S}^{-\frac{1}{2}}`.
    312     Optionally, the whitened noise covariance matrix :math:`\mathbf{I}_M`
    313     can be returned.
    314 
    315     Input
    316     -----
    317     y : [...,M], tf.float or tf.complex
    318         1+D tensor containing the received signals.
    319 
    320     h : [...,M,K], tf.float or tf.complex
    321         2+D tensor containing the  channel matrices.
    322 
    323     s : [...,M,M], tf.float or complex
    324         2+D tensor containing the noise covariance matrices.
    325 
    326     return_s : bool
    327         If `True`, the whitened covariance matrix is returned.
    328         Defaults to `True`.
    329 
    330     Output
    331     ------
    332     : [...,M], tf.float or tf.complex
    333         1+D tensor containing the whitened received signals.
    334 
    335     : [...,M,K], tf.float or tf.complex
    336         2+D tensor containing the whitened channel matrices.
    337 
    338     : [...,M,M], tf.float or tf.complex
    339         2+D tensor containing the whitened noise covariance matrices.
    340         Only returned if ``return_s`` is `True`.
    341     """
    342     # Compute whitening matrix
    343     s_inv_1_2 = matrix_sqrt_inv(s)
    344     s_inv_1_2 = expand_to_rank(s_inv_1_2, tf.rank(h), 0)
    345 
    346     # Whiten obervation and channel matrix
    347     yw = tf.expand_dims(y, -1)
    348     yw = tf.matmul(s_inv_1_2, yw)
    349     yw = tf.squeeze(yw, axis=-1)
    350 
    351     hw = tf.matmul(s_inv_1_2, h)
    352 
    353     if return_s:
    354         # Ideal interference covariance matrix after whitening
    355         sw = tf.eye(tf.shape(s)[-2], dtype=s.dtype)
    356         sw = expand_to_rank(sw, tf.rank(s), 0)
    357         return yw, hw, sw
    358     else:
    359         return yw, hw
    360 
    361 
    362 class List2LLR(ABC):
    363     # pylint: disable=line-too-long
    364     r"""List2LLR()
    365 
    366     Abstract class defining a callable to compute LLRs from a list of
    367     candidate vectors (or paths) provided by a MIMO detector.
    368 
    369     The following channel model is assumed
    370 
    371     .. math::
    372         \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}}
    373 
    374     where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S` are the channel outputs,
    375     :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is an upper-triangular matrix,
    376     :math:`\bar{\mathbf{x}}\in\mathbb{C}^S` is the transmitted vector whose entries
    377     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
    378     and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S` is white noise
    379     with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and
    380     :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`.
    381 
    382     It is assumed that a MIMO detector such as :class:`~sionna.mimo.KBestDetector`
    383     produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S`
    384     and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2`
    385     for :math:`k=1,\dots,K`. This layer can also be used with the real-valued representation of the channel.
    386 
    387     Input
    388     -----
    389     (y, r, dists, path_inds, path_syms) :
    390         Tuple:
    391 
    392     y : [...,M], tf.complex or tf.float
    393         Channel outputs of the whitened channel
    394 
    395     r : [...,num_streams, num_streams], same dtype as ``y``
    396         Upper triangular channel matrix of the whitened channel
    397 
    398     dists : [...,num_paths], tf.float
    399         Distance metric for each path (or candidate)
    400 
    401     path_inds : [...,num_paths,num_streams], tf.int32
    402         Symbol indices for every stream of every path (or candidate)
    403 
    404     path_syms : [...,num_path,num_streams], same dtype as ``y``
    405         Constellation symbol for every stream of every path (or candidate)
    406 
    407     Output
    408     ------
    409     llr : [...num_streams,num_bits_per_symbol], tf.float
    410         LLRs for all bits of every stream
    411 
    412     Note
    413     ----
    414     An implementation of this class does not need to make use of all of
    415     the provided inputs which enable various different implementations.
    416     """
    417     @abstractmethod
    418     def __call__(self, inputs):
    419         raise NotImplementedError
    420 
    421 class List2LLRSimple(Layer, List2LLR):
    422     # pylint: disable=line-too-long
    423     r"""List2LLRSimple(num_bits_per_symbol, llr_clip_val=20.0, **kwargs)
    424 
    425     Computes LLRs from a list of candidate vectors (or paths) provided by a MIMO detector.
    426 
    427     The following channel model is assumed:
    428 
    429     .. math::
    430         \bar{\mathbf{y}} = \mathbf{R}\bar{\mathbf{x}} + \bar{\mathbf{n}}
    431 
    432     where :math:`\bar{\mathbf{y}}\in\mathbb{C}^S` are the channel outputs,
    433     :math:`\mathbf{R}\in\mathbb{C}^{S\times S}` is an upper-triangular matrix,
    434     :math:`\bar{\mathbf{x}}\in\mathbb{C}^S` is the transmitted vector whose entries
    435     are uniformly and independently drawn from the constellation :math:`\mathcal{C}`,
    436     and :math:`\bar{\mathbf{n}}\in\mathbb{C}^S` is white noise
    437     with :math:`\mathbb{E}\left[\bar{\mathbf{n}}\right]=\mathbf{0}` and
    438     :math:`\mathbb{E}\left[\bar{\mathbf{n}}\bar{\mathbf{n}}^{\mathsf{H}}\right]=\mathbf{I}`.
    439 
    440     It is assumed that a MIMO detector such as :class:`~sionna.mimo.KBestDetector`
    441     produces :math:`K` candidate solutions :math:`\bar{\mathbf{x}}_k\in\mathcal{C}^S`
    442     and their associated distance metrics :math:`d_k=\lVert \bar{\mathbf{y}} - \mathbf{R}\bar{\mathbf{x}}_k \rVert^2`
    443     for :math:`k=1,\dots,K`. This layer can also be used with the real-valued representation of the channel.
    444 
    445     The LLR for the :math:`i\text{th}` bit of the :math:`k\text{th}` stream is computed as
    446 
    447     .. math::
    448         \begin{align}
    449             LLR(k,i) &= \log\left(\frac{\Pr(b_{k,i}=1|\bar{\mathbf{y}},\mathbf{R})}{\Pr(b_{k,i}=0|\bar{\mathbf{y}},\mathbf{R})}\right)\\
    450                 &\approx \min_{j \in  \mathcal{C}_{k,i,0}}d_j - \min_{j \in  \mathcal{C}_{k,i,1}}d_j
    451         \end{align}
    452 
    453     where :math:`\mathcal{C}_{k,i,1}` and :math:`\mathcal{C}_{k,i,0}` are the set of indices
    454     in the list of candidates for which the :math:`i\text{th}` bit of the :math:`k\text{th}`
    455     stream is equal to 1 and 0, respectively. The LLRs are clipped to :math:`\pm LLR_\text{clip}`
    456     which can be configured through the parameter ``llr_clip_val``.
    457 
    458     If :math:`\mathcal{C}_{k,i,0}` is empty, :math:`LLR(k,i)=LLR_\text{clip}`;
    459     if :math:`\mathcal{C}_{k,i,1}` is empty, :math:`LLR(k,i)=-LLR_\text{clip}`.
    460 
    461     Parameters
    462     ----------
    463     num_bits_per_symbol : int
    464         Number of bits per constellation symbol
    465 
    466     llr_clip_val : float
    467         The absolute values of LLRs are clipped to this value.
    468         Defaults to 20.0. Can also be a trainable variable.
    469 
    470     Input
    471     -----
    472     (y, r, dists, path_inds, path_syms) :
    473         Tuple:
    474 
    475     y : [...,M], tf.complex or tf.float
    476         Channel outputs of the whitened channel
    477 
    478     r : [...,num_streams, num_streams], same dtype as ``y``
    479         Upper triangular channel matrix of the whitened channel
    480 
    481     dists : [...,num_paths], tf.float
    482         Distance metric for each path (or candidate)
    483 
    484     path_inds : [...,num_paths,num_streams], tf.int32
    485         Symbol indices for every stream of every path (or candidate)
    486 
    487     path_syms : [...,num_path,num_streams], same dtype as ``y``
    488         Constellation symbol for every stream of every path (or candidate)
    489 
    490     Output
    491     ------
    492     llr : [...num_streams,num_bits_per_symbol], tf.float
    493         LLRs for all bits of every stream
    494     """
    495     def __init__(self,
    496                  num_bits_per_symbol,
    497                  llr_clip_val=20.0,
    498                  **kwargs):
    499         super().__init__(**kwargs)
    500 
    501         # Array composed of binary representations of all symbols indices
    502         num_points = 2**num_bits_per_symbol
    503         a = np.zeros([num_points, num_bits_per_symbol])
    504         for i in range(num_points):
    505             a[i, :] = np.array(list(np.binary_repr(i, num_bits_per_symbol)),
    506                                dtype=np.int32)
    507 
    508         # Compute symbol indices for which the bits are 0 or 1, e.g.,:
    509         # The ith column of c0 provides all symbol indices for which
    510         # the ith bit is 0.
    511         c0 = np.zeros([int(num_points/2), num_bits_per_symbol])
    512         c1 = np.zeros([int(num_points/2), num_bits_per_symbol])
    513         for i in range(num_bits_per_symbol):
    514             c0[:,i] = np.where(a[:,i]==0)[0]
    515             c1[:,i] = np.where(a[:,i]==1)[0]
    516 
    517         # Convert to tensor and add dummy dimensions needed for broadcasting
    518         self._c0 = expand_to_rank(tf.constant(c0, tf.int32), 5, 0)
    519         self._c1 = expand_to_rank(tf.constant(c1, tf.int32), 5, 0)
    520 
    521         # Assign this absolute value to all LLRs without counter-hypothesis
    522         self.llr_clip_val = llr_clip_val
    523 
    524     @property
    525     def llr_clip_val(self):
    526         return self._llr_clip_val
    527 
    528     @llr_clip_val.setter
    529     def llr_clip_val(self, value):
    530         self._llr_clip_val = value
    531 
    532     def __call__(self, inputs):
    533 
    534         # dists :     [batch_size, num_paths]
    535         # path_inds : [batch_size, num_paths, num_streams]
    536         dists, path_inds = inputs[2:4]
    537 
    538         # Scaled by 0.5 to account for the reduced noise power in each complex
    539         # dimension if real channel representation is used.
    540         if inputs[0].dtype.is_floating:
    541             dists = dists/2.0
    542 
    543         # Compute for every symbol in every path which bits are 0 or 1
    544         # b0/b1: [batch_size, num_path, num_streams, num_bits_per_symbol]
    545         # The reduce_any op is forced to run in XLA mode to be able to
    546         # work with very large tensors. There seems to an int32 indexing issue
    547         # for all TF reduce CUDA kernels.
    548         path_inds = insert_dims(path_inds, 2, axis=-1)
    549         b0 = tf.equal(path_inds, self._c0)
    550         b1 = tf.equal(path_inds, self._c1)
    551         b0 = tf.function(tf.reduce_any, jit_compile=True)(b0, axis=-2)
    552         b1 = tf.function(tf.reduce_any, jit_compile=True)(b1, axis=-2)
    553 
    554         # Compute distances for all bits in all paths, set distance to inf
    555         # if the bit does not have the correct value
    556         dists = expand_to_rank(dists, tf.rank(b0), axis=-1)
    557         d0 = tf.where(b0, dists, tf.constant(np.inf, dists.dtype))
    558         d1 = tf.where(b1, dists, tf.constant(np.inf, dists.dtype))
    559 
    560         # Compute minimum distance for each bit in each stream
    561         # l0/l1: [batch_size, num_streams, num_bits_per_symbol]
    562         l0 = tf.reduce_min(d0, axis=1)
    563         l1 = tf.reduce_min(d1, axis=1)
    564 
    565         # Compute LLRs
    566         llr = l0-l1
    567 
    568         #  Clip LLRs
    569         llr = tf.clip_by_value(llr, -self.llr_clip_val, self.llr_clip_val)
    570 
    571         return llr
    572