anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

coverage_map.py (32447B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 Class that stores coverage map
      7 """
      8 import matplotlib as mpl
      9 from matplotlib.colors import from_levels_and_colors
     10 import matplotlib.pyplot as plt
     11 import numpy as np
     12 import tensorflow as tf
     13 
     14 from sionna.utils import expand_to_rank, insert_dims, log10
     15 from .utils import rotation_matrix, mitsuba_rectangle_to_world, watt_to_dbm
     16 import warnings
     17 
     18 
     19 class CoverageMap:
     20     # pylint: disable=line-too-long
     21     r"""
     22     CoverageMap()
     23 
     24     Stores the simulated coverage maps
     25 
     26     A coverage map is generated for the loaded scene for all transmitters using
     27     :meth:`~sionna.rt.Scene.coverage_map`. Please refer to the documentation of this function
     28     for further details.
     29 
     30     Example
     31     -------
     32     .. code-block:: Python
     33 
     34         import sionna
     35         from sionna.rt import load_scene, PlanarArray, Transmitter, Receiver
     36         scene = load_scene(sionna.rt.scene.munich)
     37 
     38         # Configure antenna array for all transmitters
     39         scene.tx_array = PlanarArray(num_rows=8,
     40                                   num_cols=2,
     41                                   vertical_spacing=0.7,
     42                                   horizontal_spacing=0.5,
     43                                   pattern="tr38901",
     44                                   polarization="VH")
     45 
     46         # Configure antenna array for all receivers
     47         scene.rx_array = PlanarArray(num_rows=1,
     48                                   num_cols=1,
     49                                   vertical_spacing=0.5,
     50                                   horizontal_spacing=0.5,
     51                                   pattern="dipole",
     52                                   polarization="cross")
     53         # Add a transmitters
     54         tx = Transmitter(name="tx",
     55                       position=[8.5,21,30],
     56                       orientation=[0,0,0])
     57         scene.add(tx)
     58         tx.look_at([40,80,1.5])
     59 
     60         # Compute coverage map
     61         cm = scene.coverage_map(max_depth=8)
     62 
     63         # Show coverage map
     64         cm.show();
     65 
     66     .. figure:: ../figures/coverage_map_show.png
     67         :align: center
     68     """
     69 
     70     def __init__(self,
     71                  center,
     72                  orientation,
     73                  size,
     74                  cell_size,
     75                  path_gain,
     76                  scene,
     77                  dtype=tf.complex64):
     78 
     79         self._rdtype = dtype.real_dtype
     80 
     81         if (tf.rank(center) != 1) or (tf.shape(center)[0] != 3):
     82             msg = "`center` must be shaped as [x,y,z] (rank=1 and shape=[3])"
     83             raise ValueError(msg)
     84 
     85         if (tf.rank(orientation) != 1) or (tf.shape(orientation)[0] != 3):
     86             msg = "`orientation` must be shaped as [a,b,c]"\
     87                   " (rank=1 and shape=[3])"
     88             raise ValueError(msg)
     89 
     90         if (tf.rank(size) != 1) or (tf.shape(size)[0] != 2):
     91             msg = "`size` must be shaped as [w,h]"\
     92                   " (rank=1 and shape=[2])"
     93             raise ValueError(msg)
     94 
     95         if (tf.rank(cell_size) != 1) or (tf.shape(cell_size)[0] != 2):
     96             msg = "`cell_size` must be shaped as [w,h]"\
     97                   " (rank=1 and shape=[2])"
     98             raise ValueError(msg)
     99 
    100         num_cells_x = tf.cast(tf.math.ceil(size[0]/cell_size[0]), tf.int32)
    101         num_cells_y = tf.cast(tf.math.ceil(size[1]/cell_size[1]), tf.int32)
    102 
    103         if (tf.rank(path_gain) != 3)\
    104                 or (tf.shape(path_gain)[1] != num_cells_y)\
    105                 or (tf.shape(path_gain)[2] != num_cells_x):
    106             msg = "`path_gain` must have shape"\
    107                   " [num_tx, num_cells_y, num_cells_x]"
    108             raise ValueError(msg)
    109 
    110         self._center = tf.cast(center, self._rdtype)
    111         self._orientation = tf.cast(orientation, self._rdtype)
    112         self._size = tf.cast(size, self._rdtype)
    113         self._cell_size = tf.cast(cell_size, self._rdtype)
    114         self._path_gain = tf.cast(path_gain, self._rdtype)
    115         #self._path_gain = tf.where(tf.math.is_nan(self._path_gain),
    116         #                           0, self._path_gain)
    117         self._transmitters = scene.transmitters
    118         self._scene = scene
    119 
    120         # Dict mapping names to index for transmitters
    121         self._tx_name_2_ind = {}
    122         for tx_ind, tx_name in enumerate(self._transmitters):
    123             self._tx_name_2_ind[tx_name] = tx_ind
    124 
    125         ###############################################################
    126         # Position of the center of the cells in the world
    127         # coordinate system
    128         ###############################################################
    129         # [num_cells_x]
    130         x_positions = tf.range(num_cells_x, dtype=self._rdtype)
    131         x_positions = (x_positions + 0.5)*self._cell_size[0]
    132         # [num_cells_x, num_cells_y]
    133         x_positions = tf.expand_dims(x_positions, axis=1)
    134         x_positions = tf.tile(x_positions, [1, num_cells_y])
    135         # [num_cells_y]
    136         y_positions = tf.range(num_cells_y, dtype=self._rdtype)
    137         y_positions = (y_positions + 0.5)*self._cell_size[1]
    138         # [num_cells_x, num_cells_y]
    139         y_positions = tf.expand_dims(y_positions, axis=0)
    140         y_positions = tf.tile(y_positions, [num_cells_x, 1])
    141         # [num_cells_x, num_cells_y, 2]
    142         cell_pos = tf.stack([x_positions, y_positions], axis=-1)
    143         # Move to global coordinate system
    144         # [1, 1, 2]
    145         size = expand_to_rank(self._size, tf.rank(cell_pos), 0)
    146         # [num_cells_x, num_cells_y, 2]
    147         cell_pos = cell_pos - size*0.5
    148         # [num_cells_x, num_cells_y, 3]
    149         cell_pos = tf.concat([cell_pos,
    150                               tf.zeros([num_cells_x, num_cells_y, 1],
    151                                        dtype=self._rdtype)],
    152                              axis=-1)
    153         # [3, 3]
    154         rot_cm_2_gcs = rotation_matrix(self._orientation)
    155         # [1, 1, 3, 3]
    156         rot_cm_2_gcs_ = expand_to_rank(rot_cm_2_gcs, tf.rank(cell_pos)+1,
    157                                        axis=0)
    158         # [num_cells_x, num_cells_y, 3]
    159         cell_pos = tf.linalg.matvec(rot_cm_2_gcs_, cell_pos)
    160         # [num_cells_x, num_cells_y, 3]
    161         cell_pos = cell_pos + self._center
    162         # [num_cells_y, num_cells_x, 3]
    163         cell_pos = tf.transpose(cell_pos, [1, 0, 2])
    164         self._cell_pos = cell_pos
    165 
    166         ######################################################################
    167         # Position of the transmitters, receivers, and RIS in the coverage map
    168         ######################################################################
    169         # [num_tx/num_rx/num_ris, 3]
    170         tx_pos = [tx.position for tx in scene.transmitters.values()]
    171         tx_pos = tf.stack(tx_pos, axis=0)
    172 
    173         rx_pos = [rx.position for rx in scene.receivers.values()]
    174         rx_pos = tf.stack(rx_pos, axis=0)
    175         if len(rx_pos) == 0:
    176             rx_pos = tf.zeros([0, 3], dtype=self._rdtype)
    177 
    178         ris_pos = [ris.position for ris in scene.ris.values()]
    179         ris_pos = tf.stack(ris_pos, axis=0)
    180         if len(ris_pos) == 0:
    181             ris_pos = tf.zeros([0, 3], dtype=self._rdtype)
    182 
    183         # [num_tx/num_rx/num_ris, 3]
    184         center_ = tf.expand_dims(self._center, axis=0)
    185         tx_pos = tx_pos - center_
    186         rx_pos = rx_pos - center_
    187         ris_pos = ris_pos - center_
    188 
    189         # [3, 3]
    190         rot_gcs_2_cm = tf.transpose(rot_cm_2_gcs)
    191         # [1, 3, 3]
    192         rot_gcs_2_cm_ = tf.expand_dims(rot_gcs_2_cm, axis=0)
    193         # Positions in the coverage map system
    194         # [num_tx/num_rx/num_ris, 3]
    195         tx_pos = tf.linalg.matvec(rot_gcs_2_cm_, tx_pos)
    196         rx_pos = tf.linalg.matvec(rot_gcs_2_cm_, rx_pos)
    197         ris_pos = tf.linalg.matvec(rot_gcs_2_cm_, ris_pos)
    198 
    199         # Keep only x and y
    200         # [num_tx/num_rx/num_ris, 2]
    201         tx_pos = tx_pos[:, :2]
    202         rx_pos = rx_pos[:, :2]
    203         ris_pos = ris_pos[:, :2]
    204 
    205         # Quantizing, using the bottom left corner as origin
    206         # [num_tx/num_rx/num_ris, 2]
    207         tx_pos = self._pos_to_idx_cell(tx_pos)
    208         rx_pos = self._pos_to_idx_cell(rx_pos)
    209         ris_pos = self._pos_to_idx_cell(ris_pos)
    210 
    211         self._tx_pos = tx_pos
    212         self._rx_pos = rx_pos
    213         self._ris_pos = ris_pos
    214 
    215     @property
    216     def center(self):
    217         """
    218         [3], tf.float : Center of the coverage map in the 
    219             global coordinate system
    220         """
    221         return self._center
    222 
    223     @property
    224     def orientation(self):
    225         r"""
    226         [3], tf.float : Orientation of the coverage map
    227             :math:`(\alpha, \beta, \gamma)`
    228             specified through three angles corresponding to a 3D rotation
    229             as defined in :eq:`rotation`.
    230             An orientation of :math:`(0,0,0)` corresponds to a
    231             coverage map that is parallel to the XY plane.
    232         """
    233         return self._orientation
    234 
    235     @property
    236     def size(self):
    237         """
    238         [2], tf.float : Size of the coverage map
    239         """
    240         return self._size
    241 
    242     @property
    243     def cell_size(self):
    244         """
    245         [2], tf.float : Resolution of the coverage map, i.e., width
    246             (in the local X direction) and height (in the local Y direction) in
    247             of the cells of the coverage map
    248         """
    249         return self._cell_size
    250 
    251     @property
    252     def cell_centers(self):
    253         """
    254         [num_cells_y, num_cells_x, 3], tf.float : Positions of the
    255         centers of the cells in the global coordinate system
    256         """
    257         return self._cell_pos
    258 
    259     @property
    260     def num_cells_x(self):
    261         """
    262         int : Number of cells along the local X-axis
    263         """
    264         return self.path_gain.shape[2]
    265 
    266     @property
    267     def num_cells_y(self):
    268         """
    269         int : Number of cells along the local Y-axis
    270         """
    271         return self.path_gain.shape[1]
    272 
    273     @property
    274     def num_tx(self):
    275         """
    276         int : Number of transmitters
    277         """
    278         return self.path_gain.shape[0]
    279 
    280     @property
    281     def tx_pos(self):
    282         """
    283         [num_tx, 2], int : (column, row) cell index position of each transmitter
    284         """
    285         return self._tx_pos
    286 
    287     @property
    288     def rx_pos(self):
    289         """
    290         [num_rx, 2], int : (column, row) cell index position of each receiver
    291         """
    292         return self._rx_pos
    293 
    294     @property
    295     def ris_pos(self):
    296         """
    297         [num_ris, 2], int : (column, row) cell index position of each RIS
    298         """
    299         return self._ris_pos
    300 
    301     @property
    302     def path_gain(self):
    303         """
    304         [num_tx, num_cells_y, num_cells_x], tf.float : Path gains across the
    305         coverage map from all transmitters
    306         """
    307         return self._path_gain
    308 
    309     @property
    310     def rss(self):
    311         """
    312         [num_tx, num_cells_y, num_cells_x], tf.float : Received signal strength
    313         (RSS) across the coverage map from all transmitters
    314         """
    315         tx_powers = [tx.power for tx in self._scene.transmitters.values()]
    316         tx_powers = tf.convert_to_tensor(tx_powers)
    317         return tx_powers[:, tf.newaxis, tf.newaxis] * self.path_gain
    318 
    319     @property
    320     def sinr(self):
    321         """
    322         [num_tx, num_cells_y, num_cells_x], tf.float : SINR
    323         across the coverage map from all transmitters
    324         """
    325         # Total received power from all transmitters
    326         # [num_tx, num_cells_y, num_cells_x]
    327         total_pow = tf.reduce_sum(self.rss, axis=0)
    328 
    329         # Interference for each transmitter
    330         interference = total_pow[tf.newaxis] - self.rss
    331 
    332         # Thermal noise
    333         noise = self._scene.thermal_noise_power
    334 
    335         # SINR
    336         return self.rss / (interference + noise)
    337 
    338     def _pos_to_idx_cell(self, pos):
    339         """
    340         Convert local position [m] in the coverage map to cell index
    341 
    342         Input
    343         -----
    344         pos : [num_pos, 2], tf.float
    345             Local positions within the coverage map
    346 
    347         Output
    348         ------
    349         [num_pos, 2], tf.int32 : Cell index corresponding to each position
    350         """
    351         idx_cell = pos + self._size * 0.5
    352         idx_cell = tf.cast(tf.math.floor(idx_cell / self._cell_size), tf.int32)
    353         return idx_cell
    354 
    355     def cell_to_tx(self, metric):
    356         r""" Computes cell-to-transmitter association. Each cell 
    357         is associated with the transmitter providing the highest
    358         metric, such as path gain, received signal strength (RSS), or
    359         SINR.
    360 
    361         Input
    362         -------
    363         metric : str, one of ["path_gain", "rss", "sinr"]
    364             Metric to be used
    365 
    366         Output
    367         -------
    368         cell_to_tx : [num_cells_y, num_cells_x], tf.int64
    369             Cell-to-transmitter association
    370         """
    371         # Get tensor for desired metric
    372         if metric not in ["path_gain", "rss", "sinr"]:
    373             raise ValueError("Invalid metric")
    374         cm = getattr(self, metric)
    375         # Assign each cell to the transmitter guaranteeing the highest metric
    376         # [num_cells_y, num_cells_x]:
    377         cell_to_tx = tf.math.argmax(cm, axis=0)
    378 
    379         # No transmitter assignment for the cells with no coverage
    380         mask = tf.equal(tf.reduce_max(cm, axis=0), 0)
    381         cell_to_tx = tf.where(
    382             mask, tf.constant(-1, dtype=cell_to_tx.dtype), cell_to_tx)
    383 
    384         return cell_to_tx
    385 
    386     def cdf(self, metric="path_gain", tx=None):
    387         r"""Computes and visualizes the CDF of a metric of the coverage map
    388 
    389         Input
    390         -----
    391         metric : str, one of ["path_gain", "rss", "sinr"]
    392             Metric to be shown. Defaults to "path_gain".
    393 
    394         tx : int | str | None
    395             Index or name of the transmitter for which to show the coverage
    396             map. If `None`, the maximum value over all transmitters for each
    397             cell is shown.
    398             Defaults to `None`.
    399 
    400         Output
    401         ------
    402         : :class:`~matplotlib.pyplot.Figure`
    403             Figure showing the CDF
    404 
    405         x : tf.float, [num_cells_x * num_cells_y]
    406             Data points for the chosen metric
    407 
    408         cdf : tf.float, [num_cells_x * num_cells_y]
    409             Cummulative probabilities for the data points
    410         """
    411         if metric not in ["path_gain", "rss", "sinr"]:
    412             raise ValueError("Invalid metric")
    413 
    414         if isinstance(tx, int):
    415             if tx >= self.num_tx:
    416                 raise ValueError("Invalid transmitter index")
    417         elif isinstance(tx, str):
    418             if tx in self._tx_name_2_ind:
    419                 tx = self._tx_name_2_ind[tx]
    420             else:
    421                 raise ValueError(f"Unknown transmitter with name '{tx}'")
    422         elif tx is None:
    423             pass
    424         else:
    425             msg = "Invalid type for `tx`: Must be a string, int, or None"
    426             raise ValueError(msg)
    427 
    428         x = getattr(self, metric)
    429         if tx is not None:
    430             x = x[tx]
    431         else:
    432             x = tf.reduce_max(x, axis=0)
    433         x = tf.reshape(x, [-1])
    434         x = 10 * log10(x)
    435         # Add 30dB for RSS to acount for dBm
    436         if metric=="rss":
    437             x += 30
    438         x = tf.sort(x)
    439         cdf = tf.range(1, tf.size(x) + 1, dtype=tf.float32) \
    440               / tf.cast(tf.size(x), tf.float32)
    441         fig, _ = plt.subplots()
    442         plt.plot(x.numpy(), cdf.numpy())
    443         plt.grid(True, which="both")
    444         plt.ylabel("Cummulative probability")
    445 
    446         # Set x-label and title
    447         if metric=="path_gain":
    448             xlabel = "Path gain [dB]"
    449             title = "Path gain"
    450         elif metric=="rss":
    451             xlabel = "Received signal strength (RSS) [dBm]"
    452             title = "RSS"
    453         else:
    454             xlabel = "Signal-to-interference-plus-noise ratio (SINR) [dB]"
    455             title = "SINR"
    456         if (tx is None) & (self.num_tx > 1):
    457             title = 'Highest ' + title + ' across all TXs'
    458         elif tx is not None:
    459             title = title + f' for TX {tx}'
    460 
    461         plt.xlabel(xlabel)
    462         plt.title(title)
    463 
    464         return fig, x, cdf
    465 
    466 
    467     def show(self,
    468              metric="path_gain",
    469              tx=None,
    470              vmin=None,
    471              vmax=None,
    472              show_tx=True,
    473              show_rx=False,
    474              show_ris=False):
    475         r"""Visualizes a coverage map
    476 
    477         The position of the transmitter is indicated by a red "+" marker.
    478         The positions of the receivers are indicated by blue "x" markers.
    479         The positions of the RIS are indicated by black "*" markers.
    480 
    481         Input
    482         -----
    483         metric : str, one of ["path_gain", "rss", "sinr"]
    484             Metric to be shown. Defaults to "path_gain".
    485 
    486         tx : int | str | None
    487             Index or name of the transmitter for which to show the coverage
    488             map. If `None`, the maximum value over all transmitters for each
    489             cell is shown.
    490             Defaults to `None`.
    491 
    492         vmin,vmax : float | `None`
    493             Define the range of values [dB] that the colormap covers.
    494             If set to `None`, the complete range is shown.
    495             Defaults to `None`.
    496 
    497         show_tx : bool
    498             If set to `True`, then the position of the transmitters are shown.
    499             Defaults to `True`.
    500 
    501         show_rx : bool
    502             If set to `True`, then the position of the receivers are shown.
    503             Defaults to `False`.
    504 
    505         show_ris : bool
    506             If set to `True`, then the position of the RIS are shown.
    507             Defaults to `False`.
    508 
    509         Output
    510         ------
    511         : :class:`~matplotlib.pyplot.Figure`
    512             Figure showing the coverage map
    513         """
    514 
    515         if metric not in ["path_gain", "rss", "sinr"]:
    516             raise ValueError("Invalid metric")
    517 
    518         if isinstance(tx, int):
    519             if tx >= self.num_tx:
    520                 raise ValueError("Invalid transmitter index")
    521         elif isinstance(tx, str):
    522             if tx in self._tx_name_2_ind:
    523                 tx = self._tx_name_2_ind[tx]
    524             else:
    525                 raise ValueError(f"Unknown transmitter with name '{tx}'")
    526         elif tx is None:
    527             pass
    528         else:
    529             msg = "Invalid type for `tx`: Must be a string, int, or None"
    530             raise ValueError(msg)
    531 
    532         # Select metric for a specific transmitter or compute max
    533         cm = getattr(self, metric)
    534         if tx is not None:
    535             cm = cm[tx]
    536         else:
    537             cm = tf.reduce_max(cm, axis=0)
    538 
    539         # Convert to dB-scale
    540         if metric in ["path_gain", "sinr"]:
    541             with warnings.catch_warnings(record=True) as _:
    542                 # Convert the path gain to dB
    543                 cm = 10.*np.log10(cm.numpy())
    544         else:
    545             with warnings.catch_warnings(record=True) as _:
    546                 # Convert the signal strengmth to dBm
    547                 cm = watt_to_dbm(cm).numpy()
    548 
    549         # Visualization the coverage map
    550         fig_cm = plt.figure()
    551         plt.imshow(cm, origin='lower', vmin=vmin, vmax=vmax)
    552 
    553         # Set label
    554         if metric == "path_gain":
    555             label = "Path gain [dB]"
    556             title = "Path gain"
    557         elif metric == "rss":
    558             label = "Received signal strength (RSS) [dBm]"
    559             title = 'RSS'
    560         else:
    561             label = "Signal-to-interference-plus-noise ratio (SINR) [dB]"
    562             title = 'SINR'
    563         if (tx is None) & (self.num_tx > 1):
    564             title = 'Highest ' + title + ' across all TXs'
    565         elif tx is not None:
    566             title = title + f' for TX {tx}'
    567         plt.colorbar(label=label)
    568         plt.xlabel('Cell index (X-axis)')
    569         plt.ylabel('Cell index (Y-axis)')
    570         plt.title(title)
    571 
    572         # Show transmitter, receiver, RIS positions
    573         if show_tx:
    574             if tx is not None:
    575                 tx_pos = self._tx_pos[tx]
    576                 fig_cm.axes[0].scatter(*tx_pos, marker='P', c='r')
    577             else:
    578                 for tx_pos in self._tx_pos:
    579                     fig_cm.axes[0].scatter(*tx_pos, marker='P', c='r')
    580 
    581         if show_rx:
    582             for rx_pos in self._rx_pos:
    583                 fig_cm.axes[0].scatter(*rx_pos, marker='x', c='b')
    584 
    585         if show_ris:
    586             for ris_pos in self._ris_pos:
    587                 fig_cm.axes[0].scatter(*ris_pos, marker='*', c='k')
    588 
    589         return fig_cm
    590 
    591     def show_association(self,
    592                          metric="path_gain",
    593                          show_tx=True,
    594                          show_rx=False,
    595                          show_ris=False):
    596         r"""Visualizes cell-to-tx association for a given metric
    597 
    598         The position of the transmitter is indicated by a red "+" marker.
    599         The positions of the receivers are indicated by blue "x" markers.
    600         The positions of the RIS are indicated by black "*" markers.
    601 
    602         Input
    603         -----
    604         metric : str, one of ["path_gain", "rss", "sinr"]
    605             Metric based on which the cell-to-tx association
    606             is computed.
    607             Defaults to "path_gain".
    608 
    609         show_tx : bool
    610             If set to `True`, then the position of the transmitters are shown.
    611             Defaults to `True`.
    612 
    613         show_rx : bool
    614             If set to `True`, then the position of the receivers are shown.
    615             Defaults to `False`.
    616 
    617         show_ris : bool
    618             If set to `True`, then the position of the RIS are shown.
    619             Defaults to `False`.
    620 
    621         Output
    622         ------
    623         : :class:`~matplotlib.pyplot.Figure`
    624             Figure showing the cell-to-transmitter association
    625         """
    626         if metric not in ["path_gain", "rss", "sinr"]:
    627             raise ValueError("Invalid metric")
    628 
    629         # Create the colormap and normalization
    630         colors = mpl.colormaps['Dark2'].colors[:self.num_tx]
    631         cmap, norm = from_levels_and_colors(
    632             list(range(self.num_tx+1)), colors)
    633         fig_tx = plt.figure()
    634         plt.imshow(self.cell_to_tx(metric).numpy(),
    635                     origin='lower', cmap=cmap, norm=norm)
    636         plt.xlabel('Cell index (X-axis)')
    637         plt.ylabel('Cell index (Y-axis)')
    638         plt.title('Cell-to-TX association')
    639         cbar = plt.colorbar(label="TX")
    640         cbar.ax.get_yaxis().set_ticks([])
    641         for tx_ in range(self.num_tx):
    642             cbar.ax.text(.5, tx_ + .5, str(tx_), ha='center', va='center')
    643 
    644         # Visualizing transmitter, receiver, RIS positions
    645         if show_tx:
    646             for tx_pos in self._tx_pos:
    647                 fig_tx.axes[0].scatter(*tx_pos, marker='P', c='r')
    648 
    649         if show_rx:
    650             for rx_pos in self._rx_pos:
    651                 fig_tx.axes[0].scatter(*rx_pos, marker='x', c='b')
    652 
    653         if show_ris:
    654             for ris_pos in self._ris_pos:
    655                 fig_tx.axes[0].scatter(*ris_pos, marker='*', c='k')
    656 
    657         return fig_tx
    658 
    659 
    660     def sample_positions(self,
    661                          num_pos,
    662                          metric="path_gain",
    663                          min_val_db=None,
    664                          max_val_db=None,
    665                          min_dist=None,
    666                          max_dist=None,
    667                          tx_association=True,
    668                          center_pos=False):
    669         # pylint: disable=line-too-long
    670         r"""Sample random user positions in a scene based on a coverage map
    671 
    672         For a given coverage map, ``num_pos`` random positions are sampled
    673         around each transmitter,
    674         such that the selected metric, e.g., SINR, is larger
    675         than ``min_val_db`` and/or smaller than ``max_val_db``.
    676         Similarly, ``min_dist`` and ``max_dist`` define the minimum and maximum
    677         distance of the random positions to the transmitter under consideration.
    678         By activating the flag ``tx_association``, only positions are sampled
    679         for which the selected metric is the highest across all transmitters.
    680         This is useful if one wants to ensure, e.g., that the sampled positions for
    681         each transmitter provide the highest SINR or RSS.
    682 
    683         Note that due to the quantization of the coverage map into cells it is
    684         not guaranteed that all above parameters are exactly fulfilled for a
    685         returned position. This stems from the fact that every
    686         individual cell of the coverage map describes the expected *average*
    687         behavior of the surface within this cell. For instance, it may happen
    688         that half of the selected cell is shadowed and, thus, no path to the
    689         transmitter exists but the average path gain is still larger than the
    690         given threshold. Please enable the flag ``center_pos`` to sample only
    691         positions from the cell centers.
    692 
    693         .. figure:: ../figures/cm_user_sampling.png
    694             :align: center
    695 
    696         The above figure shows an example for random positions between 220m and
    697         250m from the transmitter and a maximum path gain of -100 dB.
    698         Keep in mind that the transmitter can have a different height than the
    699         coverage map which also contributes to this distance.
    700         For example if the transmitter is located 20m above the surface of the
    701         coverage map and a ``min_dist`` of 20m is selected, also positions
    702         directly below the transmitter are sampled.
    703 
    704         Input
    705         -----
    706         num_pos: int
    707             Number of returned random positions for ech transmitter
    708 
    709         metric : str, one of ["path_gain", "rss", "sinr"]
    710             Metric to be considered for sampling positions. Defaults to
    711             "path_gain".
    712 
    713         min_val_db: float | None
    714             Minimum value for the selected metric ([dB] for path gain and SINR;
    715             [dBm] for RSS). 
    716             Positions are only sampled from cells where the selected metric is
    717             larger than or equal to this value. 
    718             Ignored if `None`.
    719             Defaults to `None`.
    720 
    721         max_val_db: float | None
    722             Maximum value for the selected metric ([dB] for path gain and SINR;
    723             [dBm] for RSS). 
    724             Positions are only sampled from cells where the selected metric is
    725             smaller than or equal to this value. 
    726             Ignored if `None`.
    727             Defaults to `None`.
    728 
    729         min_dist: float | None
    730             Minimum distance [m] from transmitter for all random positions.
    731             Ignored if `None`.
    732             Defaults to `None`.
    733 
    734         max_dist: float | None
    735             Maximum distance [m] from transmitter for all random positions.
    736             Ignored if `None`.
    737             Defaults to `None`.
    738 
    739         tx_association : bool
    740             If `True`, only positions associated with a transmitter are chosen,
    741             i.e., positions where the chosen metric is the highest among all
    742             all transmitters. Else, a user located in a sampled position for a
    743             specific transmitter may perceive a higher metric from another TX.
    744             Defaults to `True`.
    745 
    746         center_pos: bool
    747             If `True`, all returned positions are sampled from the cell center
    748             (i.e., the grid of the coverage map). Otherwise, the positions are
    749             randomly drawn from the surface of the cell.
    750             Defaults to `False`.
    751 
    752         Output
    753         ------
    754         : [num_tx, num_pos, 3], tf.float
    755             Random positions :math:`(x,y,z)` [m] that are in cells fulfilling the
    756             configured constraints
    757 
    758         : [num_tx, num_pos, 2], tf.float
    759             Cell indices corresponding to the random positions
    760         """
    761 
    762         if metric not in ["path_gain", "rss", "sinr"]:
    763             raise ValueError("Invalid metric")
    764 
    765         # allow float values for batch_size
    766         if not isinstance(num_pos, (int, float)) or not num_pos % 1 == 0:
    767             raise ValueError("num_pos must be int.")
    768         # cast batch_size to int
    769         num_pos = int(num_pos)
    770 
    771         if min_val_db is None:
    772             min_val_db = -1. * np.infty
    773         min_val_db = tf.constant(min_val_db, self._rdtype)
    774 
    775         if max_val_db is None:
    776             max_val_db = np.infty
    777         max_val_db = tf.constant(max_val_db, self._rdtype)
    778 
    779         if min_val_db > max_val_db:
    780             raise ValueError("min_val_d cannot be larger than max_val_db.")
    781 
    782         if min_dist is None:
    783             min_dist = 0.
    784         min_dist = tf.constant(min_dist, self._rdtype)
    785 
    786         if max_dist is None:
    787             max_dist = np.infty
    788         max_dist = tf.constant(max_dist, self._rdtype)
    789 
    790         if min_dist > max_dist:
    791             raise ValueError("min_dist cannot be larger than max_dist.")
    792 
    793         # Select metric to be used
    794         cm = getattr(self, metric)
    795 
    796         # Convert to dB-scale
    797         if metric in ["path_gain", "sinr"]:
    798             with warnings.catch_warnings(record=True) as _:
    799                 # Convert the path gain to dB
    800                 cm = 10.*np.log10(cm.numpy())
    801         else:
    802             with warnings.catch_warnings(record=True) as _:
    803                 # Convert the signal strengmth to dBm
    804                 cm = watt_to_dbm(cm).numpy()
    805 
    806         # [num_tx, 3]: tx_pos_xyz[i, :] contains the i-th tx (x,y,z) coordinate
    807         # positions
    808         tx_pos_xyz = tf.stack([tx.position for tx
    809                                in self._scene.transmitters.values()])
    810 
    811         # Compute distance from each tx to all cells
    812         # [num_tx, num_cells_y. num_cells_x]
    813         cell_distance_from_tx = tf.math.reduce_euclidean_norm(
    814             self.cell_centers[tf.newaxis] -
    815             insert_dims(tx_pos_xyz, 2, axis=1), axis=-1)
    816 
    817         # [num_tx, num_cells_y. num_cells_x]
    818         distance_mask = tf.logical_and(cell_distance_from_tx >= min_dist,
    819                                        cell_distance_from_tx <= max_dist)
    820 
    821         # Get cells for which metric criterion is valid
    822         # [num_tx, num_cells_y. num_cells_x]
    823         cm_mask = tf.logical_and(cm >= min_val_db,
    824                                  cm <= max_val_db)
    825 
    826         # Get cells for which the tx association is valid
    827         tx_ids = insert_dims(tf.range(self.num_tx, dtype=tf.int64), 2, 1)
    828         association_mask = tx_ids == self.cell_to_tx(metric)[tf.newaxis]
    829 
    830         # Compute combined mask
    831         mask = distance_mask & cm_mask
    832         if tx_association:
    833             mask = mask & association_mask
    834         mask = tf.cast(mask, tf.int64)
    835 
    836         sampled_cell_ids = []
    837         sampled_cell_pos = []
    838         for i, m in enumerate(mask):
    839             valid_ids = tf.where(m)
    840             num_valid_ids = len(valid_ids)
    841             if num_valid_ids == 0:
    842                 msg = f"No valid cells for transmitter {i} to sample from."
    843                 raise RuntimeError(msg)
    844             cell_ids = tf.random.uniform(shape=[num_pos],
    845                                          minval=0, maxval=num_valid_ids,
    846                                          dtype=tf.int64)
    847             sampled_ids = tf.gather(valid_ids, cell_ids, axis=0)
    848             sampled_cell_ids.append(sampled_ids)
    849             sampled_pos = tf.gather_nd(self.cell_centers,
    850                                        sampled_ids)
    851             sampled_cell_pos.append(sampled_pos)
    852         sampled_cell_ids = tf.stack(sampled_cell_ids, axis=0)
    853         # swap cell indexes to produce (column, row) index pairs
    854         sampled_cell_ids = tf.gather(sampled_cell_ids, [1, 0], axis=-1)
    855 
    856         sampled_cell_pos = tf.stack(sampled_cell_pos, axis=0)
    857 
    858         # Add random offset within cell-size, if positions should not be
    859         # centered
    860         if not center_pos:
    861             # cell can be rotated
    862             dir_x = tf.expand_dims(0.5*(self.cell_centers[0, 0] -
    863                                         self.cell_centers[1, 0]), axis=0)
    864             dir_y = tf.expand_dims(0.5*(self.cell_centers[0, 0] -
    865                                         self.cell_centers[0, 1]), axis=0)
    866             rand_x = tf.random.uniform((num_pos, 1),
    867                                        minval=-1.,
    868                                        maxval=1.,
    869                                        dtype=self._rdtype)
    870             rand_y = tf.random.uniform((num_pos, 1),
    871                                        minval=-1.,
    872                                        maxval=1.,
    873                                        dtype=self._rdtype)
    874 
    875             sampled_cell_pos += rand_x * dir_x + rand_y * dir_y
    876 
    877         return sampled_cell_pos, sampled_cell_ids
    878 
    879     def to_world(self):
    880         r"""
    881         Returns the `to_world` transformation that maps a default Mitsuba
    882         rectangle to the rectangle that defines the coverage map surface
    883 
    884         Output
    885         -------
    886         to_world : :class:`mitsuba.ScalarTransform4f`
    887             Rectangle to world transformation
    888         """
    889         return mitsuba_rectangle_to_world(self._center, self._orientation,
    890                                           self._size)