plotting.py (16573B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """Plotting functions for the Sionna library.""" 6 7 import numpy as np 8 import matplotlib.pyplot as plt 9 from sionna.utils import sim_ber 10 from itertools import compress # to "filter" list 11 12 def plot_ber(snr_db, 13 ber, 14 legend="", 15 ylabel="BER", 16 title="Bit Error Rate", 17 ebno=True, 18 is_bler=None, 19 xlim=None, 20 ylim=None, 21 save_fig=False, 22 path=""): 23 """Plot error-rates. 24 25 Input 26 ----- 27 snr_db: ndarray 28 Array of floats defining the simulated SNR points. 29 Can be also a list of multiple arrays. 30 31 ber: ndarray 32 Array of floats defining the BER/BLER per SNR point. 33 Can be also a list of multiple arrays. 34 35 legend: str 36 Defaults to "". Defining the legend entries. Can be 37 either a string or a list of strings. 38 39 ylabel: str 40 Defaults to "BER". Defining the y-label. 41 42 title: str 43 Defaults to "Bit Error Rate". Defining the title of the figure. 44 45 ebno: bool 46 Defaults to True. If True, the x-label is set to 47 "EbNo [dB]" instead of "EsNo [dB]". 48 49 is_bler: bool 50 Defaults to False. If True, the corresponding curve is dashed. 51 52 xlim: tuple of floats 53 Defaults to None. A tuple of two floats defining x-axis limits. 54 55 ylim: tuple of floats 56 Defaults to None. A tuple of two floats defining y-axis limits. 57 58 save_fig: bool 59 Defaults to False. If True, the figure is saved as `.png`. 60 61 path: str 62 Defaults to "". Defining the path to save the figure 63 (iff ``save_fig`` is True). 64 65 Output 66 ------ 67 (fig, ax) : 68 Tuple: 69 70 fig : matplotlib.figure.Figure 71 A matplotlib figure handle. 72 73 ax : matplotlib.axes.Axes 74 A matplotlib axes object. 75 """ 76 77 # legend must be a list or string 78 if not isinstance(legend, list): 79 assert isinstance(legend, str) 80 legend = [legend] 81 82 assert isinstance(title, str), "title must be str." 83 84 # broadcast snr if ber is list 85 if isinstance(ber, list): 86 if not isinstance(snr_db, list): 87 snr_db = [snr_db]*len(ber) 88 89 # check that is_bler is list of same size and contains only bools 90 if is_bler is None: 91 if isinstance(ber, list): 92 is_bler = [False] * len(ber) # init is_bler as list with False 93 else: 94 is_bler = False 95 else: 96 if isinstance(is_bler, list): 97 assert (len(is_bler) == len(ber)), "is_bler has invalid size." 98 else: 99 assert isinstance(is_bler, bool), \ 100 "is_bler must be bool or list of bool." 101 is_bler = [is_bler] # change to list 102 103 # tile snr_db if not list, but ber is list 104 105 fig, ax = plt.subplots(figsize=(16,10)) 106 107 plt.xticks(fontsize=18) 108 plt.yticks(fontsize=18) 109 110 if xlim is not None: 111 plt.xlim(xlim) 112 if ylim is not None: 113 plt.ylim(ylim) 114 115 plt.title(title, fontsize=25) 116 # return figure handle 117 if isinstance(ber, list): 118 for idx, b in enumerate(ber): 119 if is_bler[idx]: 120 line_style = "--" 121 else: 122 line_style = "" 123 plt.semilogy(snr_db[idx], b, line_style, linewidth=2) 124 else: 125 if is_bler: 126 line_style = "--" 127 else: 128 line_style = "" 129 plt.semilogy(snr_db, ber, line_style, linewidth=2) 130 131 plt.grid(which="both") 132 if ebno: 133 plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25) 134 else: 135 plt.xlabel(r"$E_s/N_0$ (dB)", fontsize=25) 136 plt.ylabel(ylabel, fontsize=25) 137 plt.legend(legend, fontsize=20) 138 if save_fig: 139 plt.savefig(path) 140 plt.close(fig) 141 else: 142 #plt.close(fig) 143 pass 144 return fig, ax 145 146 ###### Plotting classes ####### 147 148 class PlotBER(): 149 """Provides a plotting object to simulate and store BER/BLER curves. 150 151 Parameters 152 ---------- 153 title: str 154 A string defining the title of the figure. Defaults to 155 `"Bit/Block Error Rate"`. 156 157 Input 158 ----- 159 snr_db: float 160 Python array (or list of Python arrays) of additional SNR values to be 161 plotted. 162 163 ber: float 164 Python array (or list of Python arrays) of additional BERs 165 corresponding to ``snr_db``. 166 167 legend: str 168 String (or list of strings) of legends entries. 169 170 is_bler: bool 171 A boolean (or list of booleans) defaults to False. 172 If True, ``ber`` will be interpreted as BLER. 173 174 show_ber: bool 175 A boolean defaults to True. If True, BER curves will be plotted. 176 177 show_bler: bool 178 A boolean defaults to True. If True, BLER curves will be plotted. 179 180 xlim: tuple of floats 181 Defaults to None. A tuple of two floats defining x-axis limits. 182 183 ylim: tuple of floats 184 Defaults to None. A tuple of two floats defining y-axis limits. 185 186 save_fig: bool 187 A boolean defaults to False. If True, the figure 188 is saved as file. 189 190 path: str 191 A string defining where to save the figure (if ``save_fig`` 192 is True). 193 """ 194 195 def __init__(self, title="Bit/Block Error Rate"): 196 197 assert isinstance(title, str), "title must be str." 198 self._title = title 199 200 # init lists 201 self._bers = [] 202 self._snrs = [] 203 self._legends = [] 204 self._is_bler = [] 205 206 # pylint: disable=W0102 207 def __call__(self, 208 snr_db=[], 209 ber=[], 210 legend=[], 211 is_bler=[], 212 show_ber=True, 213 show_bler=True, 214 xlim=None, 215 ylim=None, 216 save_fig=False, 217 path=""): 218 """Plot BER curves. 219 220 """ 221 222 assert isinstance(path, str), "path must be str." 223 assert isinstance(save_fig, bool), "save_fig must be bool." 224 225 # broadcast snr if ber is list 226 if isinstance(ber, list): 227 if not isinstance(snr_db, list): 228 snr_db = [snr_db]*len(ber) 229 230 if not isinstance(snr_db, list): 231 snrs = self._snrs + [snr_db] 232 else: 233 snrs = self._snrs + snr_db 234 if not isinstance(ber, list): 235 bers = self._bers + [ber] 236 else: 237 bers = self._bers + ber 238 if not isinstance(legend, list): 239 legends = self._legends + [legend] 240 else: 241 legends = self._legends + legend 242 if not isinstance(is_bler, list): 243 is_bler = self._is_bler + [is_bler] 244 else: 245 is_bler = self._is_bler + is_bler 246 247 # deactivate BER/BLER 248 if len(is_bler)>0: # ignore if object is empty 249 if show_ber is False: 250 snrs = list(compress(snrs, is_bler)) 251 bers = list(compress(bers, is_bler)) 252 legends = list(compress(legends, is_bler)) 253 is_bler = list(compress(is_bler, is_bler)) 254 255 if show_bler is False: 256 snrs = list(compress(snrs, np.invert(is_bler))) 257 bers = list(compress(bers, np.invert(is_bler))) 258 legends = list(compress(legends, np.invert(is_bler))) 259 is_bler = list(compress(is_bler, np.invert(is_bler))) 260 261 # set ylabel 262 ylabel = "BER / BLER" 263 if np.all(is_bler): # only BLERs to plot 264 ylabel = "BLER" 265 if not np.any(is_bler): # only BERs to plot 266 ylabel = "BER" 267 268 # and plot the results 269 plot_ber(snr_db=snrs, 270 ber=bers, 271 legend=legends, 272 is_bler=is_bler, 273 title=self._title, 274 ylabel=ylabel, 275 xlim=xlim, 276 ylim=ylim, 277 save_fig=save_fig, 278 path=path) 279 280 ####public methods 281 @property 282 def title(self): 283 """Title of the plot.""" 284 return self._title 285 286 @title.setter 287 def title(self, title): 288 """Set title of the plot.""" 289 assert isinstance(title, str), "title must be string" 290 self._title = title 291 292 @property 293 def ber(self): 294 """List containing all stored BER curves.""" 295 return self._bers 296 297 @property 298 def snr(self): 299 """List containing all stored SNR curves.""" 300 return self._snrs 301 302 @property 303 def legend(self): 304 """List containing all stored legend entries curves.""" 305 return self._legends 306 307 @property 308 def is_bler(self): 309 """List of booleans indicating if ber shall be interpreted as BLER.""" 310 return self._is_bler 311 312 def simulate(self, 313 mc_fun, 314 ebno_dbs, 315 batch_size, 316 max_mc_iter, 317 legend="", 318 add_ber=True, 319 add_bler=False, 320 soft_estimates=False, 321 num_target_bit_errors=None, 322 num_target_block_errors=None, 323 target_ber=None, 324 target_bler=None, 325 early_stop=True, 326 graph_mode=None, 327 distribute=None, 328 add_results=True, 329 forward_keyboard_interrupt=True, 330 show_fig=True, 331 verbose=True): 332 # pylint: disable=line-too-long 333 r"""Simulate BER/BLER curves for given Keras model and saves the results. 334 335 Internally calls :class:`sionna.utils.sim_ber`. 336 337 Input 338 ----- 339 mc_fun: 340 Callable that yields the transmitted bits `b` and the 341 receiver's estimate `b_hat` for a given ``batch_size`` and 342 ``ebno_db``. If ``soft_estimates`` is True, b_hat is interpreted as 343 logit. 344 345 ebno_dbs: ndarray of floats 346 SNR points to be evaluated. 347 348 batch_size: tf.int32 349 Batch-size for evaluation. 350 351 max_mc_iter: int 352 Max. number of Monte-Carlo iterations per SNR point. 353 354 legend: str 355 Name to appear in legend. 356 357 add_ber: bool 358 Defaults to True. Indicate if BER should be added to plot. 359 360 add_bler: bool 361 Defaults to False. Indicate if BLER should be added 362 to plot. 363 364 soft_estimates: bool 365 A boolean, defaults to False. If True, ``b_hat`` 366 is interpreted as logit and additional hard-decision is applied 367 internally. 368 369 num_target_bit_errors: int 370 Target number of bit errors per SNR point until the simulation 371 stops. 372 373 num_target_block_errors: int 374 Target number of block errors per SNR point until the simulation 375 stops. 376 377 target_ber: tf.float32 378 Defaults to `None`. The simulation stops after the first SNR point 379 which achieves a lower bit error rate as specified by 380 ``target_ber``. This requires ``early_stop`` to be `True`. 381 382 target_bler: tf.float32 383 Defaults to `None`. The simulation stops after the first SNR point 384 which achieves a lower block error rate as specified by 385 ``target_bler``. This requires ``early_stop`` to be `True`. 386 387 early_stop: bool 388 A boolean defaults to True. If True, the simulation stops after the 389 first error-free SNR point (i.e., no error occurred after 390 ``max_mc_iter`` Monte-Carlo iterations). 391 392 graph_mode: One of ["graph", "xla"], str 393 A string describing the execution mode of ``mc_fun``. 394 Defaults to `None`. In this case, ``mc_fun`` is executed as is. 395 396 distribute: `None` (default) | "all" | list of indices | `tf.distribute.strategy` 397 Distributes simulation on multiple parallel devices. If `None`, 398 multi-device simulations are deactivated. If "all", the workload 399 will be automatically distributed across all available GPUs via the 400 `tf.distribute.MirroredStrategy`. 401 If an explicit list of indices is provided, only the GPUs with the 402 given indices will be used. Alternatively, a custom 403 `tf.distribute.strategy` can be provided. Note that the same 404 `batch_size` will be used for all GPUs in parallel, but the number 405 of Monte-Carlo iterations ``max_mc_iter`` will be scaled by the 406 number of devices such that the same number of total samples is 407 simulated. However, all stopping conditions are still in-place 408 which can cause slight differences in the total number of simulated 409 samples. 410 411 add_results: bool 412 Defaults to True. If True, the simulation results will be appended 413 to the internal list of results. 414 415 show_fig: bool 416 Defaults to True. If True, a BER figure will be plotted. 417 418 verbose: bool 419 A boolean defaults to True. If True, the current progress will be 420 printed. 421 422 forward_keyboard_interrupt: bool 423 A boolean defaults to True. If False, `KeyboardInterrupts` will be 424 catched internally and not forwarded (e.g., will not stop outer 425 loops). If False, the simulation ends and returns the intermediate 426 simulation results. 427 428 Output 429 ------ 430 (ber, bler): 431 Tuple: 432 433 ber: float 434 The simulated bit-error rate. 435 436 bler: float 437 The simulated block-error rate. 438 """ 439 440 ber, bler = sim_ber( 441 mc_fun, 442 ebno_dbs, 443 batch_size, 444 soft_estimates=soft_estimates, 445 max_mc_iter=max_mc_iter, 446 num_target_bit_errors=num_target_bit_errors, 447 num_target_block_errors=num_target_block_errors, 448 target_ber=target_ber, 449 target_bler=target_bler, 450 early_stop=early_stop, 451 graph_mode=graph_mode, 452 distribute=distribute, 453 verbose=verbose, 454 forward_keyboard_interrupt=forward_keyboard_interrupt) 455 456 if add_ber: 457 self._bers += [ber] 458 self._snrs += [ebno_dbs] 459 self._legends += [legend] 460 self._is_bler += [False] 461 462 if add_bler: 463 self._bers += [bler] 464 self._snrs += [ebno_dbs] 465 self._legends += [legend + " (BLER)"] 466 self._is_bler += [True] 467 468 if show_fig: 469 self() 470 471 # remove current curve if add_results=False 472 if add_results is False: 473 if add_bler: 474 self.remove(-1) 475 if add_ber: 476 self.remove(-1) 477 478 return ber, bler 479 480 def add(self, ebno_db, ber, is_bler=False, legend=""): 481 """Add static reference curves. 482 483 Input 484 ----- 485 ebno_db: float 486 Python array or list of floats defining the SNR points. 487 488 ber: float 489 Python array or list of floats defining the BER corresponding 490 to each SNR point. 491 492 is_bler: bool 493 A boolean defaults to False. If True, ``ber`` is interpreted as 494 BLER. 495 496 legend: str 497 A string defining the text of the legend entry. 498 """ 499 500 assert (len(ebno_db)==len(ber)), \ 501 "ebno_db and ber must have same number of elements." 502 503 assert isinstance(legend, str), "legend must be str." 504 assert isinstance(is_bler, bool), "is_bler must be bool." 505 506 # concatenate curves 507 self._bers += [ber] 508 self._snrs += [ebno_db] 509 self._legends += [legend] 510 self._is_bler += [is_bler] 511 512 def reset(self): 513 """Remove all internal data.""" 514 self._bers = [] 515 self._snrs = [] 516 self._legends = [] 517 self._is_bler = [] 518 519 def remove(self, idx=-1): 520 """Remove curve with index ``idx``. 521 522 Input 523 ------ 524 idx: int 525 An integer defining the index of the dataset that should 526 be removed. Negative indexing is possible. 527 """ 528 529 assert isinstance(idx, int), "id must be int." 530 531 del self._bers[idx] 532 del self._snrs[idx] 533 del self._legends[idx] 534 del self._is_bler[idx] 535