anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

plotting.py (16573B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """Plotting functions for the Sionna library."""
      6 
      7 import numpy as np
      8 import matplotlib.pyplot as plt
      9 from sionna.utils import sim_ber
     10 from itertools import compress # to "filter" list
     11 
     12 def plot_ber(snr_db,
     13              ber,
     14              legend="",
     15              ylabel="BER",
     16              title="Bit Error Rate",
     17              ebno=True,
     18              is_bler=None,
     19              xlim=None,
     20              ylim=None,
     21              save_fig=False,
     22              path=""):
     23     """Plot error-rates.
     24 
     25     Input
     26     -----
     27     snr_db: ndarray
     28         Array of floats defining the simulated SNR points.
     29         Can be also a list of multiple arrays.
     30 
     31     ber: ndarray
     32         Array of floats defining the BER/BLER per SNR point.
     33         Can be also a list of multiple arrays.
     34 
     35     legend: str
     36         Defaults to "". Defining the legend entries. Can be
     37         either a string or a list of strings.
     38 
     39     ylabel: str
     40         Defaults to "BER". Defining the y-label.
     41 
     42     title: str
     43         Defaults to "Bit Error Rate". Defining the title of the figure.
     44 
     45     ebno: bool
     46         Defaults to True. If True, the x-label is set to
     47         "EbNo [dB]" instead of "EsNo [dB]".
     48 
     49     is_bler: bool
     50         Defaults to False. If True, the corresponding curve is dashed.
     51 
     52     xlim: tuple of floats
     53         Defaults to None. A tuple of two floats defining x-axis limits.
     54 
     55     ylim: tuple of floats
     56         Defaults to None. A tuple of two floats defining y-axis limits.
     57 
     58     save_fig: bool
     59         Defaults to False. If True, the figure is saved as `.png`.
     60 
     61     path: str
     62         Defaults to "". Defining the path to save the figure
     63         (iff ``save_fig`` is True).
     64 
     65     Output
     66     ------
     67         (fig, ax) :
     68             Tuple:
     69 
     70         fig : matplotlib.figure.Figure
     71             A matplotlib figure handle.
     72 
     73         ax : matplotlib.axes.Axes
     74             A matplotlib axes object.
     75     """
     76 
     77     # legend must be a list or string
     78     if not isinstance(legend, list):
     79         assert isinstance(legend, str)
     80         legend = [legend]
     81 
     82     assert isinstance(title, str), "title must be str."
     83 
     84     # broadcast snr if ber is list
     85     if isinstance(ber, list):
     86         if not isinstance(snr_db, list):
     87             snr_db = [snr_db]*len(ber)
     88 
     89     # check that is_bler is list of same size and contains only bools
     90     if is_bler is None:
     91         if isinstance(ber, list):
     92             is_bler = [False] * len(ber) # init is_bler as list with False
     93         else:
     94             is_bler = False
     95     else:
     96         if isinstance(is_bler, list):
     97             assert (len(is_bler) == len(ber)), "is_bler has invalid size."
     98         else:
     99             assert isinstance(is_bler, bool), \
    100                 "is_bler must be bool or list of bool."
    101             is_bler = [is_bler] # change to list
    102 
    103     # tile snr_db if not list, but ber is list
    104 
    105     fig, ax = plt.subplots(figsize=(16,10))
    106 
    107     plt.xticks(fontsize=18)
    108     plt.yticks(fontsize=18)
    109 
    110     if xlim is not None:
    111         plt.xlim(xlim)
    112     if ylim is not None:
    113         plt.ylim(ylim)
    114 
    115     plt.title(title, fontsize=25)
    116     # return figure handle
    117     if isinstance(ber, list):
    118         for idx, b in enumerate(ber):
    119             if is_bler[idx]:
    120                 line_style = "--"
    121             else:
    122                 line_style = ""
    123             plt.semilogy(snr_db[idx], b, line_style, linewidth=2)
    124     else:
    125         if is_bler:
    126             line_style = "--"
    127         else:
    128             line_style = ""
    129         plt.semilogy(snr_db, ber, line_style, linewidth=2)
    130 
    131     plt.grid(which="both")
    132     if ebno:
    133         plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25)
    134     else:
    135         plt.xlabel(r"$E_s/N_0$ (dB)", fontsize=25)
    136     plt.ylabel(ylabel, fontsize=25)
    137     plt.legend(legend, fontsize=20)
    138     if save_fig:
    139         plt.savefig(path)
    140         plt.close(fig)
    141     else:
    142         #plt.close(fig)
    143         pass
    144     return fig, ax
    145 
    146 ###### Plotting classes #######
    147 
    148 class PlotBER():
    149     """Provides a plotting object to simulate and store BER/BLER curves.
    150 
    151     Parameters
    152     ----------
    153     title: str
    154         A string defining the title of the figure. Defaults to
    155         `"Bit/Block Error Rate"`.
    156 
    157     Input
    158     -----
    159     snr_db: float
    160         Python array (or list of Python arrays) of additional SNR values to be
    161         plotted.
    162 
    163     ber: float
    164         Python array (or list of Python arrays) of additional BERs
    165         corresponding to ``snr_db``.
    166 
    167     legend: str
    168         String (or list of strings) of legends entries.
    169 
    170     is_bler: bool
    171         A boolean (or list of booleans) defaults to False.
    172         If True, ``ber`` will be interpreted as BLER.
    173 
    174     show_ber: bool
    175         A boolean defaults to True. If True, BER curves will be plotted.
    176 
    177     show_bler: bool
    178         A boolean defaults to True. If True, BLER curves will be plotted.
    179 
    180     xlim: tuple of floats
    181         Defaults to None. A tuple of two floats defining x-axis limits.
    182 
    183     ylim: tuple of floats
    184         Defaults to None. A tuple of two floats defining y-axis limits.
    185 
    186     save_fig: bool
    187         A boolean defaults to False. If True, the figure
    188         is saved as file.
    189 
    190     path: str
    191         A string defining where to save the figure (if ``save_fig``
    192         is True).
    193     """
    194 
    195     def __init__(self, title="Bit/Block Error Rate"):
    196 
    197         assert isinstance(title, str), "title must be str."
    198         self._title = title
    199 
    200         # init lists
    201         self._bers = []
    202         self._snrs = []
    203         self._legends = []
    204         self._is_bler = []
    205 
    206     # pylint: disable=W0102
    207     def __call__(self,
    208                  snr_db=[],
    209                  ber=[],
    210                  legend=[],
    211                  is_bler=[],
    212                  show_ber=True,
    213                  show_bler=True,
    214                  xlim=None,
    215                  ylim=None,
    216                  save_fig=False,
    217                  path=""):
    218         """Plot BER curves.
    219 
    220         """
    221 
    222         assert isinstance(path, str), "path must be str."
    223         assert isinstance(save_fig, bool), "save_fig must be bool."
    224 
    225         # broadcast snr if ber is list
    226         if isinstance(ber, list):
    227             if not isinstance(snr_db, list):
    228                 snr_db = [snr_db]*len(ber)
    229 
    230         if not isinstance(snr_db, list):
    231             snrs = self._snrs + [snr_db]
    232         else:
    233             snrs = self._snrs + snr_db
    234         if not isinstance(ber, list):
    235             bers = self._bers + [ber]
    236         else:
    237             bers = self._bers + ber
    238         if not isinstance(legend, list):
    239             legends = self._legends + [legend]
    240         else:
    241             legends = self._legends + legend
    242         if not isinstance(is_bler, list):
    243             is_bler = self._is_bler + [is_bler]
    244         else:
    245             is_bler = self._is_bler + is_bler
    246 
    247         # deactivate BER/BLER
    248         if len(is_bler)>0: # ignore if object is empty
    249             if show_ber is False:
    250                 snrs = list(compress(snrs, is_bler))
    251                 bers = list(compress(bers, is_bler))
    252                 legends = list(compress(legends, is_bler))
    253                 is_bler = list(compress(is_bler, is_bler))
    254 
    255             if show_bler is False:
    256                 snrs = list(compress(snrs, np.invert(is_bler)))
    257                 bers = list(compress(bers, np.invert(is_bler)))
    258                 legends = list(compress(legends, np.invert(is_bler)))
    259                 is_bler = list(compress(is_bler, np.invert(is_bler)))
    260 
    261         # set ylabel
    262         ylabel = "BER / BLER"
    263         if np.all(is_bler): # only BLERs to plot
    264             ylabel = "BLER"
    265         if not np.any(is_bler): # only BERs to plot
    266             ylabel = "BER"
    267 
    268         # and plot the results
    269         plot_ber(snr_db=snrs,
    270                  ber=bers,
    271                  legend=legends,
    272                  is_bler=is_bler,
    273                  title=self._title,
    274                  ylabel=ylabel,
    275                  xlim=xlim,
    276                  ylim=ylim,
    277                  save_fig=save_fig,
    278                  path=path)
    279 
    280     ####public methods
    281     @property
    282     def title(self):
    283         """Title of the plot."""
    284         return self._title
    285 
    286     @title.setter
    287     def title(self, title):
    288         """Set title of the plot."""
    289         assert isinstance(title, str), "title must be string"
    290         self._title = title
    291 
    292     @property
    293     def ber(self):
    294         """List containing all stored BER curves."""
    295         return self._bers
    296 
    297     @property
    298     def snr(self):
    299         """List containing all stored SNR curves."""
    300         return self._snrs
    301 
    302     @property
    303     def legend(self):
    304         """List containing all stored legend entries curves."""
    305         return self._legends
    306 
    307     @property
    308     def is_bler(self):
    309         """List of booleans indicating if ber shall be interpreted as BLER."""
    310         return self._is_bler
    311 
    312     def simulate(self,
    313                  mc_fun,
    314                  ebno_dbs,
    315                  batch_size,
    316                  max_mc_iter,
    317                  legend="",
    318                  add_ber=True,
    319                  add_bler=False,
    320                  soft_estimates=False,
    321                  num_target_bit_errors=None,
    322                  num_target_block_errors=None,
    323                  target_ber=None,
    324                  target_bler=None,
    325                  early_stop=True,
    326                  graph_mode=None,
    327                  distribute=None,
    328                  add_results=True,
    329                  forward_keyboard_interrupt=True,
    330                  show_fig=True,
    331                  verbose=True):
    332         # pylint: disable=line-too-long
    333         r"""Simulate BER/BLER curves for given Keras model and saves the results.
    334 
    335         Internally calls :class:`sionna.utils.sim_ber`.
    336 
    337         Input
    338         -----
    339         mc_fun:
    340             Callable that yields the transmitted bits `b` and the
    341             receiver's estimate `b_hat` for a given ``batch_size`` and
    342             ``ebno_db``. If ``soft_estimates`` is True, b_hat is interpreted as
    343             logit.
    344 
    345         ebno_dbs: ndarray of floats
    346             SNR points to be evaluated.
    347 
    348         batch_size: tf.int32
    349             Batch-size for evaluation.
    350 
    351         max_mc_iter: int
    352             Max. number of Monte-Carlo iterations per SNR point.
    353 
    354         legend: str
    355             Name to appear in legend.
    356 
    357         add_ber: bool
    358             Defaults to True. Indicate if BER should be added to plot.
    359 
    360         add_bler: bool
    361             Defaults to False. Indicate if BLER should be added
    362             to plot.
    363 
    364         soft_estimates: bool
    365             A boolean, defaults to False. If True, ``b_hat``
    366             is interpreted as logit and additional hard-decision is applied
    367             internally.
    368 
    369         num_target_bit_errors: int
    370             Target number of bit errors per SNR point until the simulation
    371             stops.
    372 
    373         num_target_block_errors: int
    374             Target number of block errors per SNR point until the simulation
    375             stops.
    376 
    377         target_ber: tf.float32
    378             Defaults to `None`. The simulation stops after the first SNR point
    379             which achieves a lower bit error rate as specified by
    380             ``target_ber``. This requires ``early_stop`` to be `True`.
    381 
    382         target_bler: tf.float32
    383             Defaults to `None`. The simulation stops after the first SNR point
    384             which achieves a lower block error rate as specified by
    385             ``target_bler``.  This requires ``early_stop`` to be `True`.
    386 
    387         early_stop: bool
    388             A boolean defaults to True. If True, the simulation stops after the
    389             first error-free SNR point (i.e., no error occurred after
    390             ``max_mc_iter`` Monte-Carlo iterations).
    391 
    392         graph_mode: One of ["graph", "xla"], str
    393             A string describing the execution mode of ``mc_fun``.
    394             Defaults to `None`. In this case, ``mc_fun`` is executed as is.
    395 
    396         distribute: `None` (default) | "all" | list of indices | `tf.distribute.strategy`
    397             Distributes simulation on multiple parallel devices. If `None`,
    398             multi-device simulations are deactivated. If "all", the workload
    399             will be automatically distributed across all available GPUs via the
    400             `tf.distribute.MirroredStrategy`.
    401             If an explicit list of indices is provided, only the GPUs with the
    402             given indices will be used. Alternatively, a custom
    403             `tf.distribute.strategy` can be provided. Note that the same
    404             `batch_size` will be used for all GPUs in parallel, but the number
    405             of Monte-Carlo iterations ``max_mc_iter`` will be scaled by the
    406             number of devices such that the same number of total samples is
    407             simulated. However, all stopping conditions are still in-place
    408             which can cause slight differences in the total number of simulated
    409             samples.
    410 
    411         add_results: bool
    412             Defaults to True. If True, the simulation results will be appended
    413             to the internal list of results.
    414 
    415         show_fig: bool
    416             Defaults to True. If True, a BER figure will be plotted.
    417 
    418         verbose: bool
    419             A boolean defaults to True. If True, the current progress will be
    420             printed.
    421 
    422         forward_keyboard_interrupt: bool
    423             A boolean defaults to True. If False, `KeyboardInterrupts` will be
    424             catched internally and not forwarded (e.g., will not stop outer
    425             loops). If False, the simulation ends and returns the intermediate
    426             simulation results.
    427 
    428         Output
    429         ------
    430         (ber, bler):
    431             Tuple:
    432 
    433         ber: float
    434             The simulated bit-error rate.
    435 
    436         bler: float
    437             The simulated block-error rate.
    438         """
    439 
    440         ber, bler = sim_ber(
    441                         mc_fun,
    442                         ebno_dbs,
    443                         batch_size,
    444                         soft_estimates=soft_estimates,
    445                         max_mc_iter=max_mc_iter,
    446                         num_target_bit_errors=num_target_bit_errors,
    447                         num_target_block_errors=num_target_block_errors,
    448                         target_ber=target_ber,
    449                         target_bler=target_bler,
    450                         early_stop=early_stop,
    451                         graph_mode=graph_mode,
    452                         distribute=distribute,
    453                         verbose=verbose,
    454                         forward_keyboard_interrupt=forward_keyboard_interrupt)
    455 
    456         if add_ber:
    457             self._bers += [ber]
    458             self._snrs +=  [ebno_dbs]
    459             self._legends += [legend]
    460             self._is_bler += [False]
    461 
    462         if add_bler:
    463             self._bers += [bler]
    464             self._snrs +=  [ebno_dbs]
    465             self._legends += [legend + " (BLER)"]
    466             self._is_bler += [True]
    467 
    468         if show_fig:
    469             self()
    470 
    471         # remove current curve if add_results=False
    472         if add_results is False:
    473             if add_bler:
    474                 self.remove(-1)
    475             if add_ber:
    476                 self.remove(-1)
    477 
    478         return ber, bler
    479 
    480     def add(self, ebno_db, ber, is_bler=False, legend=""):
    481         """Add static reference curves.
    482 
    483         Input
    484         -----
    485         ebno_db: float
    486             Python array or list of floats defining the SNR points.
    487 
    488         ber: float
    489             Python array or list of floats defining the BER corresponding
    490             to each SNR point.
    491 
    492         is_bler: bool
    493             A boolean defaults to False. If True, ``ber`` is interpreted as
    494             BLER.
    495 
    496         legend: str
    497             A string defining the text of the legend entry.
    498         """
    499 
    500         assert (len(ebno_db)==len(ber)), \
    501             "ebno_db and ber must have same number of elements."
    502 
    503         assert isinstance(legend, str), "legend must be str."
    504         assert isinstance(is_bler, bool), "is_bler must be bool."
    505 
    506         # concatenate curves
    507         self._bers += [ber]
    508         self._snrs +=  [ebno_db]
    509         self._legends += [legend]
    510         self._is_bler += [is_bler]
    511 
    512     def reset(self):
    513         """Remove all internal data."""
    514         self._bers = []
    515         self._snrs = []
    516         self._legends = []
    517         self._is_bler = []
    518 
    519     def remove(self, idx=-1):
    520         """Remove curve with index ``idx``.
    521 
    522         Input
    523         ------
    524         idx: int
    525             An integer defining the index of the dataset that should
    526             be removed. Negative indexing is possible.
    527         """
    528 
    529         assert isinstance(idx, int), "id must be int."
    530 
    531         del self._bers[idx]
    532         del self._snrs[idx]
    533         del self._legends[idx]
    534         del self._is_bler[idx]
    535