anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

mapping.py (66390B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Layers for (de)mapping, constellation class, and utility functions"""
      6 
      7 import numpy as np
      8 import tensorflow as tf
      9 from tensorflow.keras.layers import Layer
     10 import matplotlib.pyplot as plt
     11 
     12 import sionna
     13 from sionna import config
     14 
     15 def pam_gray(b):
     16     # pylint: disable=line-too-long
     17     r"""Maps a vector of bits to a PAM constellation points with Gray labeling.
     18 
     19     This recursive function maps a binary vector to Gray-labelled PAM
     20     constellation points. It can be used to generated QAM constellations.
     21     The constellation is not normalized.
     22 
     23     Input
     24     -----
     25     b : [n], NumPy array
     26         Tensor with with binary entries.
     27 
     28     Output
     29     ------
     30     : signed int
     31         The PAM constellation point taking values in
     32         :math:`\{\pm 1,\pm 3,\dots,\pm (2^n-1)\}`.
     33 
     34     Note
     35     ----
     36     This algorithm is a recursive implementation of the expressions found in
     37     Section 5.1 of [3GPPTS38211]_. It is used in the 5G standard.
     38     """ # pylint: disable=C0301
     39 
     40     if len(b)>1:
     41         return (1-2*b[0])*(2**len(b[1:]) - pam_gray(b[1:]))
     42     return 1-2*b[0]
     43 
     44 def qam(num_bits_per_symbol, normalize=True):
     45     r"""Generates a QAM constellation.
     46 
     47     This function generates a complex-valued vector, where each element is
     48     a constellation point of an M-ary QAM constellation. The bit
     49     label of the ``n`` th point is given by the length-``num_bits_per_symbol``
     50     binary represenation of ``n``.
     51 
     52     Input
     53     -----
     54     num_bits_per_symbol : int
     55         The number of bits per constellation point.
     56         Must be a multiple of two, e.g., 2, 4, 6, 8, etc.
     57 
     58     normalize: bool
     59         If `True`, the constellation is normalized to have unit power.
     60         Defaults to `True`.
     61 
     62     Output
     63     ------
     64     : :math:`[2^{\text{num_bits_per_symbol}}]`, np.complex64
     65         The QAM constellation.
     66 
     67     Note
     68     ----
     69     The bit label of the nth constellation point is given by the binary
     70     representation of its position within the array and can be obtained
     71     through ``np.binary_repr(n, num_bits_per_symbol)``.
     72 
     73 
     74     The normalization factor of a QAM constellation is given in
     75     closed-form as:
     76 
     77     .. math::
     78         \sqrt{\frac{1}{2^{n-2}}\sum_{i=1}^{2^{n-1}}(2i-1)^2}
     79 
     80     where :math:`n= \text{num_bits_per_symbol}/2` is the number of bits
     81     per dimension.
     82 
     83     This algorithm is a recursive implementation of the expressions found in
     84     Section 5.1 of [3GPPTS38211]_. It is used in the 5G standard.
     85     """ # pylint: disable=C0301
     86 
     87     try:
     88         assert num_bits_per_symbol % 2 == 0 # is even
     89         assert num_bits_per_symbol >0 # is larger than zero
     90     except AssertionError as error:
     91         raise ValueError("num_bits_per_symbol must be a multiple of 2") \
     92         from error
     93     assert isinstance(normalize, bool), "normalize must be boolean"
     94 
     95     # Build constellation by iterating through all points
     96     c = np.zeros([2**num_bits_per_symbol], dtype=np.complex64)
     97     for i in range(0, 2**num_bits_per_symbol):
     98         b = np.array(list(np.binary_repr(i,num_bits_per_symbol)),
     99                      dtype=np.int16)
    100         c[i] = pam_gray(b[0::2]) + 1j*pam_gray(b[1::2]) # PAM in each dimension
    101 
    102     if normalize: # Normalize to unit energy
    103         n = int(num_bits_per_symbol/2)
    104         qam_var = 1/(2**(n-2))*np.sum(np.linspace(1,2**n-1, 2**(n-1))**2)
    105         c /= np.sqrt(qam_var)
    106     return c
    107 
    108 def pam(num_bits_per_symbol, normalize=True):
    109     r"""Generates a PAM constellation.
    110 
    111     This function generates a real-valued vector, where each element is
    112     a constellation point of an M-ary PAM constellation. The bit
    113     label of the ``n`` th point is given by the length-``num_bits_per_symbol``
    114     binary represenation of ``n``.
    115 
    116     Input
    117     -----
    118     num_bits_per_symbol : int
    119         The number of bits per constellation point.
    120         Must be positive.
    121 
    122     normalize: bool
    123         If `True`, the constellation is normalized to have unit power.
    124         Defaults to `True`.
    125 
    126     Output
    127     ------
    128     : :math:`[2^{\text{num_bits_per_symbol}}]`, np.float32
    129         The PAM constellation.
    130 
    131     Note
    132     ----
    133     The bit label of the nth constellation point is given by the binary
    134     representation of its position within the array and can be obtained
    135     through ``np.binary_repr(n, num_bits_per_symbol)``.
    136 
    137 
    138     The normalization factor of a PAM constellation is given in
    139     closed-form as:
    140 
    141     .. math::
    142         \sqrt{\frac{1}{2^{n-1}}\sum_{i=1}^{2^{n-1}}(2i-1)^2}
    143 
    144     where :math:`n= \text{num_bits_per_symbol}` is the number of bits
    145     per symbol.
    146 
    147     This algorithm is a recursive implementation of the expressions found in
    148     Section 5.1 of [3GPPTS38211]_. It is used in the 5G standard.
    149     """ # pylint: disable=C0301
    150 
    151     try:
    152         assert num_bits_per_symbol >0 # is larger than zero
    153     except AssertionError as error:
    154         raise ValueError("num_bits_per_symbol must be positive") \
    155         from error
    156     assert isinstance(normalize, bool), "normalize must be boolean"
    157 
    158     # Build constellation by iterating through all points
    159     c = np.zeros([2**num_bits_per_symbol], dtype=np.float32)
    160     for i in range(0, 2**num_bits_per_symbol):
    161         b = np.array(list(np.binary_repr(i,num_bits_per_symbol)),
    162                      dtype=np.int16)
    163         c[i] = pam_gray(b)
    164 
    165     if normalize: # Normalize to unit energy
    166         n = int(num_bits_per_symbol)
    167         pam_var = 1/(2**(n-1))*np.sum(np.linspace(1,2**n-1, 2**(n-1))**2)
    168         c /= np.sqrt(pam_var)
    169     return c
    170 
    171 class Constellation(Layer):
    172     # pylint: disable=line-too-long
    173     r"""
    174     Constellation(constellation_type, num_bits_per_symbol, initial_value=None, normalize=True, center=False, trainable=False, dtype=tf.complex64, **kwargs)
    175 
    176     Constellation that can be used by a (de)mapper.
    177 
    178     This class defines a constellation, i.e., a complex-valued vector of
    179     constellation points. A constellation can be trainable. The binary
    180     representation of the index of an element of this vector corresponds
    181     to the bit label of the constellation point. This implicit bit
    182     labeling is used by the ``Mapper`` and ``Demapper`` classes.
    183 
    184     Parameters
    185     ----------
    186     constellation_type : One of ["qam", "pam", "custom"], str
    187         For "custom", the constellation points are randomly initialized
    188         if no ``initial_value`` is provided.
    189 
    190     num_bits_per_symbol : int
    191         The number of bits per constellation symbol, e.g., 4 for QAM16.
    192 
    193     initial_value : :math:`[2^\text{num_bits_per_symbol}]`, NumPy array or Tensor
    194         Initial values of the constellation points. If ``normalize`` or
    195         ``center`` are `True`, the initial constellation might be changed.
    196 
    197     normalize : bool
    198         If `True`, the constellation is normalized to have unit power.
    199         Defaults to `True`.
    200 
    201     center : bool
    202         If `True`, the constellation is ensured to have zero mean.
    203         Defaults to `False`.
    204 
    205     trainable : bool
    206         If `True`, the constellation points are trainable variables.
    207         Defaults to `False`.
    208 
    209     dtype : [tf.complex64, tf.complex128], tf.DType
    210         The dtype of the constellation.
    211 
    212     Output
    213     ------
    214     : :math:`[2^\text{num_bits_per_symbol}]`, ``dtype``
    215         The constellation.
    216 
    217     Note
    218     ----
    219     One can create a trainable PAM/QAM constellation. This is
    220     equivalent to creating a custom trainable constellation which is
    221     initialized with PAM/QAM constellation points.
    222     """
    223     # pylint: enable=C0301
    224 
    225     def __init__(self,
    226                  constellation_type,
    227                  num_bits_per_symbol,
    228                  initial_value=None,
    229                  normalize=True,
    230                  center=False,
    231                  trainable=False,
    232                  dtype=tf.complex64,
    233                  **kwargs):
    234         super().__init__(**kwargs)
    235         assert dtype in [tf.complex64, tf.complex128],\
    236             "dtype must be tf.complex64 or tf.complex128"
    237         self._dtype = dtype
    238 
    239         assert constellation_type in ("qam", "pam", "custom"),\
    240             "Wrong constellation type"
    241         self._constellation_type = constellation_type
    242 
    243         assert isinstance(normalize, bool), "normalize must be boolean"
    244         self._normalize = normalize
    245 
    246         assert isinstance(center, bool), "center must be boolean"
    247         self._center = center
    248 
    249         assert isinstance(trainable, bool), "trainable must be boolean"
    250         self._trainable = trainable
    251 
    252         # allow float inputs that represent int
    253         assert isinstance(num_bits_per_symbol, (float,int)),\
    254             "num_bits_per_symbol must be integer"
    255         assert (num_bits_per_symbol%1==0),\
    256             "num_bits_per_symbol must be integer"
    257         num_bits_per_symbol = int(num_bits_per_symbol)
    258 
    259         if self._constellation_type=="qam":
    260             assert num_bits_per_symbol%2 == 0 and num_bits_per_symbol>0,\
    261                 "num_bits_per_symbol must be a multiple of 2"
    262             self._num_bits_per_symbol = int(num_bits_per_symbol)
    263 
    264             assert initial_value is None, "QAM must not have an initial value"
    265             points = qam(self._num_bits_per_symbol, normalize=self.normalize)
    266             points = tf.cast(points, self._dtype)
    267 
    268         if self._constellation_type=="pam":
    269             assert num_bits_per_symbol>0,\
    270                 "num_bits_per_symbol must be integer"
    271             self._num_bits_per_symbol = int(num_bits_per_symbol)
    272 
    273             assert initial_value is None, "PAM must not have an initial value"
    274             points = pam(self._num_bits_per_symbol, normalize=self.normalize)
    275             points = tf.cast(points, self._dtype)
    276 
    277         if self._constellation_type=="custom":
    278             assert num_bits_per_symbol>0,\
    279                 "num_bits_per_symbol must be integer"
    280             self._num_bits_per_symbol = int(num_bits_per_symbol)
    281 
    282             # Randomly initialize points if no initial_value is provided
    283             if initial_value is None:
    284                 points = config.tf_rng.uniform(  # pylint: disable=E1123
    285                                     [2, 2**self._num_bits_per_symbol],
    286                                      minval=-0.05, maxval=0.05,
    287                                     dtype=tf.as_dtype(self._dtype).real_dtype)
    288                 points  = tf.complex(points[0], points[1])
    289             else:
    290                 assert tf.rank(initial_value).numpy() == 1
    291                 assert tf.shape(initial_value)[0] == 2**num_bits_per_symbol,\
    292                     "initial_value must have shape [2**num_bits_per_symbol]"
    293                 points = tf.cast(initial_value, self._dtype)
    294         self._points = points
    295 
    296     def build(self, input_shape): #pylint: disable=unused-argument
    297         points = self._points
    298         points = tf.stack([tf.math.real(points),
    299                            tf.math.imag(points)], axis=0)
    300         if self._trainable:
    301             self._points = tf.Variable(points,
    302                                        trainable=self._trainable,
    303                                     dtype=tf.as_dtype(self._dtype).real_dtype)
    304         else:
    305             self._points = tf.constant(points,
    306                                     dtype=tf.as_dtype(self._dtype).real_dtype)
    307 
    308     # pylint: disable=no-self-argument
    309     def create_or_check_constellation(  constellation_type=None,
    310                                         num_bits_per_symbol=None,
    311                                         constellation=None,
    312                                         dtype=tf.complex64):
    313         # pylint: disable=line-too-long
    314         r"""Static method for conviently creating a constellation object or checking that an existing one
    315         is consistent with requested settings.
    316 
    317         If ``constellation`` is `None`, then this method creates a :class:`~sionna.mapping.Constellation`
    318         object of type ``constellation_type`` and with ``num_bits_per_symbol`` bits per symbol.
    319         Otherwise, this method checks that `constellation` is consistent with ``constellation_type`` and
    320         ``num_bits_per_symbol``. If it is, ``constellation`` is returned. Otherwise, an assertion is raised.
    321 
    322         Input
    323         ------
    324         constellation_type : One of ["qam", "pam", "custom"], str
    325             For "custom", an instance of :class:`~sionna.mapping.Constellation`
    326             must be provided.
    327 
    328         num_bits_per_symbol : int
    329             The number of bits per constellation symbol, e.g., 4 for QAM16.
    330             Only required for ``constellation_type`` in ["qam", "pam"].
    331 
    332         constellation :  Constellation
    333             An instance of :class:`~sionna.mapping.Constellation` or
    334             `None`. In the latter case, ``constellation_type``
    335             and ``num_bits_per_symbol`` must be provided.
    336 
    337         Output
    338         -------
    339         : :class:`~sionna.mapping.Constellation`
    340             A constellation object.
    341         """
    342         constellation_object = None
    343         if constellation is not None:
    344             assert constellation_type in [None, "custom"], \
    345                 """`constellation_type` must be "custom"."""
    346             assert num_bits_per_symbol in \
    347                      [None, constellation.num_bits_per_symbol], \
    348                 """`Wrong value of `num_bits_per_symbol.`"""
    349             assert constellation.dtype==dtype, \
    350                 "Constellation has wrong dtype."
    351             constellation_object = constellation
    352         else:
    353             assert constellation_type in ["qam", "pam"], \
    354                 "Wrong constellation type."
    355             assert num_bits_per_symbol is not None, \
    356                 "`num_bits_per_symbol` must be provided."
    357             constellation_object = Constellation(   constellation_type,
    358                                                     num_bits_per_symbol,
    359                                                     dtype=dtype)
    360         return constellation_object
    361 
    362     def call(self, inputs): #pylint: disable=unused-argument
    363         x = self._points
    364         x = tf.complex(x[0], x[1])
    365         if self._center:
    366             x = x - tf.reduce_mean(x)
    367         if self._normalize:
    368             energy = tf.reduce_mean(tf.square(tf.abs(x)))
    369             energy_sqrt = tf.complex(tf.sqrt(energy),
    370                                      tf.constant(0.,
    371                                     dtype=tf.as_dtype(self._dtype).real_dtype))
    372             x = x / energy_sqrt
    373         return x
    374 
    375     @property
    376     def normalize(self):
    377         """Indicates if the constellation is normalized or not."""
    378         return self._normalize
    379 
    380     @normalize.setter
    381     def normalize(self, value):
    382         assert isinstance(value, bool), "`normalize` must be boolean"
    383         self._normalize = value
    384 
    385     @property
    386     def center(self):
    387         """Indicates if the constellation is centered."""
    388         return self._center
    389 
    390     @center.setter
    391     def center(self, value):
    392         assert isinstance(value, bool), "`center` must be boolean"
    393         self._center = value
    394 
    395     @property
    396     def num_bits_per_symbol(self):
    397         """The number of bits per constellation symbol."""
    398         return self._num_bits_per_symbol
    399 
    400     @property
    401     def points(self):
    402         """The (possibly) centered and normalized constellation points."""
    403         return self(None)
    404 
    405     def show(self, labels=True, figsize=(7,7)):
    406         """Generate a scatter-plot of the constellation.
    407 
    408         Input
    409         -----
    410         labels : bool
    411             If `True`, the bit labels will be drawn next to each constellation
    412             point. Defaults to `True`.
    413 
    414         figsize : Two-element Tuple, float
    415             Width and height in inches. Defaults to `(7,7)`.
    416 
    417         Output
    418         ------
    419         : matplotlib.figure.Figure
    420             A handle to a matplot figure object.
    421         """
    422         maxval = np.max(np.abs(self.points))*1.05
    423         fig = plt.figure(figsize=figsize)
    424         ax = fig.add_subplot(111)
    425         plt.xlim(-maxval, maxval)
    426         plt.ylim(-maxval, maxval)
    427         plt.scatter(np.real(self.points), np.imag(self.points))
    428         ax.set_aspect("equal", adjustable="box")
    429         plt.xlabel("Real Part")
    430         plt.ylabel("Imaginary Part")
    431         plt.grid(True, which="both", axis="both")
    432         plt.title("Constellation Plot")
    433         if labels is True:
    434             for j, p in enumerate(self.points.numpy()):
    435                 plt.annotate(
    436                     np.binary_repr(j, self.num_bits_per_symbol),
    437                     (np.real(p), np.imag(p))
    438                 )
    439         return fig
    440 
    441 class Mapper(Layer):
    442     # pylint: disable=line-too-long
    443     r"""
    444     Mapper(constellation_type=None, num_bits_per_symbol=None, constellation=None, return_indices=False, dtype=tf.complex64, **kwargs)
    445 
    446     Maps binary tensors to points of a constellation.
    447 
    448     This class defines a layer that maps a tensor of binary values
    449     to a tensor of points from a provided constellation.
    450 
    451     Parameters
    452     ----------
    453     constellation_type : One of ["qam", "pam", "custom"], str
    454         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    455         must be provided.
    456 
    457     num_bits_per_symbol : int
    458         The number of bits per constellation symbol, e.g., 4 for QAM16.
    459         Only required for ``constellation_type`` in ["qam", "pam"].
    460 
    461     constellation :  Constellation
    462         An instance of :class:`~sionna.mapping.Constellation` or
    463         `None`. In the latter case, ``constellation_type``
    464         and ``num_bits_per_symbol`` must be provided.
    465 
    466     return_indices : bool
    467         If enabled, symbol indices are additionally returned.
    468         Defaults to `False`.
    469 
    470     dtype : One of [tf.complex64, tf.complex128], tf.DType
    471         The output dtype. Defaults to tf.complex64.
    472 
    473     Input
    474     -----
    475     : [..., n], tf.float or tf.int
    476         Tensor with with binary entries.
    477 
    478     Output
    479     ------
    480     : [...,n/Constellation.num_bits_per_symbol], tf.complex
    481         The mapped constellation symbols.
    482 
    483     : [...,n/Constellation.num_bits_per_symbol], tf.int32
    484         The symbol indices corresponding to the constellation symbols.
    485         Only returned if ``return_indices`` is set to True.
    486 
    487 
    488     Note
    489     ----
    490     The last input dimension must be an integer multiple of the
    491     number of bits per constellation symbol.
    492     """
    493     def __init__(self,
    494                  constellation_type=None,
    495                  num_bits_per_symbol=None,
    496                  constellation=None,
    497                  return_indices=False,
    498                  dtype=tf.complex64,
    499                  **kwargs
    500                 ):
    501         super().__init__(dtype=dtype, **kwargs)
    502         assert dtype in [tf.complex64, tf.complex128],\
    503             "dtype must be tf.complex64 or tf.complex128"
    504 
    505         # Create constellation object
    506         self._constellation = Constellation.create_or_check_constellation(
    507                                                         constellation_type,
    508                                                         num_bits_per_symbol,
    509                                                         constellation,
    510                                                         dtype=dtype)
    511 
    512         self._return_indices = return_indices
    513 
    514         self._binary_base = 2**tf.constant(
    515                         range(self.constellation.num_bits_per_symbol-1,-1,-1))
    516 
    517     @property
    518     def constellation(self):
    519         """The Constellation used by the Mapper."""
    520         return self._constellation
    521 
    522     def call(self, inputs):
    523         tf.debugging.assert_greater_equal(tf.rank(inputs), 2,
    524             message="The input must have at least rank 2")
    525 
    526         # Reshape inputs to the desired format
    527         new_shape = [-1] + inputs.shape[1:-1].as_list() + \
    528            [int(inputs.shape[-1] / self.constellation.num_bits_per_symbol),
    529             self.constellation.num_bits_per_symbol]
    530         inputs_reshaped = tf.cast(tf.reshape(inputs, new_shape), tf.int32)
    531 
    532         # Convert the last dimension to an integer
    533         int_rep = tf.reduce_sum(inputs_reshaped * self._binary_base, axis=-1)
    534 
    535         # Map integers to constellation symbols
    536         x = tf.gather(self.constellation.points, int_rep, axis=0)
    537 
    538         if self._return_indices:
    539             return x, int_rep
    540         else:
    541             return x
    542 
    543 class SymbolLogits2LLRs(Layer):
    544     # pylint: disable=line-too-long
    545     r"""
    546     SymbolLogits2LLRs(method, num_bits_per_symbol, hard_out=False, with_prior=False, dtype=tf.float32, **kwargs)
    547 
    548     Computes log-likelihood ratios (LLRs) or hard-decisions on bits
    549     from a tensor of logits (i.e., unnormalized log-probabilities) on constellation points.
    550     If the flag ``with_prior`` is set, prior knowledge on the bits is assumed to be available.
    551 
    552     Parameters
    553     ----------
    554     method : One of ["app", "maxlog"], str
    555         The method used for computing the LLRs.
    556 
    557     num_bits_per_symbol : int
    558         The number of bits per constellation symbol, e.g., 4 for QAM16.
    559 
    560     hard_out : bool
    561         If `True`, the layer provides hard-decided bits instead of soft-values.
    562         Defaults to `False`.
    563 
    564     with_prior : bool
    565         If `True`, it is assumed that prior knowledge on the bits is available.
    566         This prior information is given as LLRs as an additional input to the layer.
    567         Defaults to `False`.
    568 
    569     dtype : One of [tf.float32, tf.float64] tf.DType (dtype)
    570         The dtype for the input and output.
    571         Defaults to `tf.float32`.
    572 
    573     Input
    574     -----
    575     logits or (logits, prior):
    576         Tuple:
    577 
    578     logits : [...,n, num_points], tf.float
    579         Logits on constellation points.
    580 
    581     prior : [num_bits_per_symbol] or [...n, num_bits_per_symbol], tf.float
    582         Prior for every bit as LLRs.
    583         It can be provided either as a tensor of shape `[num_bits_per_symbol]`
    584         for the entire input batch, or as a tensor that is "broadcastable"
    585         to `[..., n, num_bits_per_symbol]`.
    586         Only required if the ``with_prior`` flag is set.
    587 
    588     Output
    589     ------
    590     : [...,n, num_bits_per_symbol], tf.float
    591         LLRs or hard-decisions for every bit.
    592 
    593     Note
    594     ----
    595     With the "app" method, the LLR for the :math:`i\text{th}` bit
    596     is computed according to
    597 
    598     .. math::
    599         LLR(i) = \ln\left(\frac{\Pr\left(b_i=1\lvert \mathbf{z},\mathbf{p}\right)}{\Pr\left(b_i=0\lvert \mathbf{z},\mathbf{p}\right)}\right) =\ln\left(\frac{
    600                 \sum_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
    601                 e^{z_c}
    602                 }{
    603                 \sum_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
    604                 e^{z_c}
    605                 }\right)
    606 
    607     where :math:`\mathcal{C}_{i,1}` and :math:`\mathcal{C}_{i,0}` are the
    608     sets of :math:`2^K` constellation points for which the :math:`i\text{th}` bit is
    609     equal to 1 and 0, respectively. :math:`\mathbf{z} = \left[z_{c_0},\dots,z_{c_{2^K-1}}\right]` is the vector of logits on the constellation points, :math:`\mathbf{p} = \left[p_0,\dots,p_{K-1}\right]`
    610     is the vector of LLRs that serves as prior knowledge on the :math:`K` bits that are mapped to
    611     a constellation point and is set to :math:`\mathbf{0}` if no prior knowledge is assumed to be available,
    612     and :math:`\Pr(c\lvert\mathbf{p})` is the prior probability on the constellation symbol :math:`c`:
    613 
    614     .. math::
    615         \Pr\left(c\lvert\mathbf{p}\right) = \prod_{k=0}^{K-1} \Pr\left(b_k = \ell(c)_k \lvert\mathbf{p} \right)
    616         = \prod_{k=0}^{K-1} \text{sigmoid}\left(p_k \ell(c)_k\right)
    617 
    618     where :math:`\ell(c)_k` is the :math:`k^{th}` bit label of :math:`c`, where 0 is
    619     replaced by -1.
    620     The definition of the LLR has been
    621     chosen such that it is equivalent with that of logits. This is
    622     different from many textbooks in communications, where the LLR is
    623     defined as :math:`LLR(i) = \ln\left(\frac{\Pr\left(b_i=0\lvert y\right)}{\Pr\left(b_i=1\lvert y\right)}\right)`.
    624 
    625     With the "maxlog" method, LLRs for the :math:`i\text{th}` bit
    626     are approximated like
    627 
    628     .. math::
    629         \begin{align}
    630             LLR(i) &\approx\ln\left(\frac{
    631                 \max_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
    632                     e^{z_c}
    633                 }{
    634                 \max_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
    635                     e^{z_c}
    636                 }\right)
    637                 .
    638         \end{align}
    639     """
    640     def __init__(self,
    641                  method,
    642                  num_bits_per_symbol,
    643                  hard_out=False,
    644                  with_prior=False,
    645                  dtype=tf.float32,
    646                  **kwargs):
    647         super().__init__(dtype=dtype, **kwargs)
    648         assert method in ("app","maxlog"), "Unknown demapping method"
    649         self._method = method
    650         self._hard_out = hard_out
    651         self._num_bits_per_symbol = num_bits_per_symbol
    652         self._with_prior = with_prior
    653         num_points = int(2**num_bits_per_symbol)
    654 
    655         # Array composed of binary representations of all symbols indices
    656         a = np.zeros([num_points, num_bits_per_symbol])
    657         for i in range(0, num_points):
    658             a[i,:] = np.array(list(np.binary_repr(i, num_bits_per_symbol)),
    659                               dtype=np.int16)
    660 
    661         # Compute symbol indices for which the bits are 0 or 1
    662         c0 = np.zeros([int(num_points/2), num_bits_per_symbol])
    663         c1 = np.zeros([int(num_points/2), num_bits_per_symbol])
    664         for i in range(num_bits_per_symbol-1,-1,-1):
    665             c0[:,i] = np.where(a[:,i]==0)[0]
    666             c1[:,i] = np.where(a[:,i]==1)[0]
    667         self._c0 = tf.constant(c0, dtype=tf.int32) # Symbols with ith bit=0
    668         self._c1 = tf.constant(c1, dtype=tf.int32) # Symbols with ith bit=1
    669 
    670         if with_prior:
    671             # Array of labels from {-1, 1} of all symbols
    672             # [num_points, num_bits_per_symbol]
    673             a = 2*a-1
    674             self._a = tf.constant(a, dtype=dtype)
    675 
    676         # Determine the reduce function for LLR computation
    677         if self._method == "app":
    678             self._reduce = tf.reduce_logsumexp
    679         else:
    680             self._reduce = tf.reduce_max
    681 
    682     @property
    683     def num_bits_per_symbol(self):
    684         return self._num_bits_per_symbol
    685 
    686     def call(self, inputs):
    687         if self._with_prior:
    688             logits, prior = inputs
    689         else:
    690             logits = inputs
    691 
    692         # Compute exponents
    693         exponents = logits
    694 
    695         # Gather exponents for all bits
    696         # shape [...,n,num_points/2,num_bits_per_symbol]
    697         exp0 = tf.gather(exponents, self._c0, axis=-1, batch_dims=0)
    698         exp1 = tf.gather(exponents, self._c1, axis=-1, batch_dims=0)
    699 
    700         # Process the prior information
    701         if self._with_prior:
    702             # Expanding `prior` such that it is broadcastable with
    703             # shape [..., n or 1, 1, num_bits_per_symbol]
    704             prior = sionna.utils.expand_to_rank(prior, tf.rank(logits), axis=0)
    705             prior = tf.expand_dims(prior, axis=-2)
    706 
    707             # Expand the symbol labeling to be broadcastable with prior
    708             # shape [..., 1, num_points, num_bits_per_symbol]
    709             a = sionna.utils.expand_to_rank(self._a, tf.rank(prior), axis=0)
    710 
    711             # Compute the prior probabilities on symbols exponents
    712             # shape [..., n or 1, num_points]
    713             exp_ps = tf.reduce_sum(tf.math.log_sigmoid(a*prior), axis=-1)
    714 
    715             # Gather prior probability symbol for all bits
    716             # shape [..., n or 1, num_points/2, num_bits_per_symbol]
    717             exp_ps0 = tf.gather(exp_ps, self._c0, axis=-1)
    718             exp_ps1 = tf.gather(exp_ps, self._c1, axis=-1)
    719 
    720         # Compute LLRs using the definition log( Pr(b=1)/Pr(b=0) )
    721         # shape [..., n, num_bits_per_symbol]
    722         if self._with_prior:
    723             llr = self._reduce(exp_ps1 + exp1, axis=-2)\
    724                     - self._reduce(exp_ps0 + exp0, axis=-2)
    725         else:
    726             llr = self._reduce(exp1, axis=-2) - self._reduce(exp0, axis=-2)
    727 
    728         if self._hard_out:
    729             return sionna.utils.hard_decisions(llr)
    730         else:
    731             return llr
    732 
    733 class SymbolLogits2LLRsWithPrior(SymbolLogits2LLRs):
    734     # pylint: disable=line-too-long
    735     r"""
    736     SymbolLogits2LLRsWithPrior(method, num_bits_per_symbol, hard_out=False, dtype=tf.float32, **kwargs)
    737 
    738     Computes log-likelihood ratios (LLRs) or hard-decisions on bits
    739     from a tensor of logits (i.e., unnormalized log-probabilities) on constellation points,
    740     assuming that prior knowledge on the bits is available.
    741 
    742     This class is deprecated as the functionality has been integrated
    743     into :class:`~sionna.mapping.SymbolLogits2LLRs`.
    744 
    745     Parameters
    746     ----------
    747     method : One of ["app", "maxlog"], str
    748         The method used for computing the LLRs.
    749 
    750     num_bits_per_symbol : int
    751         The number of bits per constellation symbol, e.g., 4 for QAM16.
    752 
    753     hard_out : bool
    754         If `True`, the layer provides hard-decided bits instead of soft-values.
    755         Defaults to `False`.
    756 
    757     dtype : One of [tf.float32, tf.float64] tf.DType (dtype)
    758         The dtype for the input and output.
    759         Defaults to `tf.float32`.
    760 
    761     Input
    762     -----
    763     (logits, prior):
    764         Tuple:
    765 
    766     logits : [...,n, num_points], tf.float
    767         Logits on constellation points.
    768 
    769     prior : [num_bits_per_symbol] or [...n, num_bits_per_symbol], tf.float
    770         Prior for every bit as LLRs.
    771         It can be provided either as a tensor of shape `[num_bits_per_symbol]` for the
    772         entire input batch, or as a tensor that is "broadcastable"
    773         to `[..., n, num_bits_per_symbol]`.
    774 
    775     Output
    776     ------
    777     : [...,n, num_bits_per_symbol], tf.float
    778         LLRs or hard-decisions for every bit.
    779 
    780     Note
    781     ----
    782     With the "app" method, the LLR for the :math:`i\text{th}` bit
    783     is computed according to
    784 
    785     .. math::
    786         LLR(i) = \ln\left(\frac{\Pr\left(b_i=1\lvert \mathbf{z},\mathbf{p}\right)}{\Pr\left(b_i=0\lvert \mathbf{z},\mathbf{p}\right)}\right) =\ln\left(\frac{
    787                 \sum_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
    788                 e^{z_c}
    789                 }{
    790                 \sum_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
    791                 e^{z_c}
    792                 }\right)
    793 
    794     where :math:`\mathcal{C}_{i,1}` and :math:`\mathcal{C}_{i,0}` are the
    795     sets of :math:`2^K` constellation points for which the :math:`i\text{th}` bit is
    796     equal to 1 and 0, respectively. :math:`\mathbf{z} = \left[z_{c_0},\dots,z_{c_{2^K-1}}\right]` is the vector of logits on the constellation points, :math:`\mathbf{p} = \left[p_0,\dots,p_{K-1}\right]`
    797     is the vector of LLRs that serves as prior knowledge on the :math:`K` bits that are mapped to
    798     a constellation point,
    799     and :math:`\Pr(c\lvert\mathbf{p})` is the prior probability on the constellation symbol :math:`c`:
    800 
    801     .. math::
    802         \Pr\left(c\lvert\mathbf{p}\right) = \prod_{k=0}^{K-1} \Pr\left(b_k = \ell(c)_k \lvert\mathbf{p} \right)
    803         = \prod_{k=0}^{K-1} \text{sigmoid}\left(p_k \ell(c)_k\right)
    804 
    805     where :math:`\ell(c)_k` is the :math:`k^{th}` bit label of :math:`c`, where 0 is
    806     replaced by -1.
    807     The definition of the LLR has been
    808     chosen such that it is equivalent with that of logits. This is
    809     different from many textbooks in communications, where the LLR is
    810     defined as :math:`LLR(i) = \ln\left(\frac{\Pr\left(b_i=0\lvert y\right)}{\Pr\left(b_i=1\lvert y\right)}\right)`.
    811 
    812     With the "maxlog" method, LLRs for the :math:`i\text{th}` bit
    813     are approximated like
    814 
    815     .. math::
    816         \begin{align}
    817             LLR(i) &\approx\ln\left(\frac{
    818                 \max_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
    819                     e^{z_c}
    820                 }{
    821                 \max_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
    822                     e^{z_c}
    823                 }\right)
    824                 .
    825         \end{align}
    826     """
    827     def __init__(self,
    828                  method,
    829                  num_bits_per_symbol,
    830                  hard_out=False,
    831                  dtype=tf.float32,
    832                  **kwargs):
    833         super().__init__(method=method,
    834                          num_bits_per_symbol=num_bits_per_symbol,
    835                          hard_out=False,
    836                          with_prior=True,
    837                          dtype=tf.float32,
    838                          **kwargs)
    839 
    840 class Demapper(Layer):
    841     # pylint: disable=line-too-long
    842     r"""
    843     Demapper(demapping_method, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, with_prior=False, dtype=tf.complex64, **kwargs)
    844 
    845     Computes log-likelihood ratios (LLRs) or hard-decisions on bits
    846     for a tensor of received symbols.
    847     If the flag ``with_prior`` is set, prior knowledge on the bits is assumed to be available.
    848 
    849     This class defines a layer implementing different demapping
    850     functions. All demapping functions are fully differentiable when soft-decisions
    851     are computed.
    852 
    853     Parameters
    854     ----------
    855     demapping_method : One of ["app", "maxlog"], str
    856         The demapping method used.
    857 
    858     constellation_type : One of ["qam", "pam", "custom"], str
    859         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    860         must be provided.
    861 
    862     num_bits_per_symbol : int
    863         The number of bits per constellation symbol, e.g., 4 for QAM16.
    864         Only required for ``constellation_type`` in ["qam", "pam"].
    865 
    866     constellation : Constellation
    867         An instance of :class:`~sionna.mapping.Constellation` or `None`.
    868         In the latter case, ``constellation_type``
    869         and ``num_bits_per_symbol`` must be provided.
    870 
    871     hard_out : bool
    872         If `True`, the demapper provides hard-decided bits instead of soft-values.
    873         Defaults to `False`.
    874 
    875     with_prior : bool
    876         If `True`, it is assumed that prior knowledge on the bits is available.
    877         This prior information is given as LLRs as an additional input to the layer.
    878         Defaults to `False`.
    879 
    880     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
    881         The dtype of `y`. Defaults to tf.complex64.
    882         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
    883 
    884     Input
    885     -----
    886     (y,no) or (y, prior, no) :
    887         Tuple:
    888 
    889     y : [...,n], tf.complex
    890         The received symbols.
    891 
    892     prior : [num_bits_per_symbol] or [...,num_bits_per_symbol], tf.float
    893         Prior for every bit as LLRs.
    894         It can be provided either as a tensor of shape `[num_bits_per_symbol]` for the
    895         entire input batch, or as a tensor that is "broadcastable"
    896         to `[..., n, num_bits_per_symbol]`.
    897         Only required if the ``with_prior`` flag is set.
    898 
    899     no : Scalar or [...,n], tf.float
    900         The noise variance estimate. It can be provided either as scalar
    901         for the entire input batch or as a tensor that is "broadcastable" to
    902         ``y``.
    903 
    904     Output
    905     ------
    906     : [...,n*num_bits_per_symbol], tf.float
    907         LLRs or hard-decisions for every bit.
    908 
    909     Note
    910     ----
    911     With the "app" demapping method, the LLR for the :math:`i\text{th}` bit
    912     is computed according to
    913 
    914     .. math::
    915         LLR(i) = \ln\left(\frac{\Pr\left(b_i=1\lvert y,\mathbf{p}\right)}{\Pr\left(b_i=0\lvert y,\mathbf{p}\right)}\right) =\ln\left(\frac{
    916                 \sum_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
    917                 \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
    918                 }{
    919                 \sum_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
    920                 \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
    921                 }\right)
    922 
    923     where :math:`\mathcal{C}_{i,1}` and :math:`\mathcal{C}_{i,0}` are the
    924     sets of constellation points for which the :math:`i\text{th}` bit is
    925     equal to 1 and 0, respectively. :math:`\mathbf{p} = \left[p_0,\dots,p_{K-1}\right]`
    926     is the vector of LLRs that serves as prior knowledge on the :math:`K` bits that are mapped to
    927     a constellation point and is set to :math:`\mathbf{0}` if no prior knowledge is assumed to be available,
    928     and :math:`\Pr(c\lvert\mathbf{p})` is the prior probability on the constellation symbol :math:`c`:
    929 
    930     .. math::
    931         \Pr\left(c\lvert\mathbf{p}\right) = \prod_{k=0}^{K-1} \text{sigmoid}\left(p_k \ell(c)_k\right)
    932 
    933     where :math:`\ell(c)_k` is the :math:`k^{th}` bit label of :math:`c`, where 0 is
    934     replaced by -1.
    935     The definition of the LLR has been
    936     chosen such that it is equivalent with that of logits. This is
    937     different from many textbooks in communications, where the LLR is
    938     defined as :math:`LLR(i) = \ln\left(\frac{\Pr\left(b_i=0\lvert y\right)}{\Pr\left(b_i=1\lvert y\right)}\right)`.
    939 
    940     With the "maxlog" demapping method, LLRs for the :math:`i\text{th}` bit
    941     are approximated like
    942 
    943     .. math::
    944         \begin{align}
    945             LLR(i) &\approx\ln\left(\frac{
    946                 \max_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
    947                     \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
    948                 }{
    949                 \max_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
    950                     \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
    951                 }\right)\\
    952                 &= \max_{c\in\mathcal{C}_{i,0}}
    953                     \left(\ln\left(\Pr\left(c\lvert\mathbf{p}\right)\right)-\frac{|y-c|^2}{N_o}\right) -
    954                  \max_{c\in\mathcal{C}_{i,1}}\left( \ln\left(\Pr\left(c\lvert\mathbf{p}\right)\right) - \frac{|y-c|^2}{N_o}\right)
    955                 .
    956         \end{align}
    957     """
    958     def __init__(self,
    959                  demapping_method,
    960                  constellation_type=None,
    961                  num_bits_per_symbol=None,
    962                  constellation=None,
    963                  hard_out=False,
    964                  with_prior=False,
    965                  dtype=tf.complex64,
    966                  **kwargs):
    967         super().__init__(dtype=dtype, **kwargs)
    968         self._with_prior = with_prior
    969 
    970 
    971         # Create constellation object
    972         self._constellation = Constellation.create_or_check_constellation(
    973                                                         constellation_type,
    974                                                         num_bits_per_symbol,
    975                                                         constellation,
    976                                                         dtype=dtype)
    977         num_bits_per_symbol = self._constellation.num_bits_per_symbol
    978 
    979         self._logits2llrs = SymbolLogits2LLRs(demapping_method,
    980                                               num_bits_per_symbol,
    981                                               hard_out,
    982                                               with_prior,
    983                                               dtype.real_dtype,
    984                                               **kwargs)
    985 
    986         self._no_threshold = tf.cast(np.finfo(dtype.as_numpy_dtype).tiny, dtype.real_dtype)
    987 
    988     @property
    989     def constellation(self):
    990         return self._constellation
    991 
    992     def call(self, inputs):
    993         if self._with_prior:
    994             y, prior, no = inputs
    995         else:
    996             y, no = inputs
    997 
    998         # Reshape constellation points to [1,...1,num_points]
    999         points_shape = [1]*y.shape.rank + self.constellation.points.shape
   1000         points = tf.reshape(self.constellation.points, points_shape)
   1001 
   1002         # Compute squared distances from y to all points
   1003         # shape [...,n,num_points]
   1004         squared_dist = tf.pow(tf.abs(tf.expand_dims(y, axis=-1) - points), 2)
   1005 
   1006         # Add a dummy dimension for broadcasting. This is not needed when no
   1007         # is a scalar, but also does not do any harm.
   1008         no = tf.expand_dims(no, axis=-1)
   1009         # Deal with zero or very small values.
   1010         no = tf.math.maximum(no, self._no_threshold)
   1011 
   1012         # Compute exponents
   1013         exponents = -squared_dist/no
   1014 
   1015         if self._with_prior:
   1016             llr = self._logits2llrs([exponents, prior])
   1017         else:
   1018             llr = self._logits2llrs(exponents)
   1019 
   1020         # Reshape LLRs to [...,n*num_bits_per_symbol]
   1021         out_shape = tf.concat([tf.shape(y)[:-1],
   1022                                [y.shape[-1] * \
   1023                                 self.constellation.num_bits_per_symbol]], 0)
   1024         llr_reshaped = tf.reshape(llr, out_shape)
   1025 
   1026         return llr_reshaped
   1027 
   1028 class DemapperWithPrior(Demapper):
   1029     # pylint: disable=line-too-long
   1030     r"""
   1031     DemapperWithPrior(demapping_method, constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
   1032 
   1033     Computes log-likelihood ratios (LLRs) or hard-decisions on bits
   1034     for a tensor of received symbols, assuming that prior knowledge on the bits is available.
   1035 
   1036     This class defines a layer implementing different demapping
   1037     functions. All demapping functions are fully differentiable when soft-decisions
   1038     are computed.
   1039 
   1040     This class is deprecated as the functionality has been integrated
   1041     into :class:`~sionna.mapping.Demapper`.
   1042 
   1043     Parameters
   1044     ----------
   1045     demapping_method : One of ["app", "maxlog"], str
   1046         The demapping method used.
   1047 
   1048     constellation_type : One of ["qam", "pam", "custom"], str
   1049         For "custom", an instance of :class:`~sionna.mapping.Constellation`
   1050         must be provided.
   1051 
   1052     num_bits_per_symbol : int
   1053         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1054         Only required for ``constellation_type`` in ["qam", "pam"].
   1055 
   1056     constellation : Constellation
   1057         An instance of :class:`~sionna.mapping.Constellation` or `None`.
   1058         In the latter case, ``constellation_type``
   1059         and ``num_bits_per_symbol`` must be provided.
   1060 
   1061     hard_out : bool
   1062         If `True`, the demapper provides hard-decided bits instead of soft-values.
   1063         Defaults to `False`.
   1064 
   1065     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1066         The dtype of `y`. Defaults to tf.complex64.
   1067         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
   1068 
   1069     Input
   1070     -----
   1071     (y, prior, no) :
   1072         Tuple:
   1073 
   1074     y : [...,n], tf.complex
   1075         The received symbols.
   1076 
   1077     prior : [num_bits_per_symbol] or [...,num_bits_per_symbol], tf.float
   1078         Prior for every bit as LLRs.
   1079         It can be provided either as a tensor of shape `[num_bits_per_symbol]` for the
   1080         entire input batch, or as a tensor that is "broadcastable"
   1081         to `[..., n, num_bits_per_symbol]`.
   1082 
   1083     no : Scalar or [...,n], tf.float
   1084         The noise variance estimate. It can be provided either as scalar
   1085         for the entire input batch or as a tensor that is "broadcastable" to
   1086         ``y``.
   1087 
   1088     Output
   1089     ------
   1090     : [...,n*num_bits_per_symbol], tf.float
   1091         LLRs or hard-decisions for every bit.
   1092 
   1093     Note
   1094     ----
   1095     With the "app" demapping method, the LLR for the :math:`i\text{th}` bit
   1096     is computed according to
   1097 
   1098     .. math::
   1099         LLR(i) = \ln\left(\frac{\Pr\left(b_i=1\lvert y,\mathbf{p}\right)}{\Pr\left(b_i=0\lvert y,\mathbf{p}\right)}\right) =\ln\left(\frac{
   1100                 \sum_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
   1101                 \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
   1102                 }{
   1103                 \sum_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
   1104                 \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
   1105                 }\right)
   1106 
   1107     where :math:`\mathcal{C}_{i,1}` and :math:`\mathcal{C}_{i,0}` are the
   1108     sets of constellation points for which the :math:`i\text{th}` bit is
   1109     equal to 1 and 0, respectively. :math:`\mathbf{p} = \left[p_0,\dots,p_{K-1}\right]`
   1110     is the vector of LLRs that serves as prior knowledge on the :math:`K` bits that are mapped to
   1111     a constellation point,
   1112     and :math:`\Pr(c\lvert\mathbf{p})` is the prior probability on the constellation symbol :math:`c`:
   1113 
   1114     .. math::
   1115         \Pr\left(c\lvert\mathbf{p}\right) = \prod_{k=0}^{K-1} \text{sigmoid}\left(p_k \ell(c)_k\right)
   1116 
   1117     where :math:`\ell(c)_k` is the :math:`k^{th}` bit label of :math:`c`, where 0 is
   1118     replaced by -1.
   1119     The definition of the LLR has been
   1120     chosen such that it is equivalent with that of logits. This is
   1121     different from many textbooks in communications, where the LLR is
   1122     defined as :math:`LLR(i) = \ln\left(\frac{\Pr\left(b_i=0\lvert y\right)}{\Pr\left(b_i=1\lvert y\right)}\right)`.
   1123 
   1124     With the "maxlog" demapping method, LLRs for the :math:`i\text{th}` bit
   1125     are approximated like
   1126 
   1127     .. math::
   1128         \begin{align}
   1129             LLR(i) &\approx\ln\left(\frac{
   1130                 \max_{c\in\mathcal{C}_{i,1}} \Pr\left(c\lvert\mathbf{p}\right)
   1131                     \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
   1132                 }{
   1133                 \max_{c\in\mathcal{C}_{i,0}} \Pr\left(c\lvert\mathbf{p}\right)
   1134                     \exp\left(-\frac{1}{N_o}\left|y-c\right|^2\right)
   1135                 }\right)\\
   1136                 &= \max_{c\in\mathcal{C}_{i,0}}
   1137                     \left(\ln\left(\Pr\left(c\lvert\mathbf{p}\right)\right)-\frac{|y-c|^2}{N_o}\right) -
   1138                  \max_{c\in\mathcal{C}_{i,1}}\left( \ln\left(\Pr\left(c\lvert\mathbf{p}\right)\right) - \frac{|y-c|^2}{N_o}\right)
   1139                 .
   1140         \end{align}
   1141     """
   1142     def __init__(self,
   1143                  demapping_method,
   1144                  constellation_type=None,
   1145                  num_bits_per_symbol=None,
   1146                  constellation=None,
   1147                  hard_out=False,
   1148                  dtype=tf.complex64,
   1149                  **kwargs):
   1150         super().__init__(demapping_method=demapping_method,
   1151                          constellation_type=constellation_type,
   1152                          num_bits_per_symbol=num_bits_per_symbol,
   1153                          constellation=constellation,
   1154                          hard_out=hard_out,
   1155                          with_prior=True,
   1156                          dtype=dtype,
   1157                          **kwargs)
   1158 
   1159 class SymbolDemapper(Layer):
   1160     # pylint: disable=line-too-long
   1161     r"""
   1162     SymbolDemapper(constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, with_prior=False, dtype=tf.complex64, **kwargs)
   1163 
   1164     Computes normalized log-probabilities (logits) or hard-decisions on symbols
   1165     for a tensor of received symbols.
   1166     If the ``with_prior`` flag is set, prior knowldge on the transmitted constellation points is assumed to be available.
   1167     The demapping function is fully differentiable when soft-values are
   1168     computed.
   1169 
   1170     Parameters
   1171     ----------
   1172     constellation_type : One of ["qam", "pam", "custom"], str
   1173         For "custom", an instance of :class:`~sionna.mapping.Constellation`
   1174         must be provided.
   1175 
   1176     num_bits_per_symbol : int
   1177         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1178         Only required for ``constellation_type`` in ["qam", "pam"].
   1179 
   1180     constellation : Constellation
   1181         An instance of :class:`~sionna.mapping.Constellation` or `None`.
   1182         In the latter case, ``constellation_type``
   1183         and ``num_bits_per_symbol`` must be provided.
   1184 
   1185     hard_out : bool
   1186         If `True`, the demapper provides hard-decided symbols instead of soft-values.
   1187         Defaults to `False`.
   1188 
   1189     with_prior : bool
   1190         If `True`, it is assumed that prior knowledge on the constellation points is available.
   1191         This prior information is given as log-probabilities (logits) as an additional input to the layer.
   1192         Defaults to `False`.
   1193 
   1194     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1195         The dtype of `y`. Defaults to tf.complex64.
   1196         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
   1197 
   1198     Input
   1199     -----
   1200     (y, no) or (y, prior, no) :
   1201         Tuple:
   1202 
   1203     y : [...,n], tf.complex
   1204         The received symbols.
   1205 
   1206     prior : [num_points] or [...,num_points], tf.float
   1207         Prior for every symbol as log-probabilities (logits).
   1208         It can be provided either as a tensor of shape `[num_points]` for the
   1209         entire input batch, or as a tensor that is "broadcastable"
   1210         to `[..., n, num_points]`.
   1211         Only required if the ``with_prior`` flag is set.
   1212 
   1213     no : Scalar or [...,n], tf.float
   1214         The noise variance estimate. It can be provided either as scalar
   1215         for the entire input batch or as a tensor that is "broadcastable" to
   1216         ``y``.
   1217 
   1218     Output
   1219     ------
   1220     : [...,n, num_points] or [...,n], tf.float
   1221         A tensor of shape `[...,n, num_points]` of logits for every constellation
   1222         point if `hard_out` is set to `False`.
   1223         Otherwise, a tensor of shape `[...,n]` of hard-decisions on the symbols.
   1224 
   1225     Note
   1226     ----
   1227     The normalized log-probability for the constellation point :math:`c` is computed according to
   1228 
   1229     .. math::
   1230         \ln\left(\Pr\left(c \lvert y,\mathbf{p}\right)\right) = \ln\left( \frac{\exp\left(-\frac{|y-c|^2}{N_0} + p_c \right)}{\sum_{c'\in\mathcal{C}} \exp\left(-\frac{|y-c'|^2}{N_0} + p_{c'} \right)} \right)
   1231 
   1232     where :math:`\mathcal{C}` is the set of constellation points used for modulation,
   1233     and :math:`\mathbf{p} = \left\{p_c \lvert c \in \mathcal{C}\right\}` the prior information on constellation points given as log-probabilities
   1234     and which is set to :math:`\mathbf{0}` if no prior information on the constellation points is assumed to be available.
   1235     """
   1236     def __init__(self,
   1237                  constellation_type=None,
   1238                  num_bits_per_symbol=None,
   1239                  constellation=None,
   1240                  hard_out=False,
   1241                  with_prior=False,
   1242                  dtype=tf.complex64,
   1243                  **kwargs):
   1244         super().__init__(dtype=dtype, **kwargs)
   1245         self._hard_out = hard_out
   1246         self._with_prior = with_prior
   1247 
   1248         # Create constellation object
   1249         self._constellation = Constellation.create_or_check_constellation(
   1250                                                         constellation_type,
   1251                                                         num_bits_per_symbol,
   1252                                                         constellation,
   1253                                                         dtype=dtype)
   1254 
   1255     def call(self, inputs):
   1256         if self._with_prior:
   1257             y, prior, no = inputs
   1258         else:
   1259             y, no = inputs
   1260 
   1261         points = sionna.utils.expand_to_rank(self._constellation.points,
   1262                                              tf.rank(y)+1, axis=0)
   1263         y = tf.expand_dims(y, axis=-1)
   1264         d = tf.abs(y-points)
   1265 
   1266         no = sionna.utils.expand_to_rank(no, tf.rank(d), axis=-1)
   1267         exp = -d**2 / no
   1268 
   1269         if self._with_prior:
   1270             prior = sionna.utils.expand_to_rank(prior, tf.rank(exp), axis=0)
   1271             exp = exp + prior
   1272 
   1273         if self._hard_out:
   1274             return tf.argmax(exp, axis=-1, output_type=tf.int32)
   1275         else:
   1276             return tf.nn.log_softmax(exp, axis=-1)
   1277 
   1278 class SymbolDemapperWithPrior(SymbolDemapper):
   1279     # pylint: disable=line-too-long
   1280     r"""
   1281     SymbolDemapperWithPrior(constellation_type=None, num_bits_per_symbol=None, constellation=None, hard_out=False, dtype=tf.complex64, **kwargs)
   1282 
   1283     Computes normalized log-probabilities (logits) or hard-decisions on symbols
   1284     for a tensor of received symbols, assuming that prior knowledge on the constellation points is available.
   1285     The demapping function is fully differentiable when soft-values are
   1286     computed.
   1287 
   1288     This class is deprecated as the functionality has been integrated
   1289     into :class:`~sionna.mapping.SymbolDemapper`.
   1290 
   1291     Parameters
   1292     ----------
   1293     constellation_type : One of ["qam", "pam", "custom"], str
   1294         For "custom", an instance of :class:`~sionna.mapping.Constellation`
   1295         must be provided.
   1296 
   1297     num_bits_per_symbol : int
   1298         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1299         Only required for ``constellation_type`` in ["qam", "pam"].
   1300 
   1301     constellation : Constellation
   1302         An instance of :class:`~sionna.mapping.Constellation` or `None`.
   1303         In the latter case, ``constellation_type``
   1304         and ``num_bits_per_symbol`` must be provided.
   1305 
   1306     hard_out : bool
   1307         If `True`, the demapper provides hard-decided symbols instead of soft-values.
   1308         Defaults to `False`.
   1309 
   1310     dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
   1311         The dtype of `y`. Defaults to tf.complex64.
   1312         The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
   1313 
   1314     Input
   1315     -----
   1316     (y, prior, no) :
   1317         Tuple:
   1318 
   1319     y : [...,n], tf.complex
   1320         The received symbols.
   1321 
   1322     prior : [num_points] or [...,num_points], tf.float
   1323         Prior for every symbol as log-probabilities (logits).
   1324         It can be provided either as a tensor of shape `[num_points]` for the
   1325         entire input batch, or as a tensor that is "broadcastable"
   1326         to `[..., n, num_points]`.
   1327 
   1328     no : Scalar or [...,n], tf.float
   1329         The noise variance estimate. It can be provided either as scalar
   1330         for the entire input batch or as a tensor that is "broadcastable" to
   1331         ``y``.
   1332 
   1333     Output
   1334     ------
   1335     : [...,n, num_points] or [...,n], tf.float
   1336         A tensor of shape `[...,n, num_points]` of logits for every constellation
   1337         point if `hard_out` is set to `False`.
   1338         Otherwise, a tensor of shape `[...,n]` of hard-decisions on the symbols.
   1339 
   1340     Note
   1341     ----
   1342     The normalized log-probability for the constellation point :math:`c` is computed according to
   1343 
   1344     .. math::
   1345         \ln\left(\Pr\left(c \lvert y,\mathbf{p}\right)\right) = \ln\left( \frac{\exp\left(-\frac{|y-c|^2}{N_0} + p_c \right)}{\sum_{c'\in\mathcal{C}} \exp\left(-\frac{|y-c'|^2}{N_0} + p_{c'} \right)} \right)
   1346 
   1347     where :math:`\mathcal{C}` is the set of constellation points used for modulation,
   1348     and :math:`\mathbf{p} = \left\{p_c \lvert c \in \mathcal{C}\right\}` the prior information on constellation points given as log-probabilities.
   1349     """
   1350     def __init__(self,
   1351                  constellation_type=None,
   1352                  num_bits_per_symbol=None,
   1353                  constellation=None,
   1354                  hard_out=False,
   1355                  dtype=tf.complex64,
   1356                  **kwargs):
   1357         super().__init__(constellation_type=constellation_type,
   1358                          num_bits_per_symbol=num_bits_per_symbol,
   1359                          constellation=constellation,
   1360                          hard_out=hard_out,
   1361                          with_prior=True,
   1362                          dtype=dtype,
   1363                          **kwargs)
   1364 
   1365 class LLRs2SymbolLogits(Layer):
   1366     # pylint: disable=line-too-long
   1367     r"""
   1368     LLRs2SymbolLogits(num_bits_per_symbol, hard_out=False, dtype=tf.float32, **kwargs)
   1369 
   1370     Computes logits (i.e., unnormalized log-probabilities) or hard decisions
   1371     on constellation points from a tensor of log-likelihood ratios (LLRs) on bits.
   1372 
   1373     Parameters
   1374     ----------
   1375     num_bits_per_symbol : int
   1376         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1377 
   1378     hard_out : bool
   1379         If `True`, the layer provides hard-decided constellation points instead of soft-values.
   1380         Defaults to `False`.
   1381 
   1382     dtype : One of [tf.float32, tf.float64] tf.DType (dtype)
   1383         The dtype for the input and output.
   1384         Defaults to `tf.float32`.
   1385 
   1386     Input
   1387     -----
   1388     llrs : [..., n, num_bits_per_symbol], tf.float
   1389         LLRs for every bit.
   1390 
   1391     Output
   1392     ------
   1393     : [...,n, num_points], tf.float or [..., n], tf.int32
   1394         Logits or hard-decisions on constellation points.
   1395 
   1396     Note
   1397     ----
   1398     The logit for the constellation :math:`c` point
   1399     is computed according to
   1400 
   1401     .. math::
   1402         \begin{align}
   1403             \log{\left(\Pr\left(c\lvert LLRs \right)\right)}
   1404                 &= \log{\left(\prod_{k=0}^{K-1} \Pr\left(b_k = \ell(c)_k \lvert LLRs \right)\right)}\\
   1405                 &= \log{\left(\prod_{k=0}^{K-1} \text{sigmoid}\left(LLR(k) \ell(c)_k\right)\right)}\\
   1406                 &= \sum_{k=0}^{K-1} \log{\left(\text{sigmoid}\left(LLR(k) \ell(c)_k\right)\right)}
   1407         \end{align}
   1408 
   1409     where :math:`\ell(c)_k` is the :math:`k^{th}` bit label of :math:`c`, where 0 is
   1410     replaced by -1.
   1411     The definition of the LLR has been
   1412     chosen such that it is equivalent with that of logits. This is
   1413     different from many textbooks in communications, where the LLR is
   1414     defined as :math:`LLR(i) = \ln\left(\frac{\Pr\left(b_i=0\lvert y\right)}{\Pr\left(b_i=1\lvert y\right)}\right)`.
   1415     """
   1416 
   1417     def __init__(self,
   1418                  num_bits_per_symbol,
   1419                  hard_out=False,
   1420                  dtype=tf.float32,
   1421                  **kwargs):
   1422         super().__init__(dtype=dtype, **kwargs)
   1423 
   1424         self._hard_out = hard_out
   1425         self._num_bits_per_symbol = num_bits_per_symbol
   1426         num_points = int(2**num_bits_per_symbol)
   1427 
   1428         # Array composed of binary representations of all symbols indices
   1429         a = np.zeros([num_points, num_bits_per_symbol])
   1430         for i in range(0, num_points):
   1431             a[i,:] = np.array(list(np.binary_repr(i, num_bits_per_symbol)),
   1432                               dtype=np.int16)
   1433 
   1434         # Array of labels from {-1, 1} of all symbols
   1435         # [num_points, num_bits_per_symbol]
   1436         a = 2*a-1
   1437         self._a = tf.constant(a, dtype=dtype)
   1438 
   1439     @property
   1440     def num_bits_per_symbol(self):
   1441         return self._num_bits_per_symbol
   1442 
   1443     def call(self, inputs):
   1444         llrs = inputs
   1445 
   1446         # Expand the symbol labeling to be broadcastable with prior
   1447         # shape [1, ..., 1, num_points, num_bits_per_symbol]
   1448         a = sionna.utils.expand_to_rank(self._a, tf.rank(llrs), axis=0)
   1449 
   1450         # Compute the prior probabilities on symbols exponents
   1451         # shape [..., 1, num_points]
   1452         llrs = tf.expand_dims(llrs, axis=-2)
   1453         logits = tf.reduce_sum(tf.math.log_sigmoid(a*llrs), axis=-1)
   1454 
   1455         if self._hard_out:
   1456             return tf.argmax(logits, axis=-1, output_type=tf.int32)
   1457         else:
   1458             return logits
   1459 
   1460 class SymbolLogits2Moments(Layer):
   1461     # pylint: disable=line-too-long
   1462     r"""
   1463     SymbolLogits2Moments(constellation_type=None, num_bits_per_symbol=None, constellation=None, dtype=tf.float32, **kwargs)
   1464 
   1465     Computes the mean and variance of a constellation from logits (unnormalized log-probabilities) on the
   1466     constellation points.
   1467 
   1468     More precisely, given a constellation :math:`\mathcal{C} = \left[ c_0,\dots,c_{N-1} \right]` of size :math:`N`, this layer computes the mean and variance
   1469     according to
   1470 
   1471     .. math::
   1472         \begin{align}
   1473             \mu &= \sum_{n = 0}^{N-1} c_n \Pr \left(c_n \lvert \mathbf{\ell} \right)\\
   1474             \nu &= \sum_{n = 0}^{N-1} \left( c_n - \mu \right)^2 \Pr \left(c_n \lvert \mathbf{\ell} \right)
   1475         \end{align}
   1476 
   1477 
   1478     where :math:`\mathbf{\ell} = \left[ \ell_0, \dots, \ell_{N-1} \right]` are the logits, and
   1479 
   1480     .. math::
   1481         \Pr \left(c_n \lvert \mathbf{\ell} \right) = \frac{\exp \left( \ell_n \right)}{\sum_{i=0}^{N-1} \exp \left( \ell_i \right) }.
   1482 
   1483     Parameters
   1484     ----------
   1485     constellation_type : One of ["qam", "pam", "custom"], str
   1486         For "custom", an instance of :class:`~sionna.mapping.Constellation`
   1487         must be provided.
   1488 
   1489     num_bits_per_symbol : int
   1490         The number of bits per constellation symbol, e.g., 4 for QAM16.
   1491         Only required for ``constellation_type`` in ["qam", "pam"].
   1492 
   1493     constellation : Constellation
   1494         An instance of :class:`~sionna.mapping.Constellation` or `None`.
   1495         In the latter case, ``constellation_type``
   1496         and ``num_bits_per_symbol`` must be provided.
   1497 
   1498     dtype : One of [tf.float32, tf.float64] tf.DType (dtype)
   1499         The dtype for the input and output.
   1500         Defaults to `tf.float32`.
   1501 
   1502     Input
   1503     -----
   1504     logits : [...,n, num_points], tf.float
   1505         Logits on constellation points.
   1506 
   1507     Output
   1508     ------
   1509     mean : [...,n], tf.float
   1510         Mean of the constellation.
   1511 
   1512     var : [...,n], tf.float
   1513         Variance of the constellation
   1514     """
   1515     def __init__(self,
   1516                  constellation_type=None,
   1517                  num_bits_per_symbol=None,
   1518                  constellation=None,
   1519                  dtype=tf.float32,
   1520                  **kwargs):
   1521         super().__init__(dtype=dtype, **kwargs)
   1522 
   1523         # Create constellation object
   1524         const_dtype = tf.complex64 if dtype is tf.float32 else tf.complex128
   1525         self._constellation = Constellation.create_or_check_constellation(
   1526                                                         constellation_type,
   1527                                                         num_bits_per_symbol,
   1528                                                         constellation,
   1529                                                         dtype=const_dtype)
   1530 
   1531     def __call__(self, logits):
   1532         p = tf.math.softmax(logits, axis=-1)
   1533         p_c = tf.complex(p, tf.cast(0.0, self.dtype))
   1534         points = self._constellation.points
   1535         points = sionna.utils.expand_to_rank(points, tf.rank(p), axis=0)
   1536 
   1537         mean = tf.reduce_sum(p_c*points, axis=-1, keepdims=True)
   1538         var = tf.reduce_sum(p*tf.square(tf.abs(points - mean)), axis=-1)
   1539         mean = tf.squeeze(mean, axis=-1)
   1540 
   1541         return mean, var
   1542 
   1543 class QAM2PAM:
   1544     r"""Transforms QAM symbol indices to PAM symbol indices.
   1545 
   1546     For indices in a QAM constellation, computes the corresponding indices
   1547     for the two PAM constellations corresponding the real and imaginary
   1548     components of the QAM constellation.
   1549 
   1550     Parameters
   1551     ----------
   1552     num_bits_per_symbol : int
   1553         The number of bits per QAM constellation symbol, e.g., 4 for QAM16.
   1554 
   1555     Input
   1556     -----
   1557     ind_qam : Tensor, tf.int
   1558         Indices in the QAM constellation
   1559 
   1560     Output
   1561     -------
   1562     ind_pam1 : Tensor, tf.int
   1563         Indices for the first component of the corresponding PAM modulation
   1564 
   1565     ind_pam2 : Tensor, tf.int
   1566         Indices for the first component of the corresponding PAM modulation
   1567     """
   1568     def __init__(self, num_bits_per_symbol):
   1569         base = [2**i for i in range(num_bits_per_symbol//2-1, -1, -1)]
   1570         base = np.array(base)
   1571         pam1_ind = np.zeros([2**num_bits_per_symbol], dtype=np.int32)
   1572         pam2_ind = np.zeros([2**num_bits_per_symbol], dtype=np.int32)
   1573         for i in range(0, 2**num_bits_per_symbol):
   1574             b = np.array(list(np.binary_repr(i,num_bits_per_symbol)),
   1575                          dtype=np.int32)
   1576             pam1_ind[i] = np.sum(b[0::2]*base)
   1577             pam2_ind[i] = np.sum(b[1::2]*base)
   1578         self._pam1_ind = tf.constant(pam1_ind, dtype=tf.int32)
   1579         self._pam2_ind = tf.constant(pam2_ind, dtype=tf.int32)
   1580 
   1581     def __call__(self, ind_qam):
   1582 
   1583         ind_pam1 = tf.gather(self._pam1_ind, ind_qam, axis=0)
   1584         ind_pam2 = tf.gather(self._pam2_ind, ind_qam, axis=0)
   1585 
   1586         return ind_pam1, ind_pam2
   1587 
   1588 class PAM2QAM:
   1589     r"""Transforms PAM symbol indices/logits to QAM symbol indices/logits.
   1590 
   1591     For two PAM constellation symbol indices or logits, corresponding to
   1592     the real and imaginary components of a QAM constellation,
   1593     compute the QAM symbol index or logits.
   1594 
   1595     Parameters
   1596     ----------
   1597     num_bits_per_symbol : int
   1598         Number of bits per QAM constellation symbol, e.g., 4 for QAM16
   1599 
   1600     hard_in_out : bool
   1601         Determines if inputs and outputs are indices or logits over
   1602         constellation symbols.
   1603         Defaults to `True`.
   1604 
   1605     Input
   1606     -----
   1607     pam1 : Tensor, tf.int, or [...,2**(num_bits_per_symbol/2)], tf.float
   1608         Indices or logits for the first PAM constellation
   1609 
   1610     pam2 : Tensor, tf.int, or [...,2**(num_bits_per_symbol/2)], tf.float
   1611         Indices or logits for the second PAM constellation
   1612 
   1613     Output
   1614     -------
   1615     qam : Tensor, tf.int, or [...,2**num_bits_per_symbol], tf.float
   1616         Indices or logits for the corresponding QAM constellation
   1617     """
   1618     def __init__(self, num_bits_per_symbol, hard_in_out=True):
   1619         num_pam_symbols = 2**(num_bits_per_symbol//2)
   1620         base = np.array([2**i for i in range(num_bits_per_symbol-1, -1, -1)])
   1621 
   1622         # Create an array of QAM symbol indices, index by two PAM indices
   1623         ind = np.zeros([num_pam_symbols, num_pam_symbols], np.int32)
   1624         for i in range(0, num_pam_symbols):
   1625             for j in range(0, num_pam_symbols):
   1626                 b1 = np.array(list(np.binary_repr(i,num_bits_per_symbol//2)),
   1627                               dtype=np.int16)
   1628                 b2 = np.array(list(np.binary_repr(j,num_bits_per_symbol//2)),
   1629                               dtype=np.int16)
   1630                 b = np.zeros([num_bits_per_symbol], np.int32)
   1631                 b[0::2] = b1
   1632                 b[1::2] = b2
   1633                 ind[i, j] = np.sum(b*base)
   1634         self._qam_ind = tf.constant(ind, dtype=tf.int32)
   1635         self._hard_in_out = hard_in_out
   1636 
   1637     def __call__(self, pam1, pam2):
   1638 
   1639         # PAM indices to QAM indices
   1640         if self._hard_in_out:
   1641             shape = tf.shape(pam1)
   1642             ind_pam1 = tf.reshape(pam1, [-1, 1])
   1643             ind_pam2 = tf.reshape(pam2, [-1, 1])
   1644             ind_pam = tf.concat([ind_pam1, ind_pam2], axis=-1)
   1645             ind_qam = tf.gather_nd(self._qam_ind, ind_pam)
   1646             ind_qam = tf.reshape(ind_qam, shape)
   1647             return ind_qam
   1648 
   1649         # PAM logits to QAM logits
   1650         else:
   1651             # Compute all combination of sums of logits
   1652             logits_mat = tf.expand_dims(pam1, -1) + tf.expand_dims(pam2, -2)
   1653 
   1654             # Flatten to a vector
   1655             logits = sionna.utils.flatten_last_dims(logits_mat)
   1656 
   1657             # Gather symbols in the correct order
   1658             gather_ind = tf.reshape(self._qam_ind, [-1])
   1659             logits = tf.gather(logits, gather_ind, axis=-1)
   1660             return logits
   1661 
   1662 class SymbolInds2Bits(Layer):
   1663     # pylint: disable=line-too-long
   1664     r"""SymbolInds2Bits(num_bits_per_symbol, dtype=tf.float32, **kwargs)
   1665 
   1666     Transforms symbol indices to their binary representations.
   1667 
   1668     Parameters
   1669     ----------
   1670     num_bits_per_symbol : int
   1671         Number of bits per constellation symbol
   1672 
   1673     dtype: tf.DType
   1674         Output dtype. Defaults to `tf.float32`.
   1675 
   1676     Input
   1677     -----
   1678     : Tensor, tf.int
   1679         Symbol indices
   1680 
   1681     Output
   1682     -----
   1683     : input.shape + [num_bits_per_symbol], dtype
   1684         Binary representation of symbol indices
   1685     """
   1686     def __init__(self,
   1687                num_bits_per_symbol,
   1688                dtype=tf.float32,
   1689                **kwargs):
   1690         super().__init__(dtype=dtype, **kwargs)
   1691         num_symbols = 2**num_bits_per_symbol
   1692         b = np.zeros([num_symbols, num_bits_per_symbol])
   1693         for i in range(0, num_symbols):
   1694             b[i,:] = np.array(list(np.binary_repr(i, num_bits_per_symbol)),
   1695                               dtype=np.int16)
   1696         self._bit_labels = tf.constant(b, self.dtype)
   1697 
   1698     def call(self, inputs):
   1699         symbol_ind = inputs
   1700         return tf.gather(self._bit_labels, symbol_ind)