anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

misc.py (38120B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Miscellaneous utility functions of the Sionna package."""
      6 
      7 import time
      8 import numpy as np
      9 import tensorflow as tf
     10 from tensorflow.keras.layers import Layer
     11 from tensorflow.experimental.numpy import log10 as _log10
     12 from tensorflow.experimental.numpy import log2 as _log2
     13 import sionna
     14 from sionna.utils.metrics import count_errors, count_block_errors
     15 from sionna.mapping import Mapper, Constellation
     16 from sionna import signal
     17 
     18 def ebnodb2no(ebno_db, num_bits_per_symbol, coderate, resource_grid=None):
     19     r"""Compute the noise variance `No` for a given `Eb/No` in dB.
     20 
     21     The function takes into account the number of coded bits per constellation
     22     symbol, the coderate, as well as possible additional overheads related to
     23     OFDM transmissions, such as the cyclic prefix and pilots.
     24 
     25     The value of `No` is computed according to the following expression
     26 
     27     .. math::
     28         N_o = \left(\frac{E_b}{N_o} \frac{r M}{E_s}\right)^{-1}
     29 
     30     where :math:`2^M` is the constellation size, i.e., :math:`M` is the
     31     average number of coded bits per constellation symbol,
     32     :math:`E_s=1` is the average energy per constellation per symbol,
     33     :math:`r\in(0,1]` is the coderate,
     34     :math:`E_b` is the energy per information bit,
     35     and :math:`N_o` is the noise power spectral density.
     36     For OFDM transmissions, :math:`E_s` is scaled
     37     according to the ratio between the total number of resource elements in
     38     a resource grid with non-zero energy and the number
     39     of resource elements used for data transmission. Also the additionally
     40     transmitted energy during the cyclic prefix is taken into account, as
     41     well as the number of transmitted streams per transmitter.
     42 
     43     Input
     44     -----
     45     ebno_db : float
     46         The `Eb/No` value in dB.
     47 
     48     num_bits_per_symbol : int
     49         The number of bits per symbol.
     50 
     51     coderate : float
     52         The coderate used.
     53 
     54     resource_grid : ResourceGrid
     55         An (optional) instance of :class:`~sionna.ofdm.ResourceGrid`
     56         for OFDM transmissions.
     57 
     58     Output
     59     ------
     60     : float
     61         The value of :math:`N_o` in linear scale.
     62     """
     63 
     64     if tf.is_tensor(ebno_db):
     65         dtype = ebno_db.dtype
     66     else:
     67         dtype = tf.float32
     68 
     69     ebno = tf.math.pow(tf.cast(10., dtype), ebno_db/10.)
     70 
     71     energy_per_symbol = 1
     72     if resource_grid is not None:
     73         # Divide energy per symbol by the number of transmitted streams
     74         energy_per_symbol /= resource_grid.num_streams_per_tx
     75 
     76         # Number of nonzero energy symbols.
     77         # We do not account for the nulled DC and guard carriers.
     78         cp_overhead = resource_grid.cyclic_prefix_length \
     79                       / resource_grid.fft_size
     80         num_syms = resource_grid.num_ofdm_symbols * (1 + cp_overhead) \
     81                     * resource_grid.num_effective_subcarriers
     82         energy_per_symbol *= num_syms / resource_grid.num_data_symbols
     83 
     84     no = 1/(ebno * coderate * tf.cast(num_bits_per_symbol, dtype) \
     85           / tf.cast(energy_per_symbol, dtype))
     86 
     87     return no
     88 
     89 def hard_decisions(llr):
     90     """Transforms LLRs into hard decisions.
     91 
     92     Positive values are mapped to :math:`1`.
     93     Nonpositive values are mapped to :math:`0`.
     94 
     95     Input
     96     -----
     97     llr : any non-complex tf.DType
     98         Tensor of LLRs.
     99 
    100     Output
    101     ------
    102     : Same shape and dtype as ``llr``
    103         The hard decisions.
    104     """
    105     zero = tf.constant(0, dtype=llr.dtype)
    106 
    107     return tf.cast(tf.math.greater(llr, zero), dtype=llr.dtype)
    108 
    109 def log10(x):
    110     # pylint: disable=C0301
    111     """TensorFlow implementation of NumPy's `log10` function.
    112 
    113     Simple extension to `tf.experimental.numpy.log10`
    114     which casts the result to the `dtype` of the input.
    115     For more details see the `TensorFlow <https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/log10>`__ and `NumPy <https://numpy.org/doc/1.16/reference/generated/numpy.log10.html>`__ documentation.
    116     """
    117     return tf.cast(_log10(x), x.dtype)
    118 
    119 def log2(x):
    120     # pylint: disable=C0301
    121     """TensorFlow implementation of NumPy's `log2` function.
    122 
    123     Simple extension to `tf.experimental.numpy.log2`
    124     which casts the result to the `dtype` of the input.
    125     For more details see the `TensorFlow <https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/log2>`_ and `NumPy <https://numpy.org/doc/1.16/reference/generated/numpy.log2.html>`_ documentation.
    126     """
    127     return tf.cast(_log2(x), x.dtype)
    128 
    129 class BinarySource(Layer):
    130     """BinarySource(dtype=tf.float32, seed=None, **kwargs)
    131 
    132     Layer generating random binary tensors.
    133 
    134     Parameters
    135     ----------
    136     dtype : tf.DType
    137         Defines the output datatype of the layer.
    138         Defaults to `tf.float32`.
    139 
    140     seed : int or None
    141         Set the seed for the random generator used to generate the bits.
    142         Set to `None` for random initialization of the RNG.
    143 
    144     Input
    145     -----
    146     shape : 1D tensor/array/list, int
    147         The desired shape of the output tensor.
    148 
    149     Output
    150     ------
    151     : ``shape``, ``dtype``
    152         Tensor filled with random binary values.
    153     """
    154     def __init__(self, dtype=tf.float32, seed=None, **kwargs):
    155         super().__init__(dtype=dtype, **kwargs)
    156         self._seed = seed
    157         if self._seed is not None:
    158             self._rng = tf.random.Generator.from_seed(self._seed)
    159         else:
    160             self._rng = sionna.config.tf_rng
    161 
    162     def call(self, inputs):
    163         return tf.cast(self._rng.uniform(inputs, 0, 2, tf.int32),
    164                        dtype=super().dtype)
    165 
    166 class SymbolSource(Layer):
    167     # pylint: disable=line-too-long
    168     r"""SymbolSource(constellation_type=None, num_bits_per_symbol=None, constellation=None, return_indices=False, return_bits=False, seed=None, dtype=tf.complex64, **kwargs)
    169 
    170     Layer generating a tensor of arbitrary shape filled with random constellation symbols.
    171     Optionally, the symbol indices and/or binary representations of the
    172     constellation symbols can be returned.
    173 
    174     Parameters
    175     ----------
    176     constellation_type : One of ["qam", "pam", "custom"], str
    177         For "custom", an instance of :class:`~sionna.mapping.Constellation`
    178         must be provided.
    179 
    180     num_bits_per_symbol : int
    181         The number of bits per constellation symbol.
    182         Only required for ``constellation_type`` in ["qam", "pam"].
    183 
    184     constellation :  Constellation
    185         An instance of :class:`~sionna.mapping.Constellation` or
    186         `None`. In the latter case, ``constellation_type``
    187         and ``num_bits_per_symbol`` must be provided.
    188 
    189     return_indices : bool
    190         If enabled, the function also returns the symbol indices.
    191         Defaults to `False`.
    192 
    193     return_bits : bool
    194         If enabled, the function also returns the binary symbol
    195         representations (i.e., bit labels).
    196         Defaults to `False`.
    197 
    198     seed : int or None
    199         The seed for the random generator.
    200         `None` leads to a random initialization of the RNG.
    201         Defaults to `None`.
    202 
    203     dtype : One of [tf.complex64, tf.complex128], tf.DType
    204         The output dtype. Defaults to tf.complex64.
    205 
    206     Input
    207     -----
    208     shape : 1D tensor/array/list, int
    209         The desired shape of the output tensor.
    210 
    211     Output
    212     ------
    213     symbols : ``shape``, ``dtype``
    214         Tensor filled with random symbols of the chosen ``constellation_type``.
    215 
    216     symbol_indices : ``shape``, tf.int32
    217         Tensor filled with the symbol indices.
    218         Only returned if ``return_indices`` is `True`.
    219 
    220     bits : [``shape``, ``num_bits_per_symbol``], tf.int32
    221         Tensor filled with the binary symbol representations (i.e., bit labels).
    222         Only returned if ``return_bits`` is `True`.
    223     """
    224     def __init__(self,
    225                  constellation_type=None,
    226                  num_bits_per_symbol=None,
    227                  constellation=None,
    228                  return_indices=False,
    229                  return_bits=False,
    230                  seed=None,
    231                  dtype=tf.complex64,
    232                  **kwargs
    233                 ):
    234         super().__init__(dtype=dtype, **kwargs)
    235         constellation = Constellation.create_or_check_constellation(
    236             constellation_type,
    237             num_bits_per_symbol,
    238             constellation,
    239             dtype)
    240         self._num_bits_per_symbol = constellation.num_bits_per_symbol
    241         self._return_indices = return_indices
    242         self._return_bits = return_bits
    243         self._binary_source = BinarySource(seed=seed, dtype=dtype.real_dtype)
    244         self._mapper = Mapper(constellation=constellation,
    245                               return_indices=return_indices,
    246                               dtype=dtype)
    247 
    248     def call(self, inputs):
    249         shape = tf.concat([inputs, [self._num_bits_per_symbol]], axis=-1)
    250         b = self._binary_source(tf.cast(shape, tf.int32))
    251         if self._return_indices:
    252             x, ind = self._mapper(b)
    253         else:
    254             x = self._mapper(b)
    255 
    256         result = tf.squeeze(x, -1)
    257         if self._return_indices or self._return_bits:
    258             result = [result]
    259         if self._return_indices:
    260             result.append(tf.squeeze(ind, -1))
    261         if self._return_bits:
    262             result.append(b)
    263 
    264         return result
    265 
    266 class QAMSource(SymbolSource):
    267     # pylint: disable=line-too-long
    268     r"""QAMSource(num_bits_per_symbol=None, return_indices=False, return_bits=False, seed=None, dtype=tf.complex64, **kwargs)
    269 
    270     Layer generating a tensor of arbitrary shape filled with random QAM symbols.
    271     Optionally, the symbol indices and/or binary representations of the
    272     constellation symbols can be returned.
    273 
    274     Parameters
    275     ----------
    276     num_bits_per_symbol : int
    277         The number of bits per constellation symbol, e.g., 4 for QAM16.
    278 
    279     return_indices : bool
    280         If enabled, the function also returns the symbol indices.
    281         Defaults to `False`.
    282 
    283     return_bits : bool
    284         If enabled, the function also returns the binary symbol
    285         representations (i.e., bit labels).
    286         Defaults to `False`.
    287 
    288     seed : int or None
    289         The seed for the random generator.
    290         `None` leads to a random initialization of the RNG.
    291         Defaults to `None`.
    292 
    293     dtype : One of [tf.complex64, tf.complex128], tf.DType
    294         The output dtype. Defaults to tf.complex64.
    295 
    296     Input
    297     -----
    298     shape : 1D tensor/array/list, int
    299         The desired shape of the output tensor.
    300 
    301     Output
    302     ------
    303     symbols : ``shape``, ``dtype``
    304         Tensor filled with random QAM symbols.
    305 
    306     symbol_indices : ``shape``, tf.int32
    307         Tensor filled with the symbol indices.
    308         Only returned if ``return_indices`` is `True`.
    309 
    310     bits : [``shape``, ``num_bits_per_symbol``], tf.int32
    311         Tensor filled with the binary symbol representations (i.e., bit labels).
    312         Only returned if ``return_bits`` is `True`.
    313     """
    314     def __init__(self,
    315                  num_bits_per_symbol=None,
    316                  return_indices=False,
    317                  return_bits=False,
    318                  seed=None,
    319                  dtype=tf.complex64,
    320                  **kwargs
    321                 ):
    322         super().__init__(constellation_type="qam",
    323                          num_bits_per_symbol=num_bits_per_symbol,
    324                          return_indices=return_indices,
    325                          return_bits=return_bits,
    326                          seed=seed,
    327                          dtype=dtype,
    328                          **kwargs)
    329 
    330 class PAMSource(SymbolSource):
    331     # pylint: disable=line-too-long
    332     r"""PAMSource(num_bits_per_symbol=None, return_indices=False, return_bits=False, seed=None, dtype=tf.complex64, **kwargs)
    333 
    334     Layer generating a tensor of arbitrary shape filled with random PAM symbols.
    335     Optionally, the symbol indices and/or binary representations of the
    336     constellation symbols can be returned.
    337 
    338     Parameters
    339     ----------
    340     num_bits_per_symbol : int
    341         The number of bits per constellation symbol, e.g., 1 for BPSK.
    342 
    343     return_indices : bool
    344         If enabled, the function also returns the symbol indices.
    345         Defaults to `False`.
    346 
    347     return_bits : bool
    348         If enabled, the function also returns the binary symbol
    349         representations (i.e., bit labels).
    350         Defaults to `False`.
    351 
    352     seed : int or None
    353         The seed for the random generator.
    354         `None` leads to a random initialization of the RNG.
    355         Defaults to `None`.
    356 
    357     dtype : One of [tf.complex64, tf.complex128], tf.DType
    358         The output dtype. Defaults to tf.complex64.
    359 
    360     Input
    361     -----
    362     shape : 1D tensor/array/list, int
    363         The desired shape of the output tensor.
    364 
    365     Output
    366     ------
    367     symbols : ``shape``, ``dtype``
    368         Tensor filled with random PAM symbols.
    369 
    370     symbol_indices : ``shape``, tf.int32
    371         Tensor filled with the symbol indices.
    372         Only returned if ``return_indices`` is `True`.
    373 
    374     bits : [``shape``, ``num_bits_per_symbol``], tf.int32
    375         Tensor filled with the binary symbol representations (i.e., bit labels).
    376         Only returned if ``return_bits`` is `True`.
    377     """
    378     def __init__(self,
    379                  num_bits_per_symbol=None,
    380                  return_indices=False,
    381                  return_bits=False,
    382                  seed=None,
    383                  dtype=tf.complex64,
    384                  **kwargs
    385                 ):
    386         super().__init__(constellation_type="pam",
    387                          num_bits_per_symbol=num_bits_per_symbol,
    388                          return_indices=return_indices,
    389                          return_bits=return_bits,
    390                          seed=seed,
    391                          dtype=dtype,
    392                          **kwargs)
    393 
    394 def sim_ber(mc_fun,
    395             ebno_dbs,
    396             batch_size,
    397             max_mc_iter,
    398             soft_estimates=False,
    399             num_target_bit_errors=None,
    400             num_target_block_errors=None,
    401             target_ber=None,
    402             target_bler=None,
    403             early_stop=True,
    404             graph_mode=None,
    405             distribute=None,
    406             verbose=True,
    407             forward_keyboard_interrupt=True,
    408             callback=None,
    409             dtype=tf.complex64):
    410     # pylint: disable=line-too-long
    411     """Simulates until target number of errors is reached and returns BER/BLER.
    412 
    413     The simulation continues with the next SNR point if either
    414     ``num_target_bit_errors`` bit errors or ``num_target_block_errors`` block
    415     errors is achieved. Further, it continues with the next SNR point after
    416     ``max_mc_iter`` batches of size ``batch_size`` have been simulated.
    417     Early stopping allows to stop the simulation after the first error-free SNR
    418     point or after reaching a certain ``target_ber`` or ``target_bler``.
    419 
    420     Input
    421     -----
    422     mc_fun: callable
    423         Callable that yields the transmitted bits `b` and the
    424         receiver's estimate `b_hat` for a given ``batch_size`` and
    425         ``ebno_db``. If ``soft_estimates`` is True, `b_hat` is interpreted as
    426         logit.
    427 
    428     ebno_dbs: tf.float32
    429         A tensor containing SNR points to be evaluated.
    430 
    431     batch_size: tf.int32
    432         Batch-size for evaluation.
    433 
    434     max_mc_iter: tf.int32
    435         Maximum number of Monte-Carlo iterations per SNR point.
    436 
    437     soft_estimates: bool
    438         A boolean, defaults to `False`. If `True`, `b_hat`
    439         is interpreted as logit and an additional hard-decision is applied
    440         internally.
    441 
    442     num_target_bit_errors: tf.int32
    443         Defaults to `None`. Target number of bit errors per SNR point until
    444         the simulation continues to next SNR point.
    445 
    446     num_target_block_errors: tf.int32
    447         Defaults to `None`. Target number of block errors per SNR point
    448         until the simulation continues
    449 
    450     target_ber: tf.float32
    451         Defaults to `None`. The simulation stops after the first SNR point
    452         which achieves a lower bit error rate as specified by ``target_ber``.
    453         This requires ``early_stop`` to be `True`.
    454 
    455     target_bler: tf.float32
    456         Defaults to `None`. The simulation stops after the first SNR point
    457         which achieves a lower block error rate as specified by ``target_bler``.
    458         This requires ``early_stop`` to be `True`.
    459 
    460     early_stop: bool
    461         A boolean defaults to `True`. If `True`, the simulation stops after the
    462         first error-free SNR point (i.e., no error occurred after
    463         ``max_mc_iter`` Monte-Carlo iterations).
    464 
    465     graph_mode: One of ["graph", "xla"], str
    466         A string describing the execution mode of ``mc_fun``.
    467         Defaults to `None`. In this case, ``mc_fun`` is executed as is.
    468 
    469     distribute: `None` (default) | "all" | list of indices | `tf.distribute.strategy`
    470         Distributes simulation on multiple parallel devices. If `None`,
    471         multi-device simulations are deactivated. If "all", the workload will
    472         be automatically distributed across all available GPUs via the
    473         `tf.distribute.MirroredStrategy`.
    474         If an explicit list of indices is provided, only the GPUs with the given
    475         indices will be used. Alternatively, a custom `tf.distribute.strategy`
    476         can be provided. Note that the same `batch_size` will be
    477         used for all GPUs in parallel, but the number of Monte-Carlo iterations
    478         ``max_mc_iter`` will be scaled by the number of devices such that the
    479         same number of total samples is simulated. However, all stopping
    480         conditions are still in-place which can cause slight differences in the
    481         total number of simulated samples.
    482 
    483     verbose: bool
    484         A boolean defaults to `True`. If `True`, the current progress will be
    485         printed.
    486 
    487     forward_keyboard_interrupt: bool
    488         A boolean defaults to `True`. If `False`, KeyboardInterrupts will be
    489         catched internally and not forwarded (e.g., will not stop outer loops).
    490         If `False`, the simulation ends and returns the intermediate simulation
    491         results.
    492 
    493     callback: `None` (default) | callable
    494         If specified, ``callback`` will be called after each Monte-Carlo step.
    495         Can be used for logging or advanced early stopping. Input signature of
    496         ``callback`` must match `callback(mc_iter, snr_idx, ebno_dbs,
    497         bit_errors, block_errors, nb_bits, nb_blocks)` where ``mc_iter``
    498         denotes the number of processed batches for the current SNR point,
    499         ``snr_idx`` is the index of the current SNR point, ``ebno_dbs`` is the
    500         vector of all SNR points to be evaluated, ``bit_errors`` the vector of
    501         number of bit errors for each SNR point, ``block_errors`` the vector of
    502         number of block errors, ``nb_bits`` the vector of number of simulated
    503         bits, ``nb_blocks`` the vector of number of simulated blocks,
    504         respectively. If ``callable`` returns `sim_ber.CALLBACK_NEXT_SNR`, early
    505         stopping is detected and the simulation will continue with the
    506         next SNR point. If ``callable`` returns
    507         `sim_ber.CALLBACK_STOP`, the simulation is stopped
    508         immediately. For `sim_ber.CALLBACK_CONTINUE` continues with
    509         the simulation.
    510 
    511     dtype: tf.complex64
    512         Datatype of the callable ``mc_fun`` to be used as input/output.
    513 
    514     Output
    515     ------
    516     (ber, bler) :
    517         Tuple:
    518 
    519     ber: tf.float32
    520         The bit-error rate.
    521 
    522     bler: tf.float32
    523         The block-error rate.
    524 
    525     Raises
    526     ------
    527     AssertionError
    528         If ``soft_estimates`` is not bool.
    529 
    530     AssertionError
    531         If ``dtype`` is not `tf.complex`.
    532 
    533     Note
    534     ----
    535     This function is implemented based on tensors to allow
    536     full compatibility with tf.function(). However, to run simulations
    537     in graph mode, the provided ``mc_fun`` must use the `@tf.function()`
    538     decorator.
    539 
    540     """
    541 
    542     # utility function to print progress
    543     def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None):
    544         """Print summary of current simulation progress.
    545 
    546         Input
    547         -----
    548         is_final: bool
    549             A boolean. If True, the progress is printed into a new line.
    550         rt: float
    551             The runtime of the current SNR point in seconds.
    552         idx_snr: int
    553             Index of current SNR point.
    554         idx_it: int
    555             Current iteration index.
    556         header_text: list of str
    557             Elements will be printed instead of current progress, iff not None.
    558             Can be used to generate table header.
    559         """
    560         # set carriage return if not final step
    561         if is_final:
    562             end_str = "\n"
    563         else:
    564             end_str = "\r"
    565 
    566         # prepare to print table header
    567         if header_text is not None:
    568             row_text = header_text
    569             end_str = "\n"
    570         else:
    571             # calculate intermediate ber / bler
    572             ber_np = (tf.cast(bit_errors[idx_snr], tf.float64)
    573                         / tf.cast(nb_bits[idx_snr], tf.float64)).numpy()
    574             ber_np = np.nan_to_num(ber_np) # avoid nan for first point
    575             bler_np = (tf.cast(block_errors[idx_snr], tf.float64)
    576                         / tf.cast(nb_blocks[idx_snr], tf.float64)).numpy()
    577             bler_np = np.nan_to_num(bler_np) # avoid nan for first point
    578 
    579             # load statuslevel
    580             # print current iter if simulation is still running
    581             if status[idx_snr]==0:
    582                 status_txt = f"iter: {idx_it:.0f}/{max_mc_iter:.0f}"
    583             else:
    584                 status_txt = status_levels[int(status[idx_snr])]
    585 
    586             # generate list with all elements to be printed
    587             row_text = [str(np.round(ebno_dbs[idx_snr].numpy(), 3)),
    588                         f"{ber_np:.4e}",
    589                         f"{bler_np:.4e}",
    590                         np.round(bit_errors[idx_snr].numpy(), 0),
    591                         np.round(nb_bits[idx_snr].numpy(), 0),
    592                         np.round(block_errors[idx_snr].numpy(), 0),
    593                         np.round(nb_blocks[idx_snr].numpy(), 0),
    594                         np.round(rt, 1),
    595                         status_txt]
    596 
    597         # pylint: disable=line-too-long, consider-using-f-string
    598         print("{: >9} |{: >11} |{: >11} |{: >12} |{: >12} |{: >13} |{: >12} |{: >12} |{: >10}".format(*row_text), end=end_str)
    599 
    600     # distributed execution should not be done in Eager mode
    601     # XLA mode seems to have difficulties with TF2.13
    602     @tf.function(jit_compile=False)
    603     def _run_distributed(strategy, mc_fun, batch_size, ebno_db):
    604         # use tf.distribute to execute on parallel devices (=replicas)
    605         outputs_rep = strategy.run(mc_fun,
    606                                    args=(batch_size, ebno_db))
    607         # copy replicas back to single device
    608         b = strategy.gather(outputs_rep[0], axis=0)
    609         b_hat = strategy.gather(outputs_rep[1], axis=0)
    610         return b, b_hat
    611 
    612      # init table headers
    613     header_text = ["EbNo [dB]", "BER", "BLER", "bit errors",
    614                    "num bits", "block errors", "num blocks",
    615                    "runtime [s]", "status"]
    616 
    617     # replace status by text
    618     status_levels = ["not simulated", # status=0
    619             "reached max iter       ", # status=1; spacing for impr. layout
    620             "no errors - early stop", # status=2
    621             "reached target bit errors", # status=3
    622             "reached target block errors", # status=4
    623             "reached target BER - early stop", # status=5
    624             "reached target BLER - early stop", # status=6
    625             "callback triggered stopping"] # status=7
    626 
    627 
    628     # check inputs for consistency
    629     assert isinstance(early_stop, bool), "early_stop must be bool."
    630     assert isinstance(soft_estimates, bool), "soft_estimates must be bool."
    631     assert dtype.is_complex, "dtype must be a complex type."
    632     assert isinstance(verbose, bool), "verbose must be bool."
    633 
    634     # target_ber / target_bler only works if early stop is activated
    635     if target_ber is not None:
    636         if not early_stop:
    637             print("Warning: early stop is deactivated. Thus, target_ber " \
    638                   "is ignored.")
    639     else:
    640         target_ber = -1. # deactivate early stopping condition
    641     if target_bler is not None:
    642         if not early_stop:
    643             print("Warning: early stop is deactivated. Thus, target_bler " \
    644                   "is ignored.")
    645     else:
    646         target_bler = -1. # deactivate early stopping condition
    647 
    648     if graph_mode is None:
    649         graph_mode="default" # applies default graph mode
    650     assert isinstance(graph_mode, str), "graph_mode must be str."
    651 
    652     if graph_mode=="default":
    653         pass # nothing to do
    654     elif graph_mode=="graph":
    655         # avoid retracing -> check if mc_fun is already a function
    656         if not isinstance(mc_fun, tf.types.experimental.GenericFunction):
    657             mc_fun = tf.function(mc_fun,
    658                                  jit_compile=False,
    659                                  experimental_follow_type_hints=True)
    660     elif graph_mode=="xla":
    661         # avoid retracing -> check if mc_fun is already a function
    662         if not isinstance(mc_fun, tf.types.experimental.GenericFunction) or \
    663            not mc_fun.function_spec.jit_compile:
    664             mc_fun = tf.function(mc_fun,
    665                                  jit_compile=True,
    666                                  experimental_follow_type_hints=True)
    667     else:
    668         raise TypeError("Unknown graph_mode selected.")
    669 
    670     # support multi-device simulations by using the tf.distribute package
    671     if len(tf.config.list_logical_devices('GPU'))==0:
    672         run_multigpu = False
    673         distribute = None
    674     if distribute is None: # disabled per default
    675         run_multigpu = False
    676     # use strategy if explicitly provided
    677     elif isinstance(distribute, tf.distribute.Strategy):
    678         run_multigpu = True
    679         strategy = distribute # distribute is already a tf.distribute.strategy
    680     else:
    681         run_multigpu = True
    682         # use all available gpus
    683         if distribute=="all":
    684             gpus = tf.config.list_logical_devices('GPU')
    685         # mask active GPUs if indices are provided
    686         elif isinstance(distribute, (tuple, list)):
    687             gpus_avail = tf.config.list_logical_devices('GPU')
    688             gpus = [gpus_avail[i] for i in distribute if i < len(gpus_avail)]
    689         else:
    690             raise ValueError("Unknown value for distribute.")
    691 
    692         # deactivate logging of tf.device placement
    693         if verbose:
    694             print("Setting tf.debugging.set_log_device_placement to False.")
    695         tf.debugging.set_log_device_placement(False)
    696         # we reduce to the first device by default
    697         strategy = tf.distribute.MirroredStrategy(gpus,
    698                             cross_device_ops=tf.distribute.ReductionToOneDevice(
    699                                                 reduce_to_device=gpus[0].name))
    700 
    701     # reduce max_mc_iter if multi_gpu simulations are activated
    702     if run_multigpu:
    703         num_replicas = strategy.num_replicas_in_sync # pylint: disable=possibly-used-before-assignment
    704         max_mc_iter = int(np.ceil(max_mc_iter/num_replicas))
    705         print(f"Distributing simulation across {num_replicas} devices.")
    706         print(f"Reducing max_mc_iter to {max_mc_iter}")
    707 
    708     ebno_dbs = tf.cast(ebno_dbs, dtype.real_dtype)
    709     batch_size = tf.cast(batch_size, tf.int32)
    710     num_points = tf.shape(ebno_dbs)[0]
    711     bit_errors = tf.Variable(   tf.zeros([num_points], dtype=tf.int64),
    712                                 dtype=tf.int64)
    713     block_errors = tf.Variable( tf.zeros([num_points], dtype=tf.int64),
    714                                 dtype=tf.int64)
    715     nb_bits = tf.Variable(  tf.zeros([num_points], dtype=tf.int64),
    716                             dtype=tf.int64)
    717     nb_blocks = tf.Variable(tf.zeros([num_points], dtype=tf.int64),
    718                             dtype=tf.int64)
    719 
    720     # track status of simulation (early termination etc.)
    721     status = np.zeros(num_points)
    722 
    723     # measure runtime per SNR point
    724     runtime = np.zeros(num_points)
    725 
    726     # ensure num_target_errors is a tensor
    727     if num_target_bit_errors is not None:
    728         num_target_bit_errors = tf.cast(num_target_bit_errors, tf.int64)
    729     if num_target_block_errors is not None:
    730         num_target_block_errors = tf.cast(num_target_block_errors, tf.int64)
    731 
    732     try:
    733         # simulate until a target number of errors is reached
    734         for i in tf.range(num_points):
    735             runtime[i] = time.perf_counter() # save start time
    736             iter_count = -1 # for print in verbose mode
    737             for ii in tf.range(max_mc_iter):
    738 
    739                 iter_count += 1
    740 
    741                 if run_multigpu: # distributed execution
    742                     b, b_hat = _run_distributed(strategy,
    743                                                 mc_fun,
    744                                                 batch_size,
    745                                                 ebno_dbs[i])
    746                 else:
    747                     outputs = mc_fun(batch_size=batch_size, ebno_db=ebno_dbs[i])
    748                     # assume first and second return value is b and b_hat
    749                     # other returns are ignored
    750                     b = outputs[0]
    751                     b_hat = outputs[1]
    752 
    753                 if soft_estimates:
    754                     b_hat = hard_decisions(b_hat)
    755 
    756                 # count errors
    757                 bit_e = count_errors(b, b_hat)
    758                 block_e = count_block_errors(b, b_hat)
    759 
    760                 # count total number of bits
    761                 bit_n = tf.size(b)
    762                 block_n = tf.size(b[...,-1])
    763 
    764                 # update variables
    765                 bit_errors = tf.tensor_scatter_nd_add(  bit_errors, [[i]],
    766                                                     tf.cast([bit_e], tf.int64))
    767                 block_errors = tf.tensor_scatter_nd_add(  block_errors, [[i]],
    768                                                 tf.cast([block_e], tf.int64))
    769                 nb_bits = tf.tensor_scatter_nd_add( nb_bits, [[i]],
    770                                                     tf.cast([bit_n], tf.int64))
    771                 nb_blocks = tf.tensor_scatter_nd_add( nb_blocks, [[i]],
    772                                                 tf.cast([block_n], tf.int64))
    773 
    774                 cb_state = sim_ber.CALLBACK_CONTINUE
    775                 if callback is not None:
    776                     cb_state = callback (ii, i, ebno_dbs, bit_errors,
    777                                        block_errors, nb_bits,
    778                                        nb_blocks)
    779                     if cb_state in (sim_ber.CALLBACK_STOP,
    780                                     sim_ber.CALLBACK_NEXT_SNR):
    781                         # stop runtime timer
    782                         runtime[i] = time.perf_counter() - runtime[i]
    783                         status[i] = 7 # change internal status for summary
    784                         break # stop for this SNR point have been simulated
    785 
    786                 # print progress summary
    787                 if verbose:
    788                     # print summary header during first iteration
    789                     if i==0 and iter_count==0:
    790                         _print_progress(is_final=True,
    791                                         rt=0,
    792                                         idx_snr=0,
    793                                         idx_it=0,
    794                                         header_text=header_text)
    795                         # print seperator after headline
    796                         print('-' * 135)
    797 
    798                     # evaluate current runtime
    799                     rt = time.perf_counter() - runtime[i]
    800                     # print current progress
    801                     _print_progress(is_final=False, idx_snr=i, idx_it=ii, rt=rt)
    802 
    803                 # bit-error based stopping cond.
    804                 if num_target_bit_errors is not None:
    805                     if tf.greater_equal(bit_errors[i], num_target_bit_errors):
    806                         status[i] = 3 # change internal status for summary
    807                         # stop runtime timer
    808                         runtime[i] = time.perf_counter() - runtime[i]
    809                         break # enough errors for SNR point have been simulated
    810 
    811                 # block-error based stopping cond.
    812                 if num_target_block_errors is not None:
    813                     if tf.greater_equal(block_errors[i],
    814                                         num_target_block_errors):
    815                         # stop runtime timer
    816                         runtime[i] = time.perf_counter() - runtime[i]
    817                         status[i] = 4 # change internal status for summary
    818                         break # enough errors for SNR point have been simulated
    819 
    820                 # max iter have been reached -> continue with next SNR point
    821                 if iter_count==max_mc_iter-1: # all iterations are done
    822                     # stop runtime timer
    823                     runtime[i] = time.perf_counter() - runtime[i]
    824                     status[i] = 1 # change internal status for summary
    825 
    826             # print results again AFTER last iteration / early stop (new status)
    827             if verbose:
    828                 _print_progress(is_final=True,
    829                                 idx_snr=i,
    830                                 idx_it=iter_count,
    831                                 rt=runtime[i])
    832 
    833             # early stop if no error occurred or target_ber/target_bler reached
    834             if early_stop: # only if early stop is active
    835                 if block_errors[i]==0:
    836                     status[i] = 2 # change internal status for summary
    837                     if verbose:
    838                         print("\nSimulation stopped as no error occurred " \
    839                               f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n")
    840                     break
    841 
    842                 # check for target_ber / target_bler
    843                 ber_true =  bit_errors[i] / nb_bits[i]
    844                 bler_true = block_errors[i] / nb_blocks[i]
    845                 if ber_true <target_ber:
    846                     status[i] = 5 # change internal status for summary
    847                     if verbose:
    848                         print("\nSimulation stopped as target BER is reached" \
    849                               f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n")
    850                     break
    851                 if bler_true <target_bler:
    852                     status[i] = 6 # change internal status for summary
    853                     if verbose:
    854                         print("\nSimulation stopped as target BLER is " \
    855                               f"reached @ EbNo = {ebno_dbs[i].numpy():.1f} " \
    856                               "dB.\n")
    857                     break
    858 
    859             # allow callback to end the entire simulation
    860             if cb_state is sim_ber.CALLBACK_STOP:
    861                 # stop runtime timer
    862                 status[i] = 7 # change internal status for summary
    863                 if verbose:
    864                     print("\nSimulation stopped by callback function " \
    865                           f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n")
    866                 break
    867 
    868     # Stop if KeyboardInterrupt is detected and set remaining SNR points to -1
    869     except KeyboardInterrupt as e:
    870 
    871         # Raise Interrupt again to stop outer loops
    872         if forward_keyboard_interrupt:
    873             raise e
    874 
    875         print("\nSimulation stopped by the user " \
    876               f"@ EbNo = {ebno_dbs[i].numpy()} dB.")
    877         # overwrite remaining BER / BLER positions with -1
    878         for idx in range(i+1, num_points):
    879             bit_errors = tf.tensor_scatter_nd_update( bit_errors, [[idx]],
    880                                                     tf.cast([-1], tf.int64))
    881             block_errors = tf.tensor_scatter_nd_update( block_errors, [[idx]],
    882                                                     tf.cast([-1], tf.int64))
    883             nb_bits = tf.tensor_scatter_nd_update( nb_bits, [[idx]],
    884                                                     tf.cast([1], tf.int64))
    885             nb_blocks = tf.tensor_scatter_nd_update( nb_blocks, [[idx]],
    886                                                     tf.cast([1], tf.int64))
    887 
    888     # calculate BER / BLER
    889     ber = tf.cast(bit_errors, tf.float64) / tf.cast(nb_bits, tf.float64)
    890     bler = tf.cast(block_errors, tf.float64) / tf.cast(nb_blocks, tf.float64)
    891 
    892     # replace nans (from early stop)
    893     ber = tf.where(tf.math.is_nan(ber), tf.zeros_like(ber), ber)
    894     bler = tf.where(tf.math.is_nan(bler), tf.zeros_like(bler), bler)
    895 
    896     return ber, bler
    897 
    898 sim_ber.CALLBACK_CONTINUE = None
    899 sim_ber.CALLBACK_STOP = 2
    900 sim_ber.CALLBACK_NEXT_SNR = 1
    901 
    902 def complex_normal(shape, var=1.0, dtype=tf.complex64):
    903     r"""Generates a tensor of complex normal random variables.
    904 
    905     Input
    906     -----
    907     shape : tf.shape, or list
    908         The desired shape.
    909 
    910     var : float
    911         The total variance., i.e., each complex dimension has
    912         variance ``var/2``.
    913 
    914     dtype: tf.complex
    915         The desired dtype. Defaults to `tf.complex64`.
    916 
    917     Output
    918     ------
    919     : ``shape``, ``dtype``
    920         Tensor of complex normal random variables.
    921     """
    922     # Half the variance for each dimension
    923     var_dim = tf.cast(var, dtype.real_dtype)/tf.cast(2, dtype.real_dtype)
    924     stddev = tf.sqrt(var_dim)
    925 
    926     # Generate complex Gaussian noise with the right variance
    927     xr = sionna.config.tf_rng.normal(shape, stddev=stddev,
    928                                      dtype=dtype.real_dtype)
    929     xi = sionna.config.tf_rng.normal(shape, stddev=stddev,
    930                                      dtype=dtype.real_dtype)
    931     x = tf.complex(xr, xi)
    932 
    933     return x
    934 
    935 ###########################################################
    936 # Deprecated aliases that will not be included in the next
    937 # major release
    938 ###########################################################
    939 
    940 def fft(tensor, axis=-1):
    941     print(  "Warning: The alias utils.fft will not be included in Sionna 1.0."
    942             " Please use signal.fft instead.")
    943     return signal.fft(tensor, axis)
    944 
    945 
    946 def ifft(tensor, axis=-1):
    947     print(  "Warning: The alias utils.ifft will not be included in Sionna 1.0."
    948             " Please use signal.ifft instead.")
    949     return signal.ifft(tensor, axis)
    950 
    951 
    952 def empirical_psd(x, show=True, oversampling=1.0, ylim=(-30,3)):
    953     print(  "Warning: The alias utils.empirical_psd will not be included in"
    954             " Sionna 1.0. Please use signal.empirical_psd instead.")
    955     return signal.empirical_psd(x, show, oversampling, ylim)