anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

solver_cm.py (190139B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 Ray tracing algorithm that uses the image method to compute all pure reflection
      7 paths.
      8 """
      9 
     10 import mitsuba as mi
     11 import drjit as dr
     12 import tensorflow as tf
     13 from sionna.constants import PI
     14 from sionna import config
     15 from sionna.utils.tensors import expand_to_rank, insert_dims, flatten_dims
     16 from .utils import dot, outer, phi_hat, theta_hat, theta_phi_from_unit_vec,\
     17     normalize, rotation_matrix, mi_to_tf_tensor, compute_field_unit_vectors,\
     18     reflection_coefficient, component_transform, fibonacci_lattice, r_hat,\
     19     cross, cot, sign, sample_points_on_hemisphere, acos_diff, gen_basis_from_z,\
     20     compute_spreading_factor, mitsuba_rectangle_to_world,\
     21         angles_to_mitsuba_rotation
     22 from .solver_base import SolverBase
     23 from .coverage_map import CoverageMap
     24 from .scattering_pattern import ScatteringPattern
     25 
     26 
     27 class SolverCoverageMap(SolverBase):
     28     # pylint: disable=line-too-long
     29     r"""SolverCoverageMap(scene, solver=None, dtype=tf.complex64)
     30 
     31     Generates a coverage map consisting of the squared amplitudes of the channel
     32     impulse response considering the LoS and reflection paths.
     33 
     34     The main inputs of the solver are:
     35 
     36     * The properties of the rectangle defining the coverage map, i.e., its
     37     position, scale, and orientation, and the resolution of the coverage map
     38 
     39     * The receiver orientation
     40 
     41     * A maximum depth, corresponding to the maximum number of reflections. A
     42     depth of zero corresponds to LoS only.
     43 
     44     Generation of a coverage map is carried-out for every transmitter in the
     45     scene. The antenna arrays of the transmitter and receiver are used.
     46 
     47     The generation of a coverage map consists in two steps:
     48 
     49     1. Shoot-and bounce ray tracing where rays are generated from the
     50     transmitters and the intersection with the rectangle defining the coverage
     51     map are recorded.
     52     Initial rays direction are arranged in a Fibonacci lattice on the unit
     53     sphere.
     54 
     55     2. The transfer matrices of every ray that intersect the coverage map are
     56     computed considering the materials of the objects that make the scene.
     57     The antenna patterns, synthetic phase shifts due to the array geometry, and
     58     combining and precoding vectors are then applied to obtain channel
     59     coefficients. The squared amplitude of the channel coefficients are then
     60     added to the value of the output corresponding to the cell of the coverage
     61     map within which the intersection between the ray and the coverage map
     62     occured.
     63 
     64     Note: Only triangle mesh are supported.
     65 
     66     Parameters
     67     -----------
     68     scene : :class:`~sionna.rt.Scene`
     69         Sionna RT scene
     70 
     71     solver : :class:`~sionna.rt.BaseSolver` | None
     72         Another solver from which to re-use some structures to avoid useless
     73         compute and memory use
     74 
     75     dtype : tf.complex64 | tf.complex128
     76         Datatype for all computations, inputs, and outputs.
     77         Defaults to `tf.complex64`.
     78 
     79     Input
     80     ------
     81     max_depth : int
     82         Maximum depth (i.e., number of bounces) allowed for tracing the
     83         paths
     84 
     85     rx_orientation : [3], tf.float
     86         Orientation of the receiver.
     87         This is used to compute the antenna response and antenna pattern
     88         for an imaginary receiver located on the coverage map.
     89 
     90     cm_center : [3], tf.float
     91         Center of the coverage map
     92 
     93     cm_orientation : [3], tf.float
     94         Orientation of the coverage map
     95 
     96     cm_size : [2], tf.float
     97         Scale of the coverage map.
     98         The width of the map (in the local X direction) is scale[0]
     99         and its map (in the local Y direction) scale[1].
    100 
    101     cm_cell_size : [2], tf.float
    102         Resolution of the coverage map, i.e., width
    103         (in the local X direction) and height (in the local Y direction) in
    104         meters of a cell of the coverage map
    105 
    106     combining_vec : [num_rx_ant], tf.complex | None
    107         Combining vector.
    108         This is used to combine the signal from the receive antennas for
    109         an imaginary receiver located on the coverage map.
    110         If set to `None`, then no combining is applied, and
    111         the energy received by all antennas is summed.
    112 
    113     precoding_vec : [num_tx, num_tx_ant], tf.complex
    114         Precoding vectors of the transmitters
    115 
    116     num_samples : int
    117         Number of rays initially shooted from the transmitters.
    118         This number is shared by all transmitters, i.e.,
    119         ``num_samples/num_tx`` are shooted for each transmitter.
    120 
    121     los : bool
    122         If set to `True`, then the LoS paths are computed.
    123 
    124     reflection : bool
    125         If set to `True`, then the reflected paths are computed.
    126 
    127     diffraction : bool
    128         If set to `True`, then the diffracted paths are computed.
    129 
    130     scattering : bool
    131         if set to `True`, then the scattered paths are computed.
    132 
    133     ris : bool
    134         If set to `True`, then paths involving RIS are computed.
    135 
    136     edge_diffraction : bool
    137         If set to `False`, only diffraction on wedges, i.e., edges that
    138         connect two primitives, is considered.
    139 
    140     num_runs : int, >= 1
    141         Number of runs of the coverage map solver executed. The returned
    142         coverage map is the average of all runs.
    143         If greater than one, then a random rotation is applied to the Fibonacci
    144         lattice at each run.
    145 
    146     Output
    147     -------
    148     :cm : :class:`~sionna.rt.CoverageMap`
    149         The coverage maps
    150     """
    151 
    152     DISCARD_THRES = 1e-15 # -150 dB
    153 
    154     def __call__(self, max_depth, rx_orientation,
    155                  cm_center, cm_orientation, cm_size, cm_cell_size,
    156                  combining_vec, precoding_vec, num_samples,
    157                  los, reflection, diffraction, scattering, ris,
    158                  edge_diffraction, num_runs):
    159 
    160         if num_runs < 1:
    161             raise ValueError("The number of runs must be greater or equal to 1")
    162         random_lattice = num_runs > 1
    163 
    164         # If reflection and scattering are disabled, no need for a max_depth
    165         # higher than 1.
    166         # This clipping can save some compute for the shoot-and-bounce
    167         if (not reflection) and (not scattering):
    168             max_depth = tf.minimum(max_depth, 1)
    169 
    170         # Transmitters positions, orientations and tx power
    171         # sources_positions : [num_tx, 3]
    172         # sources_orientations : [num_tx, 3]
    173         sources_positions = []
    174         sources_orientations = []
    175         for tx in self._scene.transmitters.values():
    176             sources_positions.append(tx.position)
    177             sources_orientations.append(tx.orientation)
    178         sources_positions = tf.stack(sources_positions, axis=0)
    179         sources_orientations = tf.stack(sources_orientations, axis=0)
    180 
    181        # EM properties of the materials
    182         # Returns: relative_permittivities, denoted by `etas`
    183         # scattering_coefficients, xpd_coefficients,
    184         # alpha_r, alpha_i and lambda_
    185         object_properties = self._build_scene_object_properties_tensors()
    186         etas = object_properties[0]
    187         scattering_coefficient = object_properties[1]
    188         xpd_coefficient = object_properties[2]
    189         alpha_r = object_properties[3]
    190         alpha_i = object_properties[4]
    191         lambda_ = object_properties[5]
    192 
    193         # Measurement plane defining the coverage map
    194         # meas_plane : mi.Shape
    195         #     Mitsuba rectangle defining the measurement plane
    196         meas_plane = self._build_mi_measurement_plane(cm_center,
    197                                                       cm_orientation,
    198                                                       cm_size)
    199 
    200         # Builds the Mitsuba scene with RIS for
    201         # testing intersections with RIS
    202         mi_ris_objects, mi_ris_indices = self._build_mi_ris_objects()
    203 
    204         cms = []
    205         for _ in range(num_runs):
    206 
    207             ####################################################
    208             # Shooting-and-bouncing
    209             # Computes the coverage map for LoS, reflection,
    210             # and scattering.
    211             # Also returns the primitives found in LoS of the
    212             # transmitters to shoot diffracted rays.
    213             ####################################################
    214 
    215             cm, los_primitives = self._shoot_and_bounce(meas_plane,
    216                                                         mi_ris_objects,
    217                                                         mi_ris_indices,
    218                                                         rx_orientation,
    219                                                         sources_positions,
    220                                                         sources_orientations,
    221                                                         max_depth,
    222                                                         num_samples,
    223                                                         combining_vec,
    224                                                         precoding_vec,
    225                                                         cm_center,
    226                                                         cm_orientation,
    227                                                         cm_size,
    228                                                         cm_cell_size,
    229                                                         los,
    230                                                         reflection,
    231                                                         diffraction,
    232                                                         scattering,
    233                                                         ris,
    234                                                         etas,
    235                                                         scattering_coefficient,
    236                                                         xpd_coefficient,
    237                                                         alpha_r,
    238                                                         alpha_i,
    239                                                         lambda_,
    240                                                         random_lattice)
    241 
    242             # ############################################
    243             # # Diffracted
    244             # ############################################
    245 
    246             if los_primitives is not None:
    247 
    248                 cm_diff = self._diff_samples_2_coverage_map(los_primitives,
    249                                                         edge_diffraction,
    250                                                         num_samples,
    251                                                         sources_positions,
    252                                                         meas_plane,
    253                                                         cm_center,
    254                                                         cm_orientation,
    255                                                         cm_size,
    256                                                         cm_cell_size,
    257                                                         sources_orientations,
    258                                                         rx_orientation,
    259                                                         combining_vec,
    260                                                         precoding_vec,
    261                                                         etas,
    262                                                         scattering_coefficient)
    263 
    264                 cm = cm + cm_diff
    265             cms.append(cm)
    266 
    267         # Average over all runs
    268         cms = tf.stack(cms, axis=0)
    269         cm = tf.reduce_mean(cms, axis=0)
    270 
    271         # ############################################
    272         # # Combine the coverage maps.
    273         # # Coverage maps are combined non-coherently
    274         # ############################################
    275         cm = CoverageMap(cm_center,
    276                          cm_orientation,
    277                          cm_size,
    278                          cm_cell_size,
    279                          cm,
    280                          scene=self._scene,
    281                          dtype=self._dtype)
    282         return cm
    283 
    284     ##################################################################
    285     # Internal methods
    286     ##################################################################
    287 
    288     def _build_mi_measurement_plane(self, cm_center, cm_orientation, cm_size):
    289         r"""
    290         Builds the Mitsuba rectangle defining the measurement plane
    291         corresponding to the coverage map.
    292 
    293         Input
    294         ------
    295         cm_center : [3], tf.float
    296             Center of the rectangle
    297 
    298         cm_orientation : [3], tf.float
    299             Orientation of the rectangle
    300 
    301         cm_size : [2], tf.float
    302             Scale of the rectangle.
    303             The width of the rectangle (in the local X direction) is scale[0]
    304             and its height (in the local Y direction) scale[1].
    305 
    306         Output
    307         ------
    308         mi_meas_plane : mi.Shape
    309             Mitsuba rectangle defining the measurement plane
    310         """
    311         # Rectangle defining the coverage map
    312         mi_meas_plane = mi.load_dict({
    313             'type': 'rectangle',
    314             'to_world': mitsuba_rectangle_to_world(cm_center,
    315                                                    cm_orientation,
    316                                                    cm_size),
    317         })
    318 
    319         return mi_meas_plane
    320 
    321     def _mp_hit_point_2_cell_ind(self, rot_gcs_2_mp, cm_center, cm_size,
    322                                  cm_cell_size, num_cells, hit_point):
    323         r"""
    324         Computes the indices of the cells to which points ``hit_point`` on the
    325         measurement plane belongs.
    326 
    327         Input
    328         ------
    329         rot_gcs_2_mp : [3, 3], tf.float
    330             Rotation matrix for going from the measurement plane LCS to the GCS
    331 
    332         cm_center : [3], tf.float
    333             Center of the coverage map
    334 
    335         cm_size : [2], tf.float
    336             Size of the coverage map
    337 
    338         cm_cell_size : [2], tf.float
    339             Size of of the cells of the ceverage map
    340 
    341         num_cells : [2], tf.int
    342             Number of cells in the coverage map
    343 
    344         hit_point : [...,3]
    345             Intersection points
    346 
    347         Output
    348         -------
    349         cell_ind : [..., 2], tf.int
    350             Indices of the cells
    351         """
    352 
    353         # Expand for broadcasting
    354         # [..., 3, 3]
    355         rot_gcs_2_mp = expand_to_rank(rot_gcs_2_mp, tf.rank(hit_point)+1, 0)
    356         # [..., 3]
    357         cm_center = expand_to_rank(cm_center, tf.rank(hit_point), 0)
    358 
    359         # Coverage map cells' indices
    360         # Coordinates of the hit point in the coverage map LCS
    361         # [..., 3]
    362         hit_point = tf.linalg.matvec(rot_gcs_2_mp, hit_point - cm_center)
    363 
    364         # In the local coordinate system of the coverage map, z should be 0
    365         # as the coverage map is in XY
    366 
    367         # x
    368         # [...]
    369         cell_x = hit_point[...,0] + cm_size[0]*0.5
    370         cell_x = tf.cast(tf.math.floor(cell_x/cm_cell_size[0]), tf.int32)
    371         cell_x = tf.where(tf.less(cell_x, num_cells[0]), cell_x, num_cells[0])
    372         cell_x = tf.where(tf.greater_equal(cell_x, 0), cell_x, num_cells[0])
    373 
    374         # y
    375         # [...]
    376         cell_y = hit_point[...,1] + cm_size[1]*0.5
    377         cell_y = tf.cast(tf.math.floor(cell_y/cm_cell_size[1]), tf.int32)
    378         cell_y = tf.where(tf.less(cell_y, num_cells[1]), cell_y, num_cells[1])
    379         cell_y = tf.where(tf.greater_equal(cell_y, 0), cell_y, num_cells[1])
    380 
    381         # [..., 2]
    382         cell_ind = tf.stack([cell_y, cell_x], axis=-1)
    383 
    384         return cell_ind
    385 
    386     def _compute_antenna_patterns(self, rot_mat, patterns, k):
    387         r"""
    388         Evaluates the antenna ``patterns`` of a radio device with oriented
    389         following ``orientation``, and for a incident field direction
    390         ``k``.
    391 
    392         Input
    393         ------
    394         rot_mat : [..., 3, 3] or [3,3], tf.float
    395             Rotation matrix built from the orientation of the radio device
    396 
    397         patterns : [f(theta, phi)], list of callable
    398             List of antenna patterns
    399 
    400         k : [..., 3], tf.float
    401             Direction of departure/arrival in the GCS.
    402             Must point away from the radio device
    403 
    404         Output
    405         -------
    406         fields_hat : [..., num_patterns, 2], tf.complex
    407             Antenna fields theta_hat and phi_hat components in the GCS
    408 
    409         theta_hat : [..., 3], tf.float
    410             Theta hat direction in the GCS
    411 
    412         phi_hat : [..., 3], tf.float
    413             Phi hat direction in the GCS
    414 
    415         """
    416 
    417         # [..., 3, 3]
    418         rot_mat = expand_to_rank(rot_mat, tf.rank(k)+1, 0)
    419 
    420         # [...]
    421         theta, phi = theta_phi_from_unit_vec(k)
    422 
    423         # Normalized direction vector in the LCS of the radio device
    424         # [..., 3]
    425         k_prime = tf.linalg.matvec(rot_mat, k, transpose_a=True)
    426 
    427         # Angles of departure in the local coordinate system of the
    428         # radio device
    429         # [...]
    430         theta_prime, phi_prime = theta_phi_from_unit_vec(k_prime)
    431 
    432         # Spherical global frame vectors
    433         # [..., 3]
    434         theta_hat_ = theta_hat(theta, phi)
    435         phi_hat_ = phi_hat(phi)
    436 
    437         # Spherical local frame vectors
    438         # [..., 3]
    439         theta_hat_prime = theta_hat(theta_prime, phi_prime)
    440         phi_hat_prime = phi_hat(phi_prime)
    441 
    442         # Rotate the LCS according to the radio device orientation
    443         # [..., 3]
    444         theta_hat_prime = tf.linalg.matvec(rot_mat, theta_hat_prime)
    445         phi_hat_prime = tf.linalg.matvec(rot_mat, phi_hat_prime)
    446 
    447         # Rotation matrix for going from the spherical radio device LCS to the
    448         # spherical GCS
    449         # [..., 2, 2]
    450         lcs2gcs = component_transform(theta_hat_prime, phi_hat_prime, # LCS
    451                                       theta_hat_, phi_hat_) # GCS
    452         lcs2gcs = tf.complex(lcs2gcs, tf.zeros_like(lcs2gcs))
    453 
    454         # Compute the fields in the LCS
    455         fields_hat = []
    456         for pattern in patterns:
    457             # [..., 2]
    458             field_ = tf.stack(pattern(theta_prime, phi_prime), axis=-1)
    459             fields_hat.append(field_)
    460 
    461         # Stacking the patterns, corresponding to different polarization
    462         # directions, as an additional dimension
    463         # [..., num_patterns, 2]
    464         fields_hat = tf.stack(fields_hat, axis=-2)
    465 
    466         # Fields in the GCS
    467         # [..., 1, 2, 2]
    468         lcs2gcs = tf.expand_dims(lcs2gcs, axis=-3)
    469         # [..., num_patterns, 2]
    470         fields_hat = tf.linalg.matvec(lcs2gcs, fields_hat)
    471 
    472         return fields_hat, theta_hat_, phi_hat_
    473 
    474     def _apply_synthetic_array(self, tx_rot_mat, rx_rot_mat, k_rx,
    475                                k_tx, a):
    476         # pylint: disable=line-too-long
    477         r"""
    478         Synthetically apply transmitter and receiver arrays to the channel
    479         coefficients ``a``
    480 
    481         Input
    482         ------
    483         tx_rot_mat : [..., 3, 3], tf.float
    484             Rotation matrix built from the orientation of the transmitters
    485 
    486         rx_rot_mat : [3, 3], tf.float
    487             Rotation matrix built from the orientation of the receivers
    488 
    489         k_rx : [..., 3], tf.float
    490             Directions of arrivals of the rays
    491 
    492         k_tx : [..., 3], tf.float
    493             Directions of departure of the rays
    494 
    495         a : [..., num_rx_patterns, num_tx_patterns], tf.complex
    496             Channel coefficients
    497 
    498         Output
    499         -------
    500         a : [..., num_rx_ant, num_tx_ant], tf.complex
    501             Channel coefficients with the antenna array applied
    502         """
    503 
    504         two_pi = tf.cast(2.*PI, self._rdtype)
    505 
    506         # Rotated position of the TX antenna elements
    507         # [..., tx_array_size, 3]
    508         tx_rel_ant_pos = expand_to_rank(self._scene.tx_array.positions,
    509                                         tf.rank(tx_rot_mat), 0)
    510         # [..., 1, 3, 3]
    511         tx_rot_mat_ = tf.expand_dims(tx_rot_mat, axis=-3)
    512         # [..., tx_array_size, 3]
    513         tx_rel_ant_pos = tf.linalg.matvec(tx_rot_mat_, tx_rel_ant_pos)
    514 
    515         # Rotated position of the RX antenna elements
    516         # [1, rx_array_size, 3]
    517         rx_rel_ant_pos = self._scene.rx_array.positions
    518         # [1, 3, 3]
    519         rx_rot_mat = tf.expand_dims(rx_rot_mat, axis=0)
    520         # [rx_array_size, 3]
    521         rx_rel_ant_pos = tf.linalg.matvec(rx_rot_mat, rx_rel_ant_pos)
    522         # [..., rx_array_size, 3]
    523         rx_rel_ant_pos = expand_to_rank(rx_rel_ant_pos, tf.rank(tx_rel_ant_pos),
    524                                         0)
    525 
    526         # Expand dims for broadcasting with antennas
    527         # [..., 1, 3]
    528         k_rx = tf.expand_dims(k_rx, axis=-2)
    529         k_tx = tf.expand_dims(k_tx, axis=-2)
    530         # Compute the synthetic phase shifts due to the antenna array
    531         # Transmitter side
    532         # [..., tx_array_size]
    533         tx_phase_shifts = dot(tx_rel_ant_pos, k_tx)
    534         # Receiver side
    535         # [..., rx_array_size]
    536         rx_phase_shifts = dot(rx_rel_ant_pos, k_rx)
    537         # Total phase shift
    538         # [..., rx_array_size, 1]
    539         rx_phase_shifts = tf.expand_dims(rx_phase_shifts, axis=-1)
    540         # [..., 1, tx_array_size]
    541         tx_phase_shifts = tf.expand_dims(tx_phase_shifts, axis=-2)
    542         # [..., rx_array_size, tx_array_size]
    543         phase_shifts = rx_phase_shifts + tx_phase_shifts
    544         phase_shifts = two_pi*phase_shifts/self._scene.wavelength
    545         # Apply the phase shifts
    546         # [..., 1, rx_array_size, 1, tx_array_size]
    547         phase_shifts = tf.expand_dims(phase_shifts, axis=-2)
    548         phase_shifts = tf.expand_dims(phase_shifts, axis=-4)
    549         # [..., num_rx_patterns, 1, num_tx_patterns, 1]
    550         a = tf.expand_dims(a, axis=-1)
    551         a = tf.expand_dims(a, axis=-3)
    552         # [..., num_rx_patterns, rx_array_size, num_tx_patterns, tx_array_size]
    553         a = a*tf.exp(tf.complex(tf.zeros_like(phase_shifts), phase_shifts))
    554         # Reshape to merge antenna patterns and array
    555         # [...,
    556         #   num_rx_ant=num_rx_patterns*rx_array_size,
    557         #   num_tx_ant=num_tx_patterns*tx_array_size ]
    558         a = flatten_dims(a, 2, len(a.shape)-4)
    559         a = flatten_dims(a, 2, len(a.shape)-2)
    560 
    561         return a
    562 
    563     def _update_coverage_map(self, cm_center, cm_size, cm_cell_size, num_cells,
    564                              rot_gcs_2_mp, cm_normal, tx_rot_mat,
    565                              rx_rot_mat, precoding_vec, combining_vec,
    566                              samples_tx_indices, e_field, field_es, field_ep,
    567                              mp_hit_point, hit_mp, k_tx, previous_int_point, cm,
    568                              ris, radii_curv, angular_opening):
    569         r"""
    570         Updates the coverage map with the power of the paths that hit it.
    571 
    572         Input
    573         ------
    574         cm_center : [3], tf.float
    575             Center of the coverage map
    576 
    577         cm_size : [2], tf.float
    578             Scale of the coverage map.
    579             The width of the map (in the local X direction) is ``cm_size[0]``
    580             and its map (in the local Y direction) ``cm_size[1]``.
    581 
    582         cm_cell_size : [2], tf.float
    583             Resolution of the coverage map, i.e., width
    584             (in the local X direction) and height (in the local Y direction) in
    585             meters of a cell of the coverage map
    586 
    587         num_cells : [2], tf.int
    588             Number of cells in the coverage map
    589 
    590         rot_gcs_2_mp : [3, 3], tf.float
    591             Rotation matrix for going from the measurement plane LCS to the GCS
    592 
    593         cm_normal : [3], tf.float
    594             Normal to the measurement plane
    595 
    596         tx_rot_mat : [num_tx, 3, 3], tf.float
    597             Rotation matrix built from the orientation of the transmitters
    598 
    599         rx_rot_mat : [3, 3], tf.float
    600             Rotation matrix built from the orientation of the receivers
    601 
    602         precoding_vec : [num_tx, num_tx_ant], tf.complex
    603             Vector used for transmit-precoding
    604 
    605         combining_vec : [num_rx_ant], tf.complex | None
    606             Vector used for receive-combing.
    607             If set to `None`, then no combining is applied, and
    608             the energy received by all antennas is summed.
    609 
    610         samples_tx_indices : [num_samples], tf.int
    611             Transmitter indices that correspond to every sample, i.e., from
    612             which the ray was shot.
    613 
    614         e_field : [num_samples, num_tx_patterns, 2], tf.float
    615             Incoming electric field. These are the e_s and e_p components.
    616             The e_s and e_p directions are given thereafter.
    617 
    618         field_es : [num_samples, 3], tf.float
    619             S direction for the incident field
    620 
    621         field_ep : [num_samples, 3], tf.float
    622             P direction for the incident field
    623 
    624         mp_hit_point : [num_samples, 3], tf.float
    625             Positions of the hit points with the measurement plane.
    626 
    627         hit_mp : [num_samples], tf.bool
    628             Set to `True` for samples that hit the measurement plane
    629 
    630         k_tx : [num_samples, 3], tf.float
    631             Direction of departure from the transmitters
    632 
    633         previous_int_point : [num_samples, 3], tf.float
    634             Position of the previous interaction with the scene
    635 
    636         cm : [num_tx, num_cells_y+1, num_cells_x+1], tf.float
    637             Coverage map
    638 
    639         ris : bool
    640             Set to `True` if RIS are enabled
    641 
    642         radii_curv : [num_active_samples, 2], tf.float
    643             Principal radii of curvature
    644 
    645         angular_opening : [num_active_samples], tf.float
    646             Angular opening of the ray tube
    647 
    648         Output
    649         -------
    650         cm : [num_tx, num_cells_y+1, num_cells_x+1], tf.float
    651             Updated coverage map
    652         """
    653 
    654         # Extract the samples that hit the coverage map.
    655         # This is to avoid computing the channel coefficients for all the
    656         # samples.
    657         # Indices of the samples that hit the coverage map
    658         # [num_hits]
    659         hit_mp_ind = tf.where(hit_mp)[:,0]
    660         # Indices of the transmitters corresponding to the rays that hit
    661         # [num_hits]
    662         hit_mp_tx_ind = tf.gather(samples_tx_indices, hit_mp_ind)
    663         # the coverage map
    664         # [num_hits, 3]
    665         mp_hit_point = tf.gather(mp_hit_point, hit_mp_ind, axis=0)
    666         # [num_hits, 3]
    667         previous_int_point = tf.gather(previous_int_point, hit_mp_ind, axis=0)
    668         # [num_hits, 3]
    669         k_tx = tf.gather(k_tx, hit_mp_ind, axis=0)
    670         # [num_hits, 3]
    671         precoding_vec = tf.gather(precoding_vec, hit_mp_tx_ind, axis=0)
    672         # [num_hits, 3, 3]
    673         tx_rot_mat = tf.gather(tx_rot_mat, hit_mp_tx_ind, axis=0)
    674         # [num_hits, num_tx_patterns, 2]
    675         e_field = tf.gather(e_field, hit_mp_ind, axis=0)
    676         # [num_hits, 3]
    677         field_es = tf.gather(field_es, hit_mp_ind, axis=0)
    678         # [num_hits, 3]
    679         field_ep = tf.gather(field_ep, hit_mp_ind, axis=0)
    680         if ris:
    681             # [num_hits, 2]
    682             radii_curv = tf.gather(radii_curv, hit_mp_ind, axis=0)
    683             # [num_hits]
    684             angular_opening = tf.gather(angular_opening, hit_mp_ind, axis=0)
    685 
    686         # Cell indices
    687         # [num_hits, 2]
    688         hit_cells = self._mp_hit_point_2_cell_ind(rot_gcs_2_mp, cm_center,
    689                                                   cm_size, cm_cell_size,
    690                                                   num_cells, mp_hit_point)
    691         # Receive direction
    692         # k_rx : [num_hits, 3]
    693         # length : [num_hits]
    694         k_rx,length = normalize(mp_hit_point - previous_int_point)
    695 
    696         if ris:
    697             # Apply spreading factor
    698             # [num_active_samples]
    699             sf = compute_spreading_factor(radii_curv[:,0], radii_curv[:,1],
    700                                           length)
    701             # [num_active_samples, 1, 1]
    702             sf = expand_to_rank(sf, tf.rank(e_field), -1)
    703             sf = tf.complex(sf, tf.zeros_like(sf))
    704             # [num_active_samples, num_tx_patterns, 2]
    705             e_field *= sf
    706 
    707         # Compute the receive field in the GCS
    708         # rx_field : [num_hits, num_rx_patterns, 2]
    709         # rx_es_hat, rx_ep_hat : [num_hits, 3]
    710         rx_field, rx_es_hat, rx_ep_hat = self._compute_antenna_patterns(
    711             rx_rot_mat, self._scene.rx_array.antenna.patterns, -k_rx)
    712         # Move the incident field to the receiver basis
    713         # Change of basis of the field
    714         # [num_hits, 2, 2]
    715         to_rx_mat = component_transform(field_es, field_ep,
    716                                         rx_es_hat, rx_ep_hat)
    717         # [num_hits, 1, 2, 2]
    718         to_rx_mat = tf.expand_dims(to_rx_mat, axis=1)
    719         to_rx_mat = tf.complex(to_rx_mat, tf.zeros_like(to_rx_mat))
    720         # [num_hits, num_tx_patterns, 2]
    721         e_field = tf.linalg.matvec(to_rx_mat, e_field)
    722         # Apply the receiver antenna field to compute the channel coefficient
    723         # [num_hits num_rx_patterns, 1, 2]
    724         rx_field = tf.expand_dims(rx_field, axis=2)
    725         # [num_hits, 1, num_tx_patterns, 2]
    726         e_field = tf.expand_dims(e_field, axis=1)
    727 
    728         # [num_hits, num_rx_patterns, num_tx_patterns]
    729         a = tf.reduce_sum(tf.math.conj(rx_field)*e_field, axis=-1)
    730 
    731         # Apply synthetic array
    732         # [num_hits, num_rx_ant, num_tx_ant]
    733         a = self._apply_synthetic_array(tx_rot_mat, rx_rot_mat,
    734                                         k_rx, k_tx, a)
    735 
    736         # Apply precoding
    737         # [num_hits, 1, num_tx_ant]
    738         precoding_vec = tf.expand_dims(precoding_vec, 1)
    739         # [num_hits, num_rx_ant]
    740         a = tf.reduce_sum(a*precoding_vec, axis=-1)
    741         # Apply combining
    742         # If no combining vector is provided, then sum the energy received by
    743         # the antennas
    744         if combining_vec is None:
    745             # [num_hits]
    746             a = tf.reduce_sum(tf.square(tf.abs(a)), axis=-1)
    747         else:
    748             # [1, num_rx_ant]
    749             combining_vec = tf.expand_dims(combining_vec, 0)
    750             # [num_hits]
    751             a = tf.reduce_sum(tf.math.conj(combining_vec)*a, axis=-1)
    752             # Compute the amplitude of the path
    753             # [num_hits]
    754             a = tf.square(tf.abs(a))
    755 
    756         # Add the rays contribution to the coverage map
    757         # We just divide by cos(aoa) instead of dividing by the square distance
    758         # to apply the propagation loss, to then multiply by the square distance
    759         # over cos(aoa) to compute the ray weight.
    760         # Ray weighting
    761         # Cosine of the angle of arrival with respect to the normal of
    762         # the plan
    763         # [num_hits]
    764         cos_aoa = tf.abs(dot(k_rx, cm_normal, clip=True))
    765 
    766         if ris:
    767             # Radii of curvature at the interaction point with the measurement
    768             # plane
    769             # [num_hits, 2]
    770             radii_curv += tf.expand_dims(length, axis=1)
    771             # [num_hits]
    772             ray_weights = tf.math.divide_no_nan(radii_curv[:,0]*radii_curv[:,1],
    773                                                 cos_aoa)
    774             ray_weights *= angular_opening
    775         else:
    776             # [num_hits]
    777             ray_weights = tf.math.divide_no_nan(tf.ones_like(cos_aoa), cos_aoa)
    778 
    779         # Add the contribution to the coverage map
    780         # [num_hits, 3]
    781         hit_cells = tf.concat([tf.expand_dims(hit_mp_tx_ind, axis=-1),
    782                                 hit_cells], axis=-1)
    783         # [num_tx, num_cells_y+1, num_cells_x+1]
    784         cm = tf.tensor_scatter_nd_add(cm, hit_cells, ray_weights*a)
    785 
    786         return cm
    787 
    788     def _compute_reflected_field(self, normals, etas, scattering_coefficient,
    789             k_i, e_field, field_es, field_ep, scattering, ris, length,
    790             radii_curv, dirs_curv):
    791         r"""
    792         Computes the reflected field at the intersections.
    793 
    794         Input
    795         ------
    796         normals : [num_active_samples, 3], tf.float
    797             Normals to the intersected primitives
    798 
    799         etas : [num_active_samples], tf.complex
    800             Relative permittivities of the intersected primitives
    801 
    802         scattering_coefficient : [num_active_samples], tf.float
    803             Scattering coefficients of the intersected primitives
    804 
    805         k_i : [num_active_samples, 3], tf.float
    806             Direction of arrival of the ray
    807 
    808         e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
    809             S and P components of the incident field
    810 
    811         field_es : [num_active_samples, 3], tf.float
    812             Direction of the S component of the incident field
    813 
    814         field_ep : [num_active_samples, 3], tf.float
    815             Direction of the P component of the incident field
    816 
    817         scattering : bool
    818             Set to `True` if scattering is enabled
    819 
    820         ris : bool
    821             Set to `True` if RIS are enabled
    822 
    823         length : [num_active_samples], tf.float
    824             Length of the last path segment
    825 
    826         radii_curv : [num_active_samples, 2], tf.float
    827             Principal radii of curvature
    828 
    829         dirs_curv : [num_active_samples, 2, 3], tf.float
    830             Principal direction of curvature
    831 
    832         Output
    833         -------
    834         e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
    835             S and P components of the reflected field
    836 
    837         field_es : [num_active_samples, 3], tf.float
    838             Direction of the S component of the reflected field
    839 
    840         field_ep : [num_active_samples, 3], tf.float
    841             Direction of the P component of the reflected field
    842 
    843         k_r : [num_active_samples, 3], tf.float
    844             Direction of the reflected ray
    845 
    846         radii_curv : [num_active_samples, 2], tf.float
    847             Principal radii of curvature of the reflected ray tube
    848 
    849         dirs_curv : [num_active_samples, 2, 3], tf.float
    850             Principal direction of curvature of the reflected ray tube
    851         """
    852 
    853         # [num_active_samples, 3]
    854         k_r = k_i - 2.*dot(k_i, normals, keepdim=True, clip=True)*normals
    855 
    856         # S/P direction for the incident/reflected field
    857         # [num_active_samples, 3]
    858         # pylint: disable=unbalanced-tuple-unpacking
    859         e_i_s, e_i_p, e_r_s, e_r_p = compute_field_unit_vectors(k_i, k_r,
    860                                             normals, SolverBase.EPSILON)
    861 
    862         # Move to the incident S/P component
    863         # [num_active_samples, 2, 2]
    864         to_incident = component_transform(field_es, field_ep,
    865                                         e_i_s, e_i_p)
    866         # [num_active_samples, 1, 2, 2]
    867         to_incident = tf.expand_dims(to_incident, axis=1)
    868         to_incident = tf.complex(to_incident, tf.zeros_like(to_incident))
    869         # [num_active_samples, num_tx_patterns, 2]
    870         e_field = tf.linalg.matvec(to_incident, e_field)
    871 
    872         # Compute the reflection coefficients
    873         # [num_active_samples]
    874         cos_theta = -dot(k_i, normals, clip=True)
    875 
    876         # [num_active_samples]
    877         r_s, r_p = reflection_coefficient(etas, cos_theta)
    878 
    879         # If scattering is enabled, then the rays are randomly
    880         # allocated to reflection or scattering by sampling according to the
    881         # scattering coefficient. An oversampling factor is applied to keep
    882         # differentiability with respect to the scattering coefficient.
    883         # This oversampling factor is the ratio between the reduction factor
    884         # and the (non-differientiable) probability with which a
    885         # reflection phenomena is selected. In our case, this probability is the
    886         # reduction factor.
    887         # If scattering is disabled, all samples are allocated to reflection to
    888         # maximize sample-efficiency. However, this requires correcting the
    889         # contribution of the reflected rays by applying the reduction factor.
    890         # [num_active_samples]
    891         reduction_factor = tf.sqrt(1. - tf.square(scattering_coefficient))
    892         reduction_factor = tf.complex(reduction_factor,
    893                                       tf.zeros_like(reduction_factor))
    894         if scattering:
    895             # [num_active_samples]
    896             ovs_factor = tf.math.divide_no_nan(reduction_factor,
    897                                             tf.stop_gradient(reduction_factor))
    898             r_s *= ovs_factor
    899             r_p *= ovs_factor
    900         else:
    901             # [num_active_samples]
    902             r_s *= reduction_factor
    903             r_p *= reduction_factor
    904 
    905         # Apply the reflection coefficients
    906         # [num_active_samples, 2]
    907         r = tf.stack([r_s, r_p], -1)
    908         # [num_active_samples, 1, 2]
    909         r = tf.expand_dims(r, axis=-2)
    910         # [num_active_samples, num_tx_patterns, 2]
    911         e_field *= r
    912 
    913         # Update S/P directions
    914         # [num_active_samples, 3]
    915         field_es = e_r_s
    916         field_ep = e_r_p
    917 
    918         if ris:
    919             # Compute and apply the spreading factor
    920             # [num_active_samples]
    921             sf = compute_spreading_factor(radii_curv[:,0], radii_curv[:,1],
    922                                           length)
    923             # [num_active_samples, 1, 1]
    924             sf = expand_to_rank(sf, tf.rank(e_field), -1)
    925             sf = tf.complex(sf, tf.zeros_like(sf))
    926             # [num_active_samples, num_tx_patterns, 2]
    927             e_field *= sf
    928 
    929             # Update principal radii of curvature
    930             # Radii of curvature at intersection point
    931             # [num_reflected_samples, 2]
    932             radii_curv += tf.expand_dims(length, axis=1)
    933             # Radii of curvature of the reflected field
    934             # [num_reflected_samples, 2]
    935             inv_radii_curv = tf.math.reciprocal_no_nan(radii_curv)
    936             # [num_reflected_samples]
    937             inv_radii_curv_sum = inv_radii_curv[:,0] + inv_radii_curv[:,1]
    938             # [num_reflected_samples]
    939             inv_radii_curv_dif = tf.abs(inv_radii_curv[:,0]-inv_radii_curv[:,1])
    940             # [num_reflected_samples, 2]
    941             inv_new_radii_curv = tf.stack([
    942                 0.5*(inv_radii_curv_sum + inv_radii_curv_dif),
    943                 0.5*(inv_radii_curv_sum - inv_radii_curv_dif)], axis=1)
    944             # [num_reflected_samples, 2]
    945             new_radii_curv = tf.math.reciprocal_no_nan(inv_new_radii_curv)
    946 
    947             # Update the principal direction of curvature
    948             # [num_reflected_samples, 3]
    949             new_dir_curv_1 = dirs_curv[:,0]\
    950               - 2.*dot(dirs_curv[:,0], normals, keepdim=True, clip=True)*normals
    951             # [num_reflected_samples, 3]
    952             new_dir_curv_2 = -cross(k_r, new_dir_curv_1)
    953             # [num_reflected_samples, 2, 3]
    954             new_dirs_curv = tf.stack([new_dir_curv_1, new_dir_curv_2], axis=1)
    955         else:
    956             new_radii_curv = None
    957             new_dirs_curv = None
    958 
    959         return e_field, field_es, field_ep, k_r, new_radii_curv, new_dirs_curv
    960 
    961     def _compute_scattered_field(self, int_point, objects, normals, etas,
    962             scattering_coefficient, xpd_coefficient, alpha_r, alpha_i, lambda_,
    963             k_i, e_field, field_es, field_ep, reflection, ris, length,
    964             radii_curv, angular_opening):
    965         r"""
    966         Computes the scattered field at the intersections.
    967 
    968         Input
    969         ------
    970         int_point : [num_active_samples, 3], tf.float
    971             Positions at which the rays intersect with the scene
    972 
    973         objects : [num_active_samples], tf.int
    974             Indices of the intersected objects
    975 
    976         normals : [num_active_samples, 3], tf.float
    977             Normals to the intersected primitives
    978 
    979         etas : [num_active_samples], tf.complex
    980             Relative permittivities of the intersected primitives
    981 
    982         scattering_coefficient : [num_active_samples], tf.float
    983             Scattering coefficients of the intersected primitives
    984 
    985         xpd_coefficient : [num_active_samples], tf.float
    986             Tensor containing the cross-polarization discrimination
    987             coefficients of all shapes
    988 
    989         alpha_r : [num_active_samples], tf.float
    990             Tensor containing the alpha_r scattering parameters of all shapes
    991 
    992         alpha_i : [num_active_samples], tf.float
    993             Tensor containing the alpha_i scattering parameters of all shapes
    994 
    995         lambda_ : [num_shape], tf.float
    996             Tensor containing the lambda_ scattering parameters of all shapes
    997 
    998         k_i : [num_active_samples, 3], tf.float
    999             Direction of arrival of the ray
   1000 
   1001         e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
   1002             S and P components of the incident field
   1003 
   1004         field_es : [num_active_samples, 3], tf.float
   1005             Direction of the S component of the incident field
   1006 
   1007         field_ep : [num_active_samples, 3], tf.float
   1008             Direction of the P component of the incident field
   1009 
   1010         reflection : bool
   1011             Set to `True` if reflection is enabled
   1012 
   1013         ris : bool
   1014             Set to `True` if RIS is enabled
   1015 
   1016         length : [num_active_samples], tf.float
   1017             Length of the last path segment
   1018 
   1019         radii_curv : [num_active_samples, 2], tf.float
   1020             Principal radii of curvature
   1021 
   1022         angular_opening : [num_active_samples], tf.float
   1023             Angular opening
   1024 
   1025         Output
   1026         -------
   1027         e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
   1028             S and P components of the scattered field
   1029 
   1030         field_es : [num_active_samples, 3], tf.float
   1031             Direction of the S component of the scattered field
   1032 
   1033         field_ep : [num_active_samples, 3], tf.float
   1034             Direction of the P component of the scattered field
   1035 
   1036         k_s : [num_active_samples, 3], tf.float
   1037             Direction of the scattered ray
   1038 
   1039         radii_curv : [num_active_samples, 2], tf.float
   1040             Principal radii of curvature of the scattered ray tube
   1041 
   1042         dirs_curv : [num_active_samples, 2, 3], tf.float
   1043             Principal direction of curvature of the scattered ray tube
   1044 
   1045         angular_opening : [num_active_samples], tf.float
   1046             Angular opening of the scattered field
   1047         """
   1048 
   1049         if ris:
   1050             # Compute and apply the spreading factor to the incident field
   1051             # [num_active_samples]
   1052             sf = compute_spreading_factor(radii_curv[:,0], radii_curv[:,1],
   1053                                           length)
   1054             # [num_active_samples, 1, 1]
   1055             sf = expand_to_rank(sf, tf.rank(e_field), -1)
   1056             sf = tf.complex(sf, tf.zeros_like(sf))
   1057             # [num_active_samples, num_tx_patterns, 2]
   1058             e_field *= sf
   1059 
   1060         # Represent incident field in the basis for reflection
   1061         e_i_s, e_i_p = compute_field_unit_vectors(k_i, None,
   1062                                             normals, SolverBase.EPSILON,
   1063                                             return_e_r=False)
   1064 
   1065         # [num_active_samples, 2, 2]
   1066         to_incident = component_transform(field_es, field_ep,
   1067                                         e_i_s, e_i_p)
   1068         # [num_active_samples, 1, 2, 2]
   1069         to_incident = tf.expand_dims(to_incident, axis=1)
   1070         to_incident = tf.complex(to_incident, tf.zeros_like(to_incident))
   1071         # [num_active_samples, num_tx_patterns, 2]
   1072         e_field_ref = tf.linalg.matvec(to_incident, e_field)
   1073 
   1074         # Compute Fresnel reflection coefficients
   1075         # [num_active_samples]
   1076         cos_theta = -dot(k_i, normals, clip=True)
   1077 
   1078         # [num_active_samples]
   1079         r_s, r_p = reflection_coefficient(etas, cos_theta)
   1080 
   1081         # [num_active_samples, 2]
   1082         r = tf.stack([r_s, r_p], axis=-1)
   1083         # [num_active_samples, 1, 2]
   1084         r = tf.expand_dims(r, axis=-2)
   1085 
   1086         # Compute amplitude of the reflected field
   1087         # [num_active_samples, num_tx_patterns]
   1088         ref_amp = tf.sqrt(tf.reduce_sum(tf.abs(r*e_field_ref)**2, axis=-1))
   1089 
   1090         # Compute incoming field and polarization vectors
   1091         # [num_active_samples, num_tx_patterns, 1]
   1092         e_field_s, e_field_p = tf.split(e_field, 2, axis=-1)
   1093 
   1094         # [num_active_samples, 1, 3]
   1095         field_es = tf.expand_dims(field_es, axis=1)
   1096         field_es = tf.complex(field_es, tf.zeros_like(field_es))
   1097         field_ep = tf.expand_dims(field_ep, axis=1)
   1098         field_ep = tf.complex(field_ep, tf.zeros_like(field_ep))
   1099 
   1100         # Incoming field vector
   1101         # [num_active_samples, num_tx_patterns, 3]
   1102         e_in = e_field_s*field_es + e_field_p*field_ep
   1103 
   1104         # Polarization vectors
   1105         # [num_active_samples, num_tx_patterns, 3]
   1106         e_pol_hat, _ = normalize(tf.math.real(e_in))
   1107         e_xpol_hat = cross(e_pol_hat, tf.expand_dims(k_i, 1))
   1108 
   1109         # Compute incoming spherical unit vectors in GCS
   1110         theta_i, phi_i = theta_phi_from_unit_vec(-k_i)
   1111         # [num_active_samples, 1, 3]
   1112         theta_hat_i = tf.expand_dims(theta_hat(theta_i, phi_i), axis=1)
   1113         phi_hat_i = tf.expand_dims(phi_hat(phi_i), axis=1)
   1114 
   1115         # Transformation from e_pol_hat, e_xpol_hat to theta_hat_i,phi_hat_i
   1116         # [num_active_samples, num_tx_patterns, 2, 2]
   1117         trans_mat = component_transform(e_pol_hat, e_xpol_hat,
   1118                                         theta_hat_i, phi_hat_i)
   1119         trans_mat = tf.complex(trans_mat, tf.zeros_like(trans_mat))
   1120 
   1121         # Generate random phases
   1122         # All tx_patterns get the same phases
   1123         num_active_samples = tf.shape(e_field)[0]
   1124         phase_shape = [num_active_samples, 1, 2]
   1125         # [num_active_samples, 1, 2]
   1126         phases = config.tf_rng.uniform(phase_shape, maxval=2*PI,
   1127                                        dtype=self._rdtype)
   1128 
   1129         # Compute XPD weighting
   1130         # [num_active_samples, 2]
   1131         xpd_weights = tf.stack([tf.sqrt(1-xpd_coefficient),
   1132                                         tf.sqrt(xpd_coefficient)],
   1133                                        axis=-1)
   1134         xpd_weights = tf.complex(xpd_weights, tf.zeros_like(xpd_weights))
   1135         # [num_active_samples, 1, 2]
   1136         xpd_weights = tf.expand_dims(xpd_weights, axis=1)
   1137 
   1138         # Create scattered field components from phases and xpd_weights
   1139         # [num_active_samples, 1, 2]
   1140         e_field = tf.exp(tf.complex(tf.zeros_like(phases), phases))
   1141         e_field *= xpd_weights
   1142 
   1143         # Apply transformation to field vector
   1144         # [num_active_samples, num_tx_patterns, 2]
   1145         e_field = tf.linalg.matvec(trans_mat, e_field)
   1146 
   1147         # Draw random directions for scattered paths
   1148         # [num_active_samples, 3]
   1149         k_s = sample_points_on_hemisphere(normals)
   1150 
   1151         # Evaluate scattering pattern
   1152         # Evaluate scattering pattern for all paths.
   1153         # If a callable is defined to compute the scattering pattern,
   1154         # it is invoked. Otherwise, the radio materials of objects are used.
   1155         sp_callable = self._scene.scattering_pattern_callable
   1156         if sp_callable is None:
   1157             # [num_active_samples]
   1158             f_s = ScatteringPattern.pattern(k_i,
   1159                                             k_s,
   1160                                             normals,
   1161                                             alpha_r,
   1162                                             alpha_i,
   1163                                             lambda_)
   1164         else:
   1165             # [num_targets, num_sources, max_num_paths]
   1166             f_s = sp_callable(objects,
   1167                               int_point,
   1168                               k_i,
   1169                               k_s,
   1170                               normals)
   1171 
   1172         # Compute scaled scattered field
   1173         # [num_active_samples, num_tx_patterns, 2]
   1174         ref_amp = tf.expand_dims(ref_amp, -1)
   1175         e_field *= tf.complex(ref_amp, tf.zeros_like(ref_amp))
   1176         f_s = tf.reshape(tf.sqrt(f_s), [-1, 1, 1])
   1177         e_field *= tf.complex(f_s, tf.zeros_like(f_s))
   1178 
   1179         if ris:
   1180             # Weight due to angular domain
   1181             radii_curv += tf.expand_dims(length, axis=1)
   1182             # [num_active_samples, 1, 1]
   1183             w = angular_opening*radii_curv[:,0]*radii_curv[:,1]
   1184             w = expand_to_rank(w, tf.rank(e_field))
   1185             # [num_active_samples, num_tx_patterns, 2]
   1186             e_field *= tf.cast(tf.sqrt(w), self._dtype)
   1187         else:
   1188             e_field *= tf.cast(tf.sqrt(2*PI), self._dtype)
   1189 
   1190         # If reflection is enabled, then the rays are randomly
   1191         # allocated to reflection or scattering by sampling according to the
   1192         # scattering coefficient. An oversampling factor is applied to keep
   1193         # differentiability with respect to the scattering coefficient.
   1194         # This oversampling factor is the ratio between the scattering factor
   1195         # and the (non-differientiable) probability with which a
   1196         # scattering phenomena is selected. In our case, this probability is the
   1197         # scattering factor.
   1198         # If reflection is disabled, all samples are allocated to scattering to
   1199         # maximize sample-efficiency. However, this requires correcting the
   1200         # contribution of the reflected rays by applying the scattering factor.
   1201         # [num_active_samples]
   1202         scattering_factor = tf.complex(scattering_coefficient,
   1203                                        tf.zeros_like(scattering_coefficient))
   1204         # [num_active_samples, 1, 1]
   1205         scattering_factor = tf.reshape(scattering_factor, [-1, 1, 1])
   1206         if reflection:
   1207            # [num_active_samples]
   1208             ovs_factor = tf.math.divide_no_nan(scattering_factor,
   1209                                             tf.stop_gradient(scattering_factor))
   1210             # [num_active_samples, num_tx_patterns, 2]
   1211             e_field *= ovs_factor
   1212         else:
   1213             # [num_active_samples, num_tx_patterns, 2]
   1214             e_field *= scattering_factor
   1215 
   1216         # Compute outgoing spherical unit vectors in GCS
   1217         theta_s, phi_s = theta_phi_from_unit_vec(k_s)
   1218         # [num_active_samples, 3]
   1219         field_es = theta_hat(theta_s, phi_s)
   1220         field_ep = phi_hat(phi_s)
   1221 
   1222         if ris:
   1223             # Update principal radii of curvature
   1224             # [num_reflected_samples, 2]
   1225             new_radii_curv = tf.zeros_like(radii_curv)
   1226 
   1227             # Update the principal direction of curvature
   1228             # [num_reflected_samples, 3]
   1229             new_dir_curv_1, new_dir_curv_2 = gen_basis_from_z(k_s,
   1230                                                             SolverBase.EPSILON)
   1231             # [num_reflected_samples, 2, 3]
   1232             new_dirs_curv = tf.stack([new_dir_curv_1, new_dir_curv_2], axis=1)
   1233 
   1234             # New angular opening
   1235             new_angular_opening = tf.fill(tf.shape(angular_opening),
   1236                                         tf.cast(2.*PI, self._rdtype))
   1237         else:
   1238             new_radii_curv = None
   1239             new_dirs_curv = None
   1240             new_angular_opening = None
   1241 
   1242         return e_field, field_es, field_ep, k_s, new_radii_curv, new_dirs_curv,\
   1243             new_angular_opening
   1244 
   1245     def _compute_ris_reflected_field(self, int_point, ris_ind, k_i, e_field,
   1246                             field_es, field_ep, length, radii_curv, dirs_curv):
   1247         r"""
   1248         Computes the field reflected by the RIS at the intersections.
   1249 
   1250         Input
   1251         ------
   1252         int_point : [num_active_samples, 3], tf.float
   1253             Positions at which the rays intersect with the RIS
   1254 
   1255         ris_ind : [num_active_samples], tf.int
   1256             Indices of the intersected RIS
   1257 
   1258         k_i : [num_active_samples, 3], tf.float
   1259             Direction of arrival of the ray
   1260 
   1261         e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
   1262             S and P components of the incident field
   1263 
   1264         field_es : [num_active_samples, 3], tf.float
   1265             Direction of the S component of the incident field
   1266 
   1267         field_ep : [num_active_samples, 3], tf.float
   1268             Direction of the P component of the incident field
   1269 
   1270         length : [num_active_samples], tf.float
   1271             Length of the last path segment
   1272 
   1273         radii_curv : [num_active_samples, 2], tf.float
   1274             Principal radii of curvature of the incident ray tube
   1275 
   1276         dirs_curv : [num_active_samples, 2, 3], tf.float
   1277             Principal direction of curvature of the incident ray tube
   1278 
   1279         Output
   1280         -------
   1281         e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
   1282             S and P components of the reflected field
   1283 
   1284         field_es : [num_active_samples, 3], tf.float
   1285             Direction of the S component of the reflected field
   1286 
   1287         field_ep : [num_active_samples, 3], tf.float
   1288             Direction of the P component of the reflected field
   1289 
   1290         k_s : [num_active_samples, 3], tf.float
   1291             Direction of the reflected ray
   1292 
   1293         normals : [num_active_samples, 3], tf.float
   1294             Normals to the intersected RIS
   1295 
   1296         radii_curv : [num_active_samples, 2], tf.float
   1297             Principal radii of curvature of the reflected ray tube
   1298 
   1299         dirs_curv : [num_active_samples, 2, 3], tf.float
   1300             Principal direction of curvature of the reflected ray tube
   1301         """
   1302         # Compute and apply the spreading factor
   1303         # [num_active_samples]
   1304         sf = compute_spreading_factor(radii_curv[:,0], radii_curv[:,1], length)
   1305         # [num_active_samples, 1, 1]
   1306         sf = expand_to_rank(sf, tf.rank(e_field), -1)
   1307         sf = tf.complex(sf, tf.zeros_like(sf))
   1308         # [num_active_samples, num_tx_patterns, 2]
   1309         e_field *= sf
   1310         # Update radii of curvature
   1311         # [num_active_samples, 2]
   1312         radii_curv += tf.expand_dims(length, axis=1)
   1313 
   1314         all_int_point = int_point
   1315         all_k_i = k_i
   1316         all_e_field = e_field
   1317         all_field_es = field_es
   1318         all_field_ep = field_ep
   1319         all_radii_curv = radii_curv
   1320         all_dirs_curv = dirs_curv
   1321 
   1322         # Outputs
   1323         output_e_field = tf.zeros([0, e_field.shape[1], 2], self._dtype)
   1324         output_field_es = tf.zeros([0, 3], self._rdtype)
   1325         output_field_ep = tf.zeros([0, 3], self._rdtype)
   1326         output_k_r = tf.zeros([0, 3], self._rdtype)
   1327         output_radii_curv = tf.zeros([0, 2], self._rdtype)
   1328         output_dirs_curv = tf.zeros([0, 2, 3], self._rdtype)
   1329         output_normals = tf.zeros([0, 3], self._rdtype)
   1330 
   1331         # Iterate over the RIS
   1332         for ris in self._scene.ris.values():
   1333 
   1334             # Get ID of this RIS
   1335             this_ris_id = ris.object_id
   1336 
   1337             # Get normal of this RIS
   1338             # [3]
   1339             normal = ris.world_normal
   1340             # [1,3]
   1341             normal = tf.expand_dims(normal, axis=0)
   1342 
   1343             # Indices of rays hitting this RIS
   1344             # [num_active_samples]
   1345             this_ris_sample_ind = tf.where(tf.equal(ris_ind, this_ris_id))[:,0]
   1346             num_active_samples = tf.shape(this_ris_sample_ind)[0]
   1347 
   1348             # Gather incident ray directions for this RIS
   1349             # [num_active_samples, 3]
   1350             k_i = tf.gather(all_k_i, this_ris_sample_ind, axis=0)
   1351 
   1352             # Boolean indicating the RIS side
   1353             # True means it's the front, False means it's the back.
   1354             # [num_active_samples]
   1355             hit_front = -tf.math.sign(dot(k_i, normal))
   1356             hit_front = tf.greater(hit_front, 0.0)
   1357 
   1358             # Gather indices of rays that hit this RIS from the front
   1359             this_ris_sample_ind = tf.gather(this_ris_sample_ind,
   1360                                             tf.where(hit_front)[:,0])
   1361             # Number of samples corresponding to this RIS
   1362             num_ris_sample = tf.shape(this_ris_sample_ind)[0]
   1363 
   1364             # Extract data relevant to this RIS
   1365             # [this_ris_num_samples, 3]
   1366             int_point = tf.gather(all_int_point, this_ris_sample_ind, axis=0)
   1367             # [this_ris_num_samples, 3]
   1368             k_i = tf.gather(all_k_i, this_ris_sample_ind, axis=0)
   1369             # [this_ris_num_samples, num_tx_patterns, 2]
   1370             e_field = tf.gather(all_e_field, this_ris_sample_ind, axis=0)
   1371             # [this_ris_num_samples, 3]
   1372             field_es = tf.gather(all_field_es, this_ris_sample_ind, axis=0)
   1373             # [this_ris_num_samples, 3]
   1374             field_ep = tf.gather(all_field_ep, this_ris_sample_ind, axis=0)
   1375             # [this_ris_num_samples, 2]
   1376             radii_curv = tf.gather(all_radii_curv, this_ris_sample_ind, axis=0)
   1377             # [this_ris_num_samples, 2, 3]
   1378             dirs_curv = tf.gather(all_dirs_curv, this_ris_sample_ind, axis=0)
   1379 
   1380             # Number of rays hitting the RIS from the front
   1381             this_ris_num_samples = tf.shape(k_i)[0]
   1382 
   1383             # Incidence phase gradient - Eq.(9)
   1384             # [this_ris_num_samples, 3]
   1385             grad_i = k_i-normal*dot(normal, k_i)[:,tf.newaxis]
   1386             grad_i *= -self._scene.wavenumber
   1387 
   1388             # Transform interaction points to LCS of the corresponding RIS
   1389             # Store the rotation matrix for later
   1390             # [1, 3, 3]
   1391             rot_mat = rotation_matrix(ris.orientation)[tf.newaxis]
   1392             # [this_ris_num_samples, 3]
   1393             int_point_lcs = int_point - ris.position[tf.newaxis]
   1394             int_point_lcs = tf.linalg.matvec(rot_mat,
   1395                                             int_point_lcs,
   1396                                             transpose_a=True)
   1397 
   1398             # As the LCS assumes x=0, we can remove the first dimension
   1399             # [this_ris_num_samples, 2]
   1400             int_point_lcs = int_point_lcs[:,1:]
   1401 
   1402             # Compute spatial modulation coefficient for all reradiation modes
   1403             # gamma_m: [num_modes, this_ris_num_samples]
   1404             # grad_m: [num_modes, this_ris_num_samples, 3]
   1405             # hessian_m: [num_modes, this_ris_num_samples, 3, 3]
   1406             gamma_m, grad_m, hessian_m = ris(int_point_lcs, return_grads=True)
   1407             # Sample a single mode for each ray
   1408             # [this_ris_num_samples]
   1409             mode_powers = ris.amplitude_profile.mode_powers
   1410             mode = tf.random.categorical(logits=[tf.math.log(mode_powers)],
   1411                                  num_samples=this_ris_num_samples,
   1412                                  dtype=tf.int32)[0]
   1413             # gamma_m: [this_ris_num_samples]
   1414             # grad_m: [this_ris_num_samples, 3]
   1415             # hessian_m: [this_ris_num_samples, 3, 3]
   1416             gamma_m = tf.gather(tf.transpose(gamma_m, perm=[1,0]),
   1417                                 mode, batch_dims=1)
   1418             grad_m = tf.gather(tf.transpose(grad_m, perm=[1, 0, 2]),
   1419                                mode, batch_dims=1)
   1420             hessian_m = tf.gather(tf.transpose(hessian_m, perm=[1, 0, 2, 3]),
   1421                                   mode, batch_dims=1)
   1422             # Bring RIS phase gradient to GCS
   1423             # [this_ris_num_samples, 3]
   1424             grad_m = tf.linalg.matvec(rot_mat, grad_m)
   1425 
   1426             # Bring RIS phase Hessian to GCS
   1427             # [this_ris_num_samples, 3, 3]
   1428             hessian_m = tf.matmul(rot_mat,
   1429                                   tf.matmul(hessian_m,
   1430                                             rot_mat, transpose_b=True))
   1431 
   1432 
   1433             # Compute total phase gradient - Eq.(11)
   1434             # [this_ris_num_samples, 3]
   1435             grad = grad_i + grad_m
   1436 
   1437             # Compute direction of reflected ray - Eq.(13)
   1438             # [this_ris_num_samples, 3]
   1439             k_r = -grad/self._scene.wavenumber
   1440             k_r += tf.sqrt(1 - tf.reduce_sum(k_r**2, axis=-1,
   1441                                              keepdims=True)) * normal
   1442             # Compute linear transformation operator - Eq.(22)
   1443             # [this_ris_num_samples, 3, 3]
   1444             l = -outer(k_r, normal)
   1445             l /= tf.reduce_sum(k_r*normal, axis=-1,
   1446                                keepdims=True)[...,tf.newaxis]
   1447             l += tf.eye(3, batch_shape=tf.shape(l)[:1], dtype=l.dtype)
   1448 
   1449             # Compute incident curvature matrix - Eq.(4)
   1450             # [this_ris_num_samples, 3, 3]
   1451             q_i = 1/expand_to_rank(radii_curv[:,0], 3, -1) * \
   1452                    outer(dirs_curv[:,0], dirs_curv[:,0])
   1453             q_i += 1/expand_to_rank(radii_curv[:,1], 3, -1) * \
   1454                    outer(dirs_curv[:,1], dirs_curv[:,1])
   1455 
   1456             # Compute reflected curvature matrix - Eq.(21)
   1457             # [this_ris_num_samples, 3, 3]
   1458             q_r = tf.matmul(q_i - 1/self._scene.wavenumber*hessian_m, l)
   1459             q_r = tf.matmul(l, q_r, transpose_a=True)
   1460 
   1461             # Extract principal axes of curvature and associated radii - Eq.(4)
   1462             e, v,_ = tf.linalg.svd(q_r)
   1463             # [this_ris_num_samples, 2]
   1464             radii_curv = 1/e[:,:2]
   1465             # [this_ris_num_samples, 2, 3]
   1466             dirs_curv = tf.transpose(v[...,:2], perm=[0, 2, 1])
   1467 
   1468             # Basis vectors for incoming field
   1469             # [this_ris_num_samples, 3]
   1470             theta_i, phi_i = theta_phi_from_unit_vec(k_i)
   1471             e_i_s = theta_hat(theta_i, phi_i)
   1472             e_i_p = phi_hat(phi_i)
   1473 
   1474             # Component transform
   1475             # [this_ris_num_samples, 1, 2, 2]
   1476             mat_comp = component_transform(field_es, field_ep, e_i_s, e_i_p)
   1477             mat_comp = tf.complex(mat_comp, tf.zeros_like(mat_comp))
   1478             mat_comp = mat_comp[:,tf.newaxis]
   1479 
   1480             # Outgoing field - Eq.(14)
   1481             # [this_ris_num_samples, num_tx_patterns, 2]
   1482             e_field = tf.linalg.matvec(mat_comp, e_field)
   1483             e_field *= expand_to_rank(gamma_m, 3, -1)
   1484 
   1485             # Basis vectors for reflected field
   1486             # [this_ris_num_samples, 3]
   1487             theta_r, phi_r = theta_phi_from_unit_vec(k_r)
   1488             field_es = theta_hat(theta_r, phi_r)
   1489             field_ep = phi_hat(phi_r)
   1490 
   1491             # Concatenate rays from reflection by all RIS
   1492             # and create all-zeros samples for the inactive rays
   1493             # which will be dropped in a later stage.
   1494             n_p = num_active_samples - this_ris_num_samples
   1495 
   1496             def pad(x, n_p):
   1497                 """Pad input tensor with n-p zero samples"""
   1498                 paddings = tf.concat([[[0, n_p]],
   1499                                      tf.zeros([tf.rank(x)-1,2], tf.int32)],
   1500                                      axis=0)
   1501                 return tf.pad(x, paddings)
   1502 
   1503             output_e_field = tf.concat([output_e_field,
   1504                                         pad(e_field, n_p)],
   1505                                         axis=0)
   1506             output_field_es = tf.concat([output_field_es,
   1507                                          pad(field_es, n_p)],
   1508                                          axis=0)
   1509             output_field_ep = tf.concat([output_field_ep,
   1510                                          pad(field_ep, n_p)],
   1511                                          axis=0)
   1512             output_k_r = tf.concat([output_k_r, pad(k_r, n_p)], axis=0)
   1513             output_radii_curv = tf.concat([output_radii_curv,
   1514                                           pad(radii_curv, n_p)],
   1515                                           axis=0)
   1516             output_dirs_curv = tf.concat([output_dirs_curv,
   1517                                           pad(dirs_curv, n_p)],
   1518                                           axis=0)
   1519             normal = tf.tile(normal, [num_ris_sample, 1])
   1520             output_normals = tf.concat([output_normals,
   1521                                         pad(normal, n_p)], axis=0)
   1522 
   1523         output = (output_e_field, output_field_es, output_field_ep, output_k_r,\
   1524             output_normals, output_radii_curv, output_dirs_curv)
   1525         return output
   1526 
   1527     def _init_e_field(self, valid_ray, samples_tx_indices, k_tx, tx_rot_mat):
   1528         r"""
   1529         Initialize the electric field for the rays flagged as valid.
   1530 
   1531         Input
   1532         -----
   1533         valid_ray : [num_samples], tf.bool
   1534             Flag set to `True` if the ray is valid
   1535 
   1536         samples_tx_indices : [num_samples]. tf.int
   1537             Index of the transmitter from which the ray originated
   1538 
   1539         k_tx : [num_samples, 3]. tf.float
   1540             Direction of departure
   1541 
   1542         tx_rot_mat : [num_tx, 3, 3], tf.float
   1543             Matrix to go transmitter LCS to the GCS
   1544 
   1545         Output
   1546         -------
   1547         e_field : [num_valid_samples, num_tx_patterns, 2], tf.complex
   1548             Emitted electric field S and P components
   1549 
   1550         field_es : [num_valid_samples, 3], tf.float
   1551             Direction of the S component of the electric field
   1552 
   1553         field_ep : [num_valid_samples, 3], tf.float
   1554             Direction of the P component of the electric field
   1555         """
   1556 
   1557         num_samples = tf.shape(valid_ray)[0]
   1558         # [num_valid_samples]
   1559         valid_ind = tf.where(valid_ray)[:,0]
   1560         # [num_valid_samples]
   1561         valid_tx_ind = tf.gather(samples_tx_indices, valid_ind, axis=0)
   1562         # [num_valid_samples, 3]
   1563         k_tx = tf.gather(k_tx, valid_ind, axis=0)
   1564         # [num_valid_samples, 3, 3]
   1565         tx_rot_mat = tf.gather(tx_rot_mat, valid_tx_ind, axis=0)
   1566 
   1567         # val_e_field : [num_valid_samples, num_tx_patterns, 2]
   1568         # val_field_es, val_field_ep : [num_valid_samples, 3]
   1569         val_e_field, val_field_es, val_field_ep =\
   1570             self._compute_antenna_patterns(tx_rot_mat,
   1571                             self._scene.tx_array.antenna.patterns, k_tx)
   1572         valid_ind = tf.expand_dims(valid_ind, axis=-1)
   1573         # [num_samples, num_tx_patterns, 2]
   1574         e_field = tf.scatter_nd(valid_ind, val_e_field,
   1575                                 [num_samples, val_e_field.shape[1], 2])
   1576         # [num_samples, 3]
   1577         field_es = tf.scatter_nd(valid_ind, val_field_es,
   1578                                  [num_samples, 3])
   1579         field_ep = tf.scatter_nd(valid_ind, val_field_ep,
   1580                                  [num_samples, 3])
   1581         return e_field, field_es, field_ep
   1582 
   1583     def _extract_active_ris_rays(self, active_ind, int_point,
   1584         previous_int_point, primitives, e_field, field_es, field_ep,
   1585         samples_tx_indices, k_tx, radii_curv, dirs_curv, angular_opening):
   1586         r"""
   1587         Extracts the active rays hitting a RIS.
   1588 
   1589         Input
   1590         ------
   1591         active_ind : [num_active_samples], tf.int
   1592             Indices of the active rays
   1593 
   1594         int_point : [num_samples, 3], tf.float
   1595             Positions at which the rays intersect with the scene. For the rays
   1596             that did not intersect the scene, the corresponding position should
   1597             be ignored.
   1598 
   1599         previous_int_point : [num_samples, 3], tf.float
   1600             Positions of the previous intersection points of the rays with
   1601             the scene
   1602 
   1603         primitives : [num_samples], tf.int
   1604             Indices of the intersected primitives
   1605 
   1606         e_field : [num_samples, num_tx_patterns, 2], tf.complex
   1607             S and P components of the electric field
   1608 
   1609         field_es : [num_samples, 3], tf.float
   1610             Direction of the S component of the field
   1611 
   1612         field_ep : [num_samples, 3], tf.float
   1613             Direction of the P component of the field
   1614 
   1615         samples_tx_indices : [num_samples], tf.int
   1616             Index of the source from which the path originates
   1617 
   1618         k_tx : [num_samples, 3], tf.float
   1619             Direction of departure from the source
   1620 
   1621         radii_curv : [num_samples, 2], tf.float
   1622             Principal radii of curvature of the ray tubes
   1623 
   1624         dirs_curv : [num_samples, 2, 3], tf.float
   1625             Principal directions of curvature of the ray tubes
   1626 
   1627         angular_opening : [num_active_samples], tf.float
   1628             Angular opening
   1629 
   1630         Output
   1631         -------
   1632         act_e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
   1633             S and P components of the electric field of the active rays
   1634 
   1635         act_field_es : [num_active_samples, 3], tf.float
   1636             Direction of the S component of the field of the active rays
   1637 
   1638         act_field_ep : [num_active_samples, 3], tf.float
   1639             Direction of the P component of the field of the active rays
   1640 
   1641         act_point : [num_active_samples, 3], tf.float
   1642             Positions at which the rays intersect with the scene
   1643 
   1644         act_k_i : [num_active_samples, 3], tf.float
   1645             Direction of the active incident ray
   1646 
   1647         act_dist : [num_active_samples], tf.float
   1648             Length of the last path segment, i.e., distance between `int_point`
   1649             and `previous_int_point`
   1650 
   1651         samples_tx_indices : [num_active_samples], tf.int
   1652             Index of the source from which the path originates
   1653 
   1654         k_tx : [num_active_samples, 3], tf.float
   1655             Direction of departure from the source
   1656 
   1657         act_radii_curv : [num_active_samples, 2], tf.float
   1658             Principal radii of curvature of the ray tubes
   1659 
   1660         act_dirs_curv : [num_active_samples, 2, 3], tf.float
   1661             Principal directions of curvature of the ray tubes
   1662 
   1663         act_angular_opening : [num_active_samples], tf.float
   1664             Angular opening
   1665         """
   1666 
   1667         # Extract the rays that interact the scene
   1668         # [num_active_samples, num_tx_patterns, 2]
   1669         act_e_field = tf.gather(e_field, active_ind, axis=0)
   1670         # [num_active_samples, 3]
   1671         act_field_es = tf.gather(field_es, active_ind, axis=0)
   1672         # [num_active_samples, 3]
   1673         act_field_ep = tf.gather(field_ep, active_ind, axis=0)
   1674         # [num_active_samples, 2]
   1675         act_radii_curv = tf.gather(radii_curv, active_ind, axis=0)
   1676         # [num_active_samples, 2, 3]
   1677         act_dirs_curv = tf.gather(dirs_curv, active_ind, axis=0)
   1678         # [num_active_samples, 3]
   1679         act_previous_int_point = tf.gather(previous_int_point, active_ind,
   1680                                             axis=0)
   1681         # Current intersection point
   1682         # [num_active_samples, 3]
   1683         int_point = tf.gather(int_point, active_ind, axis=0)
   1684         # [num_active_samples]
   1685         act_primitives = tf.gather(primitives, active_ind, axis=0)
   1686 
   1687         # [num_active_samples]
   1688         act_samples_tx_indices = tf.gather(samples_tx_indices, active_ind,
   1689                                            axis=0)
   1690         # [num_active_samples, 3]
   1691         act_k_tx = tf.gather(k_tx, active_ind, axis=0)
   1692 
   1693         # Direction of arrival
   1694         # [num_active_samples, 3]
   1695         act_k_i,act_dist = normalize(int_point - act_previous_int_point)
   1696 
   1697         # Extract angular openings
   1698         act_angular_opening = tf.gather(angular_opening, active_ind, axis=0)
   1699 
   1700         output = (act_e_field, act_field_es, act_field_ep, int_point, act_k_i,
   1701                   act_dist, act_samples_tx_indices, act_k_tx, act_radii_curv,
   1702                   act_dirs_curv, act_primitives, act_angular_opening)
   1703 
   1704         return output
   1705 
   1706     def _extract_active_rays(self, active_ind, int_point, previous_int_point,
   1707         primitives, e_field, field_es, field_ep, samples_tx_indices, k_tx,
   1708         etas, scattering_coefficient, xpd_coefficient, alpha_r, alpha_i,
   1709         lambda_, ris, radii_curv, dirs_curv, angular_opening):
   1710         r"""
   1711         Extracts the active rays.
   1712 
   1713         Input
   1714         ------
   1715         active_ind : [num_active_samples], tf.int
   1716             Indices of the active rays
   1717 
   1718         int_point : [num_samples, 3], tf.float
   1719             Positions at which the rays intersect with the scene. For the rays
   1720             that did not intersect the scene, the corresponding position should
   1721             be ignored.
   1722 
   1723         previous_int_point : [num_samples, 3], tf.float
   1724             Positions of the previous intersection points of the rays with
   1725             the scene
   1726 
   1727         primitives : [num_samples], tf.int
   1728             Indices of the intersected primitives
   1729 
   1730         e_field : [num_samples, num_tx_patterns, 2], tf.complex
   1731             S and P components of the electric field
   1732 
   1733         field_es : [num_samples, 3], tf.float
   1734             Direction of the S component of the field
   1735 
   1736         field_ep : [num_samples, 3], tf.float
   1737             Direction of the P component of the field
   1738 
   1739         samples_tx_indices : [num_samples], tf.int
   1740             Index of the source from which the path originates
   1741 
   1742         k_tx : [num_samples, 3], tf.float
   1743             Direction of departure from the source
   1744 
   1745         etas : [num_shape], tf.complex | `None`
   1746             Tensor containing the complex relative permittivities of all shapes
   1747 
   1748         scattering_coefficient : [num_shape], tf.float | `None`
   1749             Tensor containing the scattering coefficients of all shapes
   1750 
   1751         xpd_coefficient : [num_shape], tf.float | `None`
   1752             Tensor containing the cross-polarization discrimination
   1753             coefficients of all shapes
   1754 
   1755         alpha_r : [num_shape], tf.float | `None`
   1756             Tensor containing the alpha_r scattering parameters of all shapes
   1757 
   1758         alpha_i : [num_shape], tf.float | `None`
   1759             Tensor containing the alpha_i scattering parameters of all shapes
   1760 
   1761         lambda_ : [num_shape], tf.float | `None`
   1762             Tensor containing the lambda_ scattering parameters of all shapes
   1763 
   1764         ris : bool
   1765             Set to `True` if RIS are enabled
   1766 
   1767         radii_curv : [num_samples, 2], tf.float
   1768             Principal radii of curvature of the ray tubes
   1769 
   1770         dirs_curv : [num_samples, 2, 3], tf.float
   1771             Principal directions of curvature of the ray tubes
   1772 
   1773         angular_opening : [num_active_samples], tf.float
   1774             Angular opening
   1775 
   1776         Output
   1777         -------
   1778         act_e_field : [num_active_samples, num_tx_patterns, 2], tf.complex
   1779             S and P components of the electric field of the active rays
   1780 
   1781         act_field_es : [num_active_samples, 3], tf.float
   1782             Direction of the S component of the field of the active rays
   1783 
   1784         act_field_ep : [num_active_samples, 3], tf.float
   1785             Direction of the P component of the field of the active rays
   1786 
   1787         act_point : [num_active_samples, 3], tf.float
   1788             Positions at which the rays intersect with the scene
   1789 
   1790         act_normals : [num_active_samples, 3], tf.float
   1791             Normals at the intersection point. The normals are oriented to match
   1792             the direction opposite to the incident ray
   1793 
   1794         act_etas : [num_active_samples], tf.complex
   1795             Relative permittivity of the intersected primitives
   1796 
   1797         act_scat_coeff : [num_active_samples], tf.float
   1798             Scattering coefficient of the intersected primitives
   1799 
   1800         act_k_i : [num_active_samples, 3], tf.float
   1801             Direction of the active incident ray
   1802 
   1803         act_xpd_coefficient : [num_active_samples], tf.float | `None`
   1804             Tensor containing the cross-polarization discrimination
   1805             coefficients of all shapes.
   1806             Only returned if ``xpd_coefficient`` is not `None`.
   1807 
   1808         act_alpha_r : [num_active_samples], tf.float
   1809             Tensor containing the alpha_r scattering parameters of all shapes.
   1810             Only returned if ``alpha_r`` is not `None`.
   1811 
   1812         act_alpha_i : [num_active_samples], tf.float
   1813             Tensor containing the alpha_i scattering parameters of all shapes
   1814             Only returned if ``alpha_i`` is not `None`.
   1815 
   1816         act_lambda_ : [num_active_samples], tf.float
   1817             Tensor containing the lambda_ scattering parameters of all shapes
   1818             Only returned if ``lambda_`` is not `None`.
   1819 
   1820         act_objects : [num_active_samples], tf.int
   1821             Indices of the intersected objects
   1822 
   1823         act_dist : [num_active_samples], tf.float
   1824             Length of the last path segment, i.e., distance between `int_point`
   1825             and `previous_int_point`
   1826 
   1827         samples_tx_indices : [num_active_samples], tf.int
   1828             Index of the source from which the path originates
   1829 
   1830         k_tx : [num_active_samples, 3], tf.float
   1831             Direction of departure from the source
   1832 
   1833         act_radii_curv : [num_active_samples, 2], tf.float
   1834             Principal radii of curvature of the ray tubes
   1835 
   1836         act_dirs_curv : [num_active_samples, 2, 3], tf.float
   1837             Principal directions of curvature of the ray tubes
   1838 
   1839         act_primitives : [num_active_samples], tf.int
   1840             Indices of the intersected primitives
   1841 
   1842         act_angular_opening : [num_active_samples], tf.float
   1843             Angular opening
   1844         """
   1845 
   1846         # Extract the rays that interact the scene
   1847         # [num_active_samples, num_tx_patterns, 2]
   1848         act_e_field = tf.gather(e_field, active_ind, axis=0)
   1849         # [num_active_samples, 3]
   1850         act_field_es = tf.gather(field_es, active_ind, axis=0)
   1851         # [num_active_samples, 3]
   1852         act_field_ep = tf.gather(field_ep, active_ind, axis=0)
   1853         if ris:
   1854             # [num_active_samples, 2]
   1855             act_radii_curv = tf.gather(radii_curv, active_ind, axis=0)
   1856             # [num_active_samples, 2, 3]
   1857             act_dirs_curv = tf.gather(dirs_curv, active_ind, axis=0)
   1858         else:
   1859             act_radii_curv = None
   1860             act_dirs_curv = None
   1861         # [num_active_samples, 3]
   1862         act_previous_int_point = tf.gather(previous_int_point, active_ind,
   1863                                             axis=0)
   1864         # Current intersection point
   1865         # [num_active_samples, 3]
   1866         int_point = tf.gather(int_point, active_ind, axis=0)
   1867         # [num_active_samples]
   1868         act_primitives = tf.gather(primitives, active_ind, axis=0)
   1869         # [num_active_samples]
   1870         act_objects = tf.gather(self._primitives_2_objects, act_primitives,
   1871                                 axis=0)
   1872         # [num_active_samples]
   1873         act_samples_tx_indices = tf.gather(samples_tx_indices, active_ind,
   1874                                            axis=0)
   1875         # [num_active_samples, 3]
   1876         act_k_tx = tf.gather(k_tx, active_ind, axis=0)
   1877 
   1878         # Extract the normals to the intersected primitives
   1879         # [num_active_samples, 3]
   1880         if self._normals.shape[0] > 0:
   1881             act_normals = tf.gather(self._normals, act_primitives, axis=0)
   1882         else:
   1883             act_normals = None
   1884 
   1885         # If a callable is defined to compute the radio material properties,
   1886         # it is invoked. Otherwise, the radio materials of objects are used.
   1887         rm_callable = self._scene.radio_material_callable
   1888         if rm_callable is None:
   1889             # Extract the material properties of the intersected objects
   1890             if etas is not None:
   1891                 # [num_active_samples]
   1892                 act_etas = tf.gather(etas, act_objects)
   1893             else:
   1894                 act_etas = None
   1895             if scattering_coefficient is not None:
   1896                 # [num_active_samples]
   1897                 act_scat_coeff = tf.gather(scattering_coefficient, act_objects)
   1898             else:
   1899                 act_scat_coeff = None
   1900             if xpd_coefficient is not None:
   1901                 # [num_active_samples]
   1902                 act_xpd_coefficient = tf.gather(xpd_coefficient, act_objects)
   1903             else:
   1904                 act_xpd_coefficient = None
   1905         else:
   1906             # [num_active_samples]
   1907             act_etas, act_scat_coeff, act_xpd_coefficient\
   1908                                         = rm_callable(act_objects, int_point)
   1909 
   1910         # If no callable is defined for the scattering pattern, we need to
   1911         # extract the properties of the scattering patterns built-in Sionna
   1912         if (self._scene.scattering_pattern_callable is None) and\
   1913                                                     (alpha_r is not None) :
   1914             # [num_active_samples]
   1915             act_alpha_r = tf.gather(alpha_r, act_objects)
   1916             act_alpha_i = tf.gather(alpha_i, act_objects)
   1917             act_lambda_ = tf.gather(lambda_, act_objects)
   1918         else:
   1919             act_alpha_r = act_alpha_i = act_lambda_ = None
   1920 
   1921         # Direction of arrival
   1922         # [num_active_samples, 3]
   1923         act_k_i,act_dist = normalize(int_point - act_previous_int_point)
   1924 
   1925         # Ensure the normal points in the direction -k_i
   1926         if act_normals is not None:
   1927             # [num_active_samples, 1]
   1928             flip_normal = -tf.math.sign(dot(act_k_i, act_normals, keepdim=True))
   1929             # [num_active_samples, 3]
   1930             act_normals = flip_normal*act_normals
   1931 
   1932         # Extract angular openings
   1933         if ris:
   1934             act_angular_opening = tf.gather(angular_opening, active_ind, axis=0)
   1935         else:
   1936             act_angular_opening = None
   1937 
   1938         output = (act_e_field, act_field_es, act_field_ep, int_point,
   1939                 act_normals, act_etas, act_scat_coeff, act_k_i,
   1940                 act_xpd_coefficient, act_alpha_r, act_alpha_i, act_lambda_,
   1941                 act_objects, act_dist, act_samples_tx_indices, act_k_tx,
   1942                 act_radii_curv, act_dirs_curv, act_primitives,
   1943                 act_angular_opening)
   1944 
   1945         return output
   1946 
   1947     def _sample_interaction_phenomena(self, active, int_point, primitives,
   1948                             scattering_coefficient, reflection, scattering):
   1949         r"""
   1950         Samples the interaction phenomena to apply to each active ray, among
   1951         scattering or reflection.
   1952 
   1953         This is done by sampling a Bernouilli distribution with probability p
   1954         equal to the square of the scattering coefficient amplitude, as it
   1955         corresponds to the ratio of the reflected energy that goes to
   1956         scattering. With probability p, the ray is scattered. Otherwise, it is
   1957         reflected.
   1958 
   1959         Input
   1960         ------
   1961         active : [num_samples], tf.bool
   1962             Flag indicating if a ray is active
   1963 
   1964         int_point : [num_samples, 3], tf.float
   1965             Positions at which the rays intersect with the scene. For the rays
   1966             that did not intersect the scene, the corresponding position should
   1967             be ignored.
   1968 
   1969         scattering_coefficient : [num_shape], tf.complex
   1970             Scattering coefficients of all shapes
   1971 
   1972         reflection : bool
   1973             Set to `True` if reflection is enabled
   1974 
   1975         scattering : bool
   1976             Set to `True` if scattering is enabled
   1977 
   1978         Output
   1979         -------
   1980         reflect_ind : [num_reflected_samples], tf.int
   1981             Indices of the rays that are reflected
   1982 
   1983         scatter_ind : [num_scattered_samples], tf.int
   1984             Indices of the rays that are scattered
   1985         """
   1986 
   1987         # Indices of the active samples
   1988         # [num_active_samples]
   1989         active_ind = tf.where(active)[:,0]
   1990 
   1991         # If only one of reflection or scattering is enabled, then all the
   1992         # samples are used for the enabled phenomena to avoid wasting samples
   1993         # by allocating them to a phenomena that is not requested by the users.
   1994         # This approach, however, requires to correct later the contribution
   1995         # of the rays by weighting them by the square of the scattering or
   1996         # reduction factor, depending on the selected phenomena.
   1997         # This is done in the functions that compute the reflected and scattered
   1998         # field.
   1999         if not (reflection or scattering):
   2000             reflect_ind = tf.zeros([0], tf.int32)
   2001             scatter_ind = tf.zeros([0], tf.int32)
   2002         elif not reflection:
   2003             reflect_ind = tf.zeros([0], tf.int32)
   2004             scatter_ind = active_ind
   2005         elif not scattering:
   2006             reflect_ind = active_ind
   2007             scatter_ind = tf.zeros([0], tf.int32)
   2008         else:
   2009             # Scattering coefficients of the intersected objects
   2010             # [num_active_samples]
   2011             act_primitives = tf.gather(primitives, active_ind, axis=0)
   2012             act_objects = tf.gather(self._primitives_2_objects, act_primitives,
   2013                                     axis=0)
   2014             # Current intersection point
   2015             # [num_active_samples, 3]
   2016             int_point = tf.gather(int_point, active_ind, axis=0)
   2017 
   2018             # If a callable is defined to compute the radio material properties,
   2019             # it is invoked. Otherwise, the radio materials of objects are used.
   2020             rm_callable = self._scene.radio_material_callable
   2021             if rm_callable is None:
   2022                 # [num_active_samples]
   2023                 act_scat_coeff = tf.gather(scattering_coefficient, act_objects)
   2024             else:
   2025                 # [num_active_samples]
   2026                 _, act_scat_coeff, _ = rm_callable(act_objects, int_point)
   2027 
   2028             # Probability of scattering
   2029             # [num_active_samples]
   2030             prob_scatter = tf.square(tf.abs(act_scat_coeff))
   2031 
   2032             # Sampling a Bernoulli distribution
   2033             # [num_active_samples]
   2034             scatter = config.tf_rng.uniform(tf.shape(prob_scatter),
   2035                                             tf.zeros((), self._rdtype),
   2036                                             tf.ones((), self._rdtype),
   2037                                             dtype=self._rdtype)
   2038             scatter = tf.less(scatter, prob_scatter)
   2039 
   2040             # Extract indices of the reflected and scattered rays
   2041             # [num_reflected_samples]
   2042             reflect_ind = tf.gather(active_ind, tf.where(~scatter)[:,0])
   2043             # [num_scattered_samples]
   2044             scatter_ind = tf.gather(active_ind, tf.where(scatter)[:,0])
   2045 
   2046         return reflect_ind, scatter_ind
   2047 
   2048     def _apply_reflection(self, active_ind, int_point, previous_int_point,
   2049         primitives, e_field, field_es, field_ep, samples_tx_indices, k_tx,
   2050         etas, scattering_coefficient, scattering, ris, radii_curv, dirs_curv,
   2051         angular_opening):
   2052         r"""
   2053         Apply reflection.
   2054 
   2055         Input
   2056         ------
   2057         active_ind : [num_reflected_samples], tf.int
   2058             Indices of the *active* rays to which reflection must be applied.
   2059 
   2060         int_point : [num_samples, 3], tf.float
   2061             Locations of the intersection point
   2062 
   2063         previous_int_point : [num_samples, 3], tf.float
   2064             Locations of the intersection points of the previous interaction.
   2065 
   2066         primitives : [num_samples], tf.int
   2067             Indices of the intersected primitives
   2068 
   2069         e_field : [num_samples, num_tx_patterns, 2], tf.complex
   2070             S and P components of the electric field
   2071 
   2072         field_es : [num_samples, 3], tf.float
   2073             Direction of the S component of the field
   2074 
   2075         field_ep : [num_samples, 3], tf.float
   2076             Direction of the P component of the field
   2077 
   2078         samples_tx_indices : [num_samples], tf.int
   2079             Index of the source from which the path originates
   2080 
   2081         k_tx : [num_samples, 3], tf.float
   2082             Direction of departure from the source
   2083 
   2084         etas : [num_shape], tf.complex
   2085             Complex relative permittivities of all shapes
   2086 
   2087         scattering_coefficient : [num_shape], tf.float
   2088             Scattering coefficients of all shapes
   2089 
   2090         scattering : bool
   2091             Set to `True` if scattering is enabled
   2092 
   2093         ris : bool
   2094             Set to `True` if scattering is enabled
   2095 
   2096         radii_curv : [num_active_samples, 2], tf.float
   2097             Principal radii of curvature
   2098 
   2099         dirs_curv : [num_active_samples, 2, 3], tf.float
   2100             Principal direction of curvature
   2101 
   2102         angular_opening : [num_active_samples], tf.float
   2103             Angular opening
   2104 
   2105         Output
   2106         -------
   2107         e_field : [num_reflected_samples, num_tx_patterns, 2], tf.complex
   2108             S and P components of the reflected electric field
   2109 
   2110         field_es : [num_reflected_samples, 3], tf.float
   2111             Direction of the S component of the reflected field
   2112 
   2113         field_ep : [num_reflected_samples, 3], tf.float
   2114             Direction of the P component of the reflected field
   2115 
   2116         int_point : [num_reflected_samples, 3], tf.float
   2117             Locations of the intersection point
   2118 
   2119         k_r : [num_reflected_samples, 3], tf.float
   2120             Direction of the reflected ray
   2121 
   2122         samples_tx_indices : [num_reflected_samples], tf.int
   2123             Index of the source from which the path originates
   2124 
   2125         k_tx : [num_reflected_samples, 3], tf.float
   2126             Direction of departure from the source
   2127 
   2128         normals : [num_reflected_samples, 3], tf.float
   2129             Normals at the intersection points
   2130 
   2131         radii_curv : [num_reflected_samples, 2], tf.float
   2132             Principal radii of curvature of the reflected field
   2133 
   2134         dirs_curv : [num_reflected_samples, 2, 3], tf.float
   2135             Principal direction of curvature of the reflected field
   2136 
   2137         angular_opening : [num_reflected_samples], tf.float
   2138             Angular opening of the reflected ray
   2139         """
   2140 
   2141         # Prepare field computation
   2142         # This function extract the data for the rays to which reflection
   2143         # must be applied, and ensures that the normals are correctly oriented.
   2144         act_data = self._extract_active_rays(active_ind, int_point,
   2145             previous_int_point, primitives, e_field, field_es, field_ep,
   2146             samples_tx_indices, k_tx, etas, scattering_coefficient, None, None,
   2147             None, None, ris, radii_curv, dirs_curv, angular_opening)
   2148         # [num_reflected_samples, num_tx_patterns, 2]
   2149         e_field = act_data[0]
   2150         # [num_reflected_samples, 3]
   2151         field_es = act_data[1]
   2152         field_ep = act_data[2]
   2153         int_point = act_data[3]
   2154         # [num_reflected_samples, 3]
   2155         act_normals = act_data[4]
   2156         # [num_reflected_samples]
   2157         act_etas = act_data[5]
   2158         act_scat_coeff = act_data[6]
   2159         # [num_reflected_samples, 3]
   2160         k_i = act_data[7]
   2161         # Length of the last path segment
   2162         # [num_reflected_samples]
   2163         length = act_data[13]
   2164         # Index of the intersected source
   2165         samples_tx_indices = act_data[14]
   2166         # Direction of departure form the source
   2167         k_tx = act_data[15]
   2168         if ris:
   2169             # Principal radii and directions of curvatures
   2170             # [num_reflected_samples, 2]
   2171             radii_curv = act_data[16]
   2172             # [num_reflected_samples, 2, 3]
   2173             dirs_curv = act_data[17]
   2174             # [num_reflected_samples]
   2175             angular_opening = act_data[19]
   2176 
   2177         # Compute the reflected field
   2178         e_field, field_es, field_ep, k_r, radii_curv, dirs_curv\
   2179             = self._compute_reflected_field(act_normals,
   2180                 act_etas, act_scat_coeff, k_i, e_field, field_es, field_ep,
   2181                 scattering, ris, length, radii_curv, dirs_curv)
   2182 
   2183         output = (e_field, field_es, field_ep, int_point, k_r, act_normals,
   2184                   samples_tx_indices, k_tx, radii_curv, dirs_curv,
   2185                   angular_opening)
   2186         return output
   2187 
   2188     def _apply_scattering(self, active_ind, int_point, previous_int_point,
   2189         primitives, e_field, field_es, field_ep,  samples_tx_indices, k_tx,
   2190         etas, scattering_coefficient, xpd_coefficient, alpha_r, alpha_i,
   2191         lambda_, reflection, ris, radii_curv, dirs_curv, angular_opening):
   2192         r"""
   2193         Apply scattering.
   2194 
   2195         Input
   2196         ------
   2197         active_ind : [num_scattered_samples], tf.int
   2198             Indices of the *active* rays to which scattering must be applied.
   2199 
   2200         int_point : [num_samples, 3], tf.float
   2201             Locations of the intersection point
   2202 
   2203         previous_int_point : [num_samples, 3], tf.float
   2204             Locations of the intersection points of the previous interaction.
   2205 
   2206         primitives : [num_samples], tf.int
   2207             Indices of the intersected primitives
   2208 
   2209         e_field : [num_samples, num_tx_patterns, 2], tf.complex
   2210             S and P components of the electric field
   2211 
   2212         field_es : [num_samples, 3], tf.float
   2213             Direction of the S component of the field
   2214 
   2215         field_ep : [num_samples, 3], tf.float
   2216             Direction of the P component of the field
   2217 
   2218         samples_tx_indices : [num_samples], tf.int
   2219             Index of the source from which the path originates
   2220 
   2221         k_tx : [num_samples, 3], tf.float
   2222             Direction of departure from the source
   2223 
   2224         etas : [num_shape], tf.complex
   2225             Complex relative permittivities of all shapes
   2226 
   2227         scattering_coefficient : [num_shape], tf.float
   2228             Scattering coefficients of all shapes
   2229 
   2230         xpd_coefficient : [num_shape], tf.float | `None`
   2231             Cross-polarization discrimination coefficients of all shapes
   2232 
   2233         alpha_r : [num_shape], tf.float | `None`
   2234             alpha_r scattering parameters of all shapes
   2235 
   2236         alpha_i : [num_shape], tf.float | `None`
   2237             alpha_i scattering parameters of all shapes
   2238 
   2239         lambda_ : [num_shape], tf.float | `None`
   2240             lambda_ scattering parameters of all shapes
   2241 
   2242         reflection : bool
   2243             Set to `True` if reflection is enabled
   2244 
   2245         ris : bool
   2246             Set to `True` if RIS is enabled
   2247 
   2248         radii_curv : [num_active_samples, 2], tf.float
   2249             Principal radii of curvature
   2250 
   2251         dirs_curv : [num_active_samples, 2, 3], tf.float
   2252             Principal direction of curvature
   2253 
   2254         angular_opening : [num_active_samples], tf.float
   2255             Angular opening
   2256 
   2257         Output
   2258         -------
   2259         e_field : [num_scattered_samples, num_tx_patterns, 2], tf.complex
   2260             S and P components of the scattered electric field
   2261 
   2262         field_es : [num_scattered_samples, 3], tf.float
   2263             Direction of the S component of the scattered field
   2264 
   2265         field_ep : [num_scattered_samples, 3], tf.float
   2266             Direction of the P component of the scattered field
   2267 
   2268         int_point : [num_scattered_samples, 3], tf.float
   2269             Locations of the intersection point
   2270 
   2271         k_r : [num_scattered_samples, 3], tf.float
   2272             Direction of the scattered ray
   2273 
   2274         samples_tx_indices : [num_scattered_samples], tf.int
   2275             Index of the source from which the path originates
   2276 
   2277         k_tx : [num_scattered_samples, 3], tf.float
   2278             Direction of departure from the source
   2279 
   2280         normals : [num_scattered_samples, 3], tf.float
   2281             Normals at the intersection points
   2282 
   2283         radii_curv : [num_scattered_samples, 2], tf.float
   2284             Principal radii of curvature of the scattered field
   2285 
   2286         dirs_curv : [num_scattered_samples, 2, 3], tf.float
   2287             Principal direction of curvature of the scattered field
   2288 
   2289         angular_opening : [num_scattered_samples], tf.float
   2290             Angular opening of the scattered field
   2291         """
   2292 
   2293         # Prepare field computation
   2294         # This function extract the data for the rays to which scattering
   2295         # must be applied, and ensures that the normals are correcly oriented.
   2296         act_data = self._extract_active_rays(active_ind, int_point,
   2297             previous_int_point, primitives, e_field, field_es, field_ep,
   2298             samples_tx_indices, k_tx, etas, scattering_coefficient,
   2299             xpd_coefficient, alpha_r, alpha_i, lambda_, ris, radii_curv,
   2300             dirs_curv, angular_opening)
   2301         # [num_scattered_samples, num_tx_patterns, 2]
   2302         e_field = act_data[0]
   2303         # [num_scattered_samples, 3]
   2304         field_es = act_data[1]
   2305         field_ep = act_data[2]
   2306         int_point = act_data[3]
   2307         # [num_scattered_samples, 3]
   2308         act_normals = act_data[4]
   2309         # [num_scattered_samples]
   2310         act_etas = act_data[5]
   2311         act_scat_coeff = act_data[6]
   2312         # [num_scattered_samples, 3]
   2313         k_i = act_data[7]
   2314         # [num_scattered_samples]
   2315         act_xpd_coefficient = act_data[8]
   2316         act_alpha_r = act_data[9]
   2317         act_alpha_i = act_data[10]
   2318         act_lambda_ = act_data[11]
   2319         act_objects = act_data[12]
   2320         # Length of the last path segment
   2321         # [num_scattered_samples]
   2322         length = act_data[13]
   2323         # Index of the intersected source
   2324         samples_tx_indices = act_data[14]
   2325         # Direction of departure form the source
   2326         k_tx = act_data[15]
   2327         if ris:
   2328             # Principal radii and directions of curvatures
   2329             # [num_scattered_samples, 2]
   2330             radii_curv = act_data[16]
   2331             # [num_scattered_samples, 2, 3]
   2332             dirs_curv = act_data[17]
   2333             # [num_scattered_samples]
   2334             angular_opening = act_data[19]
   2335 
   2336         # Compute the scattered field
   2337         e_field, field_es, field_ep, k_r, radii_curv, dirs_curv,\
   2338             angular_opening = self._compute_scattered_field(int_point,
   2339                 act_objects, act_normals, act_etas, act_scat_coeff,
   2340                 act_xpd_coefficient, act_alpha_r, act_alpha_i, act_lambda_,
   2341                 k_i, e_field, field_es, field_ep, reflection, ris, length,
   2342                 radii_curv, angular_opening)
   2343 
   2344         output = (e_field, field_es, field_ep, int_point, k_r, act_normals,
   2345                   samples_tx_indices, k_tx,  radii_curv, dirs_curv,
   2346                   angular_opening)
   2347         return output
   2348 
   2349     def _apply_ris_reflection(self, active_ind, int_point, previous_int_point,
   2350         primitives, e_field, field_es, field_ep, samples_tx_indices, k_tx,
   2351         radii_curv, dirs_curv, angular_opening):
   2352         r"""
   2353         Apply scattering.
   2354 
   2355         Input
   2356         ------
   2357         active_ind : [num_ris_reflected_samples], tf.int
   2358             Indices of the *active* rays to which scattering must be applied.
   2359 
   2360         int_point : [num_samples, 3], tf.float
   2361             Locations of the intersection point
   2362 
   2363         previous_int_point : [num_samples, 3], tf.float
   2364             Locations of the intersection points of the previous interaction.
   2365 
   2366         primitives : [num_samples], tf.int
   2367             Indices of the intersected primitives
   2368 
   2369         e_field : [num_samples, num_tx_patterns, 2], tf.complex
   2370             S and P components of the electric field
   2371 
   2372         field_es : [num_samples, 3], tf.float
   2373             Direction of the S component of the field
   2374 
   2375         field_ep : [num_samples, 3], tf.float
   2376             Direction of the P component of the field
   2377 
   2378         samples_tx_indices : [num_samples], tf.int
   2379             Index of the source from which the path originates
   2380 
   2381         k_tx : [num_samples, 3], tf.float
   2382             Direction of departure from the source
   2383 
   2384         radii_curv : [num_active_samples, 2], tf.float
   2385             Principal radii of curvature
   2386 
   2387         dirs_curv : [num_active_samples, 2, 3], tf.float
   2388             Principal direction of curvature
   2389 
   2390         angular_opening : [num_active_samples], tf.float
   2391             Angular opening
   2392 
   2393         Output
   2394         -------
   2395         e_field : [num_ris_reflected_samples, num_tx_patterns, 2], tf.complex
   2396             S and P components of the reflected electric field
   2397 
   2398         field_es : [num_ris_reflected_samples, 3], tf.float
   2399             Direction of the S component of the reflected field
   2400 
   2401         field_ep : [num_ris_reflected_samples, 3], tf.float
   2402             Direction of the P component of the reflected field
   2403 
   2404         int_point : [num_ris_reflected_samples, 3], tf.float
   2405             Locations of the intersection point
   2406 
   2407         k_r : [num_ris_reflected_samples, 3], tf.float
   2408             Direction of the reflected ray
   2409 
   2410         samples_tx_indices : [num_ris_reflected_samples], tf.int
   2411             Index of the intersected transmitter
   2412 
   2413         k_tx : [num_ris_reflected_samples, 3], tf.float
   2414             Direction of departure from the source
   2415 
   2416         normals : [num_ris_reflected_samples, 3], tf.float
   2417             Normals at the intersection points
   2418 
   2419         radii_curv : [num_ris_reflected_samples, 2], tf.float
   2420             Principal radii of curvature of the reflected field
   2421 
   2422         dirs_curv : [num_ris_reflected_samples, 2, 3], tf.float
   2423             Principal direction of curvature of the reflected field
   2424 
   2425         angular_opening : [num_ris_reflected_samples], tf.float
   2426             Angular opening of the reflected field
   2427         """
   2428         # Prepare field computation
   2429         # This function extract the data for the rays to which scattering
   2430         # must be applied, and ensures that the normals are correctly oriented.
   2431         act_data = self._extract_active_ris_rays(active_ind, int_point,
   2432             previous_int_point, primitives, e_field, field_es, field_ep,
   2433             samples_tx_indices, k_tx, radii_curv, dirs_curv, angular_opening)
   2434         # [num_ris_reflected_samples, num_tx_patterns, 2]
   2435         e_field = act_data[0]
   2436         # [num_ris_reflected_samples, 3]
   2437         field_es = act_data[1]
   2438         field_ep = act_data[2]
   2439         int_point = act_data[3]
   2440         # [num_ris_reflected_samples, 3]
   2441         k_i = act_data[4]
   2442         # Length of the last path segment
   2443         # [num_ris_reflected_samples]
   2444         length = act_data[5]
   2445         # Index of the intersected source
   2446         samples_tx_indices = act_data[6]
   2447         # Direction of departure form the source
   2448         k_tx = act_data[7]
   2449         # Principal radii and directions of curvatures
   2450         # [num_ris_reflected_samples, 2]
   2451         radii_curv = act_data[8]
   2452         # [num_ris_reflected_samples, 2, 3]
   2453         dirs_curv = act_data[9]
   2454         # [num_ris_reflected_samples]
   2455         ris_ind = act_data[10]
   2456         # [num_ris_reflected_samples]
   2457         angular_opening = act_data[11]
   2458 
   2459         # Compute the reflected field
   2460         e_field, field_es, field_ep, k_r, normals, radii_curv, dirs_curv,\
   2461             = self._compute_ris_reflected_field(int_point, ris_ind, k_i,
   2462                 e_field, field_es, field_ep, length, radii_curv, dirs_curv)
   2463 
   2464         output = (e_field, field_es, field_ep, int_point, k_r, normals,
   2465                   samples_tx_indices, k_tx, radii_curv, dirs_curv,
   2466                   angular_opening)
   2467         return output
   2468 
   2469     def _shoot_and_bounce(self,
   2470                           meas_plane,
   2471                           ris_objects,
   2472                           ris_indices,
   2473                           rx_orientation,
   2474                           sources_positions,
   2475                           sources_orientations,
   2476                           max_depth,
   2477                           num_samples,
   2478                           combining_vec,
   2479                           precoding_vec,
   2480                           cm_center, cm_orientation, cm_size, cm_cell_size,
   2481                           los,
   2482                           reflection,
   2483                           diffraction,
   2484                           scattering,
   2485                           ris,
   2486                           etas,
   2487                           scattering_coefficient,
   2488                           xpd_coefficient,
   2489                           alpha_r,
   2490                           alpha_i,
   2491                           lambda_,
   2492                           random_lattice):
   2493         r"""
   2494         Runs shoot-and-bounce to build the coverage map for LoS, reflection,
   2495         and scattering.
   2496 
   2497         If ``diffraction`` is set to `True`, this function also returns the
   2498         primitives in LoS with at least one transmitter.
   2499 
   2500         Input
   2501         ------
   2502         meas_plane : mi.Shape
   2503             Mitsuba rectangle defining the measurement plane
   2504 
   2505         ris_objects : list(mi.Rectangle)
   2506             List of Mitsuba rectangles implementing the RIS
   2507 
   2508         ris_indices : mi.UInt
   2509             RIS indices
   2510 
   2511         rx_orientation : [3], tf.float
   2512             Orientation of the receiver.
   2513 
   2514         sources_positions : [num_tx, 3], tf.float
   2515             Coordinates of the sources.
   2516 
   2517         max_depth : int
   2518             Maximum number of reflections
   2519 
   2520         num_samples : int
   2521             Number of rays initially shooted from the transmitters.
   2522             This number is shared by all transmitters, i.e.,
   2523             ``num_samples/num_tx`` are shooted for each transmitter.
   2524 
   2525         combining_vec : [num_rx_ant], tf.complex
   2526             Combining vector.
   2527             If set to `None`, then no combining is applied, and
   2528             the energy received by all antennas is summed.
   2529 
   2530         precoding_vec : [num_tx or 1, num_tx_ant], tf.complex
   2531             Precoding vectors of the transmitters
   2532 
   2533         cm_center : [3], tf.float
   2534             Center of the coverage map
   2535 
   2536         cm_orientation : [3], tf.float
   2537             Orientation of the coverage map
   2538 
   2539         cm_size : [2], tf.float
   2540             Scale of the coverage map.
   2541             The width of the map (in the local X direction) is scale[0]
   2542             and its map (in the local Y direction) scale[1].
   2543 
   2544         cm_cell_size : [2], tf.float
   2545             Resolution of the coverage map, i.e., width
   2546             (in the local X direction) and height (in the local Y direction) in
   2547             meters of a cell of the coverage map
   2548 
   2549         los : bool
   2550             If set to `True`, then the LoS paths are computed.
   2551 
   2552         reflection : bool
   2553             If set to `True`, then the reflected paths are computed.
   2554 
   2555         diffraction : bool
   2556             If set to `True`, then the diffracted paths are computed.
   2557 
   2558         scattering : bool
   2559             If set to `True`, then the scattered paths are computed.
   2560 
   2561         ris : bool
   2562             If set to `True`, then paths involving RIS are computed.
   2563 
   2564         etas : [num_shape], tf.complex
   2565             Tensor containing the complex relative permittivities of all shapes
   2566 
   2567         scattering_coefficient : [num_shape], tf.float
   2568             Tensor containing the scattering coefficients of all shapes
   2569 
   2570         xpd_coefficient : [num_shape], tf.float
   2571             Tensor containing the cross-polarization discrimination
   2572             coefficients of all shapes
   2573 
   2574         alpha_r : [num_shape], tf.float
   2575             Tensor containing the alpha_r scattering parameters of all shapes
   2576 
   2577         alpha_i : [num_shape], tf.float
   2578             Tensor containing the alpha_i scattering parameters of all shapes
   2579 
   2580         lambda_ : [num_shape], tf.float
   2581             Tensor containing the lambda_ scattering parameters of all shapes
   2582 
   2583         random_lattice : bool
   2584             If set to `True`, a random rotation is applied to the Fibonacci
   2585             lattice
   2586 
   2587         Output
   2588         ------
   2589         cm : [num_tx, num_cells_y, num_cells_x], tf.float
   2590             Coverage map for every transmitter.
   2591             Includes LoS, reflection, and scattering.
   2592 
   2593         los_primitives: [num_los_primitives], int | `None`
   2594             Primitives in LoS.
   2595             `None` is returned if ``diffraction`` is set to `False`.
   2596         """
   2597 
   2598         ris = ris and (len(self._scene.ris) > 0)
   2599 
   2600         # Ensure that sample count can be distributed over the emitters
   2601         num_tx = sources_positions.shape[0]
   2602         samples_per_tx_float = tf.math.ceil(num_samples / num_tx)
   2603         samples_per_tx = int(samples_per_tx_float)
   2604         num_samples = num_tx * samples_per_tx
   2605 
   2606         # Transmitters and receivers rotation matrices
   2607         # [3, 3]
   2608         rx_rot_mat = rotation_matrix(rx_orientation)
   2609         # [num_tx, 3, 3]
   2610         tx_rot_mat = rotation_matrix(sources_orientations)
   2611 
   2612         # Rotation matrix to go from the measurement plane LCS to the GCS, and
   2613         # the othwer way around
   2614         # [3,3]
   2615         rot_mp_2_gcs = rotation_matrix(cm_orientation)
   2616         rot_gcs_2_mp = tf.transpose(rot_mp_2_gcs)
   2617         # Normal to the CM
   2618         # Add a dimension for broadcasting
   2619         # [1, 3]
   2620         cm_normal = tf.expand_dims(rot_mp_2_gcs[:,2], axis=0)
   2621 
   2622         # Number of cells in the coverage map
   2623         # [2,2]
   2624         num_cells_x = tf.cast(tf.math.ceil(cm_size[0]/cm_cell_size[0]),
   2625                               tf.int32)
   2626         num_cells_y = tf.cast(tf.math.ceil(cm_size[1]/cm_cell_size[1]),
   2627                               tf.int32)
   2628         num_cells = tf.stack([num_cells_x, num_cells_y], axis=-1)
   2629 
   2630         # Primitives in LoS are required for diffraction
   2631         los_primitives = None
   2632 
   2633         # Initialize rays.
   2634         # Direction arranged in a Fibonacci lattice on the unit
   2635         # sphere.
   2636         # [num_samples, 3]
   2637         ps = fibonacci_lattice(samples_per_tx, self._rdtype)
   2638         ps_dr = self._mi_point2_t(ps)
   2639         ps_dr = dr.tile(ps_dr, num_tx)
   2640         k_tx_dr = mi.warp.square_to_uniform_sphere(ps_dr)
   2641         if random_lattice:
   2642             # Generate a random 3D rotation and apply it to the rays directions
   2643             angles = config.tf_rng.uniform([3],
   2644                                      minval=tf.cast(0.0, self._rdtype),
   2645                                      maxval=tf.cast(PI, self._rdtype),
   2646                                      dtype=self._rdtype)
   2647             rnd_rotation = angles_to_mitsuba_rotation(angles)
   2648             k_tx_dr = rnd_rotation@k_tx_dr
   2649         k_tx = mi_to_tf_tensor(k_tx_dr, self._rdtype)
   2650         # Origin placed on the given transmitters
   2651         # [num_samples]
   2652         samples_tx_indices_dr = dr.linspace(self._mi_scalar_t, 0, num_tx-1e-7,
   2653                                             num=num_samples, endpoint=False)
   2654         samples_tx_indices_dr = mi.Int32(samples_tx_indices_dr)
   2655         samples_tx_indices = mi_to_tf_tensor(samples_tx_indices_dr, tf.int32)
   2656         # [num_samples, 3]
   2657         rays_origin_dr = dr.gather(self._mi_vec_t,
   2658                                    self._mi_tensor_t(sources_positions).array,
   2659                                    samples_tx_indices_dr)
   2660         rays_origin = mi_to_tf_tensor(rays_origin_dr, self._rdtype)
   2661         # Rays
   2662         ray = mi.Ray3f(o=rays_origin_dr, d=k_tx_dr)
   2663 
   2664         # Previous intersection point. Initialized to the transmitter position
   2665         # [num_samples, 3]
   2666         previous_int_point = rays_origin
   2667 
   2668         # Initializing the coverage map
   2669         # Add dummy row and columns to store the items that are out of the
   2670         # coverage map
   2671         # [num_tx, num_cells_y+1, num_cells_x+1]
   2672         cm = tf.zeros([num_tx, num_cells_y+1, num_cells_x+1],dtype=self._rdtype)
   2673 
   2674         if ris:
   2675             # Radii of curvatures are initialized to 0
   2676             # [num_samples, 2]
   2677             radii_curv = tf.zeros([num_samples, 2], dtype=self._rdtype)
   2678             # Principal directions of curvatures are represented in the GCS.
   2679             # Waves radiated by the transmitter are spherical, and therefore any
   2680             # vectors u,v such that (u,v,k) is an orthonormal basis and where k
   2681             # is the direction of propagation are principal directions of
   2682             # curvature.
   2683             # [num_samples, 3]
   2684             dir_curv_1, dir_curv_2 = gen_basis_from_z(k_tx, SolverBase.EPSILON)
   2685             dirs_curv = tf.stack([dir_curv_1, dir_curv_2], axis=1)
   2686             # Angular opening of the ray tube
   2687             # [num_samples]
   2688             angular_opening = tf.fill([num_samples],
   2689                                 tf.cast(4.*PI/samples_per_tx_float,
   2690                                         self._rdtype))
   2691         else:
   2692             # The following quantities are not used if RIS are disabled
   2693             radii_curv = None
   2694             dirs_curv = None
   2695             angular_opening = None
   2696 
   2697         # Offset to apply to the Mitsuba shape modeling RIS to get the
   2698         # corresponding objects ids
   2699         if len(self._scene.objects)>0:
   2700             ris_ind_offset = max(obj.object_id for obj in
   2701                                  self._scene.objects.values())
   2702         else:
   2703             ris_ind_offset = 0
   2704         # Because Mitsuba does not necessarily assign IDs starting from 1,
   2705         # we need to account for this offset
   2706         ris_mi_ids = mi_to_tf_tensor(ris_indices, tf.int32)
   2707         ris_ind_offset -= (tf.reduce_min(ris_mi_ids).numpy() - 1)
   2708 
   2709         for depth in tf.range(max_depth+1):
   2710 
   2711             ################################################
   2712             # Intersection test
   2713             ################################################
   2714 
   2715             # Intersect with scene
   2716             si_scene = self._mi_scene.ray_intersect(ray)
   2717 
   2718             # Intersect with the measurement plane
   2719             si_mp = meas_plane.ray_intersect(ray)
   2720 
   2721             # Intersect with RIS
   2722             # It is required to split the kernel as intersections are
   2723             # tested with another Mitsuba scene containing the RIS
   2724             if ris:
   2725                 si_ris_val, si_ris_t, ris_ind = self._ris_intersect(ris_objects,
   2726                                                                     ray, True)
   2727             else:
   2728                 si_ris_t = float("inf")
   2729                 si_ris_val = False
   2730 
   2731             hit_scene_dr = si_scene.is_valid() & (si_scene.t < si_ris_t)
   2732             hit_ris_dr = si_ris_val & (si_ris_t <= si_scene.t)
   2733 
   2734             # A ray is active if it interacted with the scene or a RIS
   2735             # [num_samples]
   2736             active_dr = hit_scene_dr | hit_ris_dr
   2737             # [num_samples]
   2738             hit_scene = mi_to_tf_tensor(hit_scene_dr, tf.bool)
   2739             hit_ris = mi_to_tf_tensor(hit_ris_dr, tf.bool)
   2740             active = mi_to_tf_tensor(active_dr, tf.bool)
   2741 
   2742             # Hit the measurement plane?
   2743             # An intersection with the coverage map is only valid if it was
   2744             # not obstructed
   2745             # [num_samples]
   2746             hit_mp_dr =  si_mp.is_valid()\
   2747                         & (si_mp.t < si_scene.t)\
   2748                         & (si_mp.t < si_ris_t)
   2749             # [num_samples]
   2750             hit_mp = mi_to_tf_tensor(hit_mp_dr, tf.bool)
   2751 
   2752             # Discard LoS if requested
   2753             # [num_samples]
   2754             hit_mp &= (los or (depth > 0))
   2755 
   2756             ################################################
   2757             # Initialize the electric field
   2758             ################################################
   2759 
   2760             # The field is initialized with the transmit field in the GCS
   2761             # at the first iteration for rays that either hit the coverage map
   2762             # or are active
   2763             if depth == 0:
   2764                 init_ray_dr = active_dr | si_mp.is_valid()
   2765                 init_ray = mi_to_tf_tensor(init_ray_dr, tf.bool)
   2766                 e_field, field_es, field_ep = self._init_e_field(init_ray,
   2767                 samples_tx_indices, k_tx, tx_rot_mat)
   2768 
   2769             ################################################
   2770             # Update the coverage map
   2771             ################################################
   2772             # Intersection point with the measurement plane
   2773             # [num_samples, 3]
   2774             mp_hit_point = ray.o + si_mp.t*ray.d
   2775             mp_hit_point = mi_to_tf_tensor(mp_hit_point, self._rdtype)
   2776 
   2777             cm = self._update_coverage_map(cm_center, cm_size,
   2778                 cm_cell_size, num_cells, rot_gcs_2_mp, cm_normal, tx_rot_mat,
   2779                 rx_rot_mat, precoding_vec, combining_vec, samples_tx_indices,
   2780                 e_field, field_es, field_ep, mp_hit_point, hit_mp, k_tx,
   2781                 previous_int_point, cm, ris, radii_curv, angular_opening)
   2782 
   2783             # If the maximum requested depth is reached, we stop, as we just
   2784             # updated the coverage map with the last requested contribution from
   2785             # the rays.
   2786             # We also stop if there is no remaining active ray.
   2787             if (depth == max_depth) or (not tf.reduce_any(active)):
   2788                 break
   2789 
   2790             #############################################
   2791             # Extract primitives and RIS that were hit by
   2792             # active rays.
   2793             #############################################
   2794 
   2795             # Extract the scene primitives that were hit
   2796             if dr.shape(self._shape_indices)[0] > 0: # Scene is not empty
   2797                 shape_i = dr.gather(mi.Int32, self._shape_indices,
   2798                             dr.reinterpret_array_v(mi.UInt32, si_scene.shape),
   2799                             hit_scene_dr)
   2800                 offsets = dr.gather(mi.Int32, self._prim_offsets, shape_i,
   2801                                     hit_scene_dr)
   2802                 scene_primitives = offsets + si_scene.prim_index
   2803             else: # Scene is empty
   2804                 scene_primitives = dr.zeros(mi.Int32, dr.shape(hit_scene_dr)[0])
   2805 
   2806             # Extract indices of RIS that were hit
   2807             if ris:
   2808                 ris_ind = ris_ind + ris_ind_offset
   2809             else:
   2810                 ris_ind = dr.zeros(mi.Int32, dr.shape(hit_scene_dr)[0])
   2811 
   2812             # Combine into a single array
   2813             # [num_samples]
   2814             primitives = dr.select(hit_scene_dr, scene_primitives, ris_ind)
   2815             primitives = dr.select(active_dr, primitives, -1)
   2816             primitives = mi_to_tf_tensor(primitives, tf.int32)
   2817 
   2818             # If diffraction is enabled, stores the primitives in LoS
   2819             # for sampling their wedges. These are needed to compute the
   2820             # coverage map for diffraction (not in this function).
   2821             if diffraction and (depth == 0):
   2822                 # [num_samples]
   2823                 los_primitives = dr.select(hit_scene_dr, scene_primitives, -1)
   2824                 los_primitives = mi_to_tf_tensor(los_primitives, tf.int32)
   2825 
   2826             # At this point, max_depth > 0 and there are still active rays.
   2827             # However, we can stop if neither reflection, scattering or
   2828             # reflection from RIS is enabled, as only these phenomena require to
   2829             # go further.
   2830             if not (reflection or scattering or ris):
   2831                 break
   2832 
   2833             #############################################
   2834             # Update the field.
   2835             # Only active rays are updated.
   2836             #############################################
   2837 
   2838             # Intersection point
   2839             # [num_samples, 3]
   2840             int_point = dr.select(hit_scene_dr,
   2841                                   ray.o + si_scene.t*ray.d,
   2842                                   ray.o + si_ris_t*ray.d)
   2843             int_point = mi_to_tf_tensor(int_point, self._rdtype)
   2844 
   2845             # Sample scattering/reflection phenomena.
   2846             # reflect_ind : [num_reflected_samples]
   2847             #   Indices of the rays that are reflected
   2848             #  scatter_ind : [num_scattered_samples]
   2849             #   Indices of the rays that are scattered
   2850             reflect_ind, scatter_ind = self._sample_interaction_phenomena(
   2851                                 hit_scene, int_point, primitives,
   2852                                 scattering_coefficient, reflection,
   2853                                 scattering)
   2854 
   2855             # Indices of the rays that hit RIS
   2856             # [num_ris_reflected_samples]
   2857             ris_reflect_ind = tf.where(hit_ris)[:,0]
   2858             updated_e_field = tf.zeros([0, e_field.shape[1], 2], self._dtype)
   2859             updated_field_es = tf.zeros([0, 3], self._rdtype)
   2860             updated_field_ep = tf.zeros([0, 3], self._rdtype)
   2861             updated_int_point = tf.zeros([0, 3], self._rdtype)
   2862             updated_k_r = tf.zeros([0, 3], self._rdtype)
   2863             normals = tf.zeros([0, 3], self._rdtype)
   2864             updated_samples_tx_indices = tf.zeros([0], tf.int32)
   2865             updated_k_tx = tf.zeros([0, 3], self._rdtype)
   2866             if ris:
   2867                 updated_radii_curv = tf.zeros([0, 2], self._rdtype)
   2868                 updated_dirs_curv = tf.zeros([0, 2, 3], self._rdtype)
   2869                 updated_ang_opening = tf.zeros([0], self._rdtype)
   2870 
   2871             if tf.shape(reflect_ind)[0] > 0:
   2872                 # ref_e_field : [num_reflected_samples, num_tx_patterns, 2]
   2873                 # ref_field_es : [num_reflected_samples, 3]
   2874                 # ref_field_ep : [num_reflected_samples, 3]
   2875                 # ref_int_point : [num_reflected_samples, 3]
   2876                 # ref_k_r : [num_reflected_samples, 3]
   2877                 # ref_n : [num_reflected_samples, 3]
   2878                 # ref_radii_curv : [num_reflected_samples, 2]
   2879                 # ref_dirs_curv : [num_reflected_samples, 2, 3]
   2880                 # ref_ang_opening : [num_reflected_samples]
   2881                 # ref_samples_tx_indices : [num_reflected_samples]
   2882                 # ref_k_tx : [num_reflected_samples, 3]
   2883                 ref_e_field, ref_field_es, ref_field_ep, ref_int_point,\
   2884                     ref_k_r, ref_n, ref_samples_tx_indices, ref_k_tx,\
   2885                     ref_radii_curv, ref_dirs_curv, ref_ang_opening\
   2886                         = self._apply_reflection(reflect_ind,
   2887                         int_point, previous_int_point, primitives, e_field,
   2888                         field_es, field_ep, samples_tx_indices, k_tx,
   2889                         etas, scattering_coefficient, scattering, ris,
   2890                         radii_curv, dirs_curv, angular_opening)
   2891 
   2892                 updated_e_field = tf.concat([updated_e_field, ref_e_field],
   2893                                             axis=0)
   2894                 updated_field_es = tf.concat([updated_field_es, ref_field_es],
   2895                                                 axis=0)
   2896                 updated_field_ep = tf.concat([updated_field_ep, ref_field_ep],
   2897                                                 axis=0)
   2898                 updated_int_point = tf.concat([updated_int_point,ref_int_point],
   2899                                                 axis=0)
   2900                 updated_k_r = tf.concat([updated_k_r, ref_k_r], axis=0)
   2901                 normals = tf.concat([normals, ref_n], axis=0)
   2902                 updated_samples_tx_indices =\
   2903                     tf.concat([updated_samples_tx_indices,
   2904                                ref_samples_tx_indices], axis=0)
   2905                 updated_k_tx = tf.concat([updated_k_tx, ref_k_tx], axis=0)
   2906                 if ris:
   2907                     updated_radii_curv = tf.concat([updated_radii_curv,
   2908                                                     ref_radii_curv], axis=0)
   2909                     updated_dirs_curv = tf.concat([updated_dirs_curv,
   2910                                                 ref_dirs_curv], axis=0)
   2911                     updated_ang_opening = tf.concat([updated_ang_opening,
   2912                                                     ref_ang_opening], axis=0)
   2913 
   2914             if tf.shape(scatter_ind)[0] > 0:
   2915                 # scat_e_field : [num_scattered_samples, num_tx_patterns, 2]
   2916                 # scat_field_es : [num_scattered_samples, 3]
   2917                 # scat_field_ep : [num_scattered_samples, 3]
   2918                 # scat_int_point : [num_scattered_samples, 3]
   2919                 # scat_k_r : [num_scattered_samples, 3]
   2920                 # scat_n : [num_scattered_samples, 3]
   2921                 # scat_radii_curv : [num_scattered_samples, 2]
   2922                 # scat_dirs_curv : [num_scattered_samples, 2, 3]
   2923                 # scat_ang_opening : [num_scattered_samples]
   2924                 # scat_samples_tx_indices : [num_scattered_samples]
   2925                 # scat_k_tx : [num_scattered_samples, 3]
   2926                 scat_e_field, scat_field_es, scat_field_ep, scat_int_point,\
   2927                     scat_k_r, scat_n, scat_samples_tx_indices, scat_k_tx,\
   2928                     scat_radii_curv, scat_dirs_curv, scat_ang_opening\
   2929                         = self._apply_scattering(scatter_ind,
   2930                         int_point, previous_int_point, primitives, e_field,
   2931                         field_es, field_ep, samples_tx_indices, k_tx,
   2932                         etas, scattering_coefficient, xpd_coefficient, alpha_r,
   2933                         alpha_i, lambda_, reflection, ris, radii_curv,
   2934                         dirs_curv, angular_opening)
   2935 
   2936                 updated_e_field = tf.concat([updated_e_field, scat_e_field],
   2937                                             axis=0)
   2938                 updated_field_es = tf.concat([updated_field_es, scat_field_es],
   2939                                                 axis=0)
   2940                 updated_field_ep = tf.concat([updated_field_ep, scat_field_ep],
   2941                                                 axis=0)
   2942                 updated_int_point = tf.concat([updated_int_point,
   2943                                                 scat_int_point], axis=0)
   2944                 updated_k_r = tf.concat([updated_k_r, scat_k_r], axis=0)
   2945                 normals = tf.concat([normals, scat_n], axis=0)
   2946                 updated_samples_tx_indices =\
   2947                     tf.concat([updated_samples_tx_indices,
   2948                                scat_samples_tx_indices], axis=0)
   2949                 updated_k_tx = tf.concat([updated_k_tx, scat_k_tx], axis=0)
   2950                 if ris:
   2951                     updated_radii_curv = tf.concat([updated_radii_curv,
   2952                                                     scat_radii_curv], axis=0)
   2953                     updated_dirs_curv = tf.concat([updated_dirs_curv,
   2954                                                 scat_dirs_curv], axis=0)
   2955                     updated_ang_opening = tf.concat([updated_ang_opening,
   2956                                                     scat_ang_opening], axis=0)
   2957 
   2958             if tf.shape(ris_reflect_ind)[0] > 0:
   2959                 # ris_e_field : [num_ris_reflected_samples, num_tx_patterns, 2]
   2960                 # ris_field_es : [num_ris_reflected_samples, 3]
   2961                 # ris_field_ep : [num_ris_reflected_samples, 3]
   2962                 # ris_int_point : [num_ris_reflected_samples, 3]
   2963                 # ris_k_r : [num_ris_reflected_samples, 3]
   2964                 # ris_n : [num_ris_reflected_samples, 3]
   2965                 # ris_samples_tx_indices : [num_ris_reflected_samples]
   2966                 # ris_k_tx : [num_ris_reflected_samples, 3]
   2967                 # ris_radii_curv : [num_ris_reflected_samples, 2]
   2968                 # ris_dirs_curv : [num_ris_reflected_samples, 2, 3]
   2969                 # ris_ang_opening : [num_ris_reflected_samples]
   2970                 # ris_samples_tx_indices : [num_ris_reflected_samples]
   2971                 # ris_k_tx : [num_ris_reflected_samples, 3]
   2972                 ris_e_field, ris_field_es, ris_field_ep, ris_int_point,\
   2973                  ris_k_r, ris_n, ris_samples_tx_indices, ris_k_tx,\
   2974                  ris_radii_curv, ris_dirs_curv, ris_ang_opening\
   2975                      = self._apply_ris_reflection(ris_reflect_ind,
   2976                         int_point, previous_int_point, primitives, e_field,
   2977                         field_es, field_ep, samples_tx_indices, k_tx,
   2978                         radii_curv, dirs_curv, angular_opening)
   2979                 updated_e_field = tf.concat([updated_e_field, ris_e_field],
   2980                                             axis=0)
   2981                 updated_field_es = tf.concat([updated_field_es, ris_field_es],
   2982                                                 axis=0)
   2983                 updated_field_ep = tf.concat([updated_field_ep, ris_field_ep],
   2984                                                 axis=0)
   2985                 updated_int_point = tf.concat([updated_int_point,
   2986                                                 ris_int_point], axis=0)
   2987                 updated_k_r = tf.concat([updated_k_r, ris_k_r], axis=0)
   2988                 normals = tf.concat([normals, ris_n], axis=0)
   2989                 updated_radii_curv = tf.concat([updated_radii_curv,
   2990                                                 ris_radii_curv], axis=0)
   2991                 updated_dirs_curv = tf.concat([updated_dirs_curv,
   2992                                             ris_dirs_curv], axis=0)
   2993                 updated_ang_opening = tf.concat([updated_ang_opening,
   2994                                                 ris_ang_opening], axis=0)
   2995                 updated_samples_tx_indices =\
   2996                         tf.concat([updated_samples_tx_indices,
   2997                                 ris_samples_tx_indices], axis=0)
   2998                 updated_k_tx = tf.concat([updated_k_tx, ris_k_tx], axis=0)
   2999 
   3000             e_field = updated_e_field
   3001             field_es = updated_field_es
   3002             field_ep = updated_field_ep
   3003             k_r = updated_k_r
   3004             int_point = updated_int_point
   3005             samples_tx_indices = updated_samples_tx_indices
   3006             k_tx = updated_k_tx
   3007             if ris:
   3008                 radii_curv = updated_radii_curv
   3009                 dirs_curv = updated_dirs_curv
   3010                 angular_opening = updated_ang_opening
   3011 
   3012             ###############################################
   3013             # Discard paths which path loss is below a
   3014             # threshold
   3015             ###############################################
   3016             # [num_samples]
   3017             e_field_en = tf.reduce_sum(tf.square(tf.abs(e_field)), axis=(1,2))
   3018             active = tf.greater(e_field_en, SolverCoverageMap.DISCARD_THRES)
   3019             if not tf.reduce_any(active):
   3020                 break
   3021             # [num_active_samples]
   3022             active_ind = tf.where(active)[:,0]
   3023             # [num_active_samples, ...]
   3024             e_field = tf.gather(e_field, active_ind, axis=0)
   3025             field_es = tf.gather(field_es, active_ind, axis=0)
   3026             field_ep = tf.gather(field_ep, active_ind, axis=0)
   3027             k_r = tf.gather(k_r, active_ind, axis=0)
   3028             int_point = tf.gather(int_point, active_ind, axis=0)
   3029             normals = tf.gather(normals, active_ind, axis=0)
   3030             samples_tx_indices = tf.gather(samples_tx_indices, active_ind,
   3031                                            axis=0)
   3032             k_tx = tf.gather(k_tx, active_ind, axis=0)
   3033             if ris:
   3034                 radii_curv = tf.gather(radii_curv, active_ind, axis=0)
   3035                 dirs_curv = tf.gather(dirs_curv, active_ind, axis=0)
   3036                 angular_opening = tf.gather(angular_opening, active_ind, axis=0)
   3037 
   3038             ###############################################
   3039             # Reflect or scatter the current ray
   3040             ###############################################
   3041 
   3042             # Spawn a new rays
   3043             # [num_active_samples, 3]
   3044             k_r_dr = self._mi_vec_t(k_r)
   3045             rays_origin_dr = self._mi_vec_t(int_point)
   3046             normals_dr = self._mi_vec_t(normals)
   3047             rays_origin_dr += SolverBase.EPSILON_OBSTRUCTION*normals_dr
   3048             ray = mi.Ray3f(o=rays_origin_dr, d=k_r_dr)
   3049             # Update previous intersection point
   3050             # [num_active_samples, 3]
   3051             previous_int_point = int_point
   3052 
   3053         #################################################
   3054         # Finalize the computation of the coverage map
   3055         #################################################
   3056 
   3057         # Scaling factor
   3058         cell_area = cm_cell_size[0]*cm_cell_size[1]
   3059         if ris:
   3060             cm_scaling = tf.square(self._scene.wavelength/(4.*PI))/cell_area
   3061         else:
   3062             cst = tf.cast(4.*PI*cell_area*samples_per_tx_float, self._rdtype)
   3063             cm_scaling = tf.square(self._scene.wavelength)/cst
   3064         cm_scaling = tf.cast(cm_scaling, self._rdtype)
   3065 
   3066         # Dump the dummy line and row and apply the scaling factor
   3067         # [num_tx, num_cells_y, num_cells_x]
   3068         cm = cm_scaling*cm[:,:num_cells_y,:num_cells_x]
   3069 
   3070         # For diffraction, we need only primitives in LoS
   3071         # [num_los_primitives]
   3072         if los_primitives is not None:
   3073             los_primitives,_ = tf.unique(los_primitives)
   3074 
   3075         return cm, los_primitives
   3076 
   3077     def _discard_obstructing_wedges(self, candidate_wedges, sources_positions):
   3078         r"""
   3079         Discard wedges for which the source is "inside" the wedge
   3080 
   3081         Input
   3082         ------
   3083         candidate_wedges : [num_candidate_wedges], int
   3084             Candidate wedges.
   3085             Entries correspond to wedges indices.
   3086 
   3087         sources_positions : [num_tx, 3], tf.float
   3088             Coordinates of the sources.
   3089 
   3090         Output
   3091         -------
   3092         diff_mask : [num_tx, num_candidate_wedges], tf.bool
   3093             Mask set to False for invalid wedges
   3094 
   3095         diff_wedges_ind : [num_candidate_wedges], tf.int
   3096             Indices of the wedges that interacted with the diffracted paths
   3097         """
   3098 
   3099         epsilon = tf.cast(SolverBase.EPSILON, self._rdtype)
   3100 
   3101         # [num_candidate_wedges, 3]
   3102         origins = tf.gather(self._wedges_origin, candidate_wedges)
   3103 
   3104         # Expand to broadcast with sources/targets and 0/n faces
   3105         # [1, num_candidate_wedges, 1, 3]
   3106         origins = tf.expand_dims(origins, axis=0)
   3107         origins = tf.expand_dims(origins, axis=2)
   3108 
   3109         # Normals
   3110         # [num_candidate_wedges, 2, 3]
   3111         # [:,0,:] : 0-face
   3112         # [:,1,:] : n-face
   3113         normals = tf.gather(self._wedges_normals, candidate_wedges)
   3114         # Expand to broadcast with the sources or targets
   3115         # [1, num_candidate_wedges, 2, 3]
   3116         normals = tf.expand_dims(normals, axis=0)
   3117 
   3118         # Expand to broadcast with candidate and 0/n faces wedges
   3119         # [num_tx, 1, 1, 3]
   3120         sources_positions = expand_to_rank(sources_positions, 4, 1)
   3121         # Sources vectors
   3122         # [num_tx, num_candidate_wedges, 1, 3]
   3123         u_t = sources_positions - origins
   3124 
   3125         # [num_tx, num_candidate_wedges, 2]
   3126         mask = dot(u_t, normals)
   3127         mask = tf.greater(mask, tf.fill(tf.shape(mask), epsilon))
   3128         # [num_tx, num_candidate_wedges]
   3129         mask = tf.reduce_any(mask, axis=2)
   3130 
   3131         # Discard wedges with no valid link
   3132         # [num_candidate_wedges]
   3133         valid_wedges = tf.where(tf.reduce_any(mask, axis=0))[:,0]
   3134         # [num_tx, num_candidate_wedges]
   3135         mask = tf.gather(mask, valid_wedges, axis=1)
   3136         # [num_candidate_wedges]
   3137         diff_wedges_ind = tf.gather(candidate_wedges, valid_wedges, axis=0)
   3138 
   3139         return mask, diff_wedges_ind
   3140 
   3141     def _sample_wedge_points(self, diff_mask, diff_wedges_ind, num_samples):
   3142         r"""
   3143 
   3144         Samples equally spaced candidate diffraction points on the candidate
   3145         wedges.
   3146 
   3147         The distance between two points is the cumulative length of the
   3148         candidate wedges divided by ``num_samples``, i.e., the density of
   3149         samples is the same for all wedges.
   3150 
   3151         The `num_samples` dimension of the output tensors is in general slighly
   3152         smaller than input `num_samples` because of roundings.
   3153 
   3154         Input
   3155         ------
   3156         diff_mask : [num_tx, num_samples], tf.bool
   3157             Mask set to False for invalid samples
   3158 
   3159         diff_wedges_ind : [num_candidate_wedges], int
   3160             Candidate wedges indices
   3161 
   3162         num_samples : int
   3163             Number of samples to shoot
   3164 
   3165         Output
   3166         ------
   3167         diff_mask : [num_tx, num_samples], tf.bool
   3168             Mask set to False for invalid wedges
   3169 
   3170         diff_wedges_ind : [num_samples], tf.int
   3171             Indices of the wedges that interacted with the diffracted paths
   3172 
   3173         diff_ells : [num_samples], tf.float
   3174             Positions of the diffraction points on the wedges.
   3175             These positions are given as an offset from the wedges origins.
   3176 
   3177         diff_vertex : [num_samples, 3], tf.float
   3178             Positions of the diffracted points in the GCS
   3179 
   3180         diff_num_samples_per_wedge : [num_samples], tf.int
   3181             For each sample, total mumber of samples that were sampled on the
   3182             same wedge
   3183         """
   3184 
   3185         zero_dot_five = tf.cast(0.5, self._rdtype)
   3186 
   3187         # [num_candidate_wedges]
   3188         wedges_length = tf.gather(self._wedges_length, diff_wedges_ind)
   3189         # Total length of the wedges
   3190         # ()
   3191         wedges_total_length = tf.reduce_sum(wedges_length)
   3192         # Spacing between the samples
   3193         # ()
   3194         delta_ell = wedges_total_length/tf.cast(num_samples,self._rdtype)
   3195         # Number of samples for each wedge
   3196         # [num_candidate_wedges]
   3197         samples_per_wedge = tf.math.divide_no_nan(wedges_length, delta_ell)
   3198         samples_per_wedge = tf.cast(tf.floor(samples_per_wedge), tf.int32)
   3199         # Maximum number of samples for a wedge
   3200         # tf.maximum() required for the case where samples_per_wedge is empty
   3201         # ()
   3202         max_samples_per_wedge = tf.maximum(tf.reduce_max(samples_per_wedge), 0)
   3203         # Sequence used to build the equally spaced samples on the wedges
   3204         # [max_samples_per_wedge]
   3205         cseq = tf.cumsum(tf.ones([max_samples_per_wedge], dtype=tf.int32)) - 1
   3206         # [1, max_samples_per_wedge]
   3207         cseq = tf.expand_dims(cseq, axis=0)
   3208         # [num_candidate_wedges, 1]
   3209         samples_per_wedge_ = tf.expand_dims(samples_per_wedge, axis=1)
   3210         # [num_candidate_wedges, max_samples_per_wedge]
   3211         ells_i = tf.where(cseq < samples_per_wedge_, cseq,
   3212                           max_samples_per_wedge)
   3213         # Compute the relative offset of the diffraction point on the wedge
   3214         # [num_candidate_wedges, max_samples_per_wedge]
   3215         ells = (tf.cast(ells_i, self._rdtype) + zero_dot_five)*delta_ell
   3216         # [num_candidate_wedges x max_samples_per_wedge]
   3217         ells_i = tf.reshape(ells_i, [-1])
   3218         ells = tf.reshape(ells, [-1])
   3219         # Extract only relevant indices
   3220         # [num_samples]. Smaller but close than input num_samples in general
   3221         # because of previous floor() op
   3222         ells = tf.gather(ells, tf.where(ells_i < max_samples_per_wedge))[:,0]
   3223 
   3224         # Compute the corresponding points coordinates in the GCS
   3225         # Wedges origin
   3226         # [num_candidate_wedges, 3]
   3227         origins = tf.gather(self._wedges_origin, diff_wedges_ind)
   3228         # Wedges directions
   3229         # [num_candidate_wedges, 3]
   3230         e_hat = tf.gather(self._wedges_e_hat, diff_wedges_ind)
   3231         # Match each sample to the corresponding wedge origin and vector
   3232         # First, generate the indices for the gather op
   3233         # ()
   3234         num_candidate_wedges = diff_wedges_ind.shape[0]
   3235         # [num_candidate_wedges]
   3236         gather_ind = tf.range(num_candidate_wedges)
   3237         gather_ind = tf.expand_dims(gather_ind, axis=1)
   3238         # [num_candidate_wedges, max_samples_per_wedge]
   3239         gather_ind = tf.where(cseq < samples_per_wedge_, gather_ind,
   3240                               num_candidate_wedges)
   3241         # [num_candidate_wedges x max_samples_per_wedge]
   3242         gather_ind = tf.reshape(gather_ind, [-1])
   3243         # [num_samples]
   3244         gather_ind = tf.gather(gather_ind,
   3245                                tf.where(ells_i < max_samples_per_wedge))[:,0]
   3246         # [num_samples, 3]
   3247         origins = tf.gather(origins, gather_ind, axis=0)
   3248         e_hat = tf.gather(e_hat, gather_ind, axis=0)
   3249         # [num_samples]
   3250         diff_wedges_ind = tf.gather(diff_wedges_ind, gather_ind, axis=0)
   3251         # [num_tx, num_samples]
   3252         diff_mask = tf.gather(diff_mask, gather_ind, axis=1)
   3253         # Positions of the diffracted points in the GCS
   3254         # [num_samples, 3]
   3255         diff_points = origins + tf.expand_dims(ells, axis=1)*e_hat
   3256         # Number of samples per wedge
   3257         # [num_samples]
   3258         samples_per_wedge = tf.gather(samples_per_wedge, gather_ind, axis=0)
   3259 
   3260         return diff_mask, diff_wedges_ind, ells, diff_points, samples_per_wedge
   3261 
   3262     def _test_tx_visibility(self, diff_mask, diff_wedges_ind, diff_ells,
   3263                             diff_vertex, diff_num_samples_per_wedge,
   3264                             sources_positions):
   3265         r"""
   3266         Test for blockage between the diffraction points and the transmitters.
   3267         Blocked samples are discarded.
   3268 
   3269         Input
   3270         ------
   3271         diff_mask : [num_tx, num_samples], tf.bool
   3272             Mask set to False for invalid samples
   3273 
   3274         diff_wedges_ind : [num_samples], tf.int
   3275             Indices of the wedges that interacted with the diffracted paths
   3276 
   3277         diff_ells : [num_samples], tf.float
   3278             Positions of the diffraction points on the wedges.
   3279             These positions are given as an offset from the wedges origins.
   3280 
   3281         diff_vertex : [num_samples, 3], tf.float
   3282             Positions of the diffracted points in the GCS
   3283 
   3284         diff_num_samples_per_wedge : [num_samples], tf.int
   3285                 For each sample, total mumber of samples that were sampled on
   3286                 the same wedge
   3287 
   3288         sources_positions : [num_tx, 3], tf.float
   3289             Positions of the transmitters.
   3290 
   3291         Output
   3292         -------
   3293         diff_mask : [num_tx, num_samples], tf.bool
   3294             Mask set to False for invalid wedges
   3295 
   3296         diff_wedges_ind : [num_samples], tf.int
   3297             Indices of the wedges that interacted with the diffracted paths
   3298 
   3299         diff_ells : [num_samples], tf.float
   3300             Positions of the diffraction points on the wedges.
   3301             These positions are given as an offset from the wedges origins.
   3302 
   3303         diff_vertex : [num_samples, 3], tf.float
   3304             Positions of the diffracted points in the GCS
   3305 
   3306         diff_num_samples_per_wedge : [num_samples], tf.int
   3307                 For each sample, total mumber of samples that were sampled on
   3308                 the same wedge
   3309         """
   3310 
   3311         num_tx = sources_positions.shape[0]
   3312         num_samples = diff_vertex.shape[0]
   3313 
   3314         # [num_tx, 1, 3]
   3315         sources_positions = tf.expand_dims(sources_positions, axis=1)
   3316         # [1, num_samples, 3]
   3317         wedges_diff_points_ = tf.expand_dims(diff_vertex, axis=0)
   3318         # Ray directions and maximum distance for obstruction test
   3319         # ray_dir : [num_tx, num_samples, 3]
   3320         # maxt : [num_tx, num_samples]
   3321         ray_dir,maxt = normalize(sources_positions - wedges_diff_points_)
   3322         # Ray origins
   3323         # [num_tx, num_samples, 3]
   3324         ray_org = tf.tile(wedges_diff_points_, [num_tx, 1, 1])
   3325 
   3326         # Test for obstruction
   3327         # [num_tx, num_samples]
   3328         ray_org = tf.reshape(ray_org, [-1,3])
   3329         ray_dir = tf.reshape(ray_dir, [-1,3])
   3330         maxt = tf.reshape(maxt, [-1])
   3331         invalid = self._test_obstruction(ray_org, ray_dir, maxt)
   3332         invalid = tf.reshape(invalid, [num_tx, num_samples])
   3333 
   3334         # Remove discarded paths
   3335         # [num_tx, num_samples]
   3336         diff_mask = tf.logical_and(diff_mask, ~invalid)
   3337         # Discard samples with no valid link
   3338         # [num_candidate_wedges]
   3339         valid_samples = tf.where(tf.reduce_any(diff_mask, axis=0))[:,0]
   3340         # [num_tx, num_samples]
   3341         diff_mask = tf.gather(diff_mask, valid_samples, axis=1)
   3342         # [num_samples]
   3343         diff_wedges_ind = tf.gather(diff_wedges_ind, valid_samples, axis=0)
   3344         # [num_samples]
   3345         diff_vertex = tf.gather(diff_vertex, valid_samples,
   3346                                        axis=0)
   3347         # [num_samples]
   3348         diff_ells = tf.gather(diff_ells, valid_samples, axis=0)
   3349         # [num_samples]
   3350         diff_num_samples_per_wedge = tf.gather(diff_num_samples_per_wedge,
   3351                                                valid_samples, axis=0)
   3352 
   3353         return diff_mask, diff_wedges_ind, diff_ells, diff_vertex,\
   3354             diff_num_samples_per_wedge
   3355 
   3356     def _sample_diff_angles(self, diff_wedges_ind):
   3357         r"""
   3358         Samples angles of diffracted ray on the diffraction cone
   3359 
   3360         Input
   3361         ------
   3362         diff_wedges_ind : [num_samples], tf.int
   3363             Indices of the wedges that interacted with the diffracted paths
   3364 
   3365         Output
   3366         -------
   3367         diff_phi : [num_samples], tf.float
   3368             Sampled angles of diffracted rays on the diffraction cone
   3369         """
   3370 
   3371         num_samples = diff_wedges_ind.shape[0]
   3372 
   3373         # [num_samples, 2, 3]
   3374         normals = tf.gather(self._wedges_normals, diff_wedges_ind,  axis=0)
   3375 
   3376         # Compute the wedges angle
   3377         # [num_samples]
   3378         cos_wedges_angle = dot(normals[:,0,:],normals[:,1,:], clip=True)
   3379         wedges_angle = PI + tf.math.acos(cos_wedges_angle)
   3380 
   3381         # Uniformly sample angles for shooting rays on the diffraction cone
   3382         # [num_samples]
   3383         phis = config.tf_rng.uniform([num_samples],
   3384                                      minval=tf.zeros_like(wedges_angle),
   3385                                      maxval=wedges_angle,
   3386                                      dtype=self._rdtype)
   3387 
   3388         return phis
   3389 
   3390     def _shoot_diffracted_rays(self, diff_mask, diff_wedges_ind, diff_ells,
   3391                                diff_vertex, diff_num_samples_per_wedge,
   3392                                diff_phi, sources_positions, meas_plane):
   3393         r"""
   3394         Shoots the diffracted rays and computes their intersection with the
   3395         coverage map, if any. Rays blocked by the scene are discarded. Rays
   3396         that do not hit the coverage map are discarded.
   3397 
   3398         Input
   3399         ------
   3400         diff_mask : [num_tx, num_samples], tf.bool
   3401             Mask set to False for invalid samples
   3402 
   3403         diff_wedges_ind : [num_samples], tf.int
   3404             Indices of the wedges that interacted with the diffracted paths
   3405 
   3406         diff_ells : [num_samples], tf.float
   3407             Positions of the diffraction points on the wedges.
   3408             These positions are given as an offset from the wedges origins.
   3409 
   3410         diff_vertex : [num_samples, 3], tf.float
   3411             Positions of the diffracted points in the GCS
   3412 
   3413         diff_num_samples_per_wedge : [num_samples], tf.int
   3414             For each sample, total mumber of samples that were sampled on the
   3415             same wedge
   3416 
   3417         diff_phi : [num_samples], tf.float
   3418             Sampled angles of diffracted rays on the diffraction cone
   3419 
   3420         sources_positions : [num_tx, 3], tf.float
   3421             Positions of the transmitters.
   3422 
   3423         meas_plane : mi.Shape
   3424             Mitsuba rectangle defining the measurement plane
   3425 
   3426         Output
   3427         -------
   3428         diff_mask : [num_tx, num_samples], tf.bool
   3429             Mask set to False for invalid samples
   3430 
   3431         diff_wedges_ind : [num_samples], tf.int
   3432             Indices of the wedges that interacted with the diffracted paths
   3433 
   3434         diff_ells : [num_samples], tf.float
   3435             Positions of the diffraction points on the wedges.
   3436             These positions are given as an offset from the wedges origins.
   3437 
   3438         diff_phi : [num_samples], tf.float
   3439             Sampled angles of diffracted rays on the diffraction cone
   3440 
   3441         diff_vertex : [num_samples, 3], tf.float
   3442             Positions of the diffracted points in the GCS
   3443 
   3444         diff_num_samples_per_wedge : [num_samples], tf.int
   3445             For each sample, total mumber of samples that were sampled on the
   3446             same wedge
   3447 
   3448         diff_hit_points : [num_tx, num_samples, 3], tf.float
   3449             Positions of the intersection of the diffracted rays and coverage
   3450             map
   3451 
   3452         diff_cone_angle : [num_tx, num_samples], tf.float
   3453             Angle between e_hat and the diffracted ray direction.
   3454             Takes value in (0,pi).
   3455         """
   3456 
   3457         # [num_tx, 1, 3]
   3458         sources_positions = tf.expand_dims(sources_positions, axis=1)
   3459         # [1, num_samples, 3]
   3460         diff_points_ = tf.expand_dims(diff_vertex, axis=0)
   3461         # Ray directions and maximum distance for obstruction test
   3462         # ray_dir : [num_tx, num_samples, 3]
   3463         # maxt : [num_tx, num_samples]
   3464         ray_dir,_ = normalize(diff_points_ - sources_positions)
   3465 
   3466         # Edge vector
   3467         # [num_samples, 3]
   3468         e_hat = tf.gather(self._wedges_e_hat, diff_wedges_ind)
   3469         # [1, num_samples, 3]
   3470         e_hat_ = tf.expand_dims(e_hat, axis=0)
   3471         # Angles between the incident ray and wedge.
   3472         # This angle is not beta_0. It takes values in (0,pi), and is the angle
   3473         # with respect to e_hat in which to shoot the diffracted ray.
   3474         # [num_tx, num_samples]
   3475         theta_shoot_dir = acos_diff(dot(ray_dir, e_hat_))
   3476 
   3477         # Discard paths for which the incident ray is aligned or perpendicular
   3478         # to the edge
   3479         # [num_tx, num_samples, 3]
   3480         invalid_angle = tf.stack([
   3481             theta_shoot_dir < SolverBase.EPSILON,
   3482             theta_shoot_dir > PI - SolverBase.EPSILON,
   3483             tf.abs(theta_shoot_dir - 0.5*PI) < SolverBase.EPSILON],
   3484                                  axis=-1)
   3485         # [num_tx, num_samples]
   3486         invalid_angle = tf.reduce_any(invalid_angle, axis=-1)
   3487 
   3488         num_tx = diff_mask.shape[0]
   3489 
   3490         # Build the direction of the diffracted ray in the LCS
   3491         # The LCS is defined by (t_0_hat, n0_hat, e_hat)
   3492 
   3493         # Direction of the diffracted ray
   3494         # [1, num_samples]
   3495         phis = tf.expand_dims(diff_phi, axis=0)
   3496         # [num_tx, num_samples, 3]
   3497         diff_dir = r_hat(theta_shoot_dir, phis)
   3498 
   3499         # Matrix for going from the LCS to the GCS
   3500 
   3501         # Normals to face 0
   3502         # [num_samples, 2, 3]
   3503         normals = tf.gather(self._wedges_normals, diff_wedges_ind, axis=0)
   3504         # [num_samples, 3]
   3505         normals = normals[:,0,:]
   3506         # Tangent vector t_hat
   3507         # [num_samples, 3]
   3508         t_hat = cross(normals, e_hat)
   3509         # Matrix for going from LCS to GCS
   3510         # [num_samples, 3, 3]
   3511         lcs2gcs = tf.stack([t_hat, normals, e_hat], axis=-1)
   3512         # [1, num_samples, 3, 3]
   3513         lcs2gcs = tf.expand_dims(lcs2gcs, axis=0)
   3514 
   3515         # Direction of diffracted rays in CGS
   3516 
   3517         # [num_tx, num_samples, 3]
   3518         diff_dir = tf.linalg.matvec(lcs2gcs, diff_dir)
   3519 
   3520         # Origin of the diffracted rays
   3521 
   3522         # [num_tx, num_samples, 3]
   3523         diff_points_ = tf.tile(diff_points_, [num_tx, 1, 1])
   3524 
   3525         # Test of intersection of the diffracted rays with the measurement
   3526         # plane
   3527         mi_diff_dir = self._mi_vec_t(tf.reshape(diff_dir, [-1, 3]))
   3528         mi_diff_points = self._mi_vec_t(tf.reshape(diff_points_, [-1, 3]))
   3529         rays = mi.Ray3f(o=mi_diff_points, d=mi_diff_dir)
   3530         # Intersect with the coverage map
   3531         si_mp = meas_plane.ray_intersect(rays)
   3532 
   3533         # Check for obstruction
   3534         # [num_tx x num_samples]
   3535         obstructed = self._test_obstruction(mi_diff_points, mi_diff_dir,
   3536                                             si_mp.t)
   3537 
   3538         # Mask invalid rays, i.e., rays that are obstructed or do that not hit
   3539         # the measurement plane, and discard rays that are invalid for all TXs
   3540 
   3541         # [num_tx x num_samples]
   3542         maxt = mi_to_tf_tensor(si_mp.t, dtype=self._rdtype)
   3543         # [num_tx x num_samples]
   3544         invalid = tf.logical_or(tf.math.is_inf(maxt), obstructed)
   3545         # [num_tx, num_samples]
   3546         invalid = tf.reshape(invalid, [num_tx, -1])
   3547         # [num_tx, num_samples]
   3548         invalid = tf.logical_or(invalid, invalid_angle)
   3549         # [num_tx, num_samples]
   3550         diff_mask = tf.logical_and(diff_mask, ~invalid)
   3551         # Discard samples with no valid link
   3552         # [num_candidate_wedges]
   3553         valid_samples = tf.where(tf.reduce_any(diff_mask, axis=0))[:,0]
   3554         # [num_tx, num_samples]
   3555         diff_mask = tf.gather(diff_mask, valid_samples, axis=1)
   3556         # [num_samples]
   3557         diff_wedges_ind = tf.gather(diff_wedges_ind, valid_samples, axis=0)
   3558         # [num_samples]
   3559         diff_ells = tf.gather(diff_ells, valid_samples, axis=0)
   3560         # [num_samples]
   3561         diff_phi = tf.gather(diff_phi, valid_samples, axis=0)
   3562         # [num_tx, num_samples]
   3563         theta_shoot_dir = tf.gather(theta_shoot_dir, valid_samples, axis=1)
   3564         # [num_samples]
   3565         diff_num_samples_per_wedge = tf.gather(diff_num_samples_per_wedge,
   3566                                                valid_samples, axis=0)
   3567 
   3568         # Compute intersection point with the coverage map
   3569         # [num_tx, num_samples]
   3570         maxt = tf.reshape(maxt, [num_tx, -1])
   3571         # [num_tx, num_samples]
   3572         maxt = tf.gather(maxt, valid_samples, axis=1)
   3573         # Zeros invalid samples to avoid numeric issues
   3574         # [num_tx, num_samples]
   3575         maxt = tf.where(diff_mask, maxt, tf.zeros_like(maxt))
   3576         # [num_tx, num_samples, 1]
   3577         maxt = tf.expand_dims(maxt, -1)
   3578         # [num_tx, num_samples, 3]
   3579         diff_dir = tf.gather(diff_dir, valid_samples, axis=1)
   3580         # [num_samples, 3]
   3581         diff_vertex = tf.gather(diff_vertex, valid_samples, axis=0)
   3582         # [num_tx, num_samples, 3]
   3583         diff_hit_points = tf.expand_dims(diff_vertex, axis=0) + maxt*diff_dir
   3584 
   3585         return diff_mask, diff_wedges_ind, diff_ells, diff_phi,\
   3586             diff_vertex, diff_num_samples_per_wedge, diff_hit_points,\
   3587                 theta_shoot_dir
   3588 
   3589     def _compute_samples_weights(self, cm_center, cm_orientation,
   3590         sources_positions, diff_wedges_ind, diff_ells, diff_phi,
   3591         diff_cone_angle):
   3592         r"""
   3593         Computes the weights for averaging the field powers of the samples to
   3594         compute the Monte Carlo estimate of the integral of the diffracted field
   3595         power over the measurement plane.
   3596 
   3597         These weights are required as the measurement plane is parametrized by
   3598         the angle on the diffraction cones (phi) and position on the wedges
   3599         (ell).
   3600 
   3601         Input
   3602         ------
   3603         cm_center : [3], tf.float
   3604             Center of the coverage map
   3605 
   3606         cm_orientation : [3], tf.float
   3607             Orientation of the coverage map
   3608 
   3609         sources_positions : [num_tx, 3], tf.float
   3610             Coordinates of the sources
   3611 
   3612         diff_wedges_ind : [num_samples], tf.int
   3613             Indices of the wedges that interacted with the diffracted paths
   3614 
   3615         diff_ells : [num_samples], tf.float
   3616             Positions of the diffraction points on the wedges.
   3617             These positions are given as an offset from the wedges origins
   3618 
   3619         diff_phi : [num_samples], tf.float
   3620             Sampled angles of diffracted rays on the diffraction cone
   3621 
   3622         diff_cone_angle : [num_tx, num_samples], tf.float
   3623             Angle between e_hat and the diffracted ray direction.
   3624             Takes value in (0,pi).
   3625 
   3626         Output
   3627         ------
   3628         diff_samples_weights : [num_tx, num_samples], tf.float
   3629             Weights for averaging the field powers of the samples.
   3630         """
   3631         cos = tf.math.cos
   3632         sin = tf.math.sin
   3633 
   3634         # [1, 1, 3]
   3635         cm_center = expand_to_rank(cm_center, 3, 0)
   3636         # [num_tx, 1, 3]
   3637         sources_positions = tf.expand_dims(sources_positions, axis=1)
   3638 
   3639         # Normal to the coverage map
   3640         # [3]
   3641         cmo_z = cm_orientation[0]
   3642         cmo_y = cm_orientation[1]
   3643         cmo_x = cm_orientation[2]
   3644         cm_normal = tf.stack([
   3645             cos(cmo_z)*sin(cmo_y)*cos(cmo_x) + sin(cmo_z)*sin(cmo_x),
   3646             sin(cmo_z)*sin(cmo_y)*cos(cmo_x) - cos(cmo_z)*sin(cmo_x),
   3647             cos(cmo_y)*cos(cmo_x)],
   3648                              axis=0)
   3649         # [1, 1, 3]
   3650         cm_normal = expand_to_rank(cm_normal, 3, 0)
   3651 
   3652 
   3653         # Origins
   3654         # [num_samples, 3]
   3655         origins = tf.gather(self._wedges_origin, diff_wedges_ind)
   3656         # [1, num_samples, 3]
   3657         origins = tf.expand_dims(origins, axis=0)
   3658 
   3659         # Distance of the wedge to the measurement plane
   3660         # [num_tx, num_samples]
   3661         wedge_cm_dist = dot(cm_center - origins, cm_normal)
   3662 
   3663         # Edges vectors
   3664         # [num_samples, 3]
   3665         e_hat = tf.gather(self._wedges_e_hat, diff_wedges_ind)
   3666 
   3667         # Normals to face 0
   3668         # [num_samples, 2, 3]
   3669         normals = tf.gather(self._wedges_normals, diff_wedges_ind, axis=0)
   3670         # [num_samples, 3]
   3671         normals = normals[:,0,:]
   3672         # Tangent vector t_hat
   3673         # [num_samples, 3]
   3674         t_hat = cross(normals, e_hat)
   3675         # Matrix for going from LCS to GCS
   3676         # [num_samples, 3, 3]
   3677         gcs2lcs = tf.stack([t_hat, normals, e_hat], axis=-2)
   3678         # [1, num_samples, 3, 3]
   3679         gcs2lcs = tf.expand_dims(gcs2lcs, axis=0)
   3680         # Normal in LCS
   3681         # [1, num_samples, 3]
   3682         cm_normal = tf.linalg.matvec(gcs2lcs, cm_normal)
   3683 
   3684         # Projections of the transmitters on the wedges
   3685         # [1, num_samples, 3]
   3686         e_hat = tf.expand_dims(e_hat, axis=0)
   3687         # [num_tx, num_samples]
   3688         tx_proj_org_dist = dot(sources_positions - origins, e_hat)
   3689         # [num_tx, num_samples, 1]
   3690         tx_proj_org_dist_ = tf.expand_dims(tx_proj_org_dist, axis=2)
   3691 
   3692         # Position of the sources projections on the wedges
   3693         # [num_tx, num_samples, 3]
   3694         tx_proj_pos = origins + tx_proj_org_dist_*e_hat
   3695         # Distance of transmitters to wedges
   3696         # [num_tx, num_samples]
   3697         tx_wedge_dist = tf.linalg.norm(tx_proj_pos - sources_positions, axis=-1)
   3698 
   3699         # Building the derivatives of the parametrization of the intersection
   3700         # of the diffraction cone and measurement plane
   3701         # [1, num_samples]
   3702         diff_phi = tf.expand_dims(diff_phi, axis=0)
   3703         # [1, num_samples]
   3704         diff_ells = tf.expand_dims(diff_ells, axis=0)
   3705 
   3706         # [1, num_samples]
   3707         cos_phi = cos(diff_phi)
   3708         # [1, num_samples]
   3709         sin_phi = sin(diff_phi)
   3710         # [1, num_samples]
   3711         xy_dot = cm_normal[...,0]*cos_phi + cm_normal[...,1]*sin_phi
   3712         # [num_tx, num_samples]
   3713         ell_min_d = diff_ells - tx_proj_org_dist
   3714         # [num_tx, num_samples]
   3715         u = tf.math.sign(ell_min_d)
   3716         # [num_tx, num_samples]
   3717         ell_min_d = tf.math.abs(ell_min_d)
   3718         # [num_tx, num_samples]
   3719         s = tf.where(diff_cone_angle < 0.5*PI,
   3720                      tf.ones_like(diff_cone_angle),
   3721                      -tf.ones_like(diff_cone_angle))
   3722         # [num_tx, num_samples]
   3723         q = s*tx_wedge_dist*xy_dot + cm_normal[...,2]*ell_min_d
   3724         q_square = tf.square(q)
   3725         inv_q = tf.math.divide_no_nan(tf.ones_like(q), q)
   3726         # [num_tx, num_samples]
   3727         big_d_min_lz = wedge_cm_dist - diff_ells*cm_normal[...,2]
   3728 
   3729         # [num_tx, num_samples, 3]
   3730         v1 = tf.stack([
   3731                 s*big_d_min_lz*tx_wedge_dist*cos_phi,
   3732                 s*big_d_min_lz*tx_wedge_dist*sin_phi,
   3733                 wedge_cm_dist*ell_min_d + s*diff_ells*tx_wedge_dist*xy_dot],
   3734                       axis=-1)
   3735         # [num_tx, num_samples, 3]
   3736         v2 = tf.stack([-s*cm_normal[...,2]*tx_wedge_dist*cos_phi,
   3737                        -s*cm_normal[...,2]*tx_wedge_dist*sin_phi,
   3738                        u*wedge_cm_dist + s*tx_wedge_dist*xy_dot],
   3739                       axis=-1)
   3740         # Derivative with respect to ell
   3741         # [num_tx, num_samples, 3]
   3742         ds_dl = tf.expand_dims(tf.math.divide_no_nan(-u*cm_normal[...,2],
   3743                                                       q_square), axis=-1)*v1
   3744         ds_dl = ds_dl + tf.expand_dims(inv_q, axis=-1)*v2
   3745 
   3746         # Derivative with respect to phi
   3747         # [num_tx, num_samples]
   3748         w = -cm_normal[...,0]*sin_phi + cm_normal[...,1]*cos_phi
   3749         # [num_tx, num_samples, 3]
   3750         v3 = tf.stack([-s*big_d_min_lz*tx_wedge_dist*sin_phi,
   3751                        s*big_d_min_lz*tx_wedge_dist*cos_phi,
   3752                        s*diff_ells*tx_wedge_dist*w],
   3753                       axis=-1)
   3754         # [num_tx, num_samples, 3]
   3755         ds_dphi = tf.expand_dims(tf.math.divide_no_nan(
   3756             -s*tx_wedge_dist*w, q_square), axis=-1)*v1
   3757         ds_dphi = ds_dphi + tf.expand_dims(inv_q, axis=-1)*v3
   3758 
   3759         # Weighting
   3760         # [num_tx, num_samples]
   3761         diff_samples_weights = tf.linalg.norm(cross(ds_dl, ds_dphi), axis=-1)
   3762         diff_samples_weights = tf.where(tf.math.is_inf(diff_samples_weights),
   3763                                         tf.zeros((), self._rdtype),
   3764                                         diff_samples_weights)
   3765 
   3766         return diff_samples_weights
   3767 
   3768     def _compute_diffracted_path_power(self,
   3769                                        sources_positions,
   3770                                        sources_orientations,
   3771                                        rx_orientation,
   3772                                        combining_vec,
   3773                                        precoding_vec,
   3774                                        diff_mask,
   3775                                        diff_wedges_ind,
   3776                                        diff_vertex,
   3777                                        diff_hit_points,
   3778                                        relative_permittivity,
   3779                                        scattering_coefficient):
   3780         """
   3781         Computes the power of the diffracted paths.
   3782 
   3783         Input
   3784         ------
   3785         sources_positions : [num_tx, 3], tf.float
   3786             Positions of the transmitters.
   3787 
   3788         sources_orientations : [num_tx, 3], tf.float
   3789             Orientations of the sources.
   3790 
   3791         rx_orientation : [3], tf.float
   3792             Orientation of the receiver.
   3793             This is used to compute the antenna response and antenna pattern
   3794             for an imaginary receiver located on the coverage map.
   3795 
   3796         combining_vec : [num_rx_ant], tf.complex
   3797             Combining vector.
   3798             If set to `None`, then no combining is applied, and
   3799             the energy received by all antennas is summed.
   3800 
   3801         precoding_vec : [num_tx or 1, num_tx_ant], tf.complex
   3802             Precoding vectors of the transmitters
   3803 
   3804         diff_mask : [num_tx, num_samples], tf.bool
   3805             Mask set to False for invalid samples
   3806 
   3807         diff_wedges_ind : [num_samples], tf.int
   3808             Indices of the wedges that interacted with the diffracted paths
   3809 
   3810         diff_vertex : [num_samples, 3], tf.float
   3811             Positions of the diffracted points in the GCS
   3812 
   3813         diff_hit_points : [num_tx, num_samples, 3], tf.float
   3814             Positions of the intersection of the diffracted rays and coverage
   3815             map
   3816 
   3817         relative_permittivity : [num_shape], tf.complex
   3818             Tensor containing the complex relative permittivity of all objects
   3819 
   3820         scattering_coefficient : [num_shape], tf.float
   3821             Tensor containing the scattering coefficients of all objects
   3822 
   3823         Output
   3824         ------
   3825         diff_samples_power : [num_tx, num_samples], tf.float
   3826             Powers of the samples of diffracted rays.
   3827         """
   3828 
   3829         def f(x):
   3830             """F(x) Eq.(88) in [ITUR_P526]
   3831             """
   3832             sqrt_x = tf.sqrt(x)
   3833             sqrt_pi_2 = tf.cast(tf.sqrt(PI/2.), x.dtype)
   3834 
   3835             # Fresnel integral
   3836             arg = sqrt_x/sqrt_pi_2
   3837             s = tf.math.special.fresnel_sin(arg)
   3838             c = tf.math.special.fresnel_cos(arg)
   3839             f = tf.complex(s, c)
   3840 
   3841             zero = tf.cast(0, x.dtype)
   3842             one = tf.cast(1, x.dtype)
   3843             two = tf.cast(2, f.dtype)
   3844             factor = tf.complex(sqrt_pi_2*sqrt_x, zero)
   3845             factor = factor*tf.exp(tf.complex(zero, x))
   3846             res =  tf.complex(one, one) - two*f
   3847 
   3848             return factor* res
   3849 
   3850         wavelength = self._scene.wavelength
   3851         k = 2.*PI/wavelength
   3852 
   3853         # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   3854         # This makes no difference on the resulting paths as such paths
   3855         # are not flaged as active.
   3856         # [num_samples]
   3857         valid_wedges_idx = tf.where(diff_wedges_ind == -1, 0, diff_wedges_ind)
   3858 
   3859         # [num_tx, 1, 3]
   3860         sources_positions = tf.expand_dims(sources_positions, axis=1)
   3861 
   3862         # Normals
   3863         # [num_samples, 2, 3]
   3864         normals = tf.gather(self._wedges_normals, valid_wedges_idx, axis=0)
   3865 
   3866         # Compute the wedges angle
   3867         # [num_samples]
   3868         cos_wedges_angle = dot(normals[...,0,:],normals[...,1,:], clip=True)
   3869         wedges_angle = PI - tf.math.acos(cos_wedges_angle)
   3870         n = (2.*PI-wedges_angle)/PI
   3871         # [1, num_samples]
   3872         n = tf.expand_dims(n, axis=0)
   3873 
   3874         # [num_samples, 3]
   3875         e_hat = tf.gather(self._wedges_e_hat, valid_wedges_idx)
   3876         # [1, num_samples, 3]
   3877         e_hat = tf.expand_dims(e_hat, axis=0)
   3878 
   3879         # Extract surface normals
   3880         # [num_samples, 3]
   3881         n_0_hat = normals[:,0,:]
   3882         # [1, num_samples, 3]
   3883         n_0_hat = tf.expand_dims(n_0_hat, axis=0)
   3884         # [num_samples, 3]
   3885         n_n_hat = normals[:,1,:]
   3886         # [1, num_samples, 3]
   3887         n_n_hat = tf.expand_dims(n_n_hat, axis=0)
   3888 
   3889         # Relative permitivities
   3890         # [num_samples, 2]
   3891         objects_indices = tf.gather(self._wedges_objects, valid_wedges_idx,
   3892                                     axis=0)
   3893 
   3894         # Relative permitivities and scattering coefficients
   3895         # If a callable is defined to compute the radio material properties,
   3896         # it is invoked. Otherwise, the radio materials of objects are used.
   3897         rm_callable = self._scene.radio_material_callable
   3898         if rm_callable is None:
   3899             # [num_samples, 2]
   3900             etas = tf.gather(relative_permittivity, objects_indices)
   3901             scattering_coefficient = tf.gather(scattering_coefficient,
   3902                                                objects_indices)
   3903         else:
   3904             # Harmonize the shapes of the radio material callables
   3905             # [num_samples, 2, 3]
   3906             diff_vertex_ = tf.tile(tf.expand_dims(diff_vertex, axis=-2),
   3907                                    [1, 2, 1])
   3908             # scattering_coefficient, etas : [num_samples, 2]
   3909             etas, scattering_coefficient, _  = rm_callable(objects_indices,
   3910                                                            diff_vertex_)
   3911 
   3912         # [num_samples]
   3913         eta_0 = etas[:,0]
   3914         eta_n = etas[:,1]
   3915         # [1, num_samples]
   3916         eta_0 = tf.expand_dims(eta_0, axis=0)
   3917         eta_n = tf.expand_dims(eta_n, axis=0)
   3918         # [num_samples]
   3919         scattering_coefficient_0 = scattering_coefficient[...,0]
   3920         scattering_coefficient_n = scattering_coefficient[...,1]
   3921         # [1, num_samples]
   3922         scattering_coefficient_0 = tf.expand_dims(scattering_coefficient_0,
   3923                                                   axis=0)
   3924         scattering_coefficient_n = tf.expand_dims(scattering_coefficient_n,
   3925                                                   axis=0)
   3926 
   3927         # Compute s_prime_hat, s_hat, s_prime, s
   3928         # [1, num_samples, 3]
   3929         diff_vertex_ = tf.expand_dims(diff_vertex, axis=0)
   3930         # s_prime_hat : [num_tx, num_samples, 3]
   3931         # s_prime : [num_tx, num_samples]
   3932         s_prime_hat, s_prime = normalize(diff_vertex_-sources_positions)
   3933         # s_hat : [num_tx, num_samples, 3]
   3934         # s : [num_tx, num_samples]
   3935         s_hat, s = normalize(diff_hit_points-diff_vertex_)
   3936 
   3937         # Compute phi_prime_hat, beta_0_prime_hat, phi_hat, beta_0_hat
   3938         # [num_tx, num_samples, 3]
   3939         phi_prime_hat, _ = normalize(cross(s_prime_hat, e_hat))
   3940         # [num_tx, num_samples, 3]
   3941         beta_0_prime_hat = cross(phi_prime_hat, s_prime_hat)
   3942 
   3943         # [num_tx, num_samples, 3]
   3944         phi_hat_, _ = normalize(-cross(s_hat, e_hat))
   3945         beta_0_hat = cross(phi_hat_, s_hat)
   3946 
   3947         # Compute tangent vector t_0_hat
   3948         # [1, num_samples, 3]
   3949         t_0_hat = cross(n_0_hat, e_hat)
   3950 
   3951         # Compute s_t_prime_hat and s_t_hat
   3952         # [num_tx, num_samples, 3]
   3953         s_t_prime_hat, _ = normalize(s_prime_hat
   3954                                 - dot(s_prime_hat,e_hat, keepdim=True)*e_hat)
   3955         # [num_tx, num_samples, 3]
   3956         s_t_hat, _ = normalize(s_hat - dot(s_hat,e_hat, keepdim=True)*e_hat)
   3957 
   3958         # Compute phi_prime and phi
   3959         # [num_tx, num_samples]
   3960         phi_prime = PI -\
   3961             (PI-acos_diff(-dot(s_t_prime_hat, t_0_hat)))*\
   3962                 sign(-dot(s_t_prime_hat, n_0_hat))
   3963         # [num_tx, num_samples]
   3964         phi = PI - (PI-acos_diff(dot(s_t_hat, t_0_hat)))\
   3965             *sign(dot(s_t_hat, n_0_hat))
   3966 
   3967         # Compute field component vectors for reflections at both surfaces
   3968         # [num_tx, num_samples, 3]
   3969         # pylint: disable=unbalanced-tuple-unpacking
   3970         e_i_s_0, e_i_p_0, e_r_s_0, e_r_p_0 = compute_field_unit_vectors(
   3971             s_prime_hat,
   3972             s_hat,
   3973             n_0_hat,#*sign(-dot(s_t_prime_hat, n_0_hat, keepdim=True)),
   3974             SolverBase.EPSILON
   3975             )
   3976         # [num_tx, num_samples, 3]
   3977         # pylint: disable=unbalanced-tuple-unpacking
   3978         e_i_s_n, e_i_p_n, e_r_s_n, e_r_p_n = compute_field_unit_vectors(
   3979             s_prime_hat,
   3980             s_hat,
   3981             n_n_hat,#*sign(-dot(s_t_prime_hat, n_n_hat, keepdim=True)),
   3982             SolverBase.EPSILON
   3983             )
   3984 
   3985         # Compute Fresnel reflection coefficients for 0- and n-surfaces
   3986         # [num_tx, num_samples]
   3987         r_s_0, r_p_0 = reflection_coefficient(eta_0, tf.abs(tf.sin(phi_prime)))
   3988         r_s_n, r_p_n = reflection_coefficient(eta_n, tf.abs(tf.sin(n*PI-phi)))
   3989 
   3990         # Multiply the reflection coefficients with the
   3991         # corresponding reflection reduction factor
   3992         reduction_factor_0 = tf.sqrt(1 - scattering_coefficient_0**2)
   3993         reduction_factor_0 = tf.complex(reduction_factor_0,
   3994                                         tf.zeros_like(reduction_factor_0))
   3995         reduction_factor_n = tf.sqrt(1 - scattering_coefficient_n**2)
   3996         reduction_factor_n = tf.complex(reduction_factor_n,
   3997                                         tf.zeros_like(reduction_factor_n))
   3998         r_s_0 *= reduction_factor_0
   3999         r_p_0 *= reduction_factor_0
   4000         r_s_n *= reduction_factor_n
   4001         r_p_n *= reduction_factor_n
   4002 
   4003         # Compute matrices R_0, R_n
   4004         # [num_tx, num_samples, 2, 2]
   4005         w_i_0  = component_transform(phi_prime_hat,
   4006                                      beta_0_prime_hat,
   4007                                      e_i_s_0,
   4008                                      e_i_p_0)
   4009         w_i_0 = tf.complex(w_i_0, tf.zeros_like(w_i_0))
   4010         # [num_tx, num_samples, 2, 2]
   4011         w_r_0 = component_transform(e_r_s_0,
   4012                                     e_r_p_0,
   4013                                     phi_hat_,
   4014                                     beta_0_hat)
   4015         w_r_0 = tf.complex(w_r_0, tf.zeros_like(w_r_0))
   4016         # [num_tx, num_samples, 2, 2]
   4017         r_0 = tf.expand_dims(tf.stack([r_s_0, r_p_0], -1), -1) * w_i_0
   4018         # [num_tx, num_samples, 2, 2]
   4019         r_0 = -tf.matmul(w_r_0, r_0)
   4020 
   4021         # [num_tx, num_samples, 2, 2]
   4022         w_i_n = component_transform(phi_prime_hat,
   4023                                     beta_0_prime_hat,
   4024                                     e_i_s_n,
   4025                                     e_i_p_n)
   4026         w_i_n = tf.complex(w_i_n, tf.zeros_like(w_i_n))
   4027         # [num_tx, num_samples, 2, 2]
   4028         w_r_n = component_transform(e_r_s_n,
   4029                                     e_r_p_n,
   4030                                     phi_hat_,
   4031                                     beta_0_hat)
   4032         w_r_n = tf.complex(w_r_n, tf.zeros_like(w_r_n))
   4033         # [num_tx, num_samples, 2, 2]
   4034         r_n = tf.expand_dims(tf.stack([r_s_n, r_p_n], -1), -1) * w_i_n
   4035         # [num_tx, num_samples, 2, 2]
   4036         r_n = -tf.matmul(w_r_n, r_n)
   4037 
   4038         # Compute D_1, D_2, D_3, D_4
   4039         # [num_tx, num_samples]
   4040         phi_m = phi - phi_prime
   4041         phi_p = phi + phi_prime
   4042 
   4043         # [num_tx, num_samples]
   4044         cot_1 = cot((PI + phi_m)/(2*n))
   4045         cot_2 = cot((PI - phi_m)/(2*n))
   4046         cot_3 = cot((PI + phi_p)/(2*n))
   4047         cot_4 = cot((PI - phi_p)/(2*n))
   4048 
   4049         def n_p(beta, n):
   4050             return tf.math.round((beta + PI)/(2.*n*PI))
   4051 
   4052         def n_m(beta, n):
   4053             return tf.math.round((beta - PI)/(2.*n*PI))
   4054 
   4055         def a_p(beta, n):
   4056             return 2*tf.cos((2.*n*PI*n_p(beta, n)-beta)/2.)**2
   4057 
   4058         def a_m(beta, n):
   4059             return 2*tf.cos((2.*n*PI*n_m(beta, n)-beta)/2.)**2
   4060 
   4061         # [1, num_samples]
   4062         d_mul = - tf.cast(tf.exp(-1j*PI/4.), self._dtype)/\
   4063             tf.cast((2*n)*tf.sqrt(2*PI*k), self._dtype)
   4064 
   4065         # [num_tx, num_samples]
   4066         ell = s_prime*s/(s_prime + s)
   4067 
   4068         # [num_tx, num_samples]
   4069         cot_1 = tf.complex(cot_1, tf.zeros_like(cot_1))
   4070         cot_2 = tf.complex(cot_2, tf.zeros_like(cot_2))
   4071         cot_3 = tf.complex(cot_3, tf.zeros_like(cot_3))
   4072         cot_4 = tf.complex(cot_4, tf.zeros_like(cot_4))
   4073         d_1 = d_mul*cot_1*f(k*ell*a_p(phi_m, n))
   4074         d_2 = d_mul*cot_2*f(k*ell*a_m(phi_m, n))
   4075         d_3 = d_mul*cot_3*f(k*ell*a_p(phi_p, n))
   4076         d_4 = d_mul*cot_4*f(k*ell*a_m(phi_p, n))
   4077 
   4078         # [num_tx, num_samples, 1, 1]
   4079         d_1 = tf.reshape(d_1, tf.concat([tf.shape(d_1), [1, 1]], axis=0))
   4080         d_2 = tf.reshape(d_2, tf.concat([tf.shape(d_2), [1, 1]], axis=0))
   4081         d_3 = tf.reshape(d_3, tf.concat([tf.shape(d_3), [1, 1]], axis=0))
   4082         d_4 = tf.reshape(d_4, tf.concat([tf.shape(d_4), [1, 1]], axis=0))
   4083 
   4084         # [num_tx, num_samples]
   4085         spreading_factor = tf.math.divide_no_nan(tf.cast(1.0, self._rdtype),
   4086                                                  s*s_prime*(s_prime + s))
   4087         spreading_factor = tf.sqrt(spreading_factor)
   4088         spreading_factor = tf.complex(spreading_factor,
   4089                                       tf.zeros_like(spreading_factor))
   4090         # [num_tx, num_samples, 1, 1]
   4091         spreading_factor = tf.reshape(spreading_factor, tf.shape(d_1))
   4092 
   4093         # [num_tx, num_samples, 2, 2]
   4094         mat_t = (d_1+d_2)*tf.eye(2,2, batch_shape=tf.shape(r_0)[:2],
   4095                                  dtype=self._dtype)
   4096         # [num_tx, num_samples, 2, 2]
   4097         mat_t += d_3*r_n + d_4*r_0
   4098         # [num_tx, num_samples, 2, 2]
   4099         mat_t *= -spreading_factor
   4100 
   4101         # Convert from/to GCS
   4102         # [num_tx, num_samples]
   4103         theta_t, phi_t = theta_phi_from_unit_vec(s_prime_hat)
   4104         theta_r, phi_r = theta_phi_from_unit_vec(-s_hat)
   4105 
   4106         # [num_tx, num_samples, 2, 2]
   4107         mat_from_gcs = component_transform(
   4108                             theta_hat(theta_t, phi_t), phi_hat(phi_t),
   4109                             phi_prime_hat, beta_0_prime_hat)
   4110         mat_from_gcs = tf.complex(mat_from_gcs,
   4111                                   tf.zeros_like(mat_from_gcs))
   4112 
   4113         # [num_tx, num_samples, 2, 2]
   4114         mat_to_gcs = component_transform(phi_hat_, beta_0_hat,
   4115                                          theta_hat(theta_r, phi_r),
   4116                                          phi_hat(phi_r))
   4117         mat_to_gcs = tf.complex(mat_to_gcs,
   4118                                 tf.zeros_like(mat_to_gcs))
   4119 
   4120         # [num_tx, num_samples, 2, 2]
   4121         mat_t = tf.linalg.matmul(mat_t, mat_from_gcs)
   4122         mat_t = tf.linalg.matmul(mat_to_gcs, mat_t)
   4123 
   4124         # Set invalid paths to 0
   4125         # Expand masks to broadcast with the field components
   4126         # [num_tx, num_samples, 1, 1]
   4127         mask_ = expand_to_rank(diff_mask, 4, axis=-1)
   4128         # Zeroing coefficients corresponding to non-valid paths
   4129         # [num_tx, num_samples, 2, 2]
   4130         mat_t = tf.where(mask_, mat_t, tf.zeros_like(mat_t))
   4131 
   4132         # Compute transmitters antenna pattern in the GCS
   4133         # [num_tx, 3, 3]
   4134         tx_rot_mat = rotation_matrix(sources_orientations)
   4135         # [num_tx, 1, 3, 3]
   4136         tx_rot_mat = tf.expand_dims(tx_rot_mat, axis=1)
   4137         # tx_field : [num_tx, num_samples, num_tx_patterns, 2]
   4138         # tx_es, ex_ep : [num_tx, num_samples, 3]
   4139         tx_field, _, _ = self._compute_antenna_patterns(tx_rot_mat,
   4140                             self._scene.tx_array.antenna.patterns, s_prime_hat)
   4141 
   4142         # Compute receiver antenna pattern in the GCS
   4143         # [3, 3]
   4144         rx_rot_mat = rotation_matrix(rx_orientation)
   4145         # tx_field : [num_tx, num_samples, num_rx_patterns, 2]
   4146         # tx_es, ex_ep : [num_tx, num_samples, 3]
   4147         rx_field, _, _ = self._compute_antenna_patterns(rx_rot_mat,
   4148                             self._scene.rx_array.antenna.patterns, -s_hat)
   4149 
   4150         # Compute the channel coefficients for every transmitter-receiver
   4151         # pattern pairs
   4152         # [num_tx, num_samples, 1, 1, 2, 2]
   4153         mat_t = insert_dims(mat_t, 2, 2)
   4154         # [num_tx, num_samples, 1, num_tx_patterns, 1, 2]
   4155         tx_field = tf.expand_dims(tf.expand_dims(tx_field, axis=2), axis=4)
   4156         # [num_tx, num_samples, num_rx_patterns, 1, 2]
   4157         rx_field = tf.expand_dims(rx_field, axis=3)
   4158         # [num_tx, num_samples, 1, num_tx_patterns, 2]
   4159         a = tf.reduce_sum(mat_t*tx_field, axis=-1)
   4160         # [num_tx, num_samples, num_rx_patterns, num_tx_patterns]
   4161         a = tf.reduce_sum(tf.math.conj(rx_field)*a, axis=-1)
   4162 
   4163         # Apply synthetic array
   4164         # [num_tx, num_samples, num_rx_antenna, num_tx_antenna]
   4165         a = self._apply_synthetic_array(tx_rot_mat, rx_rot_mat, -s_hat,
   4166                                         s_prime_hat, a)
   4167 
   4168         # Apply precoding
   4169         # Precoding and combing
   4170         # [num_tx/1, 1, 1, num_tx_ant]
   4171         precoding_vec = insert_dims(precoding_vec, 2, 1)
   4172         # [num_tx, samples_per_tx, num_rx_ant]
   4173         a = tf.reduce_sum(a*precoding_vec, axis=-1)
   4174         # Apply combining
   4175         # If no combining vector is set, then the energy of all antennas is
   4176         # summed
   4177         if combining_vec is None:
   4178             # [num_tx, samples_per_tx]
   4179             a = tf.reduce_sum(tf.square(tf.abs(a)), axis=-1)
   4180         else:
   4181             # [1, 1, num_rx_ant]
   4182             combining_vec = insert_dims(combining_vec, 2, 0)
   4183             # [num_tx, samples_per_tx]
   4184             a = tf.reduce_sum(tf.math.conj(combining_vec)*a, axis=-1)
   4185             # [num_tx, samples_per_tx]
   4186             a = tf.square(tf.abs(a))
   4187 
   4188         # [num_tx, samples_per_tx]
   4189         cst = tf.square(self._scene.wavelength/(4.*PI))
   4190         a = a*cst
   4191 
   4192         return a
   4193 
   4194     def _build_diff_coverage_map(self, cm_center, cm_orientation, cm_size,
   4195                                  cm_cell_size, diff_wedges_ind, diff_hit_points,
   4196                                  diff_samples_power, diff_samples_weights,
   4197                                  diff_num_samples_per_wedge):
   4198         r"""
   4199         Builds the coverage map for diffraction
   4200 
   4201         Input
   4202         ------
   4203         cm_center : [3], tf.float
   4204             Center of the coverage map
   4205 
   4206         cm_orientation : [3], tf.float
   4207             Orientation of the coverage map
   4208 
   4209         cm_size : [2], tf.float
   4210             Scale of the coverage map.
   4211             The width of the map (in the local X direction) is ``cm_size[0]``
   4212             and its map (in the local Y direction) ``cm_size[1]``.
   4213 
   4214         cm_cell_size : [2], tf.float
   4215             Resolution of the coverage map, i.e., width
   4216             (in the local X direction) and height (in the local Y direction) in
   4217             meters of a cell of the coverage map
   4218 
   4219         diff_wedges_ind : [num_samples], tf.int
   4220             Indices of the wedges that interacted with the diffracted paths
   4221 
   4222         diff_hit_points : [num_tx, num_samples, 3], tf.float
   4223             Positions of the intersection of the diffracted rays and coverage
   4224             map
   4225 
   4226         diff_samples_power : [num_tx, num_samples], tf.float
   4227             Powers of the samples of diffracted rays.
   4228 
   4229         diff_samples_weights : [num_tx, num_samples], tf.float
   4230             Weights for averaging the field powers of the samples.
   4231 
   4232         diff_num_samples_per_wedge : [num_samples], tf.int
   4233             For each sample, total mumber of samples that were sampled on the
   4234             same wedge
   4235 
   4236         Output
   4237         ------
   4238         :cm : :class:`~sionna.rt.CoverageMap`
   4239             The coverage maps
   4240         """
   4241         num_tx = diff_hit_points.shape[0]
   4242         num_samples = diff_hit_points.shape[1]
   4243         cell_area = cm_cell_size[0]*cm_cell_size[1]
   4244 
   4245         # [num_tx, num_samples]
   4246         diff_wedges_ind = tf.tile(tf.expand_dims(diff_wedges_ind, axis=0),
   4247                                   [num_tx, 1])
   4248 
   4249         # Transformation matrix required for computing the cell
   4250         # indices of the intersection points
   4251         # [3,3]
   4252         rot_cm_2_gcs = rotation_matrix(cm_orientation)
   4253         # [3,3]
   4254         rot_gcs_2_cm = tf.transpose(rot_cm_2_gcs)
   4255 
   4256         # Initializing the coverage map
   4257         num_cells_x = tf.cast(tf.math.ceil(cm_size[0]/cm_cell_size[0]),
   4258                               tf.int32)
   4259         num_cells_y = tf.cast(tf.math.ceil(cm_size[1]/cm_cell_size[1]),
   4260                               tf.int32)
   4261         num_cells = tf.stack([num_cells_x, num_cells_y], axis=-1)
   4262         # [num_tx, num_cells_y+1, num_cells_x+1]
   4263         # Add dummy row and columns to store the items that are out of the
   4264         # coverage map
   4265         cm = tf.zeros([num_tx, num_cells_y+1, num_cells_x+1],
   4266                       dtype=self._rdtype)
   4267 
   4268         # Coverage map cells' indices
   4269         # [num_tx, num_samples, 2 : xy]
   4270         cell_ind = self._mp_hit_point_2_cell_ind(rot_gcs_2_cm, cm_center,
   4271                             cm_size, cm_cell_size, num_cells, diff_hit_points)
   4272         # Add the transmitter index to the coverage map
   4273         # [num_tx]
   4274         tx_ind = tf.range(num_tx, dtype=tf.int32)
   4275         # [num_tx, 1, 1]
   4276         tx_ind = expand_to_rank(tx_ind, 3)
   4277         # [num_tx, num_samples, 1]
   4278         tx_ind = tf.tile(tx_ind, [1, num_samples, 1])
   4279         # [num_tx, num_samples, 3]
   4280         cm_ind = tf.concat([tx_ind, cell_ind], axis=-1)
   4281 
   4282         # Wedges lengths
   4283         # [num_tx, num_samples]
   4284         lengths = tf.gather(self._wedges_length, diff_wedges_ind)
   4285 
   4286         # Wedges opening angles
   4287         # [num_tx, num_samples, 2, 3]
   4288         normals = tf.gather(self._wedges_normals, diff_wedges_ind)
   4289         # [num_tx, num_samples]
   4290         cos_op_angle = dot(normals[...,0,:],normals[...,1,:], clip=True)
   4291         op_angles = PI + tf.math.acos(cos_op_angle)
   4292 
   4293         # Update the weights of each ray power
   4294         # [1, num_samples]
   4295         diff_num_samples_per_wedge = tf.expand_dims(diff_num_samples_per_wedge,
   4296                                                     axis=0)
   4297         diff_num_samples_per_wedge = tf.cast(diff_num_samples_per_wedge,
   4298                                              self._rdtype)
   4299         # [num_tx, num_samples]
   4300         diff_samples_weights = tf.math.divide_no_nan(diff_samples_weights,
   4301                                                      diff_num_samples_per_wedge)
   4302         diff_samples_weights = diff_samples_weights*lengths*op_angles
   4303 
   4304         # Add the weighted powers to the coverage map
   4305         # [num_tx, num_samples]
   4306         weighted_sample_power = diff_samples_power*diff_samples_weights
   4307         # [num_tx, num_cells_y+1, num_cells_x+1]
   4308         cm = tf.tensor_scatter_nd_add(cm, cm_ind, weighted_sample_power)
   4309 
   4310         # Dump the dummy line and row
   4311         # [num_tx, num_cells_y, num_cells_x]
   4312         cm = cm[:,:num_cells_y,:num_cells_x]
   4313 
   4314         # Scaling by area of a cell
   4315         # [num_tx, num_cells_y, num_cells_x]
   4316         cm = cm / cell_area
   4317 
   4318         return cm
   4319 
   4320     def _diff_samples_2_coverage_map(self, los_primitives, edge_diffraction,
   4321                                      num_samples, sources_positions, meas_plane,
   4322                                      cm_center, cm_orientation, cm_size,
   4323                                      cm_cell_size, sources_orientations,
   4324                                      rx_orientation, combining_vec,
   4325                                      precoding_vec, etas,
   4326                                      scattering_coefficient):
   4327         r"""
   4328         Computes the coverage map for diffraction.
   4329 
   4330         Input
   4331         ------
   4332         los_primitives: [num_los_primitives], int
   4333             Primitives in LoS.
   4334 
   4335         edge_diffraction : bool
   4336             If set to `False`, only diffraction on wedges, i.e., edges that
   4337             connect two primitives, is considered.
   4338 
   4339         num_samples : int
   4340             Number of rays initially shooted from the wedges.
   4341 
   4342         sources_positions : [num_tx, 3], tf.float
   4343             Coordinates of the sources.
   4344 
   4345         meas_plane : mi.Shape
   4346             Mitsuba rectangle defining the measurement plane
   4347 
   4348         cm_center : [3], tf.float
   4349             Center of the coverage map
   4350 
   4351         cm_orientation : [3], tf.float
   4352             Orientation of the coverage map
   4353 
   4354         cm_size : [2], tf.float
   4355             Scale of the coverage map.
   4356             The width of the map (in the local X direction) is ``cm_size[0]``
   4357             and its map (in the local Y direction) ``cm_size[1]``.
   4358 
   4359         cm_cell_size : [2], tf.float
   4360             Resolution of the coverage map, i.e., width
   4361             (in the local X direction) and height (in the local Y direction) in
   4362             meters of a cell of the coverage map
   4363 
   4364         sources_orientations : [num_tx, 3], tf.float
   4365             Orientations of the sources.
   4366 
   4367         rx_orientation : [3], tf.float
   4368             Orientation of the receiver.
   4369 
   4370         combining_vec : [num_rx_ant], tf.complex
   4371             Combining vector.
   4372             If set to `None`, then no combining is applied, and
   4373             the energy received by all antennas is summed.
   4374 
   4375         precoding_vec : [num_tx or 1, num_tx_ant], tf.complex
   4376             Precoding vectors of the transmitters
   4377 
   4378         etas : [num_shape], tf.complex
   4379             Tensor containing the complex relative permittivities of all shapes
   4380 
   4381         scattering_coefficient : [num_shape], tf.float
   4382             Tensor containing the scattering coefficients of all shapes
   4383 
   4384         Output
   4385         -------
   4386         :cm : :class:`~sionna.rt.CoverageMap`
   4387             The coverage maps
   4388         """
   4389 
   4390         # Build empty coverage map
   4391         num_cells_x = tf.cast(tf.math.ceil(cm_size[0]/cm_cell_size[0]),
   4392                               tf.int32)
   4393         num_cells_y = tf.cast(tf.math.ceil(cm_size[1]/cm_cell_size[1]),
   4394                               tf.int32)
   4395         # [num_tx, num_cells_y, num_cells_x]
   4396         cm_null = tf.zeros([sources_positions.shape[0], num_cells_y,
   4397                             num_cells_x], dtype=self._rdtype)
   4398 
   4399         # Get the candidate wedges for diffraction
   4400         # diff_wedges_ind : [num_candidate_wedges], int
   4401         #     Candidate wedges indices
   4402         diff_wedges_ind = self._wedges_from_primitives(los_primitives,
   4403                                                         edge_diffraction)
   4404         # Early stop if there are no wedges
   4405         if diff_wedges_ind.shape[0] == 0:
   4406             return cm_null
   4407 
   4408         # Discard wedges for which the tx is inside the wedge
   4409         # diff_mask : [num_tx, num_candidate_wedges], bool
   4410         #   Mask set to False if the wedge is invalid
   4411         # wedges : [num_candidate_wedges], int
   4412         #     Candidate wedges indices
   4413         output = self._discard_obstructing_wedges(diff_wedges_ind,
   4414                                                     sources_positions)
   4415         diff_mask = output[0]
   4416         diff_wedges_ind = output[1]
   4417         # Early stop if there are no wedges
   4418         if diff_wedges_ind.shape[0] == 0:
   4419             return cm_null
   4420 
   4421         # Sample diffraction points on the wedges
   4422         # diff_mask : [num_tx, num_candidate_wedges], bool
   4423         #   Mask set to False if the wedge is invalid
   4424         # diff_wedges_ind : [num_candidate_wedges], int
   4425         #     Candidate wedges indices
   4426         # diff_ells : [num_samples], float
   4427         #   Positions of the diffraction points on the wedges.
   4428         #   These positionsare given as an offset from the wedges origins.
   4429         #   The size of this tensor is in general slighly smaller than
   4430         #   `num_samples` because of roundings.
   4431         # diff_vertex : [num_samples, 3], tf.float
   4432         #   Positions of the diffracted points in the GCS
   4433         # diff_num_samples_per_wedge : [num_samples], tf.int
   4434         #         For each sample, total mumber of samples that were sampled
   4435         #         on the same wedge
   4436         output = self._sample_wedge_points(diff_mask, diff_wedges_ind,
   4437                                             num_samples)
   4438         diff_mask = output[0]
   4439         diff_wedges_ind = output[1]
   4440         diff_ells = output[2]
   4441         diff_vertex = output[3]
   4442         diff_num_samples_per_wedge = output[4]
   4443 
   4444         # Test for blockage between the transmitters and diffraction points.
   4445         # Discarted blocked samples.
   4446         # diff_mask : [num_tx, num_candidate_wedges], bool
   4447         #   Mask set to False if the wedge is invalid
   4448         # diff_wedges_ind : [num_samples], int
   4449         #     Candidate wedges indices
   4450         # diff_ells : [num_samples], float
   4451         #   Positions of the diffraction points on the wedges.
   4452         #   These positionsare given as an offset from the wedges origins.
   4453         #   The size of this tensor is in general slighly smaller than
   4454         #   `num_samples` because of roundings.
   4455         # diff_vertex : [num_samples, 3], float
   4456         #   Positions of the diffracted points in the GCS
   4457         # diff_num_samples_per_wedge : [num_samples], tf.int
   4458         #         For each sample, total mumber of samples that were sampled
   4459         #         on the same wedge
   4460         output = self._test_tx_visibility(diff_mask, diff_wedges_ind,
   4461                                             diff_ells,
   4462                                             diff_vertex,
   4463                                             diff_num_samples_per_wedge,
   4464                                             sources_positions)
   4465         diff_mask = output[0]
   4466         diff_wedges_ind = output[1]
   4467         diff_ells = output[2]
   4468         diff_vertex = output[3]
   4469         diff_num_samples_per_wedge = output[4]
   4470         # Early stop if there are no wedges
   4471         if diff_wedges_ind.shape[0] == 0:
   4472             return cm_null
   4473 
   4474         # Samples angles for departure on the diffraction cone
   4475         # diff_phi : [num_samples, 3], tf.float
   4476         #   Sampled angles on the diffraction cone used for shooting rays
   4477         diff_phi = self._sample_diff_angles(diff_wedges_ind)
   4478 
   4479         # Shoot rays in the sampled directions and test for intersection
   4480         # with the coverage map.
   4481         # Discard rays that miss it.
   4482         # diff_mask : [num_tx, num_samples], tf.bool
   4483         #     Mask set to False for invalid samples
   4484         # diff_wedges_ind : [num_samples], tf.int
   4485         #     Indices of the wedges that interacted with the diffracted
   4486         #     paths
   4487         # diff_ells : [num_samples], tf.float
   4488         #     Positions of the diffraction points on the wedges.
   4489         #     These positions are given as an offset from the wedges
   4490         #     origins.
   4491         # diff_phi : [num_samples], tf.float
   4492         #     Sampled angles of diffracted rays on the diffraction cone
   4493         # diff_vertex : [num_samples, 3], tf.float
   4494         #     Positions of the diffracted points in the GCS
   4495         # diff_num_samples_per_wedge : [num_samples], tf.int
   4496         #         For each sample, total mumber of samples that were sampled
   4497         #         on the same wedge
   4498         # diff_hit_points : [num_tx, num_samples, 3], tf.float
   4499         #     Positions of the intersection of the diffracted rays and
   4500         #     coverage map
   4501         # diff_cone_angle : [num_tx, num_samples], tf.float
   4502         #     Angle between e_hat and the diffracted ray direction.
   4503         #     Takes value in (0,pi).
   4504         output = self._shoot_diffracted_rays(diff_mask, diff_wedges_ind,
   4505                                              diff_ells,
   4506                                              diff_vertex,
   4507                                              diff_num_samples_per_wedge,
   4508                                              diff_phi,
   4509                                              sources_positions,
   4510                                              meas_plane)
   4511         diff_mask = output[0]
   4512         diff_wedges_ind = output[1]
   4513         diff_ells = output[2]
   4514         diff_phi = output[3]
   4515         diff_vertex = output[4]
   4516         diff_num_samples_per_wedge = output[5]
   4517         diff_hit_points = output[6]
   4518         diff_cone_angle = output[7]
   4519 
   4520         # Computes the weights for averaging the field powers of the samples
   4521         # to compute the Monte Carlo estimate of the integral of the
   4522         # diffracted field power over the measurement plane.
   4523         # These weights are required as the measurement plane is
   4524         # parametrized by the angle on the diffraction cones (phi) and
   4525         # position on the wedges (ell).
   4526         #
   4527         # diff_samples_weights : [num_tx, num_samples], tf.float
   4528         #     Weights for averaging the field powers of the samples.
   4529         output = self._compute_samples_weights(cm_center,
   4530                                                cm_orientation,
   4531                                                sources_positions,
   4532                                                diff_wedges_ind,
   4533                                                diff_ells,
   4534                                                diff_phi,
   4535                                                diff_cone_angle)
   4536         diff_samples_weights = output
   4537 
   4538         # Computes the power of the diffracted paths.
   4539         #
   4540         # diff_samples_power : [num_tx, num_samples], tf.float
   4541         #   Powers of the samples of diffracted rays.
   4542         output = self._compute_diffracted_path_power(sources_positions,
   4543                                                      sources_orientations,
   4544                                                      rx_orientation,
   4545                                                      combining_vec,
   4546                                                      precoding_vec,
   4547                                                      diff_mask,
   4548                                                      diff_wedges_ind,
   4549                                                      diff_vertex,
   4550                                                      diff_hit_points,
   4551                                                      etas,
   4552                                                      scattering_coefficient)
   4553         diff_samples_power = output
   4554 
   4555         # Builds the coverage map for the diffracted field
   4556         cm_diff = self._build_diff_coverage_map(cm_center,
   4557                                                 cm_orientation,
   4558                                                 cm_size,
   4559                                                 cm_cell_size,
   4560                                                 diff_wedges_ind,
   4561                                                 diff_hit_points,
   4562                                                 diff_samples_power,
   4563                                                 diff_samples_weights,
   4564                                                 diff_num_samples_per_wedge)
   4565 
   4566         return cm_diff