anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

solver_paths.py (224512B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 Ray tracing algorithm that uses the image method to compute all pure reflection
      7 paths.
      8 """
      9 
     10 import mitsuba as mi
     11 import drjit as dr
     12 import tensorflow as tf
     13 
     14 from sionna import config
     15 from sionna.constants import SPEED_OF_LIGHT, PI
     16 from sionna.utils.tensors import expand_to_rank, insert_dims, flatten_dims,\
     17     split_dim
     18 from .paths import Paths
     19 from .utils import dot, phi_hat, theta_hat, theta_phi_from_unit_vec,\
     20     normalize, moller_trumbore, component_transform, mi_to_tf_tensor,\
     21         compute_field_unit_vectors, reflection_coefficient, fibonacci_lattice,\
     22             cot, cross, sign, rotation_matrix, acos_diff
     23 from .solver_base import SolverBase
     24 from .scattering_pattern import ScatteringPattern
     25 
     26 
     27 class PathsTmpData:
     28     r"""
     29     Class used to temporarily store values for paths calculation.
     30     """
     31 
     32     def __init__(self, sources, targets, dtype):
     33 
     34         self.sources = sources
     35         self.targets = targets
     36         self.dtype = dtype
     37         num_sources = tf.shape(sources)[0]
     38         num_targets = tf.shape(targets)[0]
     39 
     40         # [max_depth, num_targets, num_sources, max_num_paths, 3] or
     41         # [max_depth, num_targets, num_sources, max_num_paths, 2, 3], tf.float
     42         #     Reflected or scattered paths: Normals to the primitives at the
     43         #     intersection points.
     44         #     Diffracted paths: Normals to the two primitives forming the wedge.
     45         self.normals = tf.zeros([0, num_targets, num_sources, 0, 3],
     46                              dtype.real_dtype)
     47 
     48         # [max_depth + 1, num_targets, num_sources, max_num_paths, 3], tf.float
     49         #   Direction of arrivals.
     50         #   The last item (k_i[max_depth]) correspond to the direction of
     51         #   arrival at the target. Therefore, k_i is a tensor of length
     52         #   `max_depth + 1`, where `max_depth` is the number of maximum
     53         #   interaction (which could be zero if only LoS is requested).
     54         self.k_i = tf.zeros([0, num_targets, num_sources, 0, 3],
     55                              dtype.real_dtype)
     56 
     57         # [max_depth, num_targets, num_sources, max_num_paths, 3], tf.float
     58         #   Direction of departures at interaction points.
     59         #   We do not need the direction of departure at the source, as it
     60         #   is the same as k_i[0].
     61         self.k_r = tf.zeros([0, num_targets, num_sources, 0, 3],
     62                              dtype.real_dtype)
     63 
     64         # [max_depth+1, num_targets, num_sources, max_num_paths] or
     65         # [max_depth, num_targets, num_sources, max_num_paths] for scattering,
     66         # tf.float
     67         #     Lengths in meters of the paths segments
     68         self.total_distance = tf.zeros([num_targets, num_sources, 0],
     69                              dtype.real_dtype)
     70 
     71         # [num_targets, num_sources, max_num_paths, 2, 2] or
     72         # [num_rx, rx_array_size, num_tx, tx_array_size, max_num_paths, 2, 2],
     73         # tf.complex
     74         #     Channel transition matrix
     75         # These are initialized to emtpy tensors to handle cases where no
     76         # paths are found
     77         self.mat_t = tf.zeros([num_targets, num_sources, 0, 2, 2], dtype)
     78 
     79         # [num_targets, num_sources, max_num_paths, 3], tf.float
     80         #   Direction of departure. This vector is normalized and pointing
     81         #   awat from the radio device.
     82         # These are initialized to emtpy tensors to handle cases where no
     83         # paths are found
     84         self.k_tx = tf.zeros([num_targets, num_sources, 0, 3],
     85                              dtype.real_dtype)
     86 
     87         # [num_targets, num_sources, max_num_paths, 3], tf.float
     88         #   Direction of arrival. This vector is normalized and pointing
     89         #   awat from the radio device.
     90         # These are initialized to emtpy tensors to handle cases where no
     91         # paths are found
     92         self.k_rx = tf.zeros([num_targets, num_sources, 0, 3],
     93                              dtype.real_dtype)
     94 
     95         # [max_depth, num_targets, num_sources, max_num_paths], tf.bool
     96         #   This parameter is specific to scattering.
     97         #   For scattering, every path prefix is a potential final path.
     98         #   This tensor is a mask which indicates for every path prefix if it
     99         #   is a valid path.
    100         self.scat_prefix_mask = tf.fill([0, num_targets, num_sources, 0], False)
    101 
    102         # [max_depth, num_targets, num_sources, max_num_paths, 3], tf.float
    103         #   For every intersection point between the paths and the scene,
    104         #   gives the direction of the scattered ray, i.e., points towards the
    105         #   targets.
    106         self.scat_prefix_k_s = tf.zeros([0, num_targets, num_sources, 0, 3],
    107                              dtype.real_dtype)
    108 
    109         # [num_targets, num_sources, max_num_paths]
    110         #   This parameter is specific to scattering.
    111         #   Stores the index of the last hit object for retreiving the
    112         #   scattering properties of the objects
    113         self.scat_last_objects = tf.zeros([num_targets, num_sources, 0],
    114                                           tf.int32)
    115 
    116         # [num_targets, num_sources, max_num_paths, 3]
    117         #   This parametric is specific to scattering.
    118         #   Stores the position of the last intersection, i.e., the points at
    119         #   which the field is scattered
    120         self.scat_last_vertices = tf.zeros([num_targets, num_sources, 0, 3],
    121                              dtype.real_dtype)
    122 
    123         # [num_targets, num_sources, max_num_paths, 3]
    124         #   This parameter is specific to scattering.
    125         #   Stores the incoming vector for the last interaction, i.e., the
    126         #   one that scatters the field
    127         self.scat_last_k_i = tf.zeros([num_targets, num_sources, 0, 3],
    128                              dtype.real_dtype)
    129 
    130         # [num_targets, num_sources, max_num_paths, 3]
    131         #   This parameter is specific to scattering.
    132         #   Stores the outgoing vector for the last interaction, i.e., the
    133         #   direction of the scattered ray.
    134         self.scat_k_s = tf.zeros([num_targets, num_sources, 0, 3],
    135                              dtype.real_dtype)
    136 
    137         # [num_targets, num_sources, max_num_paths, 3]
    138         #   This parameter is specific to scattering.
    139         #   Stores the normals to the last interaction point, i.e., the
    140         #   scattering point
    141         self.scat_last_normals = tf.zeros([num_targets, num_sources, 0, 3],
    142                              dtype.real_dtype)
    143 
    144         # [num_targets, num_sources, max_num_paths]
    145         #   This parameter is specific to scattering.
    146         #   Stores the distance from the sources to the scattering points.
    147         self.scat_src_2_last_int_dist = tf.zeros([num_targets, num_sources, 0],
    148                              dtype.real_dtype)
    149 
    150         # [num_targets, num_sources, max_num_paths]
    151         #   This parameter is specific to scattering.
    152         #   Stores the distance from the scattering points to the targets.
    153         self.scat_2_target_dist = tf.zeros([num_targets, num_sources, 0],
    154                              dtype.real_dtype)
    155 
    156         # Number of samples, i.e., shooted rays
    157         # (), tf.int
    158         self.num_samples = 0
    159 
    160         # Probability with which scattered paths are kept
    161         # (), tf.float
    162         self.scat_keep_prob = 0.0
    163 
    164     def to_dict(self):
    165         # pylint: disable=line-too-long
    166         r"""
    167         Returns the properties of the paths as a dictionary which values are
    168         tensors
    169 
    170         Output
    171         -------
    172         : `dict`
    173             Dictionary defining the paths
    174         """
    175         members_names = dir(self)
    176         members_objects = [getattr(self, attr) for attr in members_names]
    177         data = {attr_name : attr_obj for (attr_obj, attr_name)
    178                 in zip(members_objects,members_names)
    179                 if not callable(attr_obj) and
    180                    not attr_name.startswith("__") and
    181                    not isinstance(attr_obj, tf.DType)}
    182         return data
    183 
    184     def from_dict(self, data_dict):
    185         # pylint: disable=line-too-long
    186         r"""
    187         Set the paths from a dictionary which values are tensors
    188 
    189         The format of the dictionary is expected to be the same as the one
    190         returned by :meth:`~sionna.rt.Paths.to_dict()`.
    191 
    192         Input
    193         ------
    194         data_dict : `dict`
    195             Dict of tensors
    196         """
    197         for attr_name in data_dict:
    198             attr_obj = data_dict[attr_name]
    199             setattr(self, attr_name, attr_obj)
    200 
    201 class SolverPaths(SolverBase):
    202     # pylint: disable=line-too-long
    203     r"""SolverPaths(scene, solver=None, dtype=tf.complex64)
    204 
    205     Generates propagation paths consisting of the line-of-sight (LoS) paths,
    206     specular, and diffracted paths for the currently loaded scene.
    207 
    208     The main inputs of the solver are:
    209 
    210     * A set of sources, from which rays are emitted.
    211 
    212     * A set of targets, at which rays are received.
    213 
    214     * A maximum depth, corresponding to the maximum number of reflections. A
    215     depth of zero corresponds to LoS.
    216 
    217     Generation of paths is carried-out for every link, i.e., for every pair of
    218     source and target.
    219 
    220     The genration of specular paths consists in three steps:
    221 
    222     1. A list of candidate paths is generated. A candidate consists in a
    223     sequence of primitives on which a ray emitted by a source sequentially
    224     reflects until it reaches a target.
    225 
    226     2. The image method is applied to every candidates (in parallel) to discard
    227     candidates that do not correspond to valid paths, either because they
    228     are obstructed by another object in the scene, or because a reflection on
    229     one of the primitive in the sequence is impossible (reflection point outside
    230     of the primitive).
    231 
    232     3. For the valid paths, Fresnel coefficients for reflections are computed,
    233     considering the materials of the intersected objects, to compute transfer
    234     matrices for every paths.
    235 
    236     For diffracted paths, after step 1.:
    237 
    238     2. The wedges of primitives in LoS are selected, i.e., primitives for which
    239     a direct connection with the sources was found at step 1
    240 
    241     3. The intersection point of the diffracted path on the wedge is computed.
    242     This is the point that minimizes the total length of the paths.
    243     Paths for which the diffraction point does not belong to the finite wedge
    244     are discarded.
    245 
    246     4. Obstruction test: Paths that are blocked are discarded.
    247 
    248     5. The transmition matrices are computed, as well as the delays and angles
    249     of arrival and departure.
    250 
    251     The output of the solver consists in, for every valid path that was found:
    252 
    253     * A transfer matrix, which is a 2x2 complex-valued matrix that describes the
    254     linear transformation incurred by the emitted field. The two dimensions
    255     correspond to the two polarization components (S and P).
    256 
    257     * A delay
    258 
    259     * Azimuth and zenith angles of arrival
    260 
    261     * Azimuth and zenith angles of departure
    262 
    263     Concerning the first step, two search methods are available for the
    264     listing of candidates:
    265 
    266     * Exhaustive search, which lists all possible combinations of primitives up
    267     to the requested maximum depth. This method is deterministic and ensures
    268     that all paths are found. However, its complexity increases exponentially
    269     with the number of primitives and with the maximum depth. Therefore, it
    270     only works for scenes of low complexity and/or for small depth values.
    271 
    272     * Fibonacci sampling, which find candidates by shooting and bouncing rays,
    273     and such that initial directions of rays shot from the sources are arranged
    274     in a Fibonacci lattice on the unit sphere. At every intersection with a
    275     primitive, the rays are bounced assuming perfectly specular reflections
    276     until the maximum depth is reached. The intersected primitives makes the
    277     candidate. This method can be applied to very large scenes. However, there
    278     is no guarantee that all possible paths are found.
    279 
    280     Note: Only triangle mesh are supported.
    281 
    282     Parameters
    283     -----------
    284     scene : :class:`~sionna.rt.Scene`
    285         Sionna RT scene
    286 
    287     solver : :class:`~sionna.rt.SolverBase` | None
    288         Another solver from which to re-use some structures to avoid useless
    289         compute and memory use
    290 
    291     dtype : tf.complex64 | tf.complex128
    292         Datatype for all computations, inputs, and outputs.
    293         Defaults to `tf.complex64`.
    294 
    295     Input
    296     ------
    297     max_depth : int
    298         Maximum depth (i.e., number of interaction with objects in the scene)
    299         allowed for tracing the paths.
    300 
    301     sources : [num_sources, 3], tf.float
    302         Coordinates of the sources.
    303 
    304     targets : [num_targets, 3], tf.float
    305         Coordinates of the targets.
    306 
    307     method : str ("exhaustive"|"fibonacci")
    308         Method to be used to list candidate paths.
    309         The "exhaustive" method tests all possible combination of primitives as
    310         paths. This method is not compatible with scattering.
    311         The "fibonacci" method uses a shoot-and-bounce approach to find
    312         candidate chains of primitives. Intial rays direction are arranged
    313         in a Fibonacci lattice on the unit sphere. This method can be
    314         applied to very large scenes. However, there is no guarantee that
    315         all possible paths are found.
    316 
    317     num_samples: int
    318         Number of random rays to trace in order to generate candidates.
    319         A large sample count may exhaust GPU memory.
    320 
    321     los : bool
    322         If set to `True`, then the LoS paths are computed.
    323 
    324     reflection : bool
    325         If set to `True`, then the reflected paths are computed.
    326 
    327     diffraction : bool
    328         If set to `True`, then the diffracted paths are computed.
    329 
    330     scattering : bool
    331         if set to `True`, then the scattered paths are computed.
    332         Only works with the Fibonacci method.
    333 
    334     ris : bool
    335         If set to `True`, then the paths involving RIS are computed.
    336 
    337     scat_keep_prob : float
    338         Probability with which to keep scattered paths.
    339         This is helpful to reduce the number of scattered paths computed,
    340         which might be prohibitively high in some setup.
    341         Must be in the range (0,1).
    342 
    343     edge_diffraction : bool
    344         If set to `False`, only diffraction on wedges, i.e., edges that
    345         connect two primitives, is considered.
    346 
    347     scat_random_phases : bool
    348         If set to `True` and if scattering is enabled, random uniform phase
    349         shifts are added to the scattered paths.
    350 
    351     Output
    352     -------
    353     paths : Paths
    354         The computed paths.
    355     """
    356 
    357 
    358     def trace_paths(self, max_depth, method, num_samples, los, reflection,
    359                     diffraction, scattering, ris, scat_keep_prob,
    360                     edge_diffraction):
    361         # pylint: disable=line-too-long
    362         r"""
    363         Traces the paths.
    364 
    365         Computes the trajectories of the paths by shooting rays.
    366         No EM field computation is performed by this function.
    367 
    368         Input
    369         ------
    370         max_depth : int
    371             Maximum depth (i.e., number of interaction with objects in the scene)
    372             allowed for tracing the paths.
    373 
    374         method : str ("exhaustive"|"fibonacci")
    375             Method to be used to list candidate paths.
    376             The "exhaustive" method tests all possible combination of primitives as
    377             paths. This method is not compatible with scattering.
    378             The "fibonacci" method uses a shoot-and-bounce approach to find
    379             candidate chains of primitives. Intial rays direction are arranged
    380             in a Fibonacci lattice on the unit sphere. This method can be
    381             applied to very large scenes. However, there is no guarantee that
    382             all possible paths are found.
    383 
    384         num_samples: int
    385             Number of random rays to trace in order to generate candidates.
    386             A large sample count may exhaust GPU memory.
    387 
    388         los : bool
    389             If set to `True`, then the LoS paths are computed.
    390 
    391         reflection : bool
    392             If set to `True`, then the reflected paths are computed.
    393 
    394         diffraction : bool
    395             If set to `True`, then the diffracted paths are computed.
    396 
    397         scattering : bool
    398             if set to `True`, then the scattered paths are computed.
    399             Only works with the Fibonacci method.
    400 
    401         ris : bool
    402             If set to `True`, then the paths involving RIS are computed.
    403 
    404         scat_keep_prob : float
    405             Probability with which to keep scattered paths.
    406             This is helpful to reduce the number of scattered paths computed,
    407             which might be prohibitively high in some setup.
    408             Must be in the range (0,1).
    409 
    410         edge_diffraction : bool
    411             If set to `False`, only diffraction on wedges, i.e., edges that
    412             connect two primitives, is considered.
    413 
    414         Output
    415         -------
    416         spec_paths : Paths
    417             The computed specular paths
    418 
    419         diff_paths : Paths
    420             The computed diffracted paths
    421 
    422         scat_paths : Paths
    423             The computed scattered paths
    424 
    425         ris_paths : :class:`~sionna.rt.Paths`
    426             Computed paths involving RIS
    427 
    428         spec_paths_tmp : PathsTmpData
    429             Additional data required to compute the EM fields of the specular
    430             paths
    431 
    432         diff_paths_tmp : PathsTmpData
    433             Additional data required to compute the EM fields of the diffracted
    434             paths
    435 
    436         scat_paths_tmp : PathsTmpData
    437             Additional data required to compute the EM fields of the scattered
    438             paths
    439 
    440         ris_paths_tmp : :class:`~sionna.rt.PathsTmpData`
    441             Additional data required to compute the EM fields of the paths
    442             involving RIS
    443         """
    444         scat_keep_prob = tf.cast(scat_keep_prob, self._rdtype)
    445         # Disable scattering if the probability of keeping a path is 0
    446         scattering = tf.logical_and(scattering,
    447                     tf.greater(scat_keep_prob, tf.zeros_like(scat_keep_prob)))
    448 
    449         # If reflection and scattering are disabled, no need for a max_depth
    450         # higher than 1.
    451         # This clipping can save some compute for the shoot-and-bounce
    452         if (not reflection) and (not scattering):
    453             max_depth = tf.minimum(max_depth, 1)
    454 
    455         # Rotation matrices corresponding to the orientations of the radio
    456         # devices
    457         # rx_rot_mat : [num_rx, 3, 3]
    458         # tx_rot_mat : [num_tx, 3, 3]
    459         rx_rot_mat, tx_rot_mat = self._get_tx_rx_rotation_matrices()
    460 
    461         #################################################
    462         # Prepares the sources (from which rays are shot)
    463         # and targets (which capture the rays)
    464         #################################################
    465 
    466         if not self._scene.synthetic_array:
    467             # Relative positions of the antennas of the transmitters and
    468             # receivers
    469             # rx_rel_ant_pos: [num_rx, rx_array_size, 3], tf.float
    470             #     Relative positions of the receivers antennas
    471             # tx_rel_ant_pos: [num_tx, rx_array_size, 3], tf.float
    472             #     Relative positions of the transmitters antennas
    473             rx_rel_ant_pos, tx_rel_ant_pos =\
    474                 self._get_antennas_relative_positions(rx_rot_mat, tx_rot_mat)
    475 
    476         # Transmitters and receivers positions
    477         # [num_tx, 3]
    478         tx_pos = [tx.position for tx in self._scene.transmitters.values()]
    479         tx_pos = tf.stack(tx_pos, axis=0)
    480         # [num_rx, 3]
    481         rx_pos = [rx.position for rx in self._scene.receivers.values()]
    482         rx_pos = tf.stack(rx_pos, axis=0)
    483 
    484         if self._scene.synthetic_array:
    485             # With synthetic arrays, each radio device corresponds to a single
    486             # endpoint (source or target)
    487             # [num_sources = num_tx, 3]
    488             sources = tx_pos
    489             # [num_targets = num_rx, 3]
    490             targets = rx_pos
    491         else:
    492             # [num_tx, tx_array_size, 3]
    493             sources = tf.expand_dims(tx_pos, axis=1) + tx_rel_ant_pos # pylint: disable=possibly-used-before-assignment
    494             # [num_sources = num_tx*tx_array_size, 3]
    495             sources = tf.reshape(sources, [-1, 3])
    496             # [num_rx, rx_array_size, 3]
    497             targets = tf.expand_dims(rx_pos, axis=1) + rx_rel_ant_pos # pylint: disable=possibly-used-before-assignment
    498             # [num_targets = num_rx*rx_array_size, 3]
    499             targets = tf.reshape(targets, [-1, 3])
    500 
    501         ##############################################
    502         # Builds the Mitsuba scene with RIS for
    503         # testing intersections with RIS
    504         ##############################################
    505         ris_objects, _ = self._build_mi_ris_objects()
    506 
    507         ##############################################
    508         # Generate candidate paths
    509         ##############################################
    510 
    511         # Candidate paths are generated according to the specified `method`.
    512         if method == 'exhaustive':
    513             if scattering:
    514                 msg = "The exhaustive method is not compatible with scattering"
    515                 raise ValueError(msg)
    516             # List all possible sequences of primitives with length up to
    517             # ``max_depth``
    518             # candidates: [max_depth, num_samples], int
    519             #     All possible candidate paths with depth up to ``max_depth``.
    520             # los_candidates: [num_samples], int
    521             #     Primitives in LoS. For the exhaustive method, this is the
    522             #     list of all the primitives in the scene.
    523             candidates, los_prim = self._list_candidates_exhaustive(max_depth,
    524                                                                 los, reflection)
    525             candidates_scat = None
    526             hit_points = None
    527         elif method == 'fibonacci':
    528             # Sample sequences of primitives using shoot-and-bounce
    529             # with length up to ``max_depth`` and by arranging the initial
    530             # rays direction in a Fibonacci lattice on the unit sphere.
    531             # candidates: [max_depth, num paths], int
    532             #     All unique candidate paths found, with depth up to
    533             #       ``max_depth``.
    534             # los_candidates: [num_samples], int
    535             #     Candidate primitives found in LoS.
    536             # candidates_scat : [max_depth, num_sources, num_paths_per_source]
    537             #       Sequence of primitives hit at `hit_points`.
    538             # hit_points : [max_depth, num_sources, num_paths_per_source, 3]
    539             #     Coordinates of the intersection points.
    540             output = self._list_candidates_fibonacci(max_depth,
    541                                         sources, num_samples, los, reflection,
    542                                         scattering, ris_objects)
    543             candidates = output[0]
    544             los_prim = output[1]
    545             candidates_scat = output[2]
    546             hit_points = output[3]
    547 
    548         else:
    549             raise ValueError(f"Unknown method '{method}'")
    550 
    551         ##############################################
    552         # LoS and Specular paths
    553         ##############################################
    554         spec_paths = Paths(sources=sources, targets=targets, scene=self._scene,
    555                            types=Paths.SPECULAR)
    556         spec_paths_tmp = PathsTmpData(sources, targets, self._dtype)
    557         if los or reflection:
    558 
    559             # Using the image method, computes the non-obstructed specular paths
    560             # interacting with the ``candidates`` primitives
    561             self._spec_image_method(candidates, spec_paths, spec_paths_tmp,
    562                                     ris_objects)
    563 
    564             # Compute paths length, delays, angles and directions of arrivals
    565             # and departures for the specular paths
    566             spec_paths, spec_paths_tmp =\
    567                 self._compute_directions_distances_delays_angles(spec_paths,
    568                                                         spec_paths_tmp, False)
    569 
    570         ############################################
    571         # Diffracted paths
    572         ############################################
    573         diff_paths = Paths(sources=sources, targets=targets, scene=self._scene,
    574                            types=Paths.DIFFRACTED)
    575         diff_paths_tmp = PathsTmpData(sources, targets, self._dtype)
    576         if (los_prim is not None) and diffraction:
    577 
    578             # Get the candidate wedges for diffraction
    579             # Note: Only one-order diffraction is supported. Therefore, we
    580             # restrict the candidate wedges to the ones of primitives in
    581             # line-of-sight with the transmitter
    582             # candidate_wedges : [num_candidate_wedges], int
    583             #     Candidate wedges indices
    584             diff_wedges_indices = self._wedges_from_primitives(los_prim,
    585                                                                edge_diffraction)
    586 
    587             # Discard paths for which at least one of the transmitter or
    588             # receiver is inside the wedge.
    589             # diff_wedges_indices : [num_targets, num_sources, max_num_paths]
    590             #   Indices of the intersected wedges
    591             diff_wedges_indices = self._discard_obstructing_wedges_and_corners(
    592                                                 diff_wedges_indices, targets,
    593                                                 sources)
    594 
    595             # Compute the intersection points with the wedges, and discard paths
    596             # for which the intersection point is not on the finite wedge.
    597             # diff_wedges_indices : [num_targets, num_sources, max_num_paths]
    598             #   Indices of the intersected wedges
    599             # diff_vertices : [num_targets, num_sources, max_num_paths, 3]
    600             #   Position of the intersection point on the wedges
    601             diff_wedges_indices, diff_vertices =\
    602                 self._compute_diffraction_points(targets, sources,
    603                                                  diff_wedges_indices)
    604 
    605 
    606             # Discard obstructed diffracted paths
    607             # Only check for wedge visibility if there is at least one candidate
    608             # diffracted path
    609             if diff_wedges_indices.shape[2] > 0: # Number of diff. paths > 0
    610                 # Discard obstructed paths
    611                 diff_wedges_indices, diff_vertices =\
    612                     self._check_wedges_visibility(targets, sources,
    613                                                   diff_wedges_indices,
    614                                                   diff_vertices,
    615                                                   ris_objects)
    616 
    617             diff_paths = Paths(sources=sources, targets=targets,
    618                                scene=self._scene, types=Paths.DIFFRACTED)
    619             diff_paths.objects = tf.expand_dims(diff_wedges_indices, axis=0)
    620             diff_paths.vertices = tf.expand_dims(diff_vertices, axis=0)
    621 
    622             # Select only the valid paths
    623             diff_paths = self._gather_valid_diff_paths(diff_paths)
    624 
    625             # Computes paths length, delays, angles and directions of arrivals
    626             # and departures for the specular paths
    627             diff_paths, diff_paths_tmp =\
    628                 self._compute_directions_distances_delays_angles(diff_paths,
    629                                                         diff_paths_tmp, False)
    630 
    631         ############################################
    632         # Scattered paths
    633         ############################################
    634         scat_paths = Paths(sources=sources, targets=targets, scene=self._scene,
    635                            types=Paths.SCATTERED)
    636         scat_paths_tmp = PathsTmpData(sources, targets, self._dtype)
    637         if scattering and tf.shape(candidates_scat)[0] > 0:
    638 
    639             scat_paths, scat_paths_tmp = self._scat_test_rx_blockage(targets,sources,
    640                                                                 candidates_scat,
    641                                                                 hit_points,
    642                                                                 ris_objects)
    643             scat_paths, scat_paths_tmp =\
    644                 self._compute_directions_distances_delays_angles(scat_paths,
    645                                                                  scat_paths_tmp,
    646                                                                  True)
    647 
    648             scat_paths, scat_paths_tmp =\
    649                 self._scat_discard_crossing_paths(scat_paths, scat_paths_tmp,
    650                                                   scat_keep_prob)
    651 
    652             # Extract the valid prefixes as paths
    653             scat_paths, scat_paths_tmp = self._scat_prefixes_2_paths(scat_paths,
    654                                                                 scat_paths_tmp)
    655         # Additional data required to compute the field
    656         spec_paths_tmp.num_samples = num_samples
    657         spec_paths_tmp.scat_keep_prob = tf.cast(scat_keep_prob, self._rdtype)
    658         diff_paths_tmp.num_samples = num_samples
    659         diff_paths_tmp.scat_keep_prob = tf.cast(scat_keep_prob, self._rdtype)
    660         scat_paths_tmp.num_samples = num_samples
    661         scat_paths_tmp.scat_keep_prob = tf.cast(scat_keep_prob, self._rdtype)
    662 
    663         ##############################################
    664         # RIS paths
    665         ##############################################
    666         ris_paths = Paths(sources=sources, targets=targets, scene=self._scene,
    667                            types=Paths.RIS)
    668         ris_paths_tmp = PathsTmpData(sources, targets, self._dtype)
    669 
    670         if ris and len(self._scene.ris)>0:
    671             ris_paths, ris_paths_tmp = self._ris_paths(ris_paths,
    672                                                        ris_paths_tmp,
    673                                                        ris_objects)
    674 
    675         return spec_paths, diff_paths, scat_paths, ris_paths, spec_paths_tmp,\
    676             diff_paths_tmp, scat_paths_tmp, ris_paths_tmp
    677 
    678     def compute_fields(self, spec_paths, diff_paths, scat_paths, ris_paths,
    679                        spec_paths_tmp, diff_paths_tmp, scat_paths_tmp,
    680                        ris_paths_tmp, scat_random_phases, testing):
    681         r"""
    682         Computes the EM fields for a set of traced paths.
    683 
    684         Input
    685         ------
    686         spec_paths : Paths
    687             Specular paths
    688 
    689         diff_paths : Paths
    690             Diffracted paths
    691 
    692         scat_paths : Paths
    693             Scattered paths
    694 
    695         ris_paths : :class:`~sionna.rt.Paths`
    696             Computed paths involving RIS
    697 
    698         ris_paths : :class:`~sionna.rt.Paths`
    699             Computed paths involving RIS
    700 
    701         spec_paths_tmp : PathsTmpData
    702             Additional data required to compute the EM fields of the specular
    703             paths
    704 
    705         diff_paths_tmp : PathsTmpData
    706             Additional data required to compute the EM fields of the diffracted
    707             paths
    708 
    709         scat_paths_tmp : PathsTmpData
    710             Additional data required to compute the EM fields of the scattered
    711             paths
    712 
    713         ris_paths_tmp : :class:`~sionna.rt.PathsTmpData`
    714             Additional data required to compute the EM fields of the paths
    715             involving RIS
    716 
    717         ris_paths_tmp : :class:`~sionna.rt.PathsTmpData`
    718             Additional data required to compute the EM fields of the paths
    719             involving RIS
    720 
    721         scat_random_phases : bool
    722             If set to `True` and if scattering is enabled, random uniform phase
    723             shifts are added to the scattered paths.
    724 
    725         testing : bool
    726             If set to `True`, then additional data is returned for testing.
    727 
    728         Output
    729         -------
    730         sources : [num_sources, 3], tf.float
    731             Coordinates of the sources
    732 
    733         targets : [num_targets, 3], tf.float
    734             Coordinates of the targets
    735 
    736         list : Paths as a list
    737             The computed paths as a dictionary of tensors, i.e., the output of
    738             `Paths.to_dict()`.
    739             Returning the paths as a list of tensors is required to enable
    740             the execution of this function in graph mode.
    741 
    742         list : PathsTmpData as a list
    743             Additional data required to compute the EM fields of the specular
    744             paths as list of tensors.
    745             Only returned if `testing` is set to `True`.
    746 
    747         list : PathsTmpData as a list
    748             Additional data required to compute the EM fields of the diffracted
    749             paths as list of tensors.
    750             Only returned if `testing` is set to `True`.
    751 
    752         list : PathsTmpData as a list
    753             Additional data required to compute the EM fields of the scattered
    754             paths as list of tensors.
    755             Only returned if `testing` is set to `True`.
    756         """
    757 
    758         sources = spec_paths.sources
    759         targets = spec_paths.targets
    760 
    761         # Create empty paths object
    762         all_paths = Paths(sources=sources,
    763                           targets=targets,
    764                           scene=self._scene)
    765         # Create empty objects for storing tensors that are required to compute
    766         # paths, but that will not be returned to the user
    767         all_paths_tmp = PathsTmpData(sources, targets, self._dtype)
    768 
    769         # Rotation matrices corresponding to the orientations of the radio
    770         # devices
    771         # rx_rot_mat : [num_rx, 3, 3]
    772         # tx_rot_mat : [num_tx, 3, 3]
    773         rx_rot_mat, tx_rot_mat = self._get_tx_rx_rotation_matrices()
    774 
    775         # Number of receive antennas (not counting for dual polarization)
    776         tx_array_size = self._scene.tx_array.array_size
    777         # Number of transmit antennas (not counting for dual polarization)
    778         rx_array_size = self._scene.rx_array.array_size
    779 
    780         #################################################
    781         # Extract the material properties of the scene
    782         #################################################
    783 
    784         # Returns: relative_permittivities, denoted by `etas`,
    785         # scattering_coefficients, xpd_coefficients,
    786         # alpha_r, alpha_i, lambda_, and velocities
    787         object_properties = self._build_scene_object_properties_tensors()
    788         etas = object_properties[0]
    789         scattering_coefficient = object_properties[1]
    790         xpd_coefficient = object_properties[2]
    791         alpha_r = object_properties[3]
    792         alpha_i = object_properties[4]
    793         lambda_ = object_properties[5]
    794         velocity = object_properties[6]
    795 
    796         ##############################################
    797         # LoS and Specular paths
    798         ##############################################
    799 
    800         if spec_paths.objects.shape[3] > 0:
    801 
    802             # Compute the EM transition matrices and Doppler shifts
    803             spec_mat_t = self._spec_transition_matrices(etas,
    804                     scattering_coefficient, spec_paths, spec_paths_tmp, False)
    805             spec_paths.doppler = self._compute_doppler_shifts(spec_paths,
    806                                                               spec_paths_tmp,
    807                                                               velocity)
    808             all_paths = all_paths.merge(spec_paths)
    809             # Only the transition matrix, vector of incidence/reflection, and
    810             # Doppler shifts are required for the computation of the paths
    811             # coefficients
    812             all_paths_tmp.mat_t = tf.concat([all_paths_tmp.mat_t, spec_mat_t],
    813                                             axis=-3)
    814             all_paths_tmp.k_tx = tf.concat([all_paths_tmp.k_tx,
    815                                             spec_paths_tmp.k_tx],
    816                                            axis=-2)
    817             all_paths_tmp.k_rx = tf.concat([all_paths_tmp.k_rx,
    818                                             spec_paths_tmp.k_rx],
    819                                            axis=-2)
    820             # If testing, the transition matrices are also returned
    821             if testing:
    822                 spec_paths_tmp.mat_t = spec_mat_t
    823 
    824         ############################################
    825         # Diffracted paths
    826         ############################################
    827 
    828         if diff_paths.objects.shape[3] > 0:
    829 
    830             # Compute the transition matrices and Doppler shifts
    831             diff_mat_t =\
    832                 self._compute_diffraction_transition_matrices(etas,
    833                             scattering_coefficient, diff_paths, diff_paths_tmp)
    834             diff_paths.doppler = self._compute_doppler_shifts(diff_paths,
    835                                                               diff_paths_tmp,
    836                                                               velocity)
    837             all_paths = all_paths.merge(diff_paths)
    838             # Only the transition matrix and vector of incidence/reflection are
    839             # required for the computation of the paths coefficients
    840             all_paths_tmp.mat_t = tf.concat([all_paths_tmp.mat_t, diff_mat_t],
    841                                             axis=-3)
    842             all_paths_tmp.k_tx = tf.concat([all_paths_tmp.k_tx,
    843                                             diff_paths_tmp.k_tx],
    844                                            axis=-2)
    845             all_paths_tmp.k_rx = tf.concat([all_paths_tmp.k_rx,
    846                                             diff_paths_tmp.k_rx],
    847                                            axis=-2)
    848             # If testing, the transition matrices are also returned
    849             if testing:
    850                 diff_paths_tmp.mat_t = diff_mat_t
    851 
    852         ############################################
    853         # Scattered paths
    854         ############################################
    855 
    856         if scat_paths.objects.shape[3] > 0:
    857 
    858             # Compute transition matrices up to the scattering point
    859             # as well as Doppler shifts
    860             scat_mat_t = self._spec_transition_matrices(etas,
    861                     scattering_coefficient, scat_paths, scat_paths_tmp, True)
    862             scat_paths.doppler = self._compute_doppler_shifts(scat_paths,
    863                                                               scat_paths_tmp,
    864                                                               velocity)
    865 
    866             all_paths = all_paths.merge(scat_paths)
    867             # The transition matrix and vector of incidence/reflection are
    868             # required for the computation of the paths coefficients, as well
    869             # as other scattering specific quantities.
    870             all_paths_tmp.mat_t = tf.concat([all_paths_tmp.mat_t, scat_mat_t],
    871                                             axis=-3)
    872             all_paths_tmp.k_tx = tf.concat([all_paths_tmp.k_tx,
    873                                             scat_paths_tmp.k_tx],
    874                                            axis=-2)
    875             all_paths_tmp.k_rx = tf.concat([all_paths_tmp.k_rx,
    876                                             scat_paths_tmp.k_rx],
    877                                            axis=-2)
    878             all_paths_tmp.scat_last_objects = scat_paths_tmp.scat_last_objects
    879             all_paths_tmp.scat_last_k_i = scat_paths_tmp.scat_last_k_i
    880             all_paths_tmp.scat_k_s = scat_paths_tmp.scat_k_s
    881             all_paths_tmp.scat_last_normals = scat_paths_tmp.scat_last_normals
    882             all_paths_tmp.scat_src_2_last_int_dist\
    883                                 = scat_paths_tmp.scat_src_2_last_int_dist
    884             all_paths_tmp.scat_2_target_dist = scat_paths_tmp.scat_2_target_dist
    885             all_paths_tmp.scat_last_vertices = scat_paths_tmp.scat_last_vertices
    886             # If testing, the transition matrices are also returned
    887             if testing:
    888                 scat_paths_tmp.mat_t = scat_mat_t
    889 
    890         ############################################
    891         # RIS paths
    892         ############################################
    893         if ris_paths.objects.shape[3] > 0:
    894             # Compute the transition matrices and Doppler shifts
    895             ris_mat_t = self._ris_transition_matrices(ris_paths, ris_paths_tmp)
    896             ris_paths.doppler = self._compute_doppler_shifts(ris_paths,
    897                                                              ris_paths_tmp,
    898                                                              velocity)
    899 
    900             all_paths = all_paths.merge(ris_paths)
    901             # Only the transition matrix and vector of incidence/reflection are
    902             # required for the computation of the paths coefficients
    903             all_paths_tmp.mat_t = tf.concat([all_paths_tmp.mat_t, ris_mat_t],
    904                                             axis=-3)
    905             all_paths_tmp.k_tx = tf.concat([all_paths_tmp.k_tx,
    906                                             ris_paths_tmp.k_tx],
    907                                            axis=-2)
    908             all_paths_tmp.k_rx = tf.concat([all_paths_tmp.k_rx,
    909                                             ris_paths_tmp.k_rx],
    910                                            axis=-2)
    911             # If testing, the transition matrices are also returned
    912             if testing:
    913                 ris_paths_tmp.mat_t = ris_mat_t
    914 
    915         #################################################
    916         # Splitting the sources (targets) dimension into
    917         # transmitters (receivers) and antennas, or
    918         # applying the synthetic arrays
    919         #################################################
    920 
    921         # If not using synthetic array, then the paths for the different
    922         # antenna elements were generated and reshaping is needed.
    923         # Otherwise, expand with the antenna dimensions.
    924         # [num_targets, num_sources, max_num_paths]
    925         all_paths.targets_sources_mask = all_paths.mask
    926         if self._scene.synthetic_array:
    927             # [num_rx, num_tx, 2, 2]
    928             mat_t = all_paths_tmp.mat_t
    929             # [num_rx, 1, num_tx, 1, max_num_paths, 2, 2]
    930             mat_t = tf.expand_dims(tf.expand_dims(mat_t, axis=1), axis=3)
    931             all_paths_tmp.mat_t = mat_t
    932         else:
    933             num_rx = len(self._scene.receivers)
    934             num_tx = len(self._scene.transmitters)
    935             max_num_paths = tf.shape(all_paths.vertices)[3]
    936             batch_dims = [num_rx, rx_array_size, num_tx, tx_array_size,
    937                           max_num_paths]
    938             # [num_rx, tx_array_size, num_tx, tx_array_size, max_num_paths]
    939             all_paths.mask = tf.reshape(all_paths.mask, batch_dims)
    940             all_paths.tau = tf.reshape(all_paths.tau, batch_dims)
    941             all_paths.theta_t = tf.reshape(all_paths.theta_t, batch_dims)
    942             all_paths.phi_t = tf.reshape(all_paths.phi_t, batch_dims)
    943             all_paths.theta_r = tf.reshape(all_paths.theta_r, batch_dims)
    944             all_paths.phi_r = tf.reshape(all_paths.phi_r, batch_dims)
    945             all_paths.doppler = tf.reshape(all_paths.doppler, batch_dims)
    946             # [num_rx, rx_array_size, num_tx, tx_array_size, max_num_paths, 2,2]
    947             all_paths_tmp.mat_t = tf.reshape(all_paths_tmp.mat_t,
    948                                              batch_dims + [2,2])
    949             # [num_rx, rx_array_size, num_tx, tx_array_size, max_num_paths, 3]
    950             all_paths_tmp.k_tx = tf.reshape(all_paths_tmp.k_tx, batch_dims+[3])
    951             all_paths_tmp.k_rx = tf.reshape(all_paths_tmp.k_rx, batch_dims+[3])
    952         ####################################################
    953         # Compute the channel coefficients
    954         ####################################################
    955         scat_keep_prob = scat_paths_tmp.scat_keep_prob
    956         num_samples = scat_paths_tmp.num_samples
    957         all_paths.a = self._compute_paths_coefficients(rx_rot_mat,
    958                                                        tx_rot_mat,
    959                                                        all_paths,
    960                                                        all_paths_tmp,
    961                                                        num_samples,
    962                                                        scattering_coefficient,
    963                                                        xpd_coefficient,
    964                                                        etas, alpha_r, alpha_i,
    965                                                        lambda_, scat_keep_prob,
    966                                                        scat_random_phases)
    967 
    968         # If using synthetic array, adds the antenna dimensions by applying
    969         # synthetic phase shifts
    970         if self._scene.synthetic_array:
    971             all_paths.a = self._apply_synthetic_array(rx_rot_mat, tx_rot_mat,
    972                                                       all_paths, all_paths_tmp)
    973 
    974         ##################################################
    975         # If not using synthetic arrays, tile the AoAs,
    976         # AoDs, and delays to handle dual-polarization
    977         ##################################################
    978         if not self._scene.synthetic_array:
    979             num_rx_patterns = len(self._scene.rx_array.antenna.patterns)
    980             num_tx_patterns = len(self._scene.tx_array.antenna.patterns)
    981             # [num_rx, 1,rx_array_size, num_tx, 1,tx_array_size, max_num_paths]
    982             mask = tf.expand_dims(tf.expand_dims(all_paths.mask, axis=2),
    983                                  axis=5)
    984             tau = tf.expand_dims(tf.expand_dims(all_paths.tau, axis=2),
    985                                  axis=5)
    986             theta_t = tf.expand_dims(tf.expand_dims(all_paths.theta_t, axis=2),
    987                                      axis=5)
    988             phi_t = tf.expand_dims(tf.expand_dims(all_paths.phi_t, axis=2),
    989                                    axis=5)
    990             theta_r = tf.expand_dims(tf.expand_dims(all_paths.theta_r, axis=2),
    991                                      axis=5)
    992             phi_r = tf.expand_dims(tf.expand_dims(all_paths.phi_r, axis=2),
    993                                    axis=5)
    994             doppler = tf.expand_dims(tf.expand_dims(all_paths.doppler, axis=2),
    995                                    axis=5)
    996             # [num_rx, num_rx_patterns, rx_array_size, num_tx, num_tx_patterns,
    997             #   tx_array_size, max_num_paths]
    998             mask = tf.tile(mask, [1, num_rx_patterns, 1, 1, num_tx_patterns,
    999                                   1, 1])
   1000             tau = tf.tile(tau, [1, num_rx_patterns, 1, 1, num_tx_patterns,
   1001                                 1, 1])
   1002             theta_t = tf.tile(theta_t, [1, num_rx_patterns, 1, 1,
   1003                                         num_tx_patterns, 1, 1])
   1004             phi_t = tf.tile(phi_t, [1, num_rx_patterns, 1, 1,
   1005                                     num_tx_patterns, 1, 1])
   1006             theta_r = tf.tile(theta_r, [1, num_rx_patterns, 1, 1,
   1007                                         num_tx_patterns, 1, 1])
   1008             phi_r = tf.tile(phi_r, [1, num_rx_patterns, 1, 1,
   1009                                     num_tx_patterns, 1, 1])
   1010             doppler = tf.tile(doppler, [1, num_rx_patterns, 1, 1,
   1011                                     num_tx_patterns, 1, 1])
   1012             # [num_rx, num_rx_ant = num_rx_patterns*num_rx_ant,
   1013             #   ... num_tx, num_tx_ant = num_tx_patterns*tx_array_size,
   1014             #   ... max_num_paths]
   1015             all_paths.mask = flatten_dims(flatten_dims(mask, 2, 1), 2, 3)
   1016             all_paths.tau = flatten_dims(flatten_dims(tau, 2, 1), 2, 3)
   1017             all_paths.theta_t = flatten_dims(flatten_dims(theta_t, 2, 1), 2, 3)
   1018             all_paths.phi_t = flatten_dims(flatten_dims(phi_t, 2, 1), 2, 3)
   1019             all_paths.theta_r = flatten_dims(flatten_dims(theta_r, 2, 1), 2, 3)
   1020             all_paths.phi_r = flatten_dims(flatten_dims(phi_r, 2, 1), 2, 3)
   1021             all_paths.doppler = flatten_dims(flatten_dims(doppler, 2, 1), 2, 3)
   1022 
   1023         # If testing, additinal data is returned
   1024         if testing:
   1025             output = (  sources, targets, all_paths.to_dict(),
   1026                         # For testing
   1027                         spec_paths_tmp.to_dict(),
   1028                         diff_paths_tmp.to_dict(),
   1029                         scat_paths_tmp.to_dict() )
   1030         else:
   1031             output = (sources, targets, all_paths.to_dict())
   1032         return output
   1033 
   1034     ##################################################################
   1035     # Methods for finding candiate primitives and edges for reflected
   1036     # and diffracted paths
   1037     ##################################################################
   1038 
   1039     def _list_candidates_exhaustive(self, max_depth, los, reflection):
   1040         r"""
   1041         Generate all possible candidate paths made of reflections only and the
   1042         LoS.
   1043 
   1044         The number of candidate paths equals
   1045 
   1046             num_triangles**max_depth + 1
   1047 
   1048         where the additional path (+1) is the LoS.
   1049 
   1050         This can easily exhaust GPU memory if the number of triangles in the
   1051         scene or the `max_depth` are too large.
   1052 
   1053         Input
   1054         ------
   1055         max_depth: int
   1056             Maximum number of reflections.
   1057             Set to 0 for LoS only.
   1058 
   1059         los : bool
   1060             Set if the LoS paths are computed.
   1061 
   1062         reflection : bool
   1063             Set if the reflected paths are computed.
   1064 
   1065         Output
   1066         -------
   1067         candidates: [max_depth, num_samples], int
   1068             All possible candidate paths with depth up to ``max_depth``.
   1069             Entries correspond to primitives indices.
   1070             For paths with depth lower than ``max_depth``, -1 is used as
   1071             padding value.
   1072             The first path is the LoS one if LoS is requested.
   1073 
   1074         los_candidates: [num_samples], int or `None`
   1075             Candidates in LoS. For the exhaustive method, this is the list of
   1076             all candidates. `None` is returned if ``max_depth`` is 0 or for
   1077             empty scenes.
   1078         """
   1079         # Number of triangles
   1080         n_prims = self._primitives.shape[0]
   1081 
   1082         # List of all triangles
   1083         # [n_prims]
   1084         all_prims = tf.range(n_prims, dtype=tf.int32)
   1085 
   1086         # Empty scene or reflection disabled
   1087         if (not reflection) or (n_prims == 0):
   1088             if los:
   1089                 # Only LoS is added as candidate
   1090                 return tf.fill([0,1], -1), all_prims
   1091             else:
   1092                 # No candidates
   1093                 return tf.fill([0,0], -1), all_prims
   1094 
   1095         # If reflection is disabled,
   1096 
   1097         # Number of candidate paths made of reflections only
   1098         # num_samples = n_prims + n_prims^2 + ... + n_prims^max_depth
   1099         if n_prims == 0:
   1100             num_samples = 0
   1101         elif n_prims == 1:
   1102             num_samples = max_depth
   1103         else:
   1104             num_samples = (n_prims * (n_prims ** max_depth - 1))//(n_prims - 1)
   1105         # Add LoS path
   1106         if los:
   1107             num_samples += 1
   1108         # Tensor of all possible reflections
   1109         # Shape : [max_depth , num_samples]
   1110         # It is transposed to fit the expected output shape at the end of this
   1111         # function.
   1112         # all_candidates[i,j] correspond to the triangle index intersected
   1113         # by the i^th path for at j^th reflection.
   1114         # The first column corresponds to LoS, i.e., no interaction.
   1115         # -1 is used as padding value for path with depth lower than
   1116         # max_depth.
   1117         # Initialized with -1.
   1118         all_candidates = tf.fill([num_samples, max_depth], -1)
   1119         # The next loop fill all_candidates with the list of intersected
   1120         # primitives for all possible paths made of reflections only.
   1121         # It starts from the paths with the 1 reflection, up to max_depth.
   1122         # The variable `offset` corresponds to the index offset for storing the
   1123         # paths in all_candidates.
   1124         if los:
   1125             # `offset` is initialized to 1 as the first path (depth = 0)
   1126             # corresponds to LoS
   1127             offset = 1
   1128         else:
   1129             # No LoS, `offset` is initialized to 0
   1130             offset = 0
   1131         for depth in range(1, max_depth+1):
   1132             # Enumerate all possible interactions for this depth
   1133             # List of `depth` tensors with shape
   1134             # [n_prims, ..., n_prims] and rank `depth`
   1135             candidates = tf.meshgrid(*([all_prims] * depth), indexing='ij')
   1136 
   1137             # Reshape to
   1138             # [n_prims**depth,depth]
   1139             candidates = tf.stack([tf.reshape(c, [-1]) for c in candidates],
   1140                                     axis=1)
   1141 
   1142             # Pad with -1 for paths shorter than max_depth
   1143             # [n_prims**depth,max_depth]
   1144             candidates = tf.pad(candidates, [[0,0],[0,max_depth-depth]],
   1145                                 mode='CONSTANT', constant_values=-1)
   1146 
   1147             # Update all_candidates
   1148             # Number of candidate paths for this depth
   1149             num_candidates = candidates.shape[0]
   1150             # Corresponding row indices in the all_candidates tensor
   1151             indices = tf.range(offset, offset+num_candidates, dtype=tf.int32)
   1152             indices = tf.expand_dims(indices, -1)
   1153             # all_candidates : [max_depth , num_samples]
   1154             all_candidates = tf.tensor_scatter_nd_update(all_candidates,
   1155                                                          indices, candidates)
   1156 
   1157             # Prepare for next iteration
   1158             offset += num_candidates
   1159 
   1160         # Transpose to fit the expected output shape.
   1161         # [max_depth, num_samples]
   1162         all_candidates = tf.transpose(all_candidates)
   1163 
   1164         # Primitives in LoS
   1165         if max_depth > 0:
   1166             los_candidates = all_prims
   1167         else:
   1168             los_candidates = None
   1169 
   1170         return all_candidates, los_candidates
   1171 
   1172     def _list_candidates_fibonacci(self, max_depth, sources, num_samples,
   1173                                    los, reflection, scattering, ris_objects):
   1174         r"""
   1175         Generate potential candidate paths made of reflections only and the
   1176         LoS. Rays direction are arranged in a Fibonacci lattice on the unit
   1177         sphere.
   1178 
   1179         This can be used when the triangle count or maximum depth make the
   1180         exhaustive method impractical.
   1181 
   1182         A budget of ``num_samples`` rays is split equally over the given
   1183         sources. Starting directions are sampled uniformly at random.
   1184         Paths are simulated until the maximum depth is reached.
   1185         We record all sequences of primitives hit and the prefixes of these
   1186         sequences, and return unique sequences.
   1187 
   1188         Input
   1189         ------
   1190         max_depth: int
   1191             Maximum number of reflections.
   1192             Set to 0 for LoS only.
   1193 
   1194         sources : [num_sources, 3], tf.float
   1195             Coordinates of the sources.
   1196 
   1197         num_samples: int
   1198             Number of rays to trace in order to generate candidates.
   1199             A large sample count may exhaust GPU memory.
   1200 
   1201         los : bool
   1202             If set to `True`, then the LoS paths are computed.
   1203 
   1204         reflection : bool
   1205             If set to `True`, then the reflected paths are computed.
   1206 
   1207         scattering : bool
   1208             if set to `True`, then the scattered paths are computed
   1209 
   1210         ris_objects : list(mi.Rectangle)
   1211             List of Mitsuba rectangles implementing the RIS
   1212 
   1213         Output
   1214         -------
   1215         candidates_ref: [max_depth, num paths], int
   1216             Unique sequence of hitted primitives, with depth up to ``max_depth``.
   1217             Entries correspond to primitives indices.
   1218             For paths with depth lower than max_depth, -1 is used as
   1219             padding value.
   1220             The first path is the LoS one if LoS is requested.
   1221 
   1222         los_candidates: [num_samples], int or `None`
   1223             Primitives in LoS. `None` is returned if ``max_depth`` is 0.
   1224 
   1225         candidates_scat : [max_depth, num_sources, num_paths_per_source], int
   1226             Sequence of primitives hit at `hit_points`. Compared to
   1227             `candidates_ref`, it does not need to be unique, as the
   1228             intersection points are different for every sequence, and is
   1229             dependant on the source, as the intersection point are specific to
   1230             the sources positions.
   1231 
   1232         hit_points : [max_depth, num_sources, num_paths_per_source, 3], tf.float
   1233             Intersection points.
   1234         """
   1235         mask_t = dr.mask_t(self._mi_scalar_t)
   1236 
   1237         # Ensure that sample count can be distributed over the emitters
   1238         num_sources = sources.shape[0]
   1239         samples_per_source = int(dr.ceil(num_samples / num_sources))
   1240         num_samples = num_sources * samples_per_source
   1241 
   1242         # List of candidates
   1243         candidates = []
   1244 
   1245         # Hit points
   1246         hit_points = []
   1247 
   1248         # Is the scene empty?
   1249         is_empty = dr.shape(self._shape_indices)[0] == 0
   1250 
   1251         # Only shoot if the scene is not empty
   1252         if not is_empty:
   1253 
   1254             # Keep track of which paths are still active
   1255             active = dr.full(mask_t, True, num_samples)
   1256 
   1257             # Initial ray: Arranged in a Fibonacci lattice on the unit
   1258             # sphere.
   1259             # [samples_per_source, 3]
   1260             lattice = fibonacci_lattice(samples_per_source, self._rdtype)
   1261             sampled_d = tf.tile(lattice, [num_sources, 1])
   1262             sampled_d = self._mi_point2_t(sampled_d)
   1263             sampled_d = mi.warp.square_to_uniform_sphere(sampled_d)
   1264             source_i = dr.linspace(self._mi_scalar_t, 0, num_sources,
   1265                                    num=num_samples, endpoint=False)
   1266             source_i = mi.Int32(source_i)
   1267             sources_dr = self._mi_tensor_t(sources)
   1268             ray = mi.Ray3f(
   1269                 o=dr.gather(self._mi_vec_t, sources_dr.array, source_i),
   1270                 d=sampled_d,
   1271             )
   1272 
   1273             for depth in range(max_depth):
   1274 
   1275                 # Intersect ray against the scene to find the next hitted
   1276                 # primitive
   1277                 si = self._mi_scene.ray_intersect(ray, active)
   1278                 # Intersect with the RIS
   1279                 _, t_ris, _ = self._ris_intersect(ris_objects, ray, active)
   1280 
   1281                 # Intersection valid if not obstructed by RIS
   1282                 valid_int = si.is_valid() & (si.t < t_ris)
   1283 
   1284                 active &= valid_int
   1285 
   1286                 # Record which primitives were hit
   1287                 shape_i = dr.gather(mi.Int32, self._shape_indices,
   1288                                     dr.reinterpret_array_v(mi.UInt32, si.shape),
   1289                                     active)
   1290                 offsets = dr.gather(mi.Int32, self._prim_offsets, shape_i,
   1291                                     active)
   1292                 prims_i = dr.select(active, offsets + si.prim_index, -1)
   1293                 candidates.append(prims_i)
   1294 
   1295                 # Record the hit point
   1296                 hit_p = ray.o + si.t*ray.d
   1297                 hit_points.append(hit_p)
   1298 
   1299                 # Prepare the next interaction, assuming purely specular
   1300                 # reflection
   1301                 ray = si.spawn_ray(si.to_world(mi.reflect(si.wi)))
   1302 
   1303         # For diffraction, we need only primitives in LoS
   1304         # [num_los_primitives]
   1305         if len(candidates) > 0:
   1306             # max_depth > 0 or empty scene
   1307             los_primitives = tf.reshape(tf.cast(candidates[0], tf.int32), [-1])
   1308             los_primitives,_ = tf.unique(los_primitives)
   1309             los_primitives = tf.gather(los_primitives,
   1310                                        tf.where(los_primitives != -1)[:,0])
   1311         else:
   1312             # max_depth == 0
   1313             los_primitives = None
   1314 
   1315         reflection = reflection and (max_depth > 0) and (len(candidates) > 0)
   1316         scattering = scattering and (max_depth > 0) and (len(candidates) > 0)
   1317 
   1318         if scattering or reflection:
   1319             # Stack all found interactions along the depth dimension
   1320             # [max_depth, num_samples]
   1321             candidates = tf.stack([mi_to_tf_tensor(r, tf.int32)
   1322                                 for r in candidates], axis=0)
   1323 
   1324         if reflection:
   1325             # [max_depth, num_samples]
   1326             candidates_ref = candidates
   1327             # Compute the actual max_depth
   1328             # [max_depth]
   1329             useless_step = tf.reduce_all(tf.equal(candidates_ref, -1), axis=1)
   1330             # ()
   1331             max_depth_ref = tf.where(tf.reduce_any(useless_step),
   1332                                      tf.argmax(tf.cast(useless_step, tf.int32),
   1333                                         output_type=tf.int32),
   1334                                      max_depth)
   1335             # [max_depth, num_samples]
   1336             candidates_ref = candidates_ref[:max_depth_ref]
   1337         else:
   1338             # No candidates
   1339             candidates_ref = tf.fill([0, 0], -1)
   1340             max_depth_ref = 0
   1341 
   1342         if scattering:
   1343             # [max_depth, num_samples, 3]
   1344             hit_points = tf.stack([mi_to_tf_tensor(r, self._rdtype)
   1345                                 for r in hit_points])
   1346             # [max_depth, num_sources, samples_per_source, 3]
   1347             hit_points = tf.reshape(hit_points,
   1348                         [max_depth, num_sources, samples_per_source, 3])
   1349             # [max_depth, num_sources, samples_per_source]
   1350             candidates_scat = tf.reshape(candidates,
   1351                                 [max_depth, num_sources, samples_per_source])
   1352             # Flag indicating no hits
   1353             # [max_depth, num_sources, samples_per_source]
   1354             no_hit = tf.equal(candidates_scat, -1)
   1355             # Compute the actual max_depth
   1356             # [max_depth]
   1357             useless_step = tf.reduce_all(no_hit, axis=(1,2))
   1358             # ()
   1359             max_depth_scat = tf.where(tf.reduce_any(useless_step),
   1360                                       tf.argmax(tf.cast(useless_step, tf.int32),
   1361                                         output_type=tf.int32),
   1362                                     max_depth)
   1363             # [max_depth, num_sources, samples_per_source, 3]
   1364             hit_points = hit_points[:max_depth_scat]
   1365             # [max_depth, num_sources, samples_per_source]
   1366             candidates_scat = candidates_scat[:max_depth_scat]
   1367             # [max_depth, num_sources, samples_per_source]
   1368             no_hit = no_hit[:max_depth_scat]
   1369             # Remove useless paths
   1370             # [samples_per_source]
   1371             useful_samples = tf.logical_not(tf.reduce_all(no_hit, axis=(0,1)))
   1372             useful_samples_index = tf.where(useful_samples)[:,0]
   1373             # [max_depth, num_sources, num_paths_per_source, 3]
   1374             hit_points = tf.gather(hit_points, useful_samples_index, axis=2)
   1375             # [max_depth, num_sources, num_paths_per_source]
   1376             candidates_scat = tf.gather(candidates_scat, useful_samples_index,
   1377                                         axis=2)
   1378             # [max_depth, num_sources, num_paths_per_source]
   1379             no_hit = tf.gather(no_hit, useful_samples_index, axis=2)
   1380 
   1381             # Zero the hit masked points
   1382             # [max_depth, num_sources, num_paths, 3]
   1383             hit_points = tf.where(tf.expand_dims(no_hit, axis=-1),
   1384                                 tf.zeros_like(hit_points),
   1385                                 hit_points)
   1386         else:
   1387             # No hit points
   1388             hit_points = tf.fill([0, num_sources, 1, 3],
   1389                                  tf.cast(0., self._rdtype))
   1390             candidates_scat = tf.fill([0, num_sources, 1], False)
   1391             max_depth_scat = 0
   1392 
   1393         if ((not reflection) and (not scattering)):
   1394             max_depth = 0
   1395 
   1396         # Remove duplicates
   1397         candidates_ref, _ = tf.raw_ops.UniqueV2(
   1398             x=candidates_ref,
   1399             axis=[1]
   1400         )
   1401 
   1402         # Add line-of-sight to list of candidates for reflection if
   1403         # required
   1404         if los:
   1405             candidates_ref = tf.concat([tf.fill([max_depth_ref, 1], -1),
   1406                                         candidates_ref],
   1407                                        axis=1)
   1408         else:
   1409             # Ensure there is no LoS by removing all paths corresponding
   1410             # to no hits
   1411             # [num_samples]
   1412             is_nlos = tf.logical_not(tf.reduce_all(candidates_ref == -1,
   1413                                                    axis=0))
   1414             is_nlos_ind = tf.where(is_nlos)[:,0]
   1415             candidates_ref = tf.gather(candidates_ref, is_nlos_ind, axis=1)
   1416 
   1417         # The previous shoot and bounce process does not do next-event
   1418         # estimation, and continues to trace until max_depth reflections occurs
   1419         # or the ray does not intersect any primitive.
   1420         # Therefore, we extend the set of rays with the prefixes of all
   1421         # rays in `results_tf` to ensure we don't miss shorter paths than the
   1422         # ones found.
   1423         candidates_ref_ = [candidates_ref]
   1424         for depth in range(1, max_depth_ref):
   1425             # Extract prefix of length depth
   1426             # [depth, num_samples]
   1427             prefix = candidates_ref[:depth]
   1428             # Pad with -1, i.e., not intersection
   1429             # [max_depth, num_samples]
   1430             prefix = tf.pad(prefix, [[0, max_depth_ref-depth], [0,0]],
   1431                             constant_values=-1)
   1432             # Add to the list of rays
   1433             candidates_ref_.insert(0, prefix)
   1434         # [max_depth, num_samples]
   1435         candidates_ref = tf.concat(candidates_ref_, axis=1)
   1436 
   1437         # Extending the rays with prefixes might have created duplicates.
   1438         # Remove duplicates
   1439         if candidates_ref.shape[0] > 0:
   1440             candidates_ref, _ = tf.raw_ops.UniqueV2(
   1441                 x=candidates_ref,
   1442                 axis=[1]
   1443             )
   1444 
   1445         return candidates_ref, los_primitives, candidates_scat, hit_points
   1446 
   1447     ##################################################################
   1448     # Methods used for computing the specular paths
   1449     ##################################################################
   1450 
   1451     ### The following functions implement the image methods
   1452 
   1453     def _spec_image_method_phase_1(self, candidates, sources):
   1454         r"""
   1455         Implements the first phase of the image method.
   1456 
   1457         Starting from the sources, mirror each point against the
   1458         given candidate primitive. At this stage, we do not carry
   1459         any verification about the visibility of the ray.
   1460         Loop through the max_depth interactions. All candidate paths are
   1461         processed in parallel.
   1462 
   1463         Input
   1464         ------
   1465         candidates: [max_depth, num_samples], tf.int
   1466             Set of candidate paths with depth up to ``max_depth``.
   1467             For paths with depth lower than ``max_depth``, -1 must be used as
   1468             padding value.
   1469             The first path is the LoS one if LoS is requested.
   1470 
   1471         sources : [num_sources, 3], tf.float
   1472             Positions of the sources from which rays (paths) are emitted
   1473 
   1474         Output
   1475         -------
   1476         mirrored_vertices : [max_depth, num_sources, num_samples, 3], tf.float
   1477             Mirrored points coordinates
   1478 
   1479         tri_p0 : [max_depth, num_sources, num_samples, 3], tf.float
   1480             Coordinates of the first vertex of potentially hitted triangles
   1481 
   1482         normals : [max_depth, num_sources, num_samples, 3], tf.float
   1483             Normals to the potentially hitted triangles
   1484         """
   1485 
   1486         # Max depth
   1487         max_depth = candidates.shape[0]
   1488 
   1489         # Number of candidates
   1490         num_samples = tf.shape(candidates)[1]
   1491 
   1492         # Number of sources and number of receivers
   1493         num_sources = len(sources)
   1494 
   1495         # Sturctures are filled by the following loop
   1496         # Indicates if a path is discarded
   1497         # [num_samples]
   1498         valid = tf.fill([num_samples], True)
   1499         # Coordinates of the first vertex of potentially hitted triangles
   1500         # [max_depth, num_sources, num_samples, 3]
   1501         tri_p0 = tf.zeros([max_depth, num_sources, num_samples, 3],
   1502                             dtype=self._rdtype)
   1503         # Coordinates of the mirrored vertices
   1504         # [max_depth, num_sources, num_samples, 3]
   1505         mirrored_vertices = tf.zeros([max_depth, num_sources, num_samples, 3],
   1506                                         dtype=self._rdtype)
   1507         # Normals to the potentially hitted triangles
   1508         # [max_depth, num_sources, num_samples, 3]
   1509         normals = tf.zeros([max_depth, num_sources, num_samples, 3],
   1510                            dtype=self._rdtype)
   1511 
   1512         # Position of the last interaction.
   1513         # It is initialized with the sources position
   1514         # Add an additional dimension for broadcasting with the paths
   1515         # [num_sources, 1, xyz : 1]
   1516         current = tf.expand_dims(sources, axis=1)
   1517         current = tf.tile(current, [1, num_samples, 1])
   1518         # Index of the last hit primitive
   1519         prev_prim_idx = tf.fill([num_samples], -1)
   1520         if max_depth > 0:
   1521             for depth in tf.range(max_depth):
   1522 
   1523                 # Primitive indices with which paths interact at this depth
   1524                 # [num_samples]
   1525                 prim_idx = tf.gather(candidates, depth, axis=0)
   1526 
   1527                 # Flag indicating which paths are still active, i.e., should be
   1528                 # tested.
   1529                 # Paths that are shorter than depth are marked as inactive
   1530                 # [num_samples]
   1531                 active = tf.not_equal(prim_idx, -1)
   1532 
   1533                 # Break the loop if no active paths
   1534                 # Could happen with empty scenes, where we have only LoS
   1535                 if tf.logical_not(tf.reduce_any(active)):
   1536                     break
   1537 
   1538                 # Eliminate paths that go through the same prim twice in a row
   1539                 # [num_samples]
   1540                 valid = tf.logical_and(
   1541                     valid,
   1542                     tf.logical_or(~active, tf.not_equal(prim_idx,prev_prim_idx))
   1543                 )
   1544 
   1545                 # On CPU, indexing with -1 does not work. Hence we replace -1
   1546                 # by 0.
   1547                 # This makes no difference on the resulting paths as such paths
   1548                 # are not flagged as active.
   1549                 # valid_prim_idx = prim_idx
   1550                 valid_prim_idx = tf.where(prim_idx == -1, 0, prim_idx)
   1551 
   1552                 # Mirroring of the current point with respected to the
   1553                 # potentially hitted triangle.
   1554                 # We need the coordinate of the first vertex of the potentially
   1555                 # hitted triangle.
   1556                 # To get this, we build the indexing tensor to gather only the
   1557                 # coordinate of the first index
   1558                 # [[num_samples, 1]]
   1559                 p0_index = tf.expand_dims(valid_prim_idx, axis=1)
   1560                 p0_index = tf.pad(p0_index, [[0,0], [0,1]], mode='CONSTANT',
   1561                                     constant_values=0) # First vertex
   1562                 # [num_samples, xyz : 3]
   1563                 p0 = tf.gather_nd(self._primitives, p0_index)
   1564                 # Expand rank and tile to broadcast with the number of
   1565                 # transmitters
   1566                 # [num_sources, num_samples, xyz : 3]
   1567                 p0 = tf.expand_dims(p0, axis=0)
   1568                 p0 = tf.tile(p0, [num_sources, 1, 1])
   1569                 # Gather normals to potentially intersected triangles
   1570                 # [num_samples, xyz : 3]
   1571                 normal = tf.gather(self._normals, valid_prim_idx)
   1572                 # Expand rank and tile to broadcast with the number of
   1573                 # transmitters
   1574                 # [1, num_samples, xyz : 3]
   1575                 normal = tf.expand_dims(normal, axis=0)
   1576                 normal = tf.tile(normal, [num_sources, 1, 1])
   1577 
   1578                 # Distance between the current intersection point (or sources)
   1579                 # and the plane the triangle is part of.
   1580                 # Note: `dist` is signed to compensate for backfacing normals
   1581                 # whenn needed.
   1582                 # [num_sources, num_samples, 1]
   1583                 dist = dot(current, normal, keepdim=True)\
   1584                             - dot(p0, normal, keepdim=True)
   1585                 # Coordinates of the mirrored point
   1586                 # [num_sources, num_samples, xyz : 3]
   1587                 mirrored = current - 2. * dist * normal
   1588 
   1589                 # Store these results
   1590                 # [max_depth, num_sources, num_samples, 3]
   1591                 mirrored_vertices = tf.tensor_scatter_nd_update(
   1592                                     mirrored_vertices, [[depth]], [mirrored])
   1593                 # [max_depth, num_sources, num_samples, 3]
   1594                 tri_p0 = tf.tensor_scatter_nd_update(tri_p0, [[depth]], [p0])
   1595                 # [max_depth, num_sources, num_samples, 3]
   1596                 normals = tf.tensor_scatter_nd_update(normals,
   1597                                                       [[depth]], [normal])
   1598 
   1599 
   1600                 # Prepare for the next interaction
   1601                 # [num_sources, num_samples, xyz : 3]
   1602                 current = mirrored
   1603                 # [num_samples]
   1604                 prev_prim_idx = prim_idx
   1605 
   1606         return mirrored_vertices, tri_p0, normals
   1607 
   1608     def _spec_image_method_phase_21(self, depth, candidates, valid,
   1609                                     mirrored_vertices, tri_p0, normals, current,
   1610                                     num_targets, num_sources):
   1611         # pylint: disable=line-too-long
   1612         r"""
   1613         Implement the first part of phase 2 of the image method:
   1614 
   1615         For a given ``depth``:
   1616         - Computes the intersection point with the ``depth``th primitive of the
   1617         sequence of candidates for a ray originating from ``current``
   1618         - Checks that the intersection point is within the primitive
   1619         - Ensures the normal points toward the ``current`` point
   1620         - Prepares the ray to test for blockage between ``current`` point and
   1621         thecomputed intersection point
   1622 
   1623         The obstruction test is note performed in this function as it uses
   1624         Mitsuba.
   1625 
   1626         Input
   1627         -----
   1628         depth : int
   1629             Current interaction number
   1630 
   1631         candidates: [max_depth, num_samples], tf.int
   1632             Set of candidate paths with depth up to ``max_depth``.
   1633             For paths with depth lower than ``max_depth``, -1 must be used as
   1634             padding value.
   1635             The first path is the LoS one if LoS is requested.
   1636 
   1637         valid : [num_targets, num_sources, num_samples], tf.bool
   1638             Mask indicating the valid paths
   1639 
   1640         mirrored_vertices : [max_depth, num_sources, num_samples, 3], tf.float
   1641             Mirrored points
   1642 
   1643         tri_p0 : [max_depth, num_sources, num_samples, 3], tf.float
   1644             Coordinates of the first vertex of potentially hitted triangles
   1645 
   1646         normals : [max_depth, num_sources, num_samples, 3], tf.float
   1647             Normals to the potentially hitted triangles
   1648 
   1649         current : [num_targets, 1, 1, xyz : 3], tf.float
   1650             Positions of the last interactions
   1651 
   1652         num_targets : int
   1653             Number of targets
   1654 
   1655         num_sources : int
   1656             Number of sources
   1657 
   1658         Output
   1659         ------
   1660         valid : [num_targets, num_sources, num_samples], tf.bool
   1661             Mask indicating the valid paths
   1662 
   1663         current : [num_targets, num_sources, num_samples, 3], tf.float
   1664             Positions of the last interactions
   1665 
   1666         p : [num_targets, num_sources, num_samples, 3], tf.float
   1667             Intersection point on the ``depth`` primitive
   1668 
   1669         n : [num_targets, num_sources, num_samples, 3], tf.float
   1670             Normals to the primitive at the ``depth`` intersection point
   1671 
   1672         maxt : [num_targets, num_sources, num_samples], tf.float
   1673             Distance from current to intersection point
   1674 
   1675         d : [num_targets, num_sources, num_samples, 3], tf.float
   1676             Ray direction to test for blockage between ``curent`` and the
   1677             intersection point
   1678 
   1679         active : [num_samples], tf.bool
   1680             Mask indicating paths that are not active, i.e., didn't start yet
   1681         """
   1682 
   1683         # Number of candidates at this stage
   1684         num_samples = tf.shape(candidates)[1]
   1685 
   1686         # Primitive indices with which paths interact at this depth
   1687         # [num_samples]
   1688         prim_idx = tf.gather(candidates, depth, axis=0)
   1689 
   1690         # Next mirrored point
   1691         # [num_sources, num_samples, 3]
   1692         next_pos = tf.gather(mirrored_vertices, depth, axis=0)
   1693         # Expand rank for broadcasting
   1694         # [1, num_sources, num_samples, 3]
   1695 
   1696         # Since paths can have different depths, we have to mask out paths
   1697         # that have not started yet.
   1698         # [num_samples]
   1699         active = tf.not_equal(prim_idx, -1)
   1700 
   1701         # Expand rank to broadcast with receivers and transmitters
   1702         # [1, 1, num_samples]
   1703         active = expand_to_rank(active, 3, axis=0)
   1704 
   1705         # Invalid paths are marked as inactive
   1706         # [num_targets, num_sources, num_samples]
   1707         active = tf.logical_and(active, valid)
   1708 
   1709         # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   1710         # This makes no difference on the resulting paths as such paths
   1711         # are not flagged as active.
   1712         # valid_prim_idx = prim_idx
   1713         valid_prim_idx = tf.where(prim_idx == -1, 0, prim_idx)
   1714 
   1715         # Trace a direct line from the current position to the next path
   1716         # vertex.
   1717 
   1718         # Ray direction
   1719         # [num_targets, num_sources, num_samples, 3]
   1720         d,_ = normalize(next_pos - current)
   1721 
   1722         # Find where it intersects the primitive that we mirrored against.
   1723         # If that falls out of the primitive, this whole path is invalid.
   1724 
   1725         # Vertices forming the triangle.
   1726         # [num_sources, num_samples, xyz : 3]
   1727         p0 = tf.gather(tri_p0, depth, axis=0)
   1728         # Expand rank to broadcast with the target dimension
   1729         # [1, num_sources, num_samples, xyz : 3]
   1730         p0 = tf.expand_dims(p0, axis=0)
   1731         # Build the indexing tensor to gather only the coordinate of the
   1732         # second index
   1733         # [[num_samples, 1]]
   1734         p1_index = tf.expand_dims(valid_prim_idx, axis=1)
   1735         p1_index = tf.pad(p1_index, [[0,0], [0,1]], mode='CONSTANT',
   1736                             constant_values=1) # Second vertex
   1737         # [num_samples, xyz : 3]
   1738         p1 = tf.gather_nd(self._primitives, p1_index)
   1739         # Expand rank to broadcast with the target and sources
   1740         # dimensions
   1741         # [1, 1, num_samples, xyz : 3]
   1742         p1 = expand_to_rank(p1, 4, axis=0)
   1743         # Build the indexing tensor to gather only the coordinate of the
   1744         # third index
   1745         # [[num_samples, 1]]
   1746         p2_index = tf.expand_dims(valid_prim_idx, axis=1)
   1747         p2_index = tf.pad(p2_index, [[0,0], [0,1]], mode='CONSTANT',
   1748                             constant_values=2) # Third vertex
   1749         # [num_samples, xyz : 3]
   1750         p2 = tf.gather_nd(self._primitives, p2_index)
   1751         # Expand rank to broadcast with the target and sources
   1752         # dimensions
   1753         # [1, 1, num_samples, xyz : 3]
   1754         p2 = expand_to_rank(p2, 4, axis=0)
   1755         # Intersection test.
   1756         # We use the Moeller Trumbore algorithm
   1757         # t : [num_targets, num_sources, num_samples]
   1758         # hit : [num_targets, num_sources, num_samples]
   1759         t, hit = moller_trumbore(current, d, p0, p1, p2, SolverBase.EPSILON)
   1760         # [num_targets, num_sources, num_samples]
   1761         valid = tf.logical_and(valid, tf.logical_or(~active, hit))
   1762 
   1763         # Force normal to point towards our current position
   1764         # [num_sources, num_samples, 3]
   1765         n = tf.gather(normals, depth, axis=0)
   1766         # Add dimension for broadcasting with receivers
   1767         # [1, num_sources, num_samples, 3]
   1768         n = tf.expand_dims(n, axis=0)
   1769         # Force to point towards current position
   1770         # [num_targets, num_sources, num_samples, 3]
   1771         s = tf.sign(dot(n, current-p0, keepdim=True))
   1772         n = n * s
   1773         # Intersection point
   1774         # [num_targets, num_sources, num_samples, 3]
   1775         t = tf.expand_dims(t, axis=3)
   1776         p = current + t*d
   1777 
   1778         # Prepare obstruction test.
   1779         # There should be no obstruction between the actual
   1780         # interaction point and the current point.
   1781         # We use Mitsuba to test for obstruction efficiently.
   1782         # We only compute here the origin and direction of the ray
   1783 
   1784         # Ensure current is already broadcasted
   1785         # [num_targets, num_sources, num_samples, 3]
   1786         current = tf.broadcast_to(current, [num_targets, num_sources,
   1787                                             num_samples, 3])
   1788         # Distance from current to intersection point
   1789         # [num_targets, num_sources, num_samples]
   1790         maxt = tf.norm(current - p, axis=-1)
   1791 
   1792         output = (
   1793             valid,
   1794             current,
   1795             p,
   1796             n,
   1797             maxt,
   1798             d,
   1799             active
   1800         )
   1801         return output
   1802 
   1803     def _spec_image_method_phase_22(self, depth, valid, mirrored_vertices,
   1804                     current, blk, num_targets, num_sources, maxt, p, active):
   1805         r"""
   1806         Implement the second part of phase 2 of the image method:
   1807 
   1808         - Discards paths that are blocked
   1809         - Discards paths for which ``current`` point and the ``next_pos`` are
   1810             not on the same side, as this would mean that the path is going
   1811             through the surface
   1812 
   1813         Input
   1814         -----
   1815         depth : int
   1816             Current interaction number
   1817 
   1818         valid : [num_targets, num_sources, num_samples], tf.bool
   1819             Mask indicating the valid paths
   1820 
   1821         mirrored_vertices : [max_depth, num_sources, num_samples, 3], tf.float
   1822             Mirrored points
   1823 
   1824         current : [num_targets, num_sources, num_samples, 3], tf.float
   1825             Positions of the last interactions
   1826 
   1827         blk : [num_targets*num_sources*num_samples], tf.bool
   1828             Mask indicating which blocked paths
   1829 
   1830         num_targets : int
   1831             Number of targets
   1832 
   1833         num_sources : int
   1834             Number of sources
   1835 
   1836         maxt : [num_targets, num_sources, num_samples], tf.float
   1837             Distance from current to intersection point
   1838 
   1839         p : [num_targets, num_sources, num_samples, 3], tf.float
   1840             Intersection point on the ``depth`` primitive
   1841 
   1842         active : [num_targets, num_sources, num_samples], tf.bool
   1843             Flag indicating which paths are active
   1844 
   1845         Output
   1846         ------
   1847         valid : [num_targets, num_sources, num_samples], tf.bool
   1848             Mask indicating the valid paths
   1849 
   1850         current : [num_targets, num_sources, num_samples, 3], tf.float
   1851             Positions of the last interactions
   1852         """
   1853 
   1854         # Number of candidates at this stage
   1855         num_samples = tf.shape(valid)[2]
   1856 
   1857         # Next mirrored point
   1858         # [num_sources, num_samples, 3]
   1859         next_pos = tf.gather(mirrored_vertices, depth, axis=0)
   1860         # Expand rank for broadcasting
   1861         # [1, num_sources, num_samples, 3]
   1862         next_pos = tf.expand_dims(next_pos, axis=0)
   1863 
   1864         # Discard paths if blocked
   1865         # [num_targets, num_sources, num_samples]
   1866         blk = tf.reshape(blk, [num_targets, num_sources, num_samples])
   1867         valid = tf.logical_and(valid, tf.logical_or(~active, ~blk))
   1868 
   1869         # Discard paths for which the shooted ray has zero-length, i.e.,
   1870         # when two consecutive intersection points have the same location,
   1871         # or when the source and target have the same locations (RADAR).
   1872         # [num_targets, num_sources, num_samples]
   1873         blk = tf.less(maxt, SolverBase.EPSILON)
   1874         # [num_targets, num_sources, num_samples]
   1875         valid = tf.logical_and(valid, tf.logical_or(~active, ~blk))
   1876 
   1877         # We must also ensure that the current point and the next_pos are
   1878         # not on the same side, as this would mean that the path is going
   1879         # through the surface
   1880 
   1881         # Vector from the intersection point to the current point
   1882         # [num_targets, num_sources, num_samples, 3]
   1883         v1 = current - p
   1884         # Vector from the intersection point to the next point
   1885         # [num_targets, num_sources, num_samples, 3]
   1886         v2 = next_pos - p
   1887         # Compute the scalar product. It must be negative, as we are using
   1888         # the image (next_pos)
   1889         # [num_targets, num_sources, num_samples]
   1890         blk = dot(v1, v2)
   1891         blk = tf.greater_equal(blk, tf.zeros_like(blk))
   1892         valid = tf.logical_and(valid, tf.logical_or(~active, ~blk))
   1893 
   1894         # Update active state
   1895         # [num_targets, num_sources, num_samples]
   1896         active = tf.logical_and(active, valid)
   1897         # Prepare for next path segment
   1898         # [num_targets, num_sources, num_samples, 3]
   1899         current = tf.where(tf.expand_dims(active, axis=-1), p, current)
   1900 
   1901         output = (
   1902             valid,
   1903             current
   1904         )
   1905         return output
   1906 
   1907     def _spec_image_method_phase_23(self, current, sources, num_targets):
   1908         r"""
   1909         Implements the third step of phase 2 of the image method.
   1910 
   1911         Prepares the rays for testing blockage between the last interaction
   1912         point and the sources.
   1913 
   1914         Input
   1915         ------
   1916         current : [num_targets, num_sources, num_samples, 3], tf.float
   1917             Positions of the last interactions
   1918 
   1919         sources : [num_sources, 3], tf.float
   1920             Sources from which rays (paths) are emitted
   1921 
   1922         num_targets : int
   1923             Number of targets
   1924 
   1925         Output
   1926         ------
   1927         current : [num_targets, num_sources, num_samples, 3], tf.float
   1928             Positions of the last interactions
   1929 
   1930         d : [num_targets, num_sources, num_samples, 3], tf.float
   1931             Ray direction between the last interaction point and the sources
   1932 
   1933         maxt : [num_targets, num_sources, num_samples], tf.float
   1934             Distances between the last interaction point and the sources
   1935         """
   1936 
   1937         # Number of candidates at this stage
   1938         num_samples = tf.shape(current)[2]
   1939 
   1940         num_sources = tf.shape(sources)[0]
   1941 
   1942         # Check visibility to the transmitters
   1943         # [1, num_sources, 1, 3]
   1944         sources_ = tf.expand_dims(tf.expand_dims(sources, axis=0),
   1945                                         axis=2)
   1946         # Direction vector and distance to the transmitters
   1947         # d : [num_targets, num_sources, num_samples, 3]
   1948         # maxt : [num_targets, num_sources, num_samples]
   1949         d,maxt = normalize(sources_ - current)
   1950         # Ensure current is already broadcasted
   1951         # [num_targets, num_sources, num_samples, 3]
   1952         current = tf.broadcast_to(current, [num_targets, num_sources,
   1953                                             num_samples, 3])
   1954         d = tf.broadcast_to(d, [num_targets, num_sources, num_samples, 3])
   1955         maxt = tf.broadcast_to(maxt, [num_targets, num_sources, num_samples])
   1956 
   1957         return current, d, maxt
   1958 
   1959     def _spec_image_method_phase_3(self, candidates, valid, num_targets,
   1960                                 num_sources, path_vertices, path_normals, blk):
   1961         # pylint: disable=line-too-long
   1962         r"""
   1963         Implements the third phase of the image method.
   1964 
   1965         Post-process the valid paths, from transmitters to receivers, to put
   1966         them in the expected output format.
   1967 
   1968         Input
   1969         -----
   1970         candidates: [max_depth, num_samples], tf.int
   1971             Set of candidate paths with depth up to ``max_depth``.
   1972             For paths with depth lower than ``max_depth``, -1 must be used as
   1973             padding value.
   1974             The first path is the LoS one if LoS is requested.
   1975 
   1976         valid : [num_targets, num_sources, num_samples], tf.bool
   1977             Mask indicating the valid paths
   1978 
   1979         num_targets : int
   1980             Number of targets
   1981 
   1982         num_sources : int
   1983             Number of sources
   1984 
   1985         path_vertices : [max_depth, num_targets, num_sources, num_samples, xyz : 3]
   1986             Positions of the intersection points
   1987 
   1988         path_normals : [max_depth, num_targets, num_sources, num_samples, 3]
   1989             Normals to the surface at the intersection points
   1990 
   1991         blk : [num_targets*num_sources*num_samples], tf.bool
   1992             Mask indicating which blocked paths
   1993 
   1994         Output
   1995         ------
   1996         mask : [num_targets, num_sources, max_num_paths], tf.bool
   1997              Mask indicating if a path is valid
   1998 
   1999         valid_vertices : [max_depth, num_targets, num_sources, max_num_paths, 3], tf.float
   2000             Positions of intersection points.
   2001 
   2002         valid_objects : [max_depth, num_targets, num_sources, max_num_paths], tf.int
   2003             Indices of the intersected scene objects or wedges.
   2004             Paths with depth lower than ``max_depth`` are padded with `-1`.
   2005 
   2006         valid_normals : [max_depth, num_targets, num_sources, max_num_paths, 3], tf.float
   2007             Normals to the primitives at the intersection points.
   2008         """
   2009 
   2010         # Discard blocked paths
   2011         # [num_targets, num_sources, num_samples]
   2012         valid = tf.logical_and(valid, ~blk)
   2013 
   2014         # Max depth
   2015         max_depth = candidates.shape[0]
   2016 
   2017         # If at least one link has the LoS paths flagged as valid,
   2018         # then we keep the LoS paths for all links.
   2019         # This makes the tracking of paths types easier.
   2020         # The LoS paths will be masked for links for which it is obstructed.
   2021         # Note that there is only one entry that correspond to LoS paths
   2022         # [1, num_samples]
   2023         is_los = tf.reduce_all(tf.equal(candidates, -1), axis=0, keepdims=True)
   2024         # Only keep the LoS path if it is valid (i.e., not obstructed) for at
   2025         # least one link
   2026         # [1, 1, num_samples]
   2027         is_los = tf.expand_dims(is_los, 0)
   2028         # [1, 1, num_samples]
   2029         is_los = tf.reduce_any(tf.logical_and(is_los, valid), axis=(0,1),
   2030                                keepdims=True)
   2031 
   2032         # Build indices for keeping only valid path
   2033         # A path is kept if its valid or the LoS and there is at least one link
   2034         # for which LoS is not obstructed
   2035         # [num_targets, num_sources, num_samples]
   2036         keep = tf.logical_or(valid, is_los)
   2037         # [num_targets, num_sources]
   2038         num_paths = tf.reduce_sum(tf.cast(keep, tf.int32), axis=-1)
   2039         # Maximum number of paths
   2040         # ()
   2041         max_num_paths = tf.reduce_max(num_paths)
   2042         # [num_valid, 3]
   2043         gather_indices = tf.where(keep)
   2044         # [num_targets, num_sources, num_samples]
   2045         path_indices = tf.cumsum(tf.cast(keep, tf.int32), axis=-1)
   2046         # [num_valid]
   2047         path_indices = tf.gather_nd(path_indices, gather_indices) - 1
   2048         # [3, num_valid]
   2049         scatter_indices = tf.transpose(gather_indices, [1,0])
   2050         if not tf.size(scatter_indices) == 0:
   2051             # [3, num_valid]
   2052             scatter_indices = tf.tensor_scatter_nd_update(scatter_indices,
   2053                                 [[2]], [path_indices])
   2054         # [num_valid, 3]
   2055         scatter_indices = tf.transpose(scatter_indices, [1,0])
   2056 
   2057         # Mask of valid paths
   2058         # [num_targets, num_sources, max_num_paths]
   2059         mask = tf.fill([num_targets, num_sources, max_num_paths], False)
   2060         # [num_keep_paths]
   2061         mask_ = tf.gather_nd(valid, gather_indices)
   2062         # [num_targets, num_sources, max_num_paths]
   2063         mask = tf.tensor_scatter_nd_update(mask, scatter_indices, mask_)
   2064 
   2065         # Locations of the interactions
   2066         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   2067         valid_vertices = tf.zeros([max_depth, num_targets, num_sources,
   2068                                     max_num_paths, 3], dtype=self._rdtype)
   2069         # Normals at the intersection points
   2070         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   2071         valid_normals = tf.zeros([max_depth, num_targets, num_sources,
   2072                                     max_num_paths, 3], dtype=self._rdtype)
   2073         # [max_depth, num_targets, num_sources, max_num_paths]
   2074         valid_primitives = tf.fill([max_depth, num_targets, num_sources,
   2075                                         max_num_paths], -1)
   2076 
   2077         if max_depth > 0:
   2078 
   2079             for depth in tf.range(max_depth, dtype=tf.int64):
   2080 
   2081                 # Indices for storing the valid vertices/normals/primitives for
   2082                 # this depth
   2083                 scatter_indices_ = tf.pad(scatter_indices, [[0,0], [1,0]],
   2084                                 mode='CONSTANT', constant_values=depth)
   2085 
   2086                 # Loaction of the interactions
   2087                 # Extract only the valid paths
   2088                 # [num_targets, num_sources, num_samples, 3]
   2089                 vertices_ = tf.gather(path_vertices, depth, axis=0)
   2090                 # [total_num_valid_paths, 3]
   2091                 vertices_ = tf.gather_nd(vertices_, gather_indices)
   2092                 # Store the valid intersection points
   2093                 # [max_depth, num_targets, num_sources, max_num_paths, 3]
   2094                 valid_vertices = tf.tensor_scatter_nd_update(valid_vertices,
   2095                                                 scatter_indices_, vertices_)
   2096 
   2097                 # Normals at the interactions
   2098                 # Extract only the valid paths
   2099                 # [num_targets, num_sources, num_samples, 3]
   2100                 normals_ = tf.gather(path_normals, depth, axis=0)
   2101                 # [total_num_valid_paths, 3]
   2102                 normals_ = tf.gather_nd(normals_, gather_indices)
   2103                 # Store the valid normals
   2104                 # [max_depth, num_targets, num_sources, max_num_paths, 3]
   2105                 valid_normals = tf.tensor_scatter_nd_update(valid_normals,
   2106                                         scatter_indices_, normals_)
   2107 
   2108                 # Intersected primitives
   2109                 # Extract only the valid paths
   2110                 # [num_samples]
   2111                 primitives_ = tf.gather(candidates, depth, axis=0)
   2112                 # [total_num_valid_paths]
   2113                 primitives_ = tf.gather(primitives_, gather_indices[:,2])
   2114                 # Store the valid primitives]
   2115                 # [max_depth, num_targets, num_sources, max_num_paths]
   2116                 valid_primitives = tf.tensor_scatter_nd_update(valid_primitives,
   2117                                         scatter_indices_, primitives_)
   2118 
   2119         # Add a dummy entry to primitives_2_objects with value -1 for invalid
   2120         # reflection.
   2121         # Invalid reflection, i.e., corresponding to paths with a depth lower
   2122         # than max_depth, will be assigned -1 as index of the intersected
   2123         # shape.
   2124         # [num_samples + 1]
   2125         primitives_2_objects = tf.pad(self._primitives_2_objects, [[0,1]],
   2126                                         constant_values=-1)
   2127         # Replace all -1 by num_samples
   2128         num_samples = tf.shape(self._primitives_2_objects)[0]
   2129         # [max_depth, num_targets, num_sources, max_num_paths]
   2130         valid_primitives = tf.where(tf.equal(valid_primitives,-1),
   2131                                     num_samples,
   2132                                     valid_primitives)
   2133         # [max_depth, num_targets, num_sources, max_num_paths]
   2134         valid_objects = tf.gather(primitives_2_objects, valid_primitives)
   2135 
   2136         # Actual maximum depth
   2137         if max_depth > 0:
   2138             # Limit the depth to the actual max_depth
   2139             # [max_depth]
   2140             useless_depth = tf.reduce_all(tf.equal(valid_objects, -1),
   2141                                           axis=(1,2,3))
   2142 
   2143             max_depth = tf.where(tf.reduce_any(useless_depth),
   2144                                 tf.argmax(tf.cast(useless_depth, tf.int32),
   2145                                           output_type=tf.int32),
   2146                                 max_depth)
   2147             max_depth = tf.maximum(max_depth, 1)
   2148             # [max_depth, num_targets, num_sources, max_num_paths, 3]
   2149             valid_vertices = valid_vertices[:max_depth]
   2150             # [max_depth, num_targets, num_sources, max_num_paths, 3]
   2151             valid_normals = valid_normals[:max_depth]
   2152             # [max_depth, num_targets, num_sources, max_num_paths]
   2153             valid_objects = valid_objects[:max_depth]
   2154 
   2155         return mask, valid_vertices, valid_objects, valid_normals
   2156 
   2157     def _spec_image_method(self, candidates, paths, spec_paths_tmp,
   2158                            ris_objects):
   2159         # pylint: disable=line-too-long
   2160         r"""
   2161         Evaluates a list of candidate paths ``candidates`` and keep only the
   2162         valid ones, i.e., the non-obstricted ones with valid reflections only,
   2163         using the image method.
   2164 
   2165         Input
   2166         -----
   2167         candidates: [max_depth, num_samples], tf.int
   2168             Set of candidate paths with depth up to ``max_depth``.
   2169             For paths with depth lower than ``max_depth``, -1 must be used as
   2170             padding value.
   2171             The first path is the LoS one if LoS is requested.
   2172 
   2173         paths : :class:`~sionna.rt.Paths`
   2174             Paths to update
   2175 
   2176         ris_objects : list(mi.Rectangle)
   2177             List of Mitsuba rectangles implementing the RIS
   2178         """
   2179 
   2180         sources = paths.sources
   2181         targets = paths.targets
   2182 
   2183         # Max depth
   2184         max_depth = candidates.shape[0]
   2185 
   2186         # Number of sources and number of receivers
   2187         num_sources = len(sources)
   2188         num_targets = len(targets)
   2189 
   2190         # --- Phase 1
   2191         # Starting from the sources, mirror each point against the
   2192         # given candidate primitive. At this stage, we do not carry
   2193         # any verification about the visibility of the ray.
   2194         # Loop through the max_depth interactions. All candidate paths are
   2195         # processed in parallel.
   2196         #
   2197         # mirrored_vertices : [max_depth, num_sources, num_samples, 3], tf.float
   2198         #     Mirrored points coordinates
   2199         #
   2200         # tri_p0 : [max_depth, num_sources, num_samples, 3], tf.float
   2201         #     Coordinates of the first vertex of potentially hitted triangles
   2202         #
   2203         # normals : [max_depth, num_sources, num_samples, 3], tf.float
   2204         #     Normals to the potentially hitted triangles
   2205         mirrored_vertices, tri_p0, normals =\
   2206             self._spec_image_method_phase_1(candidates, sources)
   2207 
   2208         # --- Phase 2
   2209 
   2210         # Number of candidates at this stage
   2211         num_samples = candidates.shape[1]
   2212 
   2213         # Starting from the receivers, go over the vertices in reverse
   2214         # and check that connections are possible.
   2215 
   2216         # Mask indicating which paths are valid
   2217         # [num_targets, num_sources, num_samples]
   2218         valid = tf.fill([num_targets, num_sources, num_samples], True)
   2219         # Positions of the last interactions.
   2220         # Initialized with the positions of the receivers.
   2221         # Add two additional dimensions for broadcasting with transmitters and
   2222         # paths.
   2223         # [num_targets, 1, 1, xyz : 3]
   2224         current = expand_to_rank(targets, 4, axis=1)
   2225         # Positions of the interactions.
   2226         # [max_depth, num_targets, num_sources, num_samples, xyz : 3]
   2227         # path_vertices = tf.zeros([max_depth, num_targets, num_sources,
   2228         #                             num_samples, 3], dtype=self._rdtype)
   2229         path_vertices = []
   2230         # Normals at the interactions.
   2231         # [max_depth, num_targets, num_sources, num_samples, xyz : 3]
   2232         path_normals = []
   2233         for depth in tf.range(max_depth-1, -1, -1):
   2234 
   2235             # The following call:
   2236             # - Computes the intersection point with the ``depth``th primitive
   2237             #   of the sequence of candidates for a ray originating from
   2238             #   ``current``
   2239             # - Checks that the intersection point is within the primitive
   2240             # - Ensures the normal points toward the ``current`` point
   2241             # - Prepares the ray to test for blockage between ``depth-1``th
   2242             # point and the current point
   2243             output = self._spec_image_method_phase_21(depth, candidates, valid,
   2244                 mirrored_vertices, tri_p0, normals, current, num_targets,
   2245                 num_sources)
   2246             # [num_targets, num_sources, num_samples]
   2247             #   Mask indicating the valid paths
   2248             valid = output[0]
   2249             # [num_targets, 1, 1, xyz : 3]
   2250             #   Positions of the last interactions
   2251             current = output[1]
   2252             # : [num_targets, num_sources, num_samples, 3], tf.float
   2253             #     Intersection point on the ``depth`` primitive
   2254             path_vertices_ = output[2]
   2255             # : [num_targets, num_sources, num_samples, 3], tf.float
   2256             #    Normals to the primitive at the ``depth`` intersection point
   2257             path_normals_ = output[3]
   2258             # maxt : [num_targets, num_sources, num_samples], tf.float
   2259             #     Distance from current to intersection point
   2260             maxt = output[4]
   2261             # d : [num_targets, num_sources, num_samples, 3], tf.float
   2262             #     Ray direction to test for blockage between ``curent`` and the
   2263             #     intersection point
   2264             d = output[5]
   2265             # [num_targets, num_sources, num_samples]
   2266             #   Mask indicating paths that are not active, i.e., didn't start
   2267             #   yet
   2268             active = output[6]
   2269 
   2270             # Test for obstruction using Mitsuba
   2271             # As Mitsuba only hanldes a single batch dimension, we flatten the
   2272             # batch dims [num_targets, num_sources, num_samples]
   2273             # [num_targets*num_sources*num_samples]
   2274             blk = self._test_obstruction(tf.reshape(current, [-1, 3]),
   2275                                          tf.reshape(d, [-1, 3]),
   2276                                          tf.reshape(maxt, [-1]),
   2277                                          ris_objects)
   2278 
   2279             # The following call:
   2280             # - Discards paths that are blocked
   2281             # - Discards paths for which ``current`` point and the ``next_pos``
   2282             #   are not on the same side, as this would mean that the path is
   2283             #   going through the surface
   2284             output = self._spec_image_method_phase_22(depth, valid,
   2285                 mirrored_vertices, current, blk, num_targets, num_sources, maxt,
   2286                 path_vertices_, active)
   2287             # [num_targets, num_sources, num_samples]
   2288             #   Mask indicating the valid paths
   2289             valid = output[0]
   2290             # [num_targets, num_sources, num_samples, xyz : 3]
   2291             #   Positions of the last interactions
   2292             current = output[1]
   2293 
   2294             path_vertices.append(path_vertices_)
   2295             path_normals.append(path_normals_)
   2296 
   2297         path_vertices.reverse()
   2298         path_normals.reverse()
   2299         path_vertices = tf.stack(path_vertices, axis=0)
   2300         path_normals = tf.stack(path_normals, axis=0)
   2301 
   2302         # Prepares the rays for testing blockage between the last
   2303         # interaction point and the sources.
   2304         #
   2305         # current : [num_targets, num_sources, num_samples, 3], tf.float
   2306         #     Positions of the last interactions
   2307         #
   2308         # d : [num_targets, num_sources, num_samples, 3], tf.float
   2309         #     Ray direction between the last interaction point and the sources
   2310         #
   2311         # maxt : [num_targets, num_sources, num_samples], tf.float
   2312         #     Distances between the last interaction point and the sources
   2313         current, d, maxt = self._spec_image_method_phase_23(current, sources,
   2314                                                       num_targets)
   2315 
   2316         # Test for obstruction using Mitsuba
   2317         # [num_targets*num_sources*num_samples]
   2318         val = self._test_obstruction(tf.reshape(current, [-1, 3]),
   2319                                      tf.reshape(d, [-1, 3]),
   2320                                      tf.reshape(maxt, [-1]),
   2321                                      ris_objects)
   2322         # [num_targets, num_sources, num_samples, 3]
   2323         blk = tf.reshape(val, tf.shape(maxt))
   2324         # Discard paths for which the shooted ray has zero-length, i.e., when
   2325         # two consecutive intersection points have the same location, or when
   2326         # the source and target have the same locations (RADAR).
   2327         # [num_targets, num_sources, num_samples]
   2328         blk = tf.logical_or(blk, tf.less(maxt, SolverBase.EPSILON))
   2329 
   2330         # --- Phase 3
   2331         # Post-process the valid paths, from transmitters to receivers, to put
   2332         # them in the expected output format.
   2333         #
   2334         # mask : [num_targets, num_sources, max_num_paths], tf.bool
   2335         #      Mask indicating if a path is valid
   2336         #
   2337         # valid_vertices : [max_depth, num_targets, num_sources,
   2338         #                   max_num_paths, 3], tf.float
   2339         #     Positions of intersection points.
   2340         #
   2341         # valid_objects : [max_depth, num_targets, num_sources,
   2342         #                   max_num_paths], tf.int
   2343         #     Indices of the intersected scene objects or wedges.
   2344         #     Paths with depth lower than ``max_depth`` are padded with `-1`.
   2345         #
   2346         # valid_normals : [max_depth, num_targets, num_sources,
   2347         #                   max_num_paths, 3], tf.float
   2348         #     Normals to the primitives at the intersection points.
   2349         mask, valid_vertices, valid_objects, valid_normals =\
   2350             self._spec_image_method_phase_3(candidates, valid, num_targets,
   2351                                 num_sources, path_vertices, path_normals, blk)
   2352 
   2353         # Update the object storing the paths
   2354         paths.mask = mask
   2355         paths.vertices = valid_vertices
   2356         paths.objects = valid_objects
   2357 
   2358         spec_paths_tmp.normals = valid_normals
   2359 
   2360     ### Transition matrices
   2361 
   2362     def _spec_transition_matrices(self, relative_permittivity,
   2363                                   scattering_coefficient,
   2364                                   paths, paths_tmp, scattering):
   2365         # pylint: disable=line-too-long
   2366         """
   2367         Compute the transfer matrices, delays, angles of departures, and angles
   2368         of arrivals, of paths from a set of valid reflection paths and the
   2369         EM properties of the materials.
   2370 
   2371         Input
   2372         ------
   2373         relative_permittivity : [num_shape], tf.complex
   2374             Tensor containing the relative permittivity of all shapes
   2375 
   2376         scattering_coefficient : [num_shape], tf.float
   2377             Tensor containing the scattering coefficients of all shapes
   2378 
   2379         paths : :class:`~sionna.rt.Paths`
   2380             Paths to update
   2381 
   2382         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   2383             Addtional quantities required for paths computation
   2384 
   2385         scattering : bool
   2386             Set to `True` if computing the scattered paths.
   2387 
   2388         Output
   2389         -------
   2390         mat_t : [num_targets, num_sources, max_num_paths, 2, 2], tf.complex
   2391                 Specular transition matrix for every path.
   2392         """
   2393 
   2394         vertices = paths.vertices
   2395         targets = paths.targets
   2396         sources = paths.sources
   2397         objects = paths.objects
   2398         theta_t = paths.theta_t
   2399         phi_t = paths.phi_t
   2400         theta_r = paths.theta_r
   2401         phi_r = paths.phi_r
   2402 
   2403         normals = paths_tmp.normals
   2404         k_i = paths_tmp.k_i
   2405         k_r = paths_tmp.k_r
   2406         if scattering:
   2407             # For scattering, only the distance up to the last intersection
   2408             # point is considered for path loss.
   2409             # [num_targets, num_sources, max_num_paths]
   2410             total_distance = paths_tmp.scat_src_2_last_int_dist
   2411         else:
   2412             # [num_targets, num_sources, max_num_paths]
   2413             total_distance = paths_tmp.total_distance
   2414 
   2415         # Maximum depth
   2416         max_depth = tf.shape(vertices)[0]
   2417         # Number of targets
   2418         num_targets = tf.shape(targets)[0]
   2419         # Number of sources
   2420         num_sources = tf.shape(sources)[0]
   2421         # Maximum number of paths
   2422         max_num_paths = tf.shape(objects)[3]
   2423 
   2424         # Flag that indicates if a ray is valid
   2425         # [max_depth, num_targets, num_sources, max_num_paths]
   2426         valid_ray = tf.not_equal(objects, -1)
   2427         # Pad to enable detection of the last valid reflection for scattering
   2428         # [max_depth+1, num_targets, num_sources, max_num_paths]
   2429         valid_ray = tf.pad(valid_ray, [[0,1], [0,0], [0,0], [0,0]],
   2430                            constant_values=False)
   2431 
   2432         # Relative perimittivities and scattering coefficients.
   2433         # If a callable is defined to compute the radio material properties,
   2434         # it is invoked. Otherwise, the radio materials of objects are used.
   2435         rm_callable = self._scene.radio_material_callable
   2436         if rm_callable is None:
   2437             # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   2438             # This makes no difference on the resulting paths as such paths
   2439             # are not flagged as active.
   2440             # [max_depth, num_targets, num_sources, max_num_paths]
   2441             valid_object_idx = tf.where(objects == -1, 0, objects)
   2442             if tf.shape(relative_permittivity)[0] == 0:
   2443                 # [max_depth, num_targets, num_sources, max_num_paths]
   2444                 etas = tf.zeros_like(valid_object_idx, dtype=self._dtype)
   2445                 scattering_coefficient = tf.zeros_like(valid_object_idx,
   2446                                                        dtype=self._rdtype)
   2447 
   2448             else:
   2449                 # [max_depth, num_targets, num_sources, max_num_paths]
   2450                 etas = tf.gather(relative_permittivity, valid_object_idx)
   2451                 scattering_coefficient = tf.gather(scattering_coefficient,
   2452                                                    valid_object_idx)
   2453         else:
   2454             # [max_depth, num_targets, num_sources, max_num_paths]
   2455             etas, scattering_coefficient, _  = rm_callable(objects, vertices)
   2456 
   2457         # Compute cos(theta) at each reflection point
   2458         # [max_depth, num_targets, num_sources, max_num_paths]
   2459         cos_theta = -dot(k_i[:max_depth], normals, clip=True)
   2460 
   2461         # Compute e_i_s, e_i_p, e_r_s, e_r_p at each reflection point
   2462         # all : [max_depth, num_targets, num_sources, max_num_paths,3]
   2463         # pylint: disable=unbalanced-tuple-unpacking
   2464         e_i_s, e_i_p, e_r_s, e_r_p = compute_field_unit_vectors(k_i[:max_depth],
   2465                                             k_r, normals, SolverBase.EPSILON)
   2466 
   2467         # Compute r_s, r_p at each reflection point
   2468         # [max_depth, num_targets, num_sources, max_num_paths]
   2469         r_s, r_p = reflection_coefficient(etas, cos_theta)
   2470 
   2471         # Multiply the reflection coefficients with the
   2472         # reflection reduction factor
   2473         # [max_depth, num_targets, num_sources, max_num_paths]
   2474         reduction_factor = tf.sqrt(1 - scattering_coefficient**2)
   2475         reduction_factor = tf.complex(reduction_factor,
   2476                                       tf.zeros_like(reduction_factor))
   2477 
   2478         # Compute the field transfer matrix.
   2479         # It is initialized with the identity matrix of size 2 (S and P
   2480         # polarization components)
   2481         # [num_targets, num_sources, max_num_paths, 2, 2]
   2482         mat_t = tf.eye(num_rows=2,
   2483                     batch_shape=[num_targets, num_sources, max_num_paths],
   2484                     dtype=self._dtype)
   2485         # Initialize last field unit vector with outgoing ones
   2486         # [num_targets, num_sources, max_num_paths, 3]
   2487         last_e_r_s = theta_hat(theta_t, phi_t)
   2488         last_e_r_p = phi_hat(phi_t)
   2489         for depth in tf.range(0,max_depth):
   2490 
   2491             # Is this a valid reflection?
   2492             # [num_targets, num_sources, max_num_paths]
   2493             valid = valid_ray[depth]
   2494 
   2495             # Is the next reflection valid?
   2496             # [num_targets, num_sources, max_num_paths]
   2497             next_valid = valid_ray[depth+1]
   2498             # Expand for broadcasting
   2499             # [num_targets, num_sources, max_num_paths, 1, 1]
   2500             next_valid = insert_dims(next_valid, 2)
   2501 
   2502             # [num_targets, num_sources, max_num_paths]
   2503             reduction_factor_ = reduction_factor[depth]
   2504             # [num_targets, num_sources, max_num_paths, 1, 1]
   2505             reduction_factor_ = insert_dims(reduction_factor_, 2, -1)
   2506 
   2507             # Early stopping if no active rays
   2508             if not tf.reduce_any(valid):
   2509                 break
   2510 
   2511             # Add dimension for broadcasting with coordinates
   2512             # [num_targets, num_sources, max_num_paths, 1]
   2513             valid_ = tf.expand_dims(valid, axis=-1)
   2514 
   2515             # Change of basis matrix
   2516             # [num_targets, num_sources, max_num_paths, 2, 2]
   2517             mat_cob = component_transform(last_e_r_s, last_e_r_p,
   2518                                           e_i_s[depth], e_i_p[depth])
   2519             mat_cob = tf.complex(mat_cob, tf.zeros_like(mat_cob))
   2520             # Only apply transform if valid reflection
   2521             # [num_targets, num_sources, max_num_paths, 1, 1]
   2522             valid__ = tf.expand_dims(valid_, axis=-1)
   2523             # [num_targets, num_sources, max_num_paths, 2, 2]
   2524             e = tf.where(valid__, tf.linalg.matmul(mat_cob, mat_t), mat_t)
   2525             # Only update ongoing direction for next iteration if this
   2526             # reflection is valid and if this is not the last step
   2527             last_e_r_s = tf.where(valid_, e_r_s[depth], last_e_r_s)
   2528             last_e_r_p = tf.where(valid_, e_r_p[depth], last_e_r_p)
   2529 
   2530             # Fresnel coefficients
   2531             # [num_targets, num_sources, max_num_paths, 2]
   2532             r = tf.stack([r_s[depth], r_p[depth]], -1)
   2533             # Set the coefficients to one if non-valid reflection
   2534             # [num_targets, num_sources, max_num_paths, 2]
   2535             r = tf.where(valid_, r, tf.ones_like(r))
   2536             # Add a dimension to broadcast with mat_t
   2537             # [num_targets, num_sources, max_num_paths, 2, 1]
   2538             r = tf.expand_dims(r, axis=-1)
   2539             # Apply Fresnel coefficient
   2540             # [num_targets, num_sources, max_num_paths, 2, 2]
   2541             mat_t = r*e
   2542 
   2543             # If scattering, then the reduction coefficient is not applied
   2544             # to the last interaction as the outgoing ray is diffusely
   2545             # reflected and not specularly reflected
   2546             if scattering:
   2547                 # [num_targets, num_sources, max_num_paths, 1, 1]
   2548                 apply_reduction = tf.logical_and(valid__, next_valid)
   2549             else:
   2550                 apply_reduction = valid__
   2551 
   2552             # Apply the reduction factor
   2553             # [num_targets, num_sources, max_num_paths, 2]
   2554             reduction_factor_ = tf.where(apply_reduction, reduction_factor_,
   2555                                         tf.ones_like(reduction_factor_))
   2556             # [num_targets, num_sources, max_num_paths, 2, 2]
   2557             mat_t = mat_t*reduction_factor_
   2558 
   2559         # Move to the targets frame
   2560         # This is not done for scattering as we stop the last interaction point
   2561         if not scattering:
   2562             # Transformation matrix
   2563             # [num_targets, num_sources, max_num_paths, 2, 2]
   2564             mat_cob = component_transform(last_e_r_s, last_e_r_p,
   2565                                         theta_hat(theta_r, phi_r),
   2566                                         phi_hat(phi_r))
   2567             mat_cob = tf.complex(mat_cob, tf.zeros_like(mat_cob))
   2568             # Apply transformation
   2569             # [num_targets, num_sources, max_num_paths, 2, 2]
   2570             mat_t = tf.linalg.matmul(mat_cob, mat_t)
   2571 
   2572         # Divide by total distance to account for propagation loss
   2573         # [num_targets, num_sources, max_num_paths, 1, 1]
   2574         total_distance = expand_to_rank(total_distance, tf.rank(mat_t),
   2575                                         axis=3)
   2576         total_distance = tf.complex(total_distance,
   2577                                     tf.zeros_like(total_distance))
   2578         # [num_targets, num_sources, max_num_paths, 2, 2]
   2579         mat_t = tf.math.divide_no_nan(mat_t, total_distance)
   2580 
   2581         # Set invalid paths to 0 and stores the transition matrices
   2582         # Expand masks to broadcast with the field components
   2583         # [num_targets, num_sources, max_num_paths, 1, 1]
   2584         mask_ = expand_to_rank(paths.mask, 5, axis=3)
   2585         # Zeroing coefficients corresponding to non-valid paths
   2586         # [num_targets, num_sources, max_num_paths, 2, 2]
   2587         mat_t = tf.where(mask_, mat_t, tf.zeros_like(mat_t))
   2588 
   2589         return mat_t
   2590 
   2591     ##################################################################
   2592     # Methods used for computing the diffracted paths
   2593     ##################################################################
   2594 
   2595     def _discard_obstructing_wedges_and_corners(self, candidate_wedges, targets,
   2596                                                 sources):
   2597         r"""
   2598         Discard wedges for which at least one of the source or target are
   2599         "inside" the wedge
   2600 
   2601         Inputs
   2602         ------
   2603         candidate_wedges : [num_candidate_wedges], int
   2604             Candidate wedges.
   2605             Entries correspond to wedges indices.
   2606 
   2607         targets : [num_targets, 3], tf.float
   2608             Coordinates of the targets.
   2609 
   2610         sources : [num_sources, 3], tf.float
   2611             Coordinates of the sources.
   2612 
   2613         Output
   2614         -------
   2615         wedges_indices : [num_targets, num_sources, max_num_paths], tf.int
   2616             Indices of the wedges that interacted with the diffracted paths
   2617         """
   2618 
   2619         epsilon = tf.cast(SolverBase.EPSILON, self._rdtype)
   2620 
   2621         # [num_candidate_wedges, 3]
   2622         origins = tf.gather(self._wedges_origin, candidate_wedges)
   2623 
   2624         # Expand to broadcast with sources/targets and 0/n faces
   2625         # [1, num_candidate_wedges, 1, 3]
   2626         origins = tf.expand_dims(origins, axis=0)
   2627         origins = tf.expand_dims(origins, axis=2)
   2628 
   2629         # Normals
   2630         # [num_candidate_wedges, 2, 3]
   2631         # [:,0,:] : 0-face
   2632         # [:,1,:] : n-face
   2633         normals = tf.gather(self._wedges_normals, candidate_wedges)
   2634         # Expand to broadcast with the sources or targets
   2635         # [1, num_candidate_wedges, 2, 3]
   2636         normals = tf.expand_dims(normals, axis=0)
   2637 
   2638         # Expand to broadcast with candidate and 0/n faces wedges
   2639         # [num_sources, 1, 1, 3]
   2640         sources = expand_to_rank(sources, 4, 1)
   2641         # [num_targets, 1, 1, 3]
   2642         targets = expand_to_rank(targets, 4, 1)
   2643         # Sources/Targets vectors
   2644         # [num_sources, num_candidate_wedges, 1, 3]
   2645         u_t = sources - origins
   2646         # [num_targets, num_candidate_wedges, 1, 3]
   2647         u_r = targets - origins
   2648 
   2649         # [num_sources, num_candidate_wedges, 2]
   2650         sources_valid_half_space = dot(u_t, normals)
   2651         sources_valid_half_space = tf.greater(sources_valid_half_space,
   2652                     tf.fill(tf.shape(sources_valid_half_space), epsilon))
   2653         # [num_sources, num_candidate_wedges]
   2654         sources_valid_half_space = tf.reduce_any(sources_valid_half_space,
   2655                                                  axis=2)
   2656         # Expand to broadcast with targets
   2657         # [1, num_sources, num_candidate_wedges]
   2658         sources_valid_half_space = tf.expand_dims(sources_valid_half_space,
   2659                                                   axis=0)
   2660 
   2661         # [num_targets, num_candidate_wedges, 2]
   2662         targets_valid_half_space = dot(u_r, normals)
   2663         targets_valid_half_space = tf.greater(targets_valid_half_space,
   2664                             tf.fill(tf.shape(targets_valid_half_space),epsilon))
   2665         # [num_targets, num_candidate_wedges]
   2666         targets_valid_half_space = tf.reduce_any(targets_valid_half_space,
   2667                                                  axis=2)
   2668         # Expand to broadcast with sources
   2669         # [num_targets, 1, num_candidate_wedges]
   2670         targets_valid_half_space = tf.expand_dims(targets_valid_half_space,
   2671                                                   axis=1)
   2672 
   2673         # [num_targets, num_sources, max_num_paths = num_candidate_wedges]
   2674         mask = tf.logical_and(sources_valid_half_space,
   2675                               targets_valid_half_space)
   2676 
   2677         # Discard paths with no valid link
   2678         # [max_num_paths]
   2679         valid_paths = tf.where(tf.reduce_any(mask, axis=(0,1)))[:,0]
   2680         # [num_targets, num_sources, max_num_paths]
   2681         mask = tf.gather(mask, valid_paths, axis=2)
   2682         # [max_num_paths]
   2683         wedges_indices = tf.gather(candidate_wedges, valid_paths, axis=0)
   2684         # Set invalid wedges to -1
   2685         # [num_targets, num_sources, max_num_paths]
   2686         wedges_indices = tf.where(mask, wedges_indices, -1)
   2687 
   2688         return wedges_indices
   2689 
   2690     def _compute_diffraction_points(self, targets, sources, wedges_indices):
   2691         r"""
   2692         Compute the interaction points on the wedges that minimizes the path
   2693         length, and masks the wedges for which the interaction points is
   2694         not on the finite wedge.
   2695 
   2696         Note: This calculation is done in double-precision (64bit).
   2697 
   2698         Input
   2699         ------
   2700         targets : [num_targets, 3], tf.float
   2701             Coordinates of the targets.
   2702 
   2703         sources : [num_sources, 3], tf.float
   2704             Coordinates of the sources.
   2705 
   2706         wedges_indices : [num_targets, num_sources, max_num_paths], tf.int
   2707             Indices of the wedges that interacted with the diffracted paths
   2708 
   2709         Output
   2710         -------
   2711         wedges_indices : [num_targets, num_sources, max_num_paths], tf.int
   2712             Indices of the wedges that interacted with the diffracted paths
   2713 
   2714         vertices : [num_targets, num_sources, max_num_paths, 3], tf.float
   2715             Coordinates of the interaction points on the intersected wedges
   2716         """
   2717 
   2718         sources = tf.cast(sources, tf.float64)
   2719         targets = tf.cast(targets, tf.float64)
   2720 
   2721         # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   2722         # This makes no difference on the resulting paths as such paths
   2723         # are not flagged as active.
   2724         # [max_num_paths]
   2725         valid_wedges_idx = tf.where(wedges_indices == -1, 0, wedges_indices)
   2726 
   2727         # [max_num_paths, 3]
   2728         origins = tf.gather(self._wedges_origin, valid_wedges_idx)
   2729         origins = tf.cast(origins, tf.float64)
   2730         # [1, 1, max_num_paths, 3]
   2731         origins = expand_to_rank(origins, 4, 0)
   2732         # [max_num_paths, 3]
   2733         e_hat = tf.gather(self._wedges_e_hat, valid_wedges_idx)
   2734         e_hat = tf.cast(e_hat, tf.float64)
   2735         # [1, 1, max_num_paths, 3]
   2736         e_hat = expand_to_rank(e_hat, 4, 0)
   2737         # [max_num_paths]
   2738         wedges_length = tf.gather(self._wedges_length, valid_wedges_idx)
   2739         wedges_length = tf.cast(wedges_length, tf.float64)
   2740         # [1, 1, max_num_paths]
   2741         wedges_length = expand_to_rank(wedges_length, 3, 0)
   2742 
   2743         # Expand to broadcast with paths and sources/targets
   2744         # [1, num_sources, 1, 3]
   2745         sources = tf.expand_dims(tf.expand_dims(sources, axis=1), axis=0)
   2746         # [num_targets, 1, 1, 3]
   2747         targets = insert_dims(targets, 2, 1)
   2748         # Sources/Targets vectors
   2749         # [1, num_sources, max_num_paths, 3]
   2750         u_t = origins - sources
   2751         # [num_targets, 1, max_num_paths, 3]
   2752         u_r = origins - targets
   2753 
   2754         # Quantites required for the computation of the interaction points
   2755         # [1, num_sources, max_num_paths]
   2756         a = dot(u_t,e_hat)
   2757         # [num_targets, 1, max_num_paths]
   2758         b = dot(u_r, e_hat)
   2759         # [1, num_sources, max_num_paths]
   2760         c = dot(u_t, u_t)
   2761         # [num_targets, 1, max_num_paths]
   2762         d = dot(u_r, u_r)
   2763 
   2764         # Quantites required for the computation of the interaction points
   2765         # [num_targets, num_sources, max_num_paths]
   2766         alpha = -tf.square(a) + tf.square(b) + c - d
   2767         beta = 2.*(a*tf.square(b) - b*tf.square(a) + b*c - a*d)
   2768         gamma = tf.square(b)*c - tf.square(a)*d
   2769 
   2770         # Normalized quantites to improve numerical preicion, only valid if
   2771         # alpha != 0
   2772         # [num_targets, num_sources, max_num_paths]
   2773         beta_norm = tf.math.divide_no_nan(beta, alpha)
   2774         gamma_norm = tf.math.divide_no_nan(gamma, alpha)
   2775         delta = tf.square(beta_norm) - 4.*gamma_norm
   2776 
   2777         # Because of numerical imprecision, delta could be slighlty smaller than
   2778         # 0
   2779         # [num_targets, num_sources, max_num_paths]
   2780         delta = tf.where(tf.less(delta, tf.zeros_like(delta)),
   2781                          tf.zeros_like(delta), delta)
   2782 
   2783 
   2784         # Four possible outcomes depending on the value of the previous
   2785         # quantities.
   2786         # Values of t that minimizes the path length for each outcome.
   2787         # [num_targets, num_sources, max_num_paths]
   2788         t_min_1 = -a
   2789         t_min_2 = -tf.math.divide_no_nan(gamma, beta)
   2790         t_min_3 = (-beta_norm + tf.sqrt(delta))*0.5
   2791         t_min_4 = (-beta_norm - tf.sqrt(delta))*0.5
   2792         # Condition for each outcome to be selected
   2793         # If a == b and c == d, then set to t_min_1
   2794         # [num_targets, num_sources, max_num_paths]
   2795         cond_1 = tf.logical_and(tf.experimental.numpy.isclose(a, b),
   2796                                 tf.experimental.numpy.isclose(c, d))
   2797         # If cond_1 does not hold and alpha == 0, then set to t_min_2
   2798         # [num_targets, num_sources, max_num_paths]
   2799         cond_2 = tf.logical_and(tf.logical_not(cond_1),
   2800                                 tf.experimental.numpy.isclose(alpha,
   2801                                 tf.zeros_like(alpha)))
   2802         # If neither cond_1 nor cond_2 holds, then set to t_min_3 or t_min_4
   2803         # depending on the signs of t+a and t+b
   2804         # [num_targets, num_sources, max_num_paths]
   2805         not_cond_12 = tf.logical_and(tf.logical_not(cond_1),
   2806                                      tf.logical_not(cond_2))
   2807         # [num_targets, num_sources, max_num_paths]
   2808         t_min_3a = t_min_3 + a
   2809         t_min_3b = t_min_3 + b
   2810         # [num_targets, num_sources, max_num_paths]
   2811         cond_3 = tf.logical_and(not_cond_12,
   2812                 tf.less_equal(tf.sign(t_min_3a)*tf.sign(t_min_3b), 0.0))
   2813         # If none of conditions 1, 2, or 3 are satisfied, then all is left is
   2814         # t_min_4
   2815         # [num_targets, num_sources, max_num_paths]
   2816         cond_4 = tf.logical_and(not_cond_12, tf.logical_not(cond_3))
   2817         # Assign t_min according to the previously computed conditions
   2818         # [num_targets, num_sources, max_num_paths]
   2819         t_min = tf.zeros_like(cond_1, tf.float64)
   2820         t_min = tf.where(cond_1, t_min_1, t_min)
   2821         t_min = tf.where(cond_2, t_min_2, t_min)
   2822         t_min = tf.where(cond_3, t_min_3, t_min)
   2823         t_min = tf.where(cond_4, t_min_4, t_min)
   2824 
   2825         # Mask paths for which the interaction point is not on the finite
   2826         # wedge
   2827         # [num_targets, num_sources, max_num_paths]
   2828         mask_ = tf.logical_and(
   2829             tf.greater_equal(t_min, tf.zeros_like(t_min)),
   2830             tf.less_equal(t_min, wedges_length))
   2831         # [num_targets, num_sources, max_num_paths]
   2832         wedges_indices = tf.where(mask_, wedges_indices, -1)
   2833 
   2834         # Interaction points
   2835         # Expand to broadcast with coordinates
   2836         # [num_targets, num_sources, max_num_paths, 1]
   2837         t_min = tf.expand_dims(t_min, axis=3)
   2838         # [num_targets, num_sources, max_num_paths, 3]
   2839         inter_point = origins + t_min*e_hat
   2840 
   2841         # Discard wedges with no valid paths
   2842         # [max_num_paths]
   2843         used_wedges = tf.where(tf.reduce_any(tf.not_equal(wedges_indices, -1),
   2844                                              axis=(0,1)))[:,0]
   2845         # [num_targets, num_sources, max_num_paths, 3]
   2846         inter_point = tf.gather(inter_point, used_wedges, axis=2)
   2847         # [num_targets, num_sources, max_num_paths]
   2848         wedges_indices = tf.gather(wedges_indices, used_wedges, axis=2)
   2849 
   2850         # Back to the required precision
   2851         inter_point = tf.cast(inter_point, self._rdtype)
   2852 
   2853         return wedges_indices, inter_point
   2854 
   2855     def _check_wedges_visibility(self, targets, sources, wedges_indices,
   2856                                  vertices, ris_objects):
   2857         r"""
   2858         Discard the wedges that are not valid due to obstruction by updating the
   2859         mask and removing the wedges related to no valid links.
   2860 
   2861         Input
   2862         ------
   2863         targets : [num_targets, 3], tf.float
   2864             Coordinates of the targets.
   2865 
   2866         sources : [num_sources, 3], tf.float
   2867             Coordinates of the sources.
   2868 
   2869         wedges_indices : [num_targets, num_sources, max_num_paths], tf.int
   2870             Indices of the wedges that interacted with the diffracted paths
   2871 
   2872         vertices : [num_targets, num_sources, max_num_paths, 3], tf.float
   2873             Coordinates of the interaction points on the intersected wedges
   2874 
   2875         ris_objects : list(mi.Rectangle)
   2876             List of Mitsuba rectangles implementing the RIS
   2877 
   2878         Output
   2879         -------
   2880         wedges_indices : [num_targets, num_sources, max_num_paths], tf.int
   2881             Indices of the wedges that interacted with the diffracted paths
   2882 
   2883         vertices : [num_targets, num_sources, max_num_paths, 3], tf.float
   2884             Coordinates of the interaction points on the intersected wedges
   2885         """
   2886 
   2887         max_num_paths = vertices.shape[2]
   2888         num_sources = sources.shape[0]
   2889         num_targets = targets.shape[0]
   2890 
   2891         # Broadcast sources and targets with wedges diffraction point
   2892         # [1, num_sources, 1, 3]
   2893         sources = tf.expand_dims(sources, axis=0)
   2894         sources = tf.expand_dims(sources, axis=2)
   2895         # [num_targets, num_sources, max_num_paths, 3]
   2896         sources = tf.broadcast_to(sources, vertices.shape)
   2897         # Flatten
   2898         # batch_size = num_targets*num_sources*max_num_paths
   2899         # [batch_size, 3]
   2900         sources = tf.reshape(sources, [-1, 3])
   2901         # [num_targets, 1, 1, 3]
   2902         targets = expand_to_rank(targets, tf.rank(vertices), 1)
   2903         # [num_targets, num_sources, max_num_paths, 3]
   2904         targets = tf.broadcast_to(targets, vertices.shape)
   2905         # Flatten
   2906         # [batch_size, 3]
   2907         targets = tf.reshape(targets, [-1, 3])
   2908 
   2909         # Flatten interaction points
   2910         # [batch_size, 3]
   2911         wedges_points = tf.reshape(vertices, [-1, 3])
   2912 
   2913         # Check visibility between transmitter and wedge
   2914         # Ray origin
   2915         # d : [batch_size, 3]
   2916         # maxt : [batch_size]
   2917         d,maxt = tf.linalg.normalize(wedges_points - sources, axis=1)
   2918         maxt = tf.squeeze(maxt, axis=1)
   2919         # [batch_size]
   2920         valid_t2w = tf.logical_not(self._test_obstruction(sources, d, maxt,
   2921                                                           ris_objects))
   2922 
   2923         # Check visibility between wedge and receiver
   2924         # Ray origin
   2925         # d : [batch_size, 3]
   2926         # maxt : [batch_size]
   2927         d,maxt = tf.linalg.normalize(wedges_points - targets, axis=1)
   2928         maxt = tf.squeeze(maxt, axis=1)
   2929         # [batch_size]
   2930         valid_w2r = tf.logical_not(self._test_obstruction(targets, d, maxt,
   2931                                                           ris_objects))
   2932 
   2933         # Mask obstructed wedges
   2934         # [batch_size]
   2935         valid = tf.logical_and(valid_t2w, valid_w2r)
   2936         # [num_targets, num_sources, max_num_paths]
   2937         valid = tf.reshape(valid, [num_targets, num_sources, max_num_paths])
   2938         # Set wedge indices of blocked paths to -1
   2939         wedges_indices = tf.where(valid, wedges_indices, -1)
   2940 
   2941         # Discard wedges not involved in any link
   2942         # [max_num_paths]
   2943         used_wedges = tf.where(tf.reduce_any(tf.not_equal(wedges_indices, -1),
   2944                                              axis=(0,1)))[:,0]
   2945         # [num_targets, num_sources, max_num_paths, 3]
   2946         vertices = tf.gather(vertices, used_wedges, axis=2)
   2947         # [num_targets, num_sources, max_num_paths]
   2948         wedges_indices = tf.gather(wedges_indices, used_wedges, axis=2)
   2949 
   2950         return wedges_indices, vertices
   2951 
   2952     def _gather_valid_diff_paths(self, paths):
   2953         r"""
   2954         Extracts only valid diffracted paths to reduce memory consumption when
   2955         having multiple links with different number of valid paths.
   2956 
   2957         Input
   2958         ------
   2959         paths : :class:`~sionna.rt.Paths`
   2960             Paths to update
   2961 
   2962         Output
   2963         ------
   2964         paths : :class:`~sionna.rt.Paths`
   2965             Updated paths
   2966         """
   2967 
   2968         # [num_targets, num_sources, max_num_candidates]
   2969         wedges_indices = paths.objects[0]
   2970         # [num_targets, num_sources, max_num_candidates, 3]
   2971         vertices = paths.vertices[0]
   2972         # [num_sources, 3]
   2973         sources = paths.sources
   2974         # [num_targets, 3]
   2975         targets = paths.targets
   2976 
   2977         num_sources = tf.shape(sources)[0]
   2978         num_targets = tf.shape(targets)[0]
   2979 
   2980         # [num_targets, num_sources, max_num_candidates]
   2981         valid = tf.not_equal(wedges_indices, -1)
   2982 
   2983         # [num_targets, num_sources]
   2984         num_paths = tf.reduce_sum(tf.cast(valid, tf.int32), axis=-1)
   2985         # Maximum number of valid paths
   2986         # ()
   2987         max_num_paths = tf.reduce_max(num_paths)
   2988 
   2989         # Build indices for keeping only valid path
   2990         # [num_valid_paths, 3]
   2991         gather_indices = tf.where(valid)
   2992         # [num_targets, num_sources, max_num_candidates]
   2993         path_indices = tf.cumsum(tf.cast(valid, tf.int32), axis=-1)
   2994         # [num_valid_paths]
   2995         path_indices = tf.gather_nd(path_indices, gather_indices) - 1
   2996         scatter_indices = tf.transpose(gather_indices, [1,0])
   2997         if not tf.size(scatter_indices) == 0:
   2998             scatter_indices = tf.tensor_scatter_nd_update(scatter_indices,
   2999                                 [[2]], [path_indices])
   3000         # [num_valid_paths, 3]
   3001         scatter_indices = tf.transpose(scatter_indices, [1,0])
   3002 
   3003         # Mask of valid paths
   3004         # [num_targets, num_sources, max_num_paths]
   3005         mask = tf.fill([num_targets, num_sources, max_num_paths], False)
   3006         mask = tf.tensor_scatter_nd_update(mask, scatter_indices,
   3007                 tf.fill([tf.shape(scatter_indices)[0]], True))
   3008 
   3009         # Locations of the interactions
   3010         # [num_targets, num_sources, max_num_paths, 3]
   3011         valid_vertices = tf.zeros([num_targets, num_sources, max_num_paths, 3],
   3012                                   dtype=self._rdtype)
   3013         # [total_num_valid_paths, 3]
   3014         vertices = tf.gather_nd(vertices, gather_indices)
   3015         valid_vertices = tf.tensor_scatter_nd_update(valid_vertices,
   3016                                                      scatter_indices, vertices)
   3017 
   3018         # Intersected wedges
   3019         # [num_targets, num_sources, max_num_paths]
   3020         valid_wedges_indices = tf.fill([num_targets, num_sources,
   3021                                         max_num_paths], -1)
   3022         # [total_num_valid_paths]
   3023         wedges_indices = tf.gather_nd(wedges_indices, gather_indices)
   3024         valid_wedges_indices = tf.tensor_scatter_nd_update(valid_wedges_indices,
   3025                                             scatter_indices, wedges_indices)
   3026 
   3027         # [1, num_targets, num_sources, max_num_candidates]
   3028         paths.objects = tf.expand_dims(valid_wedges_indices, axis=0)
   3029         # [1, num_targets, num_sources, max_num_candidates, 3]
   3030         paths.vertices = tf.expand_dims(valid_vertices, axis=0)
   3031         # [num_targets, num_sources, max_num_candidates]
   3032         paths.mask = mask
   3033 
   3034         return paths
   3035 
   3036     def _compute_diffraction_transition_matrices(self,
   3037                                                  relative_permittivity,
   3038                                                  scattering_coefficient,
   3039                                                  paths,
   3040                                                  paths_tmp):
   3041         # pylint: disable=line-too-long
   3042         """
   3043         Compute the transition matrices for diffracted rays.
   3044 
   3045         Input
   3046         ------
   3047         relative_permittivity : [num_shape], tf.complex
   3048             Tensor containing the complex relative permittivity of all shape
   3049             in the scene
   3050 
   3051         scattering_coefficient : [num_shape], tf.float
   3052             Tensor containing the scattering coefficients of all shapes
   3053 
   3054         paths : :class:`~sionna.rt.Paths`
   3055             Paths to update
   3056 
   3057         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   3058             Addtional quantities required for paths computation
   3059 
   3060         Output
   3061         ------
   3062         paths : :class:`~sionna.rt.Paths`
   3063             Updated paths
   3064         """
   3065 
   3066         mask = paths.mask
   3067         targets = paths.targets
   3068         sources = paths.sources
   3069         theta_t = paths.theta_t
   3070         phi_t = paths.phi_t
   3071         theta_r = paths.theta_r
   3072         phi_r = paths.phi_r
   3073 
   3074         normals = paths_tmp.normals
   3075 
   3076         def f(x):
   3077             """F(x) Eq.(88) in [ITUR_P526]
   3078             """
   3079             sqrt_x = tf.sqrt(x)
   3080             sqrt_pi_2 = tf.cast(tf.sqrt(PI/2.), x.dtype)
   3081 
   3082             # Fresnel integral
   3083             arg = sqrt_x/sqrt_pi_2
   3084             s = tf.math.special.fresnel_sin(arg)
   3085             c = tf.math.special.fresnel_cos(arg)
   3086             f = tf.complex(s, c)
   3087 
   3088             zero = tf.cast(0, x.dtype)
   3089             one = tf.cast(1, x.dtype)
   3090             two = tf.cast(2, f.dtype)
   3091             factor = tf.complex(sqrt_pi_2*sqrt_x, zero)
   3092             factor = factor*tf.exp(tf.complex(zero, x))
   3093             res =  tf.complex(one, one) - two*f
   3094 
   3095             return factor* res
   3096 
   3097         wavelength = self._scene.wavelength
   3098         k = 2.*PI/wavelength
   3099 
   3100         # [num_targets, num_sources, max_num_paths, 3]
   3101         diff_points = paths.vertices[0]
   3102         # [num_targets, num_sources, max_num_paths]
   3103         wedges_indices = paths.objects[0]
   3104 
   3105         # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   3106         # This makes no difference on the resulting paths as such paths
   3107         # are not flagged as active.
   3108         # [num_targets, num_sources, max_num_paths]
   3109         valid_wedges_idx = tf.where(wedges_indices == -1, 0, wedges_indices)
   3110 
   3111         # Normals
   3112         # [num_targets, num_sources, max_num_paths, 2, 3]
   3113         normals = tf.gather(self._wedges_normals, valid_wedges_idx, axis=0)
   3114 
   3115         # Compute the wedges angle
   3116         # [num_targets, num_sources, max_num_paths]
   3117         cos_wedges_angle = dot(normals[...,0,:],normals[...,1,:], clip=True)
   3118         wedges_angle = PI - tf.math.acos(cos_wedges_angle)
   3119         n = (2.*PI-wedges_angle)/PI
   3120 
   3121         # [num_targets, num_sources, max_num_paths, 3]
   3122         e_hat = tf.gather(self._wedges_e_hat, valid_wedges_idx)
   3123 
   3124         # Reshape sources and targets
   3125         # [1, num_sources, 1, 3]
   3126         sources = tf.reshape(sources, [1, -1, 1, 3])
   3127         # [num_targets, 1, 1, 3]
   3128         targets = tf.reshape(targets, [-1, 1, 1, 3])
   3129 
   3130         # Extract surface normals
   3131         # [num_targets, num_sources, max_num_paths, 3]
   3132         n_0_hat = normals[...,0,:]
   3133         # [num_targets, num_sources, max_num_paths, 3]
   3134         n_n_hat = normals[...,1,:]
   3135 
   3136         # Relative permitivities and scattering coefficients
   3137         # If a callable is defined to compute the radio material properties,
   3138         # it is invoked. Otherwise, the radio materials of objects are used.
   3139         rm_callable = self._scene.radio_material_callable
   3140         # [num_targets, num_sources, max_num_paths, 2]
   3141         objects_indices = tf.gather(self._wedges_objects, valid_wedges_idx,
   3142                                     axis=0)
   3143         if rm_callable is None:
   3144             # [num_targets, num_sources, max_num_paths, 2]
   3145             etas = tf.gather(relative_permittivity, objects_indices)
   3146             scattering_coefficient = tf.gather(scattering_coefficient,
   3147                                                objects_indices)
   3148         else:
   3149             # Harmonize the shapes of the radio material callables
   3150             # [num_targets, num_sources, max_num_paths, 2, 3]
   3151             diff_points_ = tf.tile(tf.expand_dims(diff_points, axis=-2),
   3152                                    [1, 1, 1, 2, 1])
   3153             # scattering_coefficient, etas : [num_targets, num_sources,
   3154             #   max_num_paths, 2]
   3155             etas, scattering_coefficient, _   = rm_callable(objects_indices,
   3156                                                             diff_points_)
   3157         # [num_targets, num_sources, max_num_paths]
   3158         eta_0 = etas[...,0]
   3159         eta_n = etas[...,1]
   3160         # [num_targets, num_sources, max_num_paths]
   3161         scattering_coefficient_0 = scattering_coefficient[...,0]
   3162         scattering_coefficient_n = scattering_coefficient[...,1]
   3163 
   3164         # Compute s_prime_hat, s_hat, s_prime, s
   3165         # s_prime_hat : [num_targets, num_sources, max_num_paths, 3]
   3166         # s_prime : [num_targets, num_sources, max_num_paths]
   3167         s_prime_hat, s_prime = normalize(diff_points-sources)
   3168         # s_hat : [num_targets, num_sources, max_num_paths, 3]
   3169         # s : [num_targets, num_sources, max_num_paths]
   3170         s_hat, s = normalize(targets-diff_points)
   3171 
   3172         # Compute phi_prime_hat, beta_0_prime_hat, phi_hat, beta_0_hat
   3173         # [num_targets, num_sources, max_num_paths, 3]
   3174         phi_prime_hat, _ = normalize(cross(s_prime_hat, e_hat))
   3175         # [num_targets, num_sources, max_num_paths, 3]
   3176         beta_0_prime_hat = cross(phi_prime_hat, s_prime_hat)
   3177 
   3178         # [num_targets, num_sources, max_num_paths, 3]
   3179         phi_hat_, _ = normalize(-cross(s_hat, e_hat))
   3180         beta_0_hat = cross(phi_hat_, s_hat)
   3181 
   3182         # Compute tangent vector t_0_hat
   3183         # [num_targets, num_sources, max_num_paths, 3]
   3184         t_0_hat = cross(n_0_hat, e_hat)
   3185 
   3186         # Compute s_t_prime_hat and s_t_hat
   3187         # [num_targets, num_sources, max_num_paths, 3]
   3188         s_t_prime_hat, _ = normalize(s_prime_hat
   3189                                 - dot(s_prime_hat,e_hat, keepdim=True)*e_hat)
   3190         # [num_targets, num_sources, max_num_paths, 3]
   3191         s_t_hat, _ = normalize(s_hat - dot(s_hat,e_hat, keepdim=True)*e_hat)
   3192 
   3193         # Compute phi_prime and phi
   3194         # [num_targets, num_sources, max_num_paths]
   3195         phi_prime = PI - \
   3196             (PI-acos_diff(-dot(s_t_prime_hat, t_0_hat)))\
   3197                 * sign(-dot(s_t_prime_hat, n_0_hat))
   3198         # [num_targets, num_sources, max_num_paths]
   3199         phi = PI - (PI-acos_diff(dot(s_t_hat, t_0_hat)))\
   3200             * sign(dot(s_t_hat, n_0_hat))
   3201 
   3202         # Compute field component vectors for reflections at both surfaces
   3203         # [num_targets, num_sources, max_num_paths, 3]
   3204         # pylint: disable=unbalanced-tuple-unpacking
   3205         e_i_s_0, e_i_p_0, e_r_s_0, e_r_p_0 = compute_field_unit_vectors(
   3206             s_prime_hat,
   3207             s_hat,
   3208             n_0_hat,#*sign(-dot(s_t_prime_hat, n_0_hat, keepdim=True)),
   3209             SolverBase.EPSILON
   3210             )
   3211         # [num_targets, num_sources, max_num_paths, 3]
   3212         # pylint: disable=unbalanced-tuple-unpacking
   3213         e_i_s_n, e_i_p_n, e_r_s_n, e_r_p_n = compute_field_unit_vectors(
   3214             s_prime_hat,
   3215             s_hat,
   3216             n_n_hat,#*sign(-dot(s_t_prime_hat, n_n_hat, keepdim=True)),
   3217             SolverBase.EPSILON
   3218             )
   3219 
   3220         # Compute Fresnel reflection coefficients for 0- and n-surfaces
   3221         # [num_targets, num_sources, max_num_paths]
   3222         r_s_0, r_p_0 = reflection_coefficient(eta_0, tf.abs(tf.sin(phi_prime)))
   3223         r_s_n, r_p_n = reflection_coefficient(eta_n, tf.abs(tf.sin(n*PI-phi)))
   3224 
   3225         # Multiply the reflection coefficients with the
   3226         # corresponding reflection reduction factor
   3227         reduction_factor_0 = tf.sqrt(1 - scattering_coefficient_0**2)
   3228         reduction_factor_0 = tf.complex(reduction_factor_0,
   3229                                         tf.zeros_like(reduction_factor_0))
   3230         reduction_factor_n = tf.sqrt(1 - scattering_coefficient_n**2)
   3231         reduction_factor_n = tf.complex(reduction_factor_n,
   3232                                         tf.zeros_like(reduction_factor_n))
   3233         r_s_0 *= reduction_factor_0
   3234         r_p_0 *= reduction_factor_0
   3235         r_s_n *= reduction_factor_n
   3236         r_p_n *= reduction_factor_n
   3237 
   3238         # Compute matrices R_0, R_n
   3239         # [num_targets, num_sources, max_num_paths, 2, 2]
   3240         w_i_0 = component_transform(phi_prime_hat,
   3241                                     beta_0_prime_hat,
   3242                                     e_i_s_0,
   3243                                     e_i_p_0)
   3244         w_i_0 = tf.complex(w_i_0, tf.zeros_like(w_i_0))
   3245         # [num_targets, num_sources, max_num_paths, 2, 2]
   3246         w_r_0 = component_transform(e_r_s_0,
   3247                                     e_r_p_0,
   3248                                     phi_hat_,
   3249                                     beta_0_hat)
   3250         w_r_0 = tf.complex(w_r_0, tf.zeros_like(w_r_0))
   3251         # [num_targets, num_sources, max_num_paths, 2, 1]
   3252         r_0 = tf.expand_dims(tf.stack([r_s_0, r_p_0], -1), -1) * w_i_0
   3253         # [num_targets, num_sources, max_num_paths, 2, 1]
   3254         r_0 = -tf.matmul(w_r_0, r_0)
   3255 
   3256         # [num_targets, num_sources, max_num_paths, 2, 2]
   3257         w_i_n = component_transform(phi_prime_hat,
   3258                                     beta_0_prime_hat,
   3259                                     e_i_s_n,
   3260                                     e_i_p_n)
   3261         w_i_n = tf.complex(w_i_n, tf.zeros_like(w_i_n))
   3262         # [num_targets, num_sources, max_num_paths, 2, 2]
   3263         w_r_n = component_transform(e_r_s_n,
   3264                                     e_r_p_n,
   3265                                     phi_hat_,
   3266                                     beta_0_hat)
   3267         w_r_n = tf.complex(w_r_n, tf.zeros_like(w_r_n))
   3268         # [num_targets, num_sources, max_num_paths, 2, 1]
   3269         r_n = tf.expand_dims(tf.stack([r_s_n, r_p_n], -1), -1) * w_i_n
   3270         # [num_targets, num_sources, max_num_paths, 2, 1]
   3271         r_n = -tf.matmul(w_r_n, r_n)
   3272 
   3273         # Compute D_1, D_2, D_3, D_4
   3274         # [num_targets, num_sources, max_num_paths]
   3275         phi_m = phi - phi_prime
   3276         phi_p = phi + phi_prime
   3277 
   3278         # [num_targets, num_sources, max_num_paths]
   3279         cot_1 = cot((PI + phi_m)/(2*n))
   3280         cot_2 = cot((PI - phi_m)/(2*n))
   3281         cot_3 = cot((PI + phi_p)/(2*n))
   3282         cot_4 = cot((PI - phi_p)/(2*n))
   3283 
   3284         def n_p(beta, n):
   3285             return tf.math.round((beta + PI)/(2.*n*PI))
   3286 
   3287         def n_m(beta, n):
   3288             return tf.math.round((beta - PI)/(2.*n*PI))
   3289 
   3290         def a_p(beta, n):
   3291             return 2*tf.cos((2.*n*PI*n_p(beta, n)-beta)/2.)**2
   3292 
   3293         def a_m(beta, n):
   3294             return 2*tf.cos((2.*n*PI*n_m(beta, n)-beta)/2.)**2
   3295 
   3296         d_mul = - tf.cast(tf.exp(-1j*PI/4.), self._dtype)/\
   3297             tf.cast((2*n)*tf.sqrt(2*PI*k), self._dtype)
   3298 
   3299         # [num_targets, num_sources, max_num_paths]
   3300         ell = s_prime*s/(s_prime + s)
   3301 
   3302         # [num_targets, num_sources, max_num_paths]
   3303         cot_1 = tf.complex(cot_1, tf.zeros_like(cot_1))
   3304         cot_2 = tf.complex(cot_2, tf.zeros_like(cot_2))
   3305         cot_3 = tf.complex(cot_3, tf.zeros_like(cot_3))
   3306         cot_4 = tf.complex(cot_4, tf.zeros_like(cot_4))
   3307         d_1 = d_mul*cot_1*f(k*ell*a_p(phi_m, n))
   3308         d_2 = d_mul*cot_2*f(k*ell*a_m(phi_m, n))
   3309         d_3 = d_mul*cot_3*f(k*ell*a_p(phi_p, n))
   3310         d_4 = d_mul*cot_4*f(k*ell*a_m(phi_p, n))
   3311 
   3312         # [num_targets, num_sources, max_num_paths, 1, 1]
   3313         d_1 = tf.reshape(d_1, tf.concat([tf.shape(d_1), [1,1]], axis=0))
   3314         d_2 = tf.reshape(d_2, tf.concat([tf.shape(d_2), [1,1]], axis=0))
   3315         d_3 = tf.reshape(d_3, tf.concat([tf.shape(d_3), [1,1]], axis=0))
   3316         d_4 = tf.reshape(d_4, tf.concat([tf.shape(d_4), [1,1]], axis=0))
   3317 
   3318         # [num_targets, num_sources, max_num_paths]
   3319         spreading_factor = tf.sqrt(1.0 / (s*s_prime*(s_prime + s)))
   3320         spreading_factor = tf.complex(spreading_factor,
   3321                                       tf.zeros_like(spreading_factor))
   3322         # [num_targets, num_sources, max_num_paths, 1, 1]
   3323         spreading_factor = tf.reshape(spreading_factor, tf.shape(d_1))
   3324 
   3325         # [num_targets, num_sources, max_num_paths, 2, 2]
   3326         mat_t = (d_1+d_2)*tf.eye(2,2, batch_shape=tf.shape(r_0)[:3],
   3327                                  dtype=self._dtype)
   3328         # [num_targets, num_sources, max_num_paths, 2, 2]
   3329         mat_t += d_3*r_n + d_4*r_0
   3330         # [num_targets, num_sources, max_num_paths, 2, 2]
   3331         mat_t *= -spreading_factor
   3332 
   3333         # Convert from/to GCS
   3334         theta_t = paths.theta_t
   3335         phi_t = paths.phi_t
   3336         theta_r = paths.theta_r
   3337         phi_r = paths.phi_r
   3338 
   3339         mat_from_gcs = component_transform(
   3340                             theta_hat(theta_t, phi_t), phi_hat(phi_t),
   3341                             phi_prime_hat, beta_0_prime_hat)
   3342         mat_from_gcs = tf.complex(mat_from_gcs, tf.zeros_like(mat_from_gcs))
   3343 
   3344 
   3345         mat_to_gcs = component_transform(phi_hat_, beta_0_hat,
   3346                                       theta_hat(theta_r, phi_r), phi_hat(phi_r))
   3347         mat_to_gcs = tf.complex(mat_to_gcs, tf.zeros_like(mat_to_gcs))
   3348 
   3349         mat_t = tf.linalg.matmul(mat_t, mat_from_gcs)
   3350         mat_t = tf.linalg.matmul(mat_to_gcs, mat_t)
   3351 
   3352         # Set invalid paths to 0
   3353         # Expand masks to broadcast with the field components
   3354         # [num_targets, num_sources, max_num_paths, 1, 1]
   3355         mask_ = expand_to_rank(mask, 5, axis=3)
   3356         # Zeroing coefficients corresponding to non-valid paths
   3357         # [num_targets, num_sources, max_num_paths, 2]
   3358         mat_t = tf.where(mask_, mat_t, tf.zeros_like(mat_t))
   3359 
   3360         return mat_t
   3361 
   3362     ##################################################################
   3363     # Methods used for computing the scattered paths
   3364     ##################################################################
   3365 
   3366     def _scat_test_rx_blockage(self, targets, sources, candidates, hit_points,
   3367                                ris_objects):
   3368         r"""
   3369         Test if the LoS between the hit points and the target is blocked.
   3370         Blocked paths are masked out.
   3371 
   3372         Input
   3373         -----
   3374         targets : [num_targets, 3], tf.float
   3375             Coordinates of the targets
   3376 
   3377         sources : [num_sources, 3], tf.float
   3378             Coordinates of the sources
   3379 
   3380         candidates : [max_depth, num_sources, num_paths_per_source], int
   3381             Sequence of primitives hit at `hit_points`
   3382 
   3383         hit_points : [max_depth, num_sources, num_paths_per_source, 3], tf.float
   3384             Intersection points
   3385 
   3386         ris_objects : list(mi.Rectangle)
   3387             List of Mitsuba rectangles implementing the RIS
   3388 
   3389         Output
   3390         -------
   3391         paths : :class:`~sionna.rt.Paths`
   3392             Structure storing the scattered paths.
   3393 
   3394         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   3395             Addtional quantities required for paths computation
   3396         """
   3397 
   3398         num_sources = candidates.shape[1]
   3399         num_targets = targets.shape[0]
   3400         max_depth = tf.shape(candidates)[0]
   3401 
   3402         # Expand for broadcasting with max_depth, num_sources, and num_paths
   3403         # [1, num_targets, 1, 1, 3]
   3404         targets_ = tf.expand_dims(insert_dims(targets, 2, 1), axis=0)
   3405 
   3406         # Build the rays for shooting
   3407         # Origins
   3408         # [max_depth, num_targets, num_targets, num_paths, 3]
   3409         hit_points = tf.tile(tf.expand_dims(hit_points, axis=1),
   3410                               [1, num_targets, 1, 1, 1])
   3411         # [max_depth * num_targets * num_sources * num_paths, 3]
   3412         ray_origins = tf.reshape(hit_points, [-1, 3])
   3413         # Directions
   3414         # [max_depth, num_targets, num_sources, num_paths, 3]
   3415         ray_directions,rays_lengths = normalize(targets_ - hit_points)
   3416         # [max_depth * num_targets * num_sources * num_paths, 3]
   3417         ray_directions = tf.reshape(ray_directions, [-1, 3])
   3418         # [max_depth * num_targets * num_sources * num_paths]
   3419         rays_lengths = tf.reshape(rays_lengths, [-1])
   3420 
   3421         # Test for blockage
   3422         # [max_depth * num_targets * num_sources * num_paths]
   3423         blocked = self._test_obstruction(ray_origins, ray_directions,
   3424                                           rays_lengths, ris_objects)
   3425         # [max_depth, num_targets, num_sources, num_paths]
   3426         blocked = tf.reshape(blocked,
   3427                              [max_depth, num_targets, num_sources, -1])
   3428 
   3429         # Mask blocked paths
   3430         # [max_depth, num_targets, num_sources, num_paths]
   3431         candidates = tf.tile(tf.expand_dims(candidates, axis=1),
   3432                              [1, num_targets, 1, 1])
   3433         # [max_depth, num_targets, num_sources, num_paths]
   3434         prefix_mask = tf.logical_and(~blocked, tf.not_equal(candidates, -1))
   3435 
   3436         # Optimize tensor size by ensuring that the length of the paths
   3437         # dimension correspond to the maximum number of paths over all links
   3438 
   3439         # Keep a path if at least one of its prefix is valid
   3440         # [num_targets, num_sources, num_paths]
   3441         prefix_mask_ = tf.reduce_any(prefix_mask, axis=0)
   3442         prefix_mask_int_ = tf.cast(prefix_mask_, tf.int32)
   3443 
   3444         # Maximum number of valid paths over all links
   3445         # [num_targets, num_sources]
   3446         num_paths = tf.reduce_sum(prefix_mask_int_, axis=-1)
   3447         # Maximum number of paths
   3448         # ()
   3449         max_num_paths = tf.reduce_max(num_paths)
   3450 
   3451         # [num_valid_paths, 3]
   3452         gather_indices = tf.where(prefix_mask_)
   3453         # To build the indices of the paths in the tensor with optimized size,
   3454         # the path dimension is indexed by counting the valid path in the order
   3455         # in which they appear
   3456         # [num_targets, num_sources, num_paths]
   3457         path_indices = tf.cumsum(prefix_mask_int_, axis=-1)
   3458         # [num_valid_paths, 3]
   3459         path_indices = tf.gather_nd(path_indices, gather_indices) - 1
   3460         # The indices used to scatter the valid paths in the tensors with
   3461         # optimized size are built by replacing the index of the paths by
   3462         # the previous ones, which leads to skipping the invalid paths
   3463          # [3, num_valid_paths]
   3464         scatter_indices = tf.transpose(gather_indices, [1,0])
   3465         if not tf.size(scatter_indices) == 0:
   3466             scatter_indices = tf.tensor_scatter_nd_update(scatter_indices,
   3467                                 [[2]], [path_indices])
   3468         # [num_valid_paths, 3]
   3469         scatter_indices = tf.transpose(scatter_indices, [1,0])
   3470 
   3471         # Mask of valid paths
   3472         # [num_targets, num_sources, max_num_paths]
   3473         opt_prefix_mask = tf.fill([max_depth, num_targets, num_sources,
   3474                                    max_num_paths], False)
   3475         # Locations of the interactions
   3476         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   3477         opt_hit_points = tf.zeros([max_depth, num_targets, num_sources,
   3478                                    max_num_paths, 3], dtype=self._rdtype)
   3479         # [max_depth, num_targets, num_sources, max_num_paths]
   3480         opt_candidates = tf.fill([max_depth, num_targets, num_sources,
   3481                                   max_num_paths], -1)
   3482 
   3483         if max_depth > 0:
   3484 
   3485             for depth in tf.range(max_depth, dtype=tf.int64):
   3486 
   3487                 # Indices for storing the valid items for this depth
   3488                 scatter_indices_ = tf.pad(scatter_indices, [[0,0], [1,0]],
   3489                                 mode='CONSTANT', constant_values=depth)
   3490 
   3491                 # Prefix mask
   3492                 # [num_targets, num_sources, num_samples]
   3493                 prefix_mask_ = tf.gather(prefix_mask, depth, axis=0)
   3494                 # [num_valid_paths, 3]
   3495                 prefix_mask_ = tf.gather_nd(prefix_mask_, gather_indices)
   3496                 # Store the valid intersection points
   3497                 # [max_depth, num_targets, num_sources, max_num_paths]
   3498                 opt_prefix_mask = tf.tensor_scatter_nd_update(opt_prefix_mask,
   3499                                                 scatter_indices_, prefix_mask_)
   3500 
   3501                 # Location of the interactions
   3502                 # [num_targets, num_sources, num_samples, 3]
   3503                 hit_points_ = tf.gather(hit_points, depth, axis=0)
   3504                 # [num_valid_paths, 3]
   3505                 hit_points_ = tf.gather_nd(hit_points_, gather_indices)
   3506                 # Store the valid intersection points
   3507                 # [max_depth, num_targets, num_sources, max_num_paths, 3]
   3508                 opt_hit_points = tf.tensor_scatter_nd_update(opt_hit_points,
   3509                                                 scatter_indices_, hit_points_)
   3510 
   3511                 # Intersected primitives
   3512                 # [num_targets, num_sources, num_samples]
   3513                 candidates_ = tf.gather(candidates, depth, axis=0)
   3514                 # [num_valid_paths, 3]
   3515                 candidates_ = tf.gather_nd(candidates_, gather_indices)
   3516                 # Store the valid intersection points
   3517                 # [max_depth, num_targets, num_sources, max_num_paths]
   3518                 opt_candidates = tf.tensor_scatter_nd_update(opt_candidates,
   3519                                                 scatter_indices_, candidates_)
   3520 
   3521         # Gather normals to the intersected primitives
   3522         # Note: They are not oriented in the direction of the incoming wave.
   3523         # This is done later.
   3524         # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   3525         # This makes no difference on the resulting paths as such paths
   3526         # are not flagged as active.
   3527         # [max_depth, num_targets, num_sources, max_num_paths]
   3528         opt_candidates_ = tf.where(opt_candidates == -1, 0, opt_candidates)
   3529         # [max_depth, num_targets, num_sources, num_paths, 3]
   3530         normals = tf.gather(self._normals, opt_candidates_)
   3531 
   3532         # Map primitives to the corresponding objects
   3533         # Add a dummy entry to primitives_2_objects with value -1.
   3534         # [num_samples + 1]
   3535         primitives_2_objects = tf.pad(self._primitives_2_objects, [[0,1]],
   3536                                         constant_values=-1)
   3537         # Replace all -1 by num_samples
   3538         num_primitives = tf.shape(self._primitives_2_objects)[0]
   3539         # [max_depth, num_targets, num_sources, max_num_paths]
   3540         opt_candidates_ = tf.where(opt_candidates == -1, num_primitives,
   3541                                    opt_candidates)
   3542         # [max_depth, num_targets, num_sources, max_num_paths]
   3543         objects = tf.gather(primitives_2_objects, opt_candidates_)
   3544 
   3545         # Create and return the the objects storing the scattered paths
   3546         paths = Paths(sources=sources,
   3547                       targets=targets,
   3548                       scene=self._scene,
   3549                       types=Paths.SCATTERED)
   3550         paths.vertices = opt_hit_points
   3551         paths.objects = objects
   3552 
   3553         paths_tmp = PathsTmpData(sources, targets, self._dtype)
   3554         paths_tmp.normals = normals
   3555         paths_tmp.scat_prefix_mask = opt_prefix_mask
   3556         paths_tmp.scat_prefix_k_s,_ = normalize(targets_ - opt_hit_points)
   3557 
   3558         return paths, paths_tmp
   3559 
   3560     def _scat_discard_crossing_paths(self, paths, paths_tmp, scat_keep_prob):
   3561         r"""
   3562         Discards paths:
   3563 
   3564         - for which the scattered ray is crossing the intersected
   3565         primitive, and
   3566 
   3567         - randomly with probability `` 1 - scat_keep_prob``.
   3568 
   3569         Input
   3570         ------
   3571         paths : :class:`~sionna.rt.Paths`
   3572             Structure storing the scattered paths.
   3573 
   3574         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   3575             Addtional quantities required for paths computation
   3576 
   3577         scat_keep_prob : tf.float
   3578             Probablity of keeping a valid scattered paths.
   3579             Must be in )0,1).
   3580 
   3581         Output
   3582         -------
   3583         paths : :class:`~sionna.rt.Paths`
   3584             Updates paths.
   3585 
   3586         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   3587             Updated addtional quantities required for paths computation
   3588         """
   3589 
   3590         theta_t = paths.theta_t
   3591         phi_t = paths.phi_t
   3592         objects = paths.objects
   3593         vertices = paths.vertices
   3594 
   3595         normals = paths_tmp.normals
   3596         mask = paths_tmp.scat_prefix_mask
   3597         k_i = paths_tmp.k_i
   3598         k_r = paths_tmp.k_r
   3599         k_s = paths_tmp.scat_prefix_k_s
   3600         total_distance = paths_tmp.total_distance
   3601 
   3602         max_depth = tf.shape(vertices)[0]
   3603 
   3604         # Ensure the normals point in the same direction as -k_i
   3605         # [max_depth, num_targets, num_sources, max_num_paths, 1]
   3606         s = -tf.math.sign(dot(k_i[:max_depth], normals, keepdim=True))
   3607         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   3608         normals = normals * s
   3609 
   3610         # Mask paths for which k_s does not point in the same direction as the
   3611         # normal
   3612         # [max_depth, num_targets, num_sources, max_num_paths]
   3613         same_side = dot(normals, k_s) > SolverBase.EPSILON
   3614         # [max_depth, num_targets, num_sources, max_num_paths]
   3615         mask = tf.logical_and(mask, same_side)
   3616 
   3617         # Keep valid path with probability `scat_keep_prob`
   3618         # [max_depth, num_targets, num_sources, max_num_paths]
   3619         random_mask = config.tf_rng.uniform(tf.shape(mask), 0., 1.,
   3620                                             self._rdtype)
   3621         # [max_depth, num_targets, num_sources, max_num_paths]
   3622         random_mask = tf.less(random_mask, scat_keep_prob)
   3623         # [max_depth, num_targets, num_sources, max_num_paths]
   3624         mask = tf.logical_and(mask, random_mask)
   3625 
   3626         # Discard paths invalid for all links
   3627         valid_indices = tf.where(tf.reduce_any(mask, axis=(0,1,2)))[:,0]
   3628         # [num_targets, num_sources, max_num_paths]]
   3629         theta_t = tf.gather(theta_t, valid_indices, axis=2)
   3630         phi_t = tf.gather(phi_t, valid_indices, axis=2)
   3631         # [max_depth, num_targets, num_sources, max_num_paths]
   3632         objects = tf.gather(objects, valid_indices, axis=3)
   3633         mask = tf.gather(mask, valid_indices, axis=3)
   3634         total_distance = tf.gather(total_distance, valid_indices, axis=3)
   3635         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   3636         normals = tf.gather(normals, valid_indices, axis=3)
   3637         k_r = tf.gather(k_r, valid_indices, axis=3)
   3638         k_i = tf.gather(k_i, valid_indices, axis=3)
   3639         vertices = tf.gather(vertices, valid_indices, axis=3)
   3640 
   3641         paths.theta_t = theta_t
   3642         paths.phi_t = phi_t
   3643         paths.vertices = vertices
   3644         paths.objects = objects
   3645 
   3646         paths_tmp.scat_prefix_mask = mask
   3647         paths_tmp.k_i = k_i
   3648         paths_tmp.k_r = k_r
   3649         paths_tmp.k_tx = k_i[0]
   3650         paths_tmp.total_distance = total_distance
   3651         paths_tmp.normals = normals
   3652 
   3653         return paths, paths_tmp
   3654 
   3655     def _scat_prefixes_2_paths(self, paths, paths_tmp):
   3656         """
   3657         Extracts valid prefixes as invidual paths.
   3658 
   3659         Input
   3660         ------
   3661         paths : :class:`~sionna.rt.Paths`
   3662             Structure storing the scattered paths.
   3663 
   3664         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   3665             Addtional quantities required for paths computation
   3666 
   3667         Output
   3668         -------
   3669         paths : :class:`~sionna.rt.Paths`
   3670             Updates paths.
   3671 
   3672         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   3673             Updated addtional quantities required for paths computation
   3674         """
   3675 
   3676         # [max_depth, num_targets, num_sources, max_num_paths]
   3677         prefix_mask = paths_tmp.scat_prefix_mask
   3678         prefix_mask_int = tf.cast(prefix_mask, tf.int32)
   3679         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   3680         prefix_vertices = paths.vertices
   3681         # [max_depth, num_targets, num_sources, max_num_paths]
   3682         prefix_objects = paths.objects
   3683         # [num_targets, num_sources, max_num_paths]
   3684         prefix_theta_t = paths.theta_t
   3685         # [num_targets, num_sources, max_num_paths]
   3686         prefix_phi_t = paths.phi_t
   3687         # [max_depth, num_targets, num_sources, num_paths, 3]
   3688         prefix_normals = paths_tmp.normals
   3689         # [max_depth + 1, num_targets, num_sources, num_paths, 3]
   3690         prefix_k_i = paths_tmp.k_i
   3691         # [max_depth, num_targets, num_sources, num_paths, 3]
   3692         prefix_k_r = paths_tmp.k_r
   3693         # [max_depth, num_targets, num_sources, num_paths]
   3694         prefix_distances = paths_tmp.total_distance
   3695         # [num_targets, num_sources, max_num_paths, 3]
   3696         prefix_ktx = paths_tmp.k_tx
   3697 
   3698         max_depth = tf.shape(prefix_mask)[0]
   3699         max_depth64 = tf.cast(max_depth, tf.int64)
   3700         num_targets = tf.shape(prefix_mask)[1]
   3701         num_sources = tf.shape(prefix_mask)[2]
   3702 
   3703         # Number of paths for each link and depth
   3704         # [max_depth, num_targets, num_sources]
   3705         paths_count = tf.reduce_sum(prefix_mask_int, axis=3)
   3706         # Maximum number of paths for each depth over all the links
   3707         # [max_depth]
   3708         path_count_depth = tf.reduce_max(paths_count, axis=(1,2))
   3709         # Upper bound on the total number of paths
   3710         # ()
   3711         max_num_paths = tf.reduce_sum(path_count_depth)
   3712 
   3713         # [num_valid_paths, 4]
   3714         gather_indices = tf.where(prefix_mask)
   3715         # To build the indices of the paths in the tensor with optimized size,
   3716         # the path dimension is indexed by counting the valid path in the order
   3717         # in which they appear
   3718         # [max_depth, num_targets, num_sources, num_paths]
   3719         path_indices = tf.cumsum(prefix_mask_int, axis=-1)
   3720         # [num_valid_paths, 4]
   3721         path_indices = tf.gather_nd(path_indices, gather_indices) - 1
   3722         scatter_indices = tf.transpose(gather_indices, [1,0])
   3723         if not tf.size(scatter_indices) == 0:
   3724             scatter_indices = tf.tensor_scatter_nd_update(scatter_indices,
   3725                                 [[3]], [path_indices])
   3726         # [num_valid_paths, 3]
   3727         scatter_indices = tf.transpose(scatter_indices, [1,0])
   3728 
   3729         # Create the final tensors to update
   3730         # [num_targets, num_sources, max_num_paths]
   3731         mask = tf.fill([num_targets, num_sources, max_num_paths], False)
   3732         # This tensor is created transposed as paths
   3733         # are added with all the objects hit along the paths.
   3734         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3735         vertices = tf.zeros([num_targets, num_sources, max_num_paths, max_depth,
   3736                              3], self._rdtype)
   3737         # Last vertices that were hit
   3738         # [num_targets, num_sources, max_num_paths, 3]
   3739         last_vertices = tf.zeros([num_targets, num_sources, max_num_paths, 3],
   3740                                  self._rdtype)
   3741         # Objects that were hit. This tensor is created transposed as paths
   3742         # are added with all the objects hit along the paths.
   3743         # [num_targets, num_sources, max_num_paths, max_depth]
   3744         objects = tf.fill([num_targets, num_sources, max_num_paths, max_depth],
   3745                           -1)
   3746         # Last objects that were hit
   3747         # [num_targets, num_sources, max_num_paths]
   3748         last_objects = tf.fill([num_targets, num_sources, max_num_paths], -1)
   3749         # Angles of departure
   3750         # [num_targets, num_sources, max_num_paths]
   3751         theta_t = tf.zeros([num_targets, num_sources, max_num_paths],
   3752                            self._rdtype)
   3753         # [num_targets, num_sources, max_num_paths]
   3754         phi_t = tf.zeros([num_targets, num_sources, max_num_paths],
   3755                          self._rdtype)
   3756         # Normal to the last intersected objects
   3757         # [num_targets, num_sources, max_num_paths, 3]
   3758         last_normals = tf.zeros([num_targets, num_sources, max_num_paths, 3],
   3759                                 self._rdtype)
   3760         # Direction of incidence at the last interaction point
   3761         # [num_targets, num_sources, max_num_paths, 3]
   3762         last_k_i = tf.zeros([num_targets, num_sources, max_num_paths, 3],
   3763                                 self._rdtype)
   3764         # Distance from the sources to the last interaction point
   3765         # [num_targets, num_sources, max_num_paths]
   3766         last_distance = tf.zeros([num_targets, num_sources, max_num_paths],
   3767                                  self._rdtype)
   3768         # [num_targets, num_sources, max_num_paths, 3]
   3769         k_tx = tf.zeros([num_targets, num_sources, max_num_paths, 3],
   3770                         self._rdtype)
   3771         # Normals at the intersection points
   3772         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3773         normals = tf.zeros([num_targets, num_sources, max_num_paths,
   3774                             max_depth, 3], self._rdtype)
   3775         # Direction of reflection at intersection points
   3776         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3777         k_r = tf.zeros([num_targets, num_sources, max_num_paths, max_depth, 3],
   3778                        self._rdtype)
   3779         # Direction of incidence at intersection points
   3780         # [num_targets, num_sources, max_num_paths, max_depth+1, 3]
   3781         k_i = tf.zeros([num_targets, num_sources, max_num_paths, max_depth+1,3],
   3782                        self._rdtype)
   3783 
   3784         # Need to transpose these tensors in order to gather paths from them
   3785         # with all the interactions, i.e., extract the entire "max_depth"
   3786         # dimension.
   3787         # [max_depth, num_targets, num_sources, max_num_paths]
   3788         prefix_objects_tp = tf.transpose(prefix_objects, [1,2,3,0])
   3789         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3790         prefix_vertices_tp = tf.transpose(prefix_vertices, [1,2,3,0,4])
   3791         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3792         normals_tp = tf.transpose(prefix_normals, [1,2,3,0,4])
   3793         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3794         k_i_tp = tf.transpose(prefix_k_i, [1,2,3,0,4])
   3795         # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3796         k_r_tp = tf.transpose(prefix_k_r, [1,2,3,0,4])
   3797 
   3798         # We sequentially add the prefixes for each depth value.
   3799         # To avoid overwriting the paths scattered at the previous
   3800         # iterations, we incremdent the path index by the maximum number
   3801         # of paths over all links, cumulated over the iterations.
   3802         path_ind_incr = 0
   3803         for depth in tf.range(max_depth, dtype=tf.int64):
   3804             # Indices of valid paths with depth d
   3805             # [num_valid_paths with depth=depth, 4]
   3806             gather_indices_ = tf.gather(gather_indices,
   3807                                 tf.where(gather_indices[:,0] == depth)[:,0],
   3808                                 axis=0)
   3809             # Depth is not needed for some tensors
   3810             # [num_valid_paths with depth=depth, 3]
   3811             gather_indices_nd_ = gather_indices_[:,1:]
   3812 
   3813             # Indices for scattering the results in the target tensor
   3814             # [num_valid_paths with depth=depth, 4]
   3815             scatter_indices_ = tf.gather(scatter_indices,
   3816                                 tf.where(scatter_indices[:,0] == depth)[:,0],
   3817                                 axis=0)
   3818 
   3819 
   3820             # [1, 4]
   3821             path_ind_incr_ = tf.cast([0, 0, 0, path_ind_incr], tf.int64)
   3822             # [num_valid_paths with depth=depth, 4]
   3823             scatter_indices_ = scatter_indices_ + path_ind_incr_
   3824             # Depth is not needed for some tensors
   3825             # [num_valid_paths with depth=depth, 3]
   3826             scatter_indices_nd_ = scatter_indices_[:,1:]
   3827             # Prepare for next iteration
   3828             path_ind_incr = path_ind_incr + path_count_depth[depth]
   3829 
   3830             # Update the tensors
   3831 
   3832             # Mask
   3833             # [num_valid_paths with depth=depth]
   3834             prefix_mask_ = tf.fill([tf.shape(scatter_indices_nd_)[0]], True)
   3835             # [num_targets, num_sources, max_num_paths]
   3836             mask = tf.tensor_scatter_nd_update(mask, scatter_indices_nd_,
   3837                                                prefix_mask_)
   3838 
   3839             # Vertices
   3840             prefix_vertices_ = tf.gather_nd(prefix_vertices_tp,
   3841                                             gather_indices_nd_)
   3842             # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3843             vertices = tf.tensor_scatter_nd_update(vertices,
   3844                                                    scatter_indices_nd_,
   3845                                                    prefix_vertices_)
   3846 
   3847             # Last vertex
   3848             prefix_vertex = tf.gather_nd(prefix_vertices, gather_indices_)
   3849             # [num_targets, num_sources, max_num_paths, 3]
   3850             last_vertices = tf.tensor_scatter_nd_update(last_vertices,
   3851                                                         scatter_indices_nd_,
   3852                                                         prefix_vertex)
   3853 
   3854             # Objects
   3855             # [num_paths, max_depth]
   3856             objects_ = tf.gather_nd(prefix_objects_tp, gather_indices_nd_)
   3857             # Only keep the prefix of length depth
   3858             # [num_paths, depth]
   3859             objects_ = objects_[:,:depth+1]
   3860             # [num_paths, max_depth]
   3861             objects_ = tf.pad(objects_, [[0,0],[0,max_depth64-depth-1]],
   3862                               constant_values=-1)
   3863             # [num_targets, num_sources, max_num_paths, max_depth]
   3864             objects = tf.tensor_scatter_nd_update(objects, scatter_indices_nd_,
   3865                                                   objects_)
   3866 
   3867             # Normals at intersection points
   3868             normals_ = tf.gather_nd(normals_tp, gather_indices_nd_)
   3869             # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3870             normals = tf.tensor_scatter_nd_update(normals, scatter_indices_nd_,
   3871                                                   normals_)
   3872 
   3873             # Direction of incidence at intersection points
   3874             k_i_ = tf.gather_nd(k_i_tp, gather_indices_nd_)
   3875             # [num_targets, num_sources, max_num_paths, max_depth+1, 3]
   3876             k_i = tf.tensor_scatter_nd_update(k_i, scatter_indices_nd_, k_i_)
   3877 
   3878             # Direction of reflection at intersection points
   3879             k_r_ = tf.gather_nd(k_r_tp, gather_indices_nd_)
   3880             # [num_targets, num_sources, max_num_paths, max_depth, 3]
   3881             k_r = tf.tensor_scatter_nd_update(k_r, scatter_indices_nd_, k_r_)
   3882 
   3883             # Last hit objects
   3884             objects_ = tf.gather_nd(prefix_objects, gather_indices_)
   3885             # [num_targets, num_sources, max_num_paths]
   3886             last_objects = tf.tensor_scatter_nd_update(last_objects,
   3887                                                        scatter_indices_nd_,
   3888                                                        objects_)
   3889 
   3890             # Azimuth of departure
   3891             phi_t_ = tf.gather_nd(prefix_phi_t, gather_indices_nd_)
   3892             # [num_targets, num_sources, max_num_paths]
   3893             phi_t = tf.tensor_scatter_nd_update(phi_t, scatter_indices_nd_,
   3894                                                 phi_t_)
   3895 
   3896             # Elevation of departure
   3897             theta_t_ = tf.gather_nd(prefix_theta_t, gather_indices_nd_)
   3898             # [num_targets, num_sources, max_num_paths]
   3899             theta_t = tf.tensor_scatter_nd_update(theta_t, scatter_indices_nd_,
   3900                                                   theta_t_)
   3901 
   3902             # Normals at the last intersected object
   3903             normals_ = tf.gather_nd(prefix_normals, gather_indices_)
   3904             # [num_targets, num_sources, max_num_paths, 3]
   3905             last_normals = tf.tensor_scatter_nd_update(last_normals,
   3906                                                 scatter_indices_nd_, normals_)
   3907 
   3908             # Direction of incidence at the last interaction point
   3909             k_i_ = tf.gather_nd(prefix_k_i, gather_indices_)
   3910             # [num_targets, num_sources, max_num_paths, 3]
   3911             last_k_i = tf.tensor_scatter_nd_update(last_k_i,
   3912                                                    scatter_indices_nd_,
   3913                                                    k_i_)
   3914 
   3915             # Distance from the sources to the last interaction point
   3916             last_dist_ = tf.gather_nd(prefix_distances, gather_indices_)
   3917             # [num_targets, num_sources, max_num_paths]
   3918             last_distance = tf.tensor_scatter_nd_update(last_distance,
   3919                                                         scatter_indices_nd_,
   3920                                                         last_dist_)
   3921 
   3922             # Direction of tx
   3923             k_tx_ = tf.gather_nd(prefix_ktx, gather_indices_nd_)
   3924             # [num_targets, num_sources, max_num_paths, 3]
   3925             k_tx = tf.tensor_scatter_nd_update(k_tx, scatter_indices_nd_, k_tx_)
   3926 
   3927         # Computes the angles of arrivals, direction of the scattered field,
   3928         # and distance from the scattering point to the targets
   3929         # [num_targets, 3]
   3930         targets = paths.targets
   3931         # [num_targets, 1, 1, 3]
   3932         targets = insert_dims(targets, 2, 1)
   3933         # k_s : [num_targets, num_sources, max_num_paths, 3]
   3934         # scat_2_target_dist : [num_targets, num_sources, max_num_paths]
   3935         k_s,scat_2_target_dist = normalize(targets - last_vertices)
   3936         # Angles of arrivales
   3937         # theta_r, phi_r : [num_targets, num_sources, max_num_paths]
   3938         theta_r, phi_r = theta_phi_from_unit_vec(-k_s)
   3939         # Compute the delays
   3940         # [num_targets, num_sources, max_num_paths]
   3941         tau = (last_distance + scat_2_target_dist)/SPEED_OF_LIGHT
   3942         # [num_targets, num_sources, max_num_paths]
   3943         tau = tf.where(mask, tau, -tf.ones_like(tau))
   3944         # [max_depth, num_targets, num_sources, max_num_paths]
   3945         objects = tf.transpose(objects, [3, 0, 1, 2])
   3946         vertices = tf.transpose(vertices, [3, 0, 1, 2, 4])
   3947         normals = tf.transpose(normals, [3, 0, 1, 2, 4])
   3948         k_i = tf.transpose(k_i, [3, 0, 1, 2, 4])
   3949         k_r = tf.transpose(k_r, [3, 0, 1, 2, 4])
   3950 
   3951         paths.mask = mask
   3952         paths.vertices = vertices
   3953         paths.objects = objects
   3954         paths.tau = tau
   3955         paths.phi_t = phi_t
   3956         paths.theta_t = theta_t
   3957         paths.phi_r = phi_r
   3958         paths.theta_r = theta_r
   3959 
   3960         paths_tmp.scat_last_objects = last_objects
   3961         paths_tmp.scat_last_normals = last_normals
   3962         paths_tmp.scat_last_k_i = last_k_i
   3963         paths_tmp.scat_last_vertices = last_vertices
   3964         paths_tmp.scat_src_2_last_int_dist = last_distance
   3965         paths_tmp.scat_k_s = k_s
   3966         paths_tmp.scat_2_target_dist = scat_2_target_dist
   3967         paths_tmp.k_tx = k_tx
   3968         paths_tmp.k_rx = -k_s
   3969         paths_tmp.normals = normals
   3970         paths_tmp.k_i = k_i
   3971         paths_tmp.k_r = k_r
   3972         paths_tmp.total_distance = scat_2_target_dist + last_distance
   3973 
   3974         return paths, paths_tmp
   3975 
   3976     ##################################################################
   3977     # Methods used for computing paths involving RIS
   3978     ##################################################################
   3979 
   3980     def _ris_paths(self, paths, paths_tmp, ris_objects):
   3981         sources = paths.sources
   3982         targets = paths.targets
   3983         num_sources = tf.shape(sources)[0]
   3984         num_targets = tf.shape(targets)[0]
   3985 
   3986         # Concatenate the cell positions of all RIS
   3987         # [num_ris*num_cells, 3]
   3988         cells = [r.cell_world_positions for r in self._scene.ris.values()]
   3989         cells = tf.concat(cells, axis=0)
   3990 
   3991         # Broadcast cell positions to vertices
   3992         # [max_depths=1, num_targets, num_sources, max_num_paths=num_cells, 3]
   3993         vertices = tf.reshape(cells, [1, 1, 1, -1, 3])
   3994         vertices = tf.repeat(vertices, num_targets, axis=1)
   3995         vertices = tf.repeat(vertices, num_sources, axis=2)
   3996         paths.vertices = vertices
   3997 
   3998         # Create object tensor
   3999         objects = []
   4000         for obj in self._scene.ris.values():
   4001             objects.extend([obj.object_id]*obj.num_cells)
   4002         objects = tf.cast(objects, tf.int32)
   4003         objects = tf.reshape(objects, [1, 1, 1, -1])
   4004         objects = tf.repeat(objects, num_targets, axis=1)
   4005         objects = tf.repeat(objects, num_sources, axis=2)
   4006         paths.objects = objects
   4007 
   4008         # Compute directions, angles, etc
   4009         paths, paths_tmp = self._compute_directions_distances_delays_angles(
   4010                                                     paths, paths_tmp, False)
   4011 
   4012         # Compute TX-RIS and RIS-RX rays
   4013         # Directions
   4014         # [num_targets, num_sources, max_num_paths, 3]
   4015         d_tx_ris = paths_tmp.k_tx
   4016         d_ris_rx = -paths_tmp.k_rx
   4017 
   4018         # Lengths
   4019         # [num_targets, num_sources, max_num_paths]
   4020         maxt_tx_ris = paths_tmp.distances[0]
   4021         maxt_ris_rx = paths_tmp.distances[1]
   4022 
   4023         # Origins
   4024         # [num_targets, num_sources, max_num_paths, 3]
   4025         o_tx_ris = tf.expand_dims(tf.expand_dims(sources, axis=0), axis=2)
   4026         o_tx_ris = tf.broadcast_to(o_tx_ris, d_tx_ris.shape)
   4027         o_ris_rx = vertices[0]
   4028 
   4029         # Test obstruction of rays
   4030         mask_tx_ris = self._test_obstruction(tf.reshape(o_tx_ris, [-1,3]),
   4031                                              tf.reshape(d_tx_ris, [-1,3]),
   4032                                              tf.reshape(maxt_tx_ris, [-1]),
   4033                                              ris_objects)
   4034 
   4035         mask_ris_rx = self._test_obstruction(tf.reshape(o_ris_rx, [-1,3]),
   4036                                              tf.reshape(d_ris_rx, [-1,3]),
   4037                                              tf.reshape(maxt_ris_rx, [-1]),
   4038                                              ris_objects)
   4039 
   4040         mask_ris = tf.logical_or(mask_tx_ris, mask_ris_rx)
   4041         mask_ris = tf.reshape(mask_ris, [num_targets, num_sources, -1])
   4042         mask_ris = tf.logical_not(mask_ris)
   4043 
   4044         # Only consider paths that have a positive angle with the RIS normal
   4045         # Create tensor with RIS normals for all paths
   4046         n_hat = []
   4047         for r in self._scene.ris.values():
   4048             n_hat.append(tf.repeat(r.world_normal[tf.newaxis,...],
   4049                                    r.num_cells, axis=0))
   4050         n_hat = tf.concat(n_hat, axis=0)
   4051         n_hat = n_hat[tf.newaxis, tf.newaxis,...]
   4052 
   4053         # Compute dot products between RIS normals and incoming/outgoing rays
   4054         cos_theta_i = dot(-d_tx_ris, n_hat)
   4055         cos_theta_m = dot(d_ris_rx, n_hat)
   4056 
   4057         # Store dot products for later field computation
   4058         paths_tmp.cos_theta_i = cos_theta_i
   4059         paths_tmp.cos_theta_m = cos_theta_m
   4060 
   4061         # Only keep paths with positive dot products
   4062         mask_ris = tf.logical_and(mask_ris, tf.greater(cos_theta_i, 0.))
   4063         mask_ris = tf.logical_and(mask_ris, tf.greater(cos_theta_m, 0.))
   4064         paths.mask = mask_ris
   4065         paths.targets_sources_mask = paths.mask
   4066 
   4067         # Set delays to -1 for masked paths
   4068         paths.tau = tf.where(mask_ris, paths.tau, tf.cast(-1, self._rdtype))
   4069 
   4070         return paths, paths_tmp
   4071 
   4072     def _ris_transition_matrices(self, ris_paths, ris_paths_tmp):
   4073 
   4074         # Compute spatial modulation coefficients for all RIS
   4075         sc = [tf.reduce_sum(r(), axis=0) for r in self._scene.ris.values()]
   4076         sc = tf.concat(sc, axis=0)
   4077         sc = sc[tf.newaxis, tf.newaxis,...]
   4078         coef = (1+ris_paths_tmp.cos_theta_i)*(1+ris_paths_tmp.cos_theta_m)
   4079         coef *= tf.cast(3*self._scene.wavelength/16/PI, self._rdtype)
   4080         coef /= tf.reduce_prod(ris_paths_tmp.distances, axis=0)
   4081         coef = tf.complex(coef, tf.cast(0, self._rdtype))
   4082         coef *= sc
   4083 
   4084         # Set coefficients of masked paths to zero
   4085         coef = tf.where(ris_paths.mask, coef, tf.cast(0, coef.dtype))
   4086 
   4087         # Create transition matrices from coefficients
   4088         # We assume here that the polarization remains unchanged, i.e.,
   4089         # The incoming field is already decomposed in theta/phi components
   4090         # and the outgoing field is represented in theta/phi components
   4091         coef = coef[...,tf.newaxis,tf.newaxis]
   4092         ris_mat_t = coef*tf.eye(2, batch_shape=[1,1,1], dtype=self._dtype)
   4093 
   4094         return ris_mat_t
   4095 
   4096 
   4097 
   4098     ##################################################################
   4099     # Utilities
   4100     ##################################################################
   4101 
   4102     def _compute_directions_distances_delays_angles(self, paths, paths_tmp,
   4103                                                     scattering):
   4104         # pylint: disable=line-too-long
   4105         r"""
   4106         Computes:
   4107         - The direction of incidence and departure at every interaction points
   4108         ``k_i`` and ``k_r``
   4109         - The length of each path segment ``distances``
   4110         - The delays of each path
   4111         - The angles of departure (``theta_t``, ``phi_t``) and arrival
   4112         (``theta_r``, ``phi_r``)
   4113 
   4114         Input
   4115         ------
   4116         paths : :class:`~sionna.rt.Paths`
   4117             Paths to update
   4118 
   4119         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   4120             Addtional quantities required for paths computation
   4121 
   4122         scattering : bool
   4123             Set to `True` computing the scattered paths.
   4124 
   4125         Output
   4126         -------
   4127         paths : :class:`~sionna.rt.Paths`
   4128             Updated paths
   4129 
   4130         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   4131             Updated addtional quantities required for paths computation
   4132         """
   4133 
   4134         objects = paths.objects
   4135         vertices = paths.vertices
   4136         sources = paths.sources
   4137         targets = paths.targets
   4138         if scattering:
   4139             mask = paths_tmp.scat_prefix_mask
   4140         else:
   4141             mask = paths.mask
   4142 
   4143         # Maximum depth
   4144         max_depth = tf.shape(vertices)[0]
   4145 
   4146         # Flag that indicates if a ray is valid
   4147         # [max_depth, num_targets, num_sources, max_num_paths]
   4148         valid_ray = tf.not_equal(objects, -1)
   4149 
   4150         # Vertices updated with the sources and targets
   4151         # [1, num_sources, 1, 3]
   4152         sources = tf.expand_dims(tf.expand_dims(sources, axis=0), axis=2)
   4153         # [num_targets, num_sources, max_num_paths, 3]
   4154         sources = tf.broadcast_to(sources, tf.shape(vertices)[1:])
   4155         # [1, num_targets, num_sources, max_num_paths, 3]
   4156         sources = tf.expand_dims(sources, axis=0)
   4157         # [1 + max_depth, num_targets, num_sources, max_num_paths, 3]
   4158         vertices = tf.concat([sources, vertices], axis=0)
   4159         # For the targets, we need to account for the paths having different
   4160         # depths.
   4161         # Pad vertices with dummy values to create the required extra depth
   4162         # [1 + max_depth + 1, num_targets, num_sources, max_num_paths, 3]
   4163         vertices = tf.pad(vertices, [[0,1],[0,0],[0,0],[0,0],[0,0]])
   4164         # [num_targets, 1, 1, 3]
   4165         targets = tf.expand_dims(tf.expand_dims(targets, axis=1), axis=2)
   4166         # [num_targets, num_sources, max_num_paths, 3]
   4167         targets = tf.broadcast_to(targets, tf.shape(vertices)[1:])
   4168 
   4169         #  [max_depth, num_targets, num_sources, max_num_paths]
   4170         target_indices = tf.cast(valid_ray, tf.int64)
   4171         #  [num_targets, num_sources, max_num_paths]
   4172         target_indices = tf.reduce_sum(target_indices, axis=0) + 1
   4173         # [num_targets*num_sources*max_num_paths]
   4174         target_indices = tf.reshape(target_indices, [-1,1])
   4175         # Indices of all (target, source,paths) entries
   4176         # [num_targets*num_sources*max_num_paths, 3]
   4177         target_indices_ = tf.where(tf.fill(tf.shape(vertices)[1:4], True))
   4178         # Indices of all entries in vertices
   4179         # [num_targets*num_sources*max_num_paths, 4]
   4180         target_indices = tf.concat([target_indices, target_indices_], axis=1)
   4181         # Reshape targets
   4182         # vertices : [max_depth + 1, num_targets, num_sources, max_num_paths, 3]
   4183         targets = tf.reshape(targets, [-1,3])
   4184         vertices = tf.tensor_scatter_nd_update(vertices, target_indices,
   4185                                                 targets)
   4186         # Direction of arrivals (k_i)
   4187         # The last item (k_i[max_depth]) correspond to the direction of arrival
   4188         # at the target. Therefore, k_i is a tensor of length `max_depth + 1`,
   4189         # where `max_depth` is the number of maximum interaction (which could be
   4190         # zero if only LoS is requested).
   4191         # k_i : [max_depth + 1, num_targets, num_sources, max_num_paths, 3]
   4192         # ray_lengths : [max_depth + 1, num_targets, num_sources, max_num_paths]
   4193         k_i = tf.roll(vertices, -1, axis=0) - vertices
   4194         k_i,ray_lengths = normalize(k_i)
   4195         k_i = k_i[:max_depth+1]
   4196         ray_lengths = ray_lengths[:max_depth+1]
   4197 
   4198         # Direction of departures (k_r) at interaction points.
   4199         # We do not need the direction of departure at the source, as it
   4200         # is the same as k_i[0]. Therefore `k_r` only stores the directions of
   4201         # departures at the `max_depth` interaction points.
   4202         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   4203         k_r = tf.roll(vertices, -2, axis=0) - tf.roll(vertices, -1, axis=0)
   4204         k_r,_ = normalize(k_r)
   4205         k_r = k_r[:max_depth]
   4206 
   4207         # Compute the distances
   4208         # [max_depth, num_targets, num_sources, max_num_paths]
   4209         lengths_mask = tf.cast(valid_ray, self._rdtype)
   4210         # First ray is always valid (LoS)
   4211         # [1 + max_depth, num_targets, num_sources, max_num_paths]
   4212         lengths_mask = tf.pad(lengths_mask, [[1,0],[0,0],[0,0],[0,0]],
   4213                                 constant_values=tf.ones((),self._rdtype))
   4214         # Compute path distance
   4215         # [1 + max_depth, num_targets, num_sources, max_num_paths]
   4216         distances = lengths_mask*ray_lengths
   4217 
   4218         # Propagation delay [s]
   4219         # Total length of the paths
   4220         if scattering:
   4221             # Distances of every path prefix, not including the one connecting
   4222             # to the target
   4223             # [max_depth, num_targets, num_sources, max_num_paths]
   4224             total_distance = tf.cumsum(distances[:max_depth], axis=0)
   4225         else:
   4226             # [num_targets, num_sources, max_num_paths]
   4227             total_distance = tf.reduce_sum(distances, axis=0)
   4228             # [num_targets, num_sources, max_num_paths]
   4229             tau = total_distance / SPEED_OF_LIGHT
   4230 
   4231         # Compute angles of departures and arrival
   4232         # theta_t, phi_t: [num_targets, num_sources, max_num_paths]
   4233         theta_t, phi_t = theta_phi_from_unit_vec(k_i[0])
   4234         # In the case of scattering, the angles of arrival are not computed
   4235         # by this function
   4236         if not scattering:
   4237             # Depth of the rays
   4238             # [num_targets, num_sources, max_num_paths]
   4239             ray_depth = tf.reduce_sum(tf.cast(valid_ray, tf.int32), axis=0)
   4240             k_rx = -tf.gather(tf.transpose(k_i, [1,2,3,0,4]), ray_depth,
   4241                                     batch_dims=3, axis=3)
   4242             # theta_r, phi_r: [num_targets, num_sources, max_num_paths]
   4243             theta_r, phi_r = theta_phi_from_unit_vec(k_rx)
   4244 
   4245             if paths.types is not Paths.RIS:
   4246                 # Remove duplicated paths.
   4247                 # Paths intersecting an edge belonging to two different
   4248                 # triangles can be considered twice.
   4249                 # Note that this is rare, as intersections rarely occur on
   4250                 # edges.
   4251                 # Paths are considered different if they have different
   4252                 # angles of departure, angles of arrival, and total length.
   4253                 # [num_targets, num_sources, max_num_paths, 5]
   4254                 sim = tf.stack([theta_t, phi_t, theta_r, phi_r, total_distance],
   4255                                axis=3)
   4256                 # [num_targets, num_sources, max_num_paths, max_num_paths, 5]
   4257                 sim = tf.expand_dims(sim, axis=2) - tf.expand_dims(sim, axis=3)
   4258                 # [num_targets, num_sources, max_num_paths, max_num_paths]
   4259                 sim = tf.reduce_sum(tf.square(sim), axis=4)
   4260                 sim = tf.equal(sim, tf.zeros_like(sim))
   4261                 # Keep only the paths with no duplicates.
   4262                 # If many paths are identical, keep the one with the highest
   4263                 # index.
   4264                 # [num_targets, num_sources, max_num_paths, max_num_paths]
   4265                 sim = tf.logical_and(tf.linalg.band_part(sim, 0, -1),
   4266                                     ~tf.eye(tf.shape(sim)[-1],
   4267                                             dtype=tf.bool,
   4268                                             batch_shape=tf.shape(sim)[:2]))
   4269                 sim = tf.logical_and(sim, tf.expand_dims(mask, axis=-2))
   4270                 # [num_targets, num_sources, max_num_paths]
   4271                 uniques = tf.reduce_all(~sim, axis=3)
   4272                 # Keep only the unique paths
   4273                 # [num_targets, num_sources, max_num_paths]
   4274                 mask = tf.logical_and(uniques, mask)
   4275 
   4276                 # Setting -1 for delays corresponding to non-valid paths
   4277                 # [num_targets, num_sources, max_num_paths]
   4278                 tau = tf.where(mask, tau, -tf.ones_like(tau))
   4279 
   4280         # Updates the object storing the paths
   4281         if not scattering:
   4282             if paths.types is not Paths.RIS:
   4283                 paths.mask = mask
   4284             if paths.types is not Paths.RIS:
   4285                 paths.mask = mask
   4286             paths.tau = tau
   4287             # In the case of scattering, the angles of arrival are not computed
   4288             # by this function
   4289             paths.theta_r = theta_r
   4290             paths.phi_r = phi_r
   4291             paths_tmp.k_rx = k_rx
   4292         else:
   4293             paths_tmp.scat_prefix_mask = mask
   4294         paths.theta_t = theta_t
   4295         paths.phi_t = phi_t
   4296         paths_tmp.k_i = k_i
   4297         paths_tmp.k_r = k_r
   4298         paths_tmp.k_tx = k_i[0]
   4299         paths_tmp.total_distance = total_distance
   4300         if paths.types is Paths.RIS:
   4301             paths_tmp.distances = distances
   4302         if paths.types is Paths.RIS:
   4303             paths_tmp.distances = distances
   4304 
   4305         return paths, paths_tmp
   4306 
   4307     def _compute_doppler_shifts(self, paths, paths_tmp, velocity):
   4308         # pylint: disable=line-too-long
   4309         """
   4310         Computes the Doppler shift resulting from the movement
   4311         of objects in the scene for every path.
   4312 
   4313         The Doppler shift resulting from the movement of the
   4314         transmitter and receiver are added later when the function
   4315         :method:`~sionna.rt.Paths.apply_doppler` is called.
   4316 
   4317         Input
   4318         ------
   4319         paths : :class:`~sionna.rt.Paths`
   4320             Paths to update
   4321 
   4322         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   4323             Addtional quantities required for paths computation
   4324 
   4325         velocity : [num_shapes, 3]
   4326             Velocity vectors of all objects in the scene
   4327 
   4328         Output
   4329         ------
   4330         doppler : [num_targets, num_sources, max_num_paths]
   4331             Doppler shifts for all paths due to the movement ob objects
   4332         """
   4333 
   4334         # Compute Doppler shift for every path segment
   4335         # Difference of outgoing and incoming direction vectors for every
   4336         # intersection point
   4337         # k_diff : [max_depth, num_targets, num_sources, max_num_paths, 3]
   4338         k_diff = paths_tmp.k_i[1:]-paths_tmp.k_i[:-1]
   4339 
   4340         # Get velocity for all involved objects of each path
   4341         objects_mask = paths.objects==-1
   4342         if paths.types==2:
   4343             # For diffracted paths, path.objects indicates wedge ids
   4344             # that we need to convert to object ids. Since each wedge
   4345             # consists of two objects, we simply pick the first.
   4346             # This assumes that both objects move at the same speed,
   4347             # which is justified as the wedge would otherwise be destroyed
   4348             valid_wedges_idx = tf.where(objects_mask, 0, paths.objects)
   4349             valid_objects = tf.gather(self._wedges_objects[:,0],
   4350                                       valid_wedges_idx, axis=0)
   4351         else:
   4352             valid_objects = tf.where(objects_mask, 0, paths.objects)
   4353         # [max_depth, num_targets, num_sources, max_num_paths, 3]
   4354         velocity = tf.gather(velocity, valid_objects, axis=0)
   4355 
   4356         # Compute Doppler shift per path
   4357         #[num_targets, num_sources, max_num_paths]
   4358         doppler = tf.reduce_sum(velocity*k_diff, axis=-1)
   4359         doppler = tf.where(objects_mask, tf.constant(0, doppler.dtype), doppler)
   4360         doppler = tf.reduce_sum(doppler, axis=0)
   4361         doppler /= self._scene.wavelength
   4362         return doppler
   4363 
   4364     def _get_tx_rx_rotation_matrices(self):
   4365         r"""
   4366         Computes and returns the rotation matrices for rotating according to
   4367         the orientations of the transmitters and receivers rotation matrices,
   4368 
   4369         Output
   4370         -------
   4371         rx_rot_mat : [num_rx, 3, 3], tf.float
   4372             Matrices for rotating according to the receivers orientations
   4373 
   4374         tx_rot_mat : [num_tx, 3, 3], tf.float
   4375             Matrices for rotating according to the receivers orientations
   4376         """
   4377 
   4378         transmitters = self._scene.transmitters.values()
   4379         receivers = self._scene.receivers.values()
   4380 
   4381         # Rotation matrices for transmitters
   4382         # [num_tx, 3]
   4383         tx_orientations = [tx.orientation for tx in transmitters]
   4384         tx_orientations = tf.stack(tx_orientations, axis=0)
   4385         # [num_tx, 3, 3]
   4386         tx_rot_mat = rotation_matrix(tx_orientations)
   4387 
   4388         # Rotation matrices for receivers
   4389         # [num_rx, 3]
   4390         rx_orientations = [rx.orientation for rx in receivers]
   4391         rx_orientations = tf.stack(rx_orientations, axis=0)
   4392         # [num_rx, 3, 3]
   4393         rx_rot_mat = rotation_matrix(rx_orientations)
   4394 
   4395         return rx_rot_mat, tx_rot_mat
   4396 
   4397     def _get_antennas_relative_positions(self, rx_rot_mat, tx_rot_mat):
   4398         r"""
   4399         Returns the positions of the antennas of the transmitters and receivers.
   4400         The positions are relative to the center of the radio devices, but
   4401         rotated to the GCS.
   4402 
   4403         Input
   4404         ------
   4405         rx_rot_mat : [num_rx, 3, 3], tf.float
   4406             Matrices for rotating according to the receivers orientations
   4407 
   4408         tx_rot_mat : [num_tx, 3, 3], tf.float
   4409             Matrices for rotating according to the receivers orientations
   4410 
   4411         Output
   4412         -------
   4413         rx_rel_ant_pos: [num_rx, rx_array_size, 3], tf.float
   4414             Relative positions of the receivers antennas
   4415 
   4416         tx_rel_ant_pos: [num_tx, rx_array_size, 3], tf.float
   4417             Relative positions of the transmitters antennas
   4418         """
   4419 
   4420         # Rotated position of the TX and RX antenna elements
   4421         # [1, tx_array_size, 3]
   4422         tx_rel_ant_pos = tf.expand_dims(self._scene.tx_array.positions, axis=0)
   4423         # [num_tx, 1, 3, 3]
   4424         tx_rot_mat = tf.expand_dims(tx_rot_mat, axis=1)
   4425         # [num_tx, tx_array_size, 3]
   4426         tx_rel_ant_pos = tf.linalg.matvec(tx_rot_mat, tx_rel_ant_pos)
   4427 
   4428         # [1, rx_array_size, 3]
   4429         rx_rel_ant_pos = tf.expand_dims(self._scene.rx_array.positions, axis=0)
   4430         # [num_rx, 1, 3, 3]
   4431         rx_rot_mat = tf.expand_dims(rx_rot_mat, axis=1)
   4432         # [num_tx, tx_array_size, 3]
   4433         rx_rel_ant_pos = tf.linalg.matvec(rx_rot_mat, rx_rel_ant_pos)
   4434 
   4435         return rx_rel_ant_pos, tx_rel_ant_pos
   4436 
   4437     def _apply_synthetic_array(self, rx_rot_mat, tx_rot_mat, paths, paths_tmp):
   4438         # pylint: disable=line-too-long
   4439         r"""
   4440         Applies the phase shifts to simulate the effect of a synthetic array
   4441         on a planar wave
   4442 
   4443         Input
   4444         ------
   4445         rx_rot_mat : [num_rx, 3, 3], tf.float
   4446             Matrices for rotating according to the receivers orientations
   4447 
   4448         tx_rot_mat : [num_tx, 3, 3], tf.float
   4449             Matrices for rotating according to the receivers orientations
   4450 
   4451         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   4452             Addtional quantities required for paths computation
   4453 
   4454         Output
   4455         -------
   4456         paths : :class:`~sionna.rt.PathsTmpData`
   4457             Updated paths
   4458         """
   4459 
   4460         # [num_rx, num_rx_patterns, 1, num_tx, num_tx_patterns, 1,
   4461         #   max_num_paths]
   4462         a = paths.a
   4463         # [num_rx, num_tx, samples_per_tx, 3]
   4464         k_tx = paths_tmp.k_tx
   4465         # [num_tx, num_tx, samples_per_tx, 3]
   4466         k_rx = paths_tmp.k_rx
   4467 
   4468         two_pi = tf.cast(2.*PI, self._rdtype)
   4469 
   4470         # Relative positions of the antennas of the transmitters and receivers
   4471         # rx_rel_ant_pos: [num_rx, rx_array_size, 3], tf.float
   4472         #     Relative positions of the receivers antennas
   4473         # tx_rel_ant_pos: [num_tx, rx_array_size, 3], tf.float
   4474         #     Relative positions of the transmitters antennas
   4475         rx_rel_ant_pos, tx_rel_ant_pos =\
   4476             self._get_antennas_relative_positions(rx_rot_mat, tx_rot_mat)
   4477 
   4478         # Expand dims for broadcasting with antennas
   4479         # The receive vector is flipped as we need vectors that point away
   4480         # from the arrays.
   4481         # [num_rx, 1, 1, num_tx, 1, 1, max_num_paths, 3]
   4482         k_rx = insert_dims(insert_dims(k_rx, 2, 1), 2, 4)
   4483         k_tx = insert_dims(insert_dims(k_tx, 2, 1), 2, 4)
   4484         # Compute the synthetic phase shifts due to the antenna array
   4485         # Transmitter side
   4486         # Expand for broadcasting with receiver, receive antennas,
   4487         # paths
   4488         # [1, 1, 1, num_tx, tx_array_size, 3]
   4489         tx_rel_ant_pos = insert_dims(tx_rel_ant_pos, 3, axis=0)
   4490         # [1, 1, 1, num_tx, 1, tx_array_size, 1, 3]
   4491         tx_rel_ant_pos = tf.expand_dims(tf.expand_dims(tx_rel_ant_pos, axis=4),
   4492                                         axis=6)
   4493         # [num_rx, 1, 1, num_tx, 1, tx_array_size, max_num_paths]
   4494         tx_phase_shifts = dot(tx_rel_ant_pos, k_tx)
   4495         # Receiver side
   4496         # Expand for broadcasting with transmitter, transmit antennas,
   4497         # paths
   4498         # [num_rx, 1, rx_array_size, 1, 1, 1, 1, 3]
   4499         rx_rel_ant_pos = insert_dims(tf.expand_dims(rx_rel_ant_pos, axis=1),
   4500                                      4, axis=3)
   4501         # [num_rx, 1, rx_array_size, num_tx, 1, 1, 1, max_num_paths]
   4502         rx_phase_shifts = dot(rx_rel_ant_pos, k_rx)
   4503         # Total phase shift
   4504         # [num_rx, 1, rx_array_size, num_tx, 1, tx_array_size, max_num_paths]
   4505         phase_shifts = rx_phase_shifts + tx_phase_shifts
   4506         phase_shifts = two_pi*phase_shifts/self._scene.wavelength
   4507         # Apply the phase shifts
   4508         # Broadcast is not supported by TF for such high rank tensors.
   4509         # We therefore do it manually
   4510         # [num_rx, num_rx_patterns, rx_array_size, num_tx, num_tx_patterns,
   4511         #   tx_array_size, max_num_paths]
   4512         a = tf.tile(a, [1, 1, phase_shifts.shape[2], 1, 1,
   4513                         phase_shifts.shape[5], 1])
   4514         # [num_rx, num_rx_patterns, rx_array_size, num_tx, num_tx_patterns,
   4515         #   tx_array_size, max_num_paths]
   4516         a = a*tf.exp(tf.complex(tf.zeros_like(phase_shifts), phase_shifts))
   4517         a = flatten_dims(flatten_dims(a, 2, 1), 2, 3)
   4518 
   4519         return a
   4520 
   4521     def _compute_paths_coefficients(self, rx_rot_mat, tx_rot_mat, paths,
   4522                                     paths_tmp, num_samples,
   4523                                     scattering_coefficient, xpd_coefficient,
   4524                                     etas, alpha_r, alpha_i, lambda_,
   4525                                     scat_keep_prob, scat_random_phases):
   4526         # pylint: disable=line-too-long
   4527         r"""
   4528         Computes the paths coefficients.
   4529 
   4530         Input
   4531         ------
   4532         rx_rot_mat : [num_rx, 3, 3], tf.float
   4533             Matrices for rotating according to the receivers orientations
   4534 
   4535         tx_rot_mat : [num_tx, 3, 3], tf.float
   4536             Matrices for rotating according to the receivers orientations
   4537 
   4538         paths : :class:`~sionna.rt.Paths`
   4539             Paths to update
   4540 
   4541         paths_tmp : :class:`~sionna.rt.PathsTmpData`
   4542             Updated addtional quantities required for paths computation
   4543 
   4544         num_samples : int
   4545             Number of random rays to trace in order to generate candidates.
   4546             A large sample count may exhaust GPU memory.
   4547 
   4548         scattering_coefficient : [num_shapes], tf.float
   4549             Scattering coefficient :math:`S\in[0,1]` as defined in
   4550             :eq:`scattering_coefficient`.
   4551 
   4552         xpd_coefficient: [num_shapes], tf.float
   4553             Cross-polarization discrimination coefficient :math:`K_x\in[0,1]` as
   4554             defined in :eq:`xpd`.
   4555 
   4556         etas : [num_shapes], tf.complex
   4557             Complex relative permittivity :math:`\eta` :eq:`eta`
   4558 
   4559         alpha_r : [num_shapes], tf.int32
   4560             Parameter related to the width of the scattering lobe in the
   4561             direction of the specular reflection.
   4562 
   4563         alpha_i : [num_shapes], tf.int32
   4564             Parameter related to the width of the scattering lobe in the
   4565             incoming direction.
   4566 
   4567         lambda_ : [num_shapes], tf.float
   4568             Parameter determining the percentage of the diffusely
   4569             reflected energy in the lobe around the specular reflection.
   4570 
   4571         scat_keep_prob : float
   4572             Probability with which to keep scattered paths.
   4573             This is helpful to reduce the number of scattered paths computed,
   4574             which might be prohibitively high in some setup.
   4575             Must be in the range (0,1).
   4576 
   4577         scat_random_phases : bool
   4578             If set to `True` and if scattering is enabled, random uniform phase
   4579             shifts are added to the scattered paths.
   4580 
   4581         Output
   4582         ------
   4583         paths : :class:`~sionna.rt.Paths`
   4584             Updated paths
   4585         """
   4586 
   4587         # [num_rx, num_tx, max_num_paths, 2, 2]
   4588         theta_t = paths.theta_t
   4589         phi_t = paths.phi_t
   4590         theta_r = paths.theta_r
   4591         phi_r = paths.phi_r
   4592         types = paths.types
   4593 
   4594         mat_t = paths_tmp.mat_t
   4595         k_tx = paths_tmp.k_tx
   4596         k_rx = paths_tmp.k_rx
   4597 
   4598         # Apply multiplication by wavelength/4pi
   4599         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths,2, 2]
   4600         cst = tf.cast(self._scene.wavelength/(4.*PI), self._dtype)
   4601         a = cst*mat_t
   4602 
   4603         # Get dimensions that are needed later on
   4604         num_rx = a.shape[0]
   4605         rx_array_size = a.shape[1]
   4606         num_tx = a.shape[2]
   4607         tx_array_size = a.shape[3]
   4608 
   4609         # Expand dimension for broadcasting with receivers/transmitters,
   4610         # antenna dimensions, and paths dimensions
   4611         # [1, 1, num_tx, 1, 1, 3, 3]
   4612         tx_rot_mat = insert_dims(insert_dims(tx_rot_mat, 2, 0), 2, 3)
   4613         # [num_rx, 1, 1, 1, 1, 3, 3]
   4614         rx_rot_mat = insert_dims(rx_rot_mat, 4, 1)
   4615 
   4616         if self._scene.synthetic_array:
   4617             # Expand for broadcasting with antenna dimensions
   4618             # [num_rx, 1, num_tx, 1, max_num_paths, 3]
   4619             k_rx = tf.expand_dims(tf.expand_dims(k_rx, axis=1), axis=3)
   4620             k_tx = tf.expand_dims(tf.expand_dims(k_tx, axis=1), axis=3)
   4621             # [num_rx, 1, num_tx, 1, max_num_paths]
   4622             theta_t = tf.expand_dims(tf.expand_dims(theta_t,axis=1), axis=3)
   4623             phi_t = tf.expand_dims(tf.expand_dims(phi_t, axis=1), axis=3)
   4624             theta_r = tf.expand_dims(tf.expand_dims(theta_r,axis=1), axis=3)
   4625             phi_r = tf.expand_dims(tf.expand_dims(phi_r, axis=1), axis=3)
   4626 
   4627         # Normalized wave transmit vector in the local coordinate system of
   4628         # the transmitters
   4629         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4630         k_prime_t = tf.linalg.matvec(tx_rot_mat, k_tx, transpose_a=True)
   4631 
   4632         # Normalized wave receiver vector in the local coordinate system of
   4633         # the receivers
   4634         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4635         k_prime_r = tf.linalg.matvec(rx_rot_mat, k_rx, transpose_a=True)
   4636 
   4637         # Angles of departure in the local coordinate system of the
   4638         # transmitter
   4639         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4640         theta_prime_t, phi_prime_t = theta_phi_from_unit_vec(k_prime_t)
   4641 
   4642         # Angles of arrival in the local coordinate system of the
   4643         # receivers
   4644         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4645         theta_prime_r, phi_prime_r = theta_phi_from_unit_vec(k_prime_r)
   4646 
   4647         # Spherical global frame vectors for tx and rx
   4648         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4649         theta_hat_t = theta_hat(theta_t, phi_t)
   4650         phi_hat_t = phi_hat(phi_t)
   4651         theta_hat_r = theta_hat(theta_r, phi_r)
   4652         phi_hat_r = phi_hat(phi_r)
   4653 
   4654         # Spherical local frame vectors for tx and rx
   4655         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4656         theta_hat_prime_t = theta_hat(theta_prime_t, phi_prime_t)
   4657         phi_hat_prime_t = phi_hat(phi_prime_t)
   4658         theta_hat_prime_r = theta_hat(theta_prime_r, phi_prime_r)
   4659         phi_hat_prime_r = phi_hat(phi_prime_r)
   4660 
   4661         # Rotation matrix for going from the spherical LCS to the spherical GCS
   4662         # For transmitters
   4663         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths]
   4664         tx_lcs2gcs_11 = dot(theta_hat_t,
   4665                             tf.linalg.matvec(tx_rot_mat, theta_hat_prime_t))
   4666         tx_lcs2gcs_12 = dot(theta_hat_t,
   4667                             tf.linalg.matvec(tx_rot_mat, phi_hat_prime_t))
   4668         tx_lcs2gcs_21 = dot(phi_hat_t,
   4669                             tf.linalg.matvec(tx_rot_mat, theta_hat_prime_t))
   4670         tx_lcs2gcs_22 = dot(phi_hat_t,
   4671                             tf.linalg.matvec(tx_rot_mat, phi_hat_prime_t))
   4672         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths,2, 2]
   4673         tx_lcs2gcs = tf.stack(
   4674                     [tf.stack([tx_lcs2gcs_11, tx_lcs2gcs_12], axis=-1),
   4675                      tf.stack([tx_lcs2gcs_21, tx_lcs2gcs_22], axis=-1)],
   4676                     axis=-2)
   4677         tx_lcs2gcs = tf.complex(tx_lcs2gcs, tf.zeros_like(tx_lcs2gcs))
   4678         # For receivers
   4679         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths]
   4680         rx_lcs2gcs_11 = dot(theta_hat_r,
   4681                             tf.linalg.matvec(rx_rot_mat, theta_hat_prime_r))
   4682         rx_lcs2gcs_12 = dot(theta_hat_r,
   4683                             tf.linalg.matvec(rx_rot_mat, phi_hat_prime_r))
   4684         rx_lcs2gcs_21 = dot(phi_hat_r,
   4685                             tf.linalg.matvec(rx_rot_mat, theta_hat_prime_r))
   4686         rx_lcs2gcs_22 = dot(phi_hat_r,
   4687                             tf.linalg.matvec(rx_rot_mat, phi_hat_prime_r))
   4688         # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths,2, 2]
   4689         rx_lcs2gcs = tf.stack(
   4690                     [tf.stack([rx_lcs2gcs_11, rx_lcs2gcs_12], axis=-1),
   4691                      tf.stack([rx_lcs2gcs_21, rx_lcs2gcs_22], axis=-1)],
   4692                     axis=-2)
   4693         rx_lcs2gcs = tf.complex(rx_lcs2gcs, tf.zeros_like(rx_lcs2gcs))
   4694 
   4695         # List of antenna patterns (callables)
   4696         tx_patterns = self._scene.tx_array.antenna.patterns
   4697         rx_patterns = self._scene.rx_array.antenna.patterns
   4698 
   4699         tx_ant_fields_hat = []
   4700         for pattern in tx_patterns:
   4701             # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size,
   4702             #   max_num_paths, 2]
   4703             tx_ant_f = tf.stack(pattern(theta_prime_t, phi_prime_t), axis=-1)
   4704             tx_ant_fields_hat.append(tx_ant_f)
   4705 
   4706         rx_ant_fields_hat = []
   4707         for pattern in rx_patterns:
   4708             # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size,
   4709             #   max_num_paths, 2]
   4710             rx_ant_f = tf.stack(pattern(theta_prime_r, phi_prime_r), axis=-1)
   4711             rx_ant_fields_hat.append(rx_ant_f)
   4712 
   4713         # Stacking the patterns, corresponding to different polarization
   4714         # directions, as an additional dimension
   4715         # [num_rx, num_rx_patterns, 1/rx_array_size, num_tx, 1/tx_array_size,
   4716         #   max_num_paths, 2]
   4717         rx_ant_fields_hat = tf.stack(rx_ant_fields_hat, axis=1)
   4718         # Expand for broadcasting with tx polarization
   4719         # [num_rx, num_rx_patterns, 1/rx_array_size, num_tx, 1, 1,
   4720         #   1/tx_array_size, max_num_paths, 2]
   4721         rx_ant_fields_hat = tf.expand_dims(rx_ant_fields_hat, axis=4)
   4722 
   4723         # Stacking the patterns, corresponding to different polarization
   4724         # [num_rx, 1/rx_array_size, num_tx, num_tx_patterns, 1/tx_array_size,
   4725         #   max_num_paths, 2]
   4726         tx_ant_fields_hat = tf.stack(tx_ant_fields_hat, axis=3)
   4727         # Expand for broadcasting with rx polarization
   4728         # [num_rx, 1, 1/rx_array_size, num_tx, num_tx_patterns, 1/tx_array_size,
   4729         #   max_num_paths, 2]
   4730         tx_ant_fields_hat = tf.expand_dims(tx_ant_fields_hat, axis=1)
   4731 
   4732         # Antenna patterns to spherical global coordinate system
   4733         # Expand to broadcast with antenna patterns
   4734         # [num_rx, 1, 1/rx_array_size, num_tx, 1, 1/tx_array_size,
   4735         #   max_num_paths, 2, 2]
   4736         rx_lcs2gcs = tf.expand_dims(tf.expand_dims(rx_lcs2gcs, axis=1), axis=4)
   4737         # [num_rx, num_rx_patterns, 1/rx_array_size, num_tx, 1, 1/tx_array_size,
   4738         #   max_num_paths, 2]
   4739         rx_ant_fields = tf.linalg.matvec(rx_lcs2gcs, rx_ant_fields_hat)
   4740         # Expand to broadcast with antenna patterns
   4741         # [num_rx, 1, 1/rx_array_size, num_tx, 1, 1/tx_array_size,
   4742         #   max_num_paths, 2, 2]
   4743         tx_lcs2gcs = tf.expand_dims(tf.expand_dims(tx_lcs2gcs, axis=1), axis=4)
   4744         # [num_rx, 1, 1/rx_array_size, num_tx, num_tx_patterns, 1/tx_array_size,
   4745         #   max_num_paths, 2, 2]
   4746         tx_ant_fields = tf.linalg.matvec(tx_lcs2gcs, tx_ant_fields_hat)
   4747 
   4748         # Expand the field to broadcast with the antenna patterns
   4749         # [num_rx, 1, rx_array_size, num_tx, 1, tx_array_size, max_num_paths,
   4750         #   2, 2]
   4751         a = tf.expand_dims(tf.expand_dims(a, axis=1), axis=4)
   4752 
   4753         # Compute transmitted field
   4754         # [num_rx, 1, 1/rx_array_size, num_tx, num_tx_patterns, 1/tx_array_size,
   4755         #   max_num_paths, 2]
   4756         a = tf.linalg.matvec(a, tx_ant_fields)
   4757 
   4758         ## Scattering: For scattering, a is the field specularly reflected by
   4759         # the last interaction point. We need to compute the scattered field.
   4760         # [num_scat_paths]
   4761         scat_ind = tf.where(types == Paths.SCATTERED)[:,0]
   4762         n_scat = tf.size(scat_ind)
   4763         if n_scat > 0:
   4764             n_other = a.shape[-2] - n_scat
   4765 
   4766             # On CPU, indexing with -1 does not work. Hence we replace -1 by 0.
   4767             # This makes no difference on the resulting paths as such paths are
   4768             # not flagged as active.
   4769             # [max_num_paths]
   4770             valid_object_idx = tf.where(paths_tmp.scat_last_objects == -1,
   4771                                         0, paths_tmp.scat_last_objects)
   4772 
   4773             # Cross-polarization discrimination and scattering coefficients.
   4774             # If a callable is defined to compute the radio material properties,
   4775             # it is invoked. Otherwise, the radio materials of objects are used.
   4776             rm_callable = self._scene.radio_material_callable
   4777             if rm_callable is None:
   4778                 # [num_targets, num_sources, max_num_paths]
   4779                 k_x = tf.gather(xpd_coefficient, valid_object_idx)
   4780                 s = tf.gather(scattering_coefficient, valid_object_idx)
   4781                 etas = tf.gather(etas, valid_object_idx)
   4782             else:
   4783                 # [num_targets, num_sources, max_num_paths]
   4784                 etas, s, k_x = rm_callable(paths_tmp.scat_last_objects,
   4785                                            paths_tmp.scat_last_vertices)
   4786 
   4787             # Generate random phase shifts, and compute field vector
   4788             phase_shape = tf.concat([tf.shape(k_x), [2]], axis=0)
   4789             if scat_random_phases:
   4790                 # [num_targets, num_sources, max_num_paths, 2]
   4791                 phases = config.tf_rng.uniform(phase_shape, maxval=2*PI,
   4792                                                dtype=self._rdtype)
   4793             else:
   4794                 phases = tf.zeros(phase_shape, dtype=self._rdtype)
   4795             # [num_targets, num_sources, max_num_paths, 2]
   4796             field_vec = tf.exp(tf.complex(tf.cast(0, self._rdtype), phases))
   4797             # [num_targets, num_sources, max_num_paths, 2]
   4798             k_x_ = tf.stack([tf.sqrt(1-k_x), tf.sqrt(k_x)], axis=-1)
   4799             k_x_ = tf.complex(k_x_, tf.zeros_like(k_x_))
   4800             field_vec *= k_x_
   4801 
   4802             # Evaluate scattering pattern for all paths.
   4803             # If a callable is defined to compute the scattering pattern,
   4804             # it is invoked. Otherwise, the radio materials of objects are used.
   4805             sp_callable = self._scene.scattering_pattern_callable
   4806             if sp_callable is None:
   4807                 # Get all material properties related to scattering for each
   4808                 # path
   4809                 # [num_targets, num_sources, max_num_paths]
   4810                 alpha_r = tf.gather(alpha_r, valid_object_idx)
   4811                 alpha_i = tf.gather(alpha_i, valid_object_idx)
   4812                 lambda_ = tf.gather(lambda_, valid_object_idx)
   4813                 # Flattening is needed here as the pattern cannot handle it
   4814                 # otherwise
   4815                 f_s = ScatteringPattern.pattern(
   4816                                 tf.reshape(paths_tmp.scat_last_k_i, [-1, 3]),
   4817                                 tf.reshape(paths_tmp.scat_k_s, [-1, 3]),
   4818                                 tf.reshape(paths_tmp.scat_last_normals,[-1, 3]),
   4819                                 tf.reshape(alpha_r, [-1]),
   4820                                 tf.reshape(alpha_i, [-1]),
   4821                                 tf.reshape(lambda_, [-1]))
   4822                 # Reshape f_s to original dimensions
   4823                 # [num_targets, num_sources, max_num_paths]
   4824                 f_s = tf.reshape(f_s, tf.shape(alpha_r))
   4825             else:
   4826                 # [num_targets, num_sources, max_num_paths]
   4827                 f_s = sp_callable(paths_tmp.scat_last_objects,
   4828                                   paths_tmp.scat_last_vertices,
   4829                                   paths_tmp.scat_last_k_i,
   4830                                   paths_tmp.scat_k_s,
   4831                                   paths_tmp.scat_last_normals)
   4832 
   4833             # Complete the computation of the field
   4834             # [num_targets, num_sources, max_num_paths]
   4835             scaling = tf.sqrt(f_s)*s
   4836 
   4837             # The term cos(theta_i)*dA is equal to 4*PI/N*r^2
   4838             # [num_targets, num_sources, max_num_paths]
   4839             num_samples = tf.cast(num_samples, self._rdtype)
   4840             scaling *= tf.sqrt(4*tf.cast(PI, self._rdtype)\
   4841                 /(scat_keep_prob*num_samples))
   4842             scaling *= paths_tmp.scat_src_2_last_int_dist
   4843 
   4844             # Apply path loss due to propagation from scattering point
   4845             # to target
   4846             # [num_targets, num_sources, max_num_paths]
   4847             scaling = tf.math.divide_no_nan(scaling,
   4848                                             paths_tmp.scat_2_target_dist)
   4849 
   4850             # Compute scaled field vector
   4851             # [num_targets, num_sources, max_num_paths, 2]
   4852             field_vec *= tf.expand_dims(tf.complex(scaling,
   4853                                                    tf.zeros_like(scaling)), -1)
   4854 
   4855             # Compute Fresnel reflection coefficients at hit point
   4856             # These will be scaled by the reflection reduction factor
   4857             # [num_targets, num_sources, max_num_paths]
   4858             cos_theta = -dot(paths_tmp.scat_last_k_i,
   4859                              paths_tmp.scat_last_normals, clip=True)
   4860 
   4861             # [num_targets, num_sources, max_num_paths]
   4862             r_s, r_p = reflection_coefficient(etas, cos_theta)
   4863 
   4864             # [num_targets, num_sources, max_num_paths, 3]
   4865             e_i_s, e_i_p = compute_field_unit_vectors(
   4866                                         paths_tmp.scat_last_k_i,
   4867                                         paths_tmp.scat_k_s,
   4868                                         paths_tmp.scat_last_normals,
   4869                                         SolverBase.EPSILON,
   4870                                         return_e_r=False)
   4871 
   4872             # a_scat : [num_rx, 1, rx_array_size, num_tx, num_tx_patterns,
   4873             #   tx_array_size, n_scat, 2]
   4874             # a_other : [num_rx, 1, rx_array_size, num_tx, num_tx_patterns,
   4875             #   tx_array_size, max_num_paths - n_scat, 2]
   4876             a_other, a_scat = tf.split(a, [n_other, n_scat], axis=-2)
   4877             # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size,
   4878             #   max_num_paths, 3]
   4879             _, scat_theta_hat_r = tf.split(theta_hat_r, [n_other, n_scat],
   4880                                            axis=-2)
   4881             # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size,
   4882             #   max_num_paths, 3]
   4883             _, scat_phi_hat_r = tf.split(phi_hat_r, [n_other, n_scat],
   4884                                            axis=-2)
   4885 
   4886             # Compute incoming field
   4887             # [num_rx, 1, 1/rx_array_size, num_tx, 1, 1/tx_array_size, n_scat,
   4888             #   (3)]
   4889             scat_k_i = paths_tmp.scat_last_k_i
   4890             if self._scene.synthetic_array:
   4891                 r_s = insert_dims(r_s, 2, axis=1)
   4892                 r_s = insert_dims(r_s, 2, axis=4)
   4893                 r_p = insert_dims(r_p, 2, axis=1)
   4894                 r_p = insert_dims(r_p, 2, axis=4)
   4895                 e_i_s = insert_dims(e_i_s, 2, axis=1)
   4896                 e_i_s = insert_dims(e_i_s, 2, axis=4)
   4897                 e_i_p = insert_dims(e_i_p, 2, axis=1)
   4898                 e_i_p = insert_dims(e_i_p, 2, axis=4)
   4899                 scat_k_i = insert_dims(scat_k_i, 2, axis=1)
   4900                 scat_k_i = insert_dims(scat_k_i, 2, axis=4)
   4901                 field_vec = insert_dims(field_vec, 2, axis=1)
   4902                 field_vec = insert_dims(field_vec, 2, axis=4)
   4903             else:
   4904                 num_rx = len(self._scene.receivers)
   4905                 num_tx = len(self._scene.transmitters)
   4906                 r_s = split_dim(r_s, [num_rx, -1], 0)
   4907                 r_s = tf.expand_dims(r_s, axis=1)
   4908                 r_s = split_dim(r_s, [num_tx, -1], 3)
   4909                 r_s = tf.expand_dims(r_s, axis=4)
   4910                 r_p = split_dim(r_p, [num_rx, -1], 0)
   4911                 r_p = tf.expand_dims(r_p, axis=1)
   4912                 r_p = split_dim(r_p, [num_tx, -1], 3)
   4913                 r_p = tf.expand_dims(r_p, axis=4)
   4914                 e_i_s = split_dim(e_i_s, [num_rx, -1], 0)
   4915                 e_i_s = tf.expand_dims(e_i_s, axis=1)
   4916                 e_i_s = split_dim(e_i_s, [num_tx, -1], 3)
   4917                 e_i_s = tf.expand_dims(e_i_s, axis=4)
   4918                 e_i_p = split_dim(e_i_p, [num_rx, -1], 0)
   4919                 e_i_p = tf.expand_dims(e_i_p, axis=1)
   4920                 e_i_p = split_dim(e_i_p, [num_tx, -1], 3)
   4921                 e_i_p = tf.expand_dims(e_i_p, axis=4)
   4922                 scat_k_i = split_dim(scat_k_i, [num_rx, -1], 0)
   4923                 scat_k_i = tf.expand_dims(scat_k_i, axis=1)
   4924                 scat_k_i = split_dim(scat_k_i, [num_tx, -1], 3)
   4925                 scat_k_i = tf.expand_dims(scat_k_i, axis=4)
   4926                 field_vec = split_dim(field_vec, [num_rx, -1], 0)
   4927                 field_vec = tf.expand_dims(field_vec, axis=1)
   4928                 field_vec = split_dim(field_vec, [num_tx, -1], 3)
   4929                 field_vec = tf.expand_dims(field_vec, axis=4)
   4930 
   4931             # [num_rx, 1, 1/rx_array_size, num_tx, 1, 1/tx_array_size, n_scat,2]
   4932             scat_r = tf.stack([r_s, r_p], axis=-1)
   4933 
   4934             # [num_rx, 1, 1/rx_array_size, num_tx, num_tx_patterns,
   4935             #   1/tx_array_size, n_scat, 2]
   4936             a_in = tf.math.divide_no_nan(a_scat, scat_r)
   4937 
   4938             # Compute polarization field vector
   4939             a_in_s, a_in_p = tf.split(a_in, 2, axis=-1)
   4940             e_i_s = tf.complex(e_i_s, tf.zeros_like(e_i_s))
   4941             e_i_p = tf.complex(e_i_p, tf.zeros_like(e_i_p))
   4942             e_in_pol = a_in_s*e_i_s + a_in_p*e_i_p
   4943             e_pol_hat, _ = normalize(tf.math.real(e_in_pol))
   4944             e_xpol_hat = cross(e_pol_hat, scat_k_i)
   4945 
   4946             # Compute incoming spherical unit vectors in GCS
   4947             scat_theta_i, scat_phi_i = theta_phi_from_unit_vec(-scat_k_i)
   4948             scat_theta_hat_i = theta_hat(scat_theta_i, scat_phi_i)
   4949             scat_phi_hat_i = phi_hat(scat_phi_i)
   4950 
   4951             # Transformation to theta_hat_i, phi_hat_i
   4952             trans_mat = component_transform(e_pol_hat, e_xpol_hat,
   4953                                             scat_theta_hat_i, scat_phi_hat_i)
   4954 
   4955             # Transformation from theta_hat_s, phi_hat_s to theta_hat_r, phi_hat_r
   4956             # [num_targets, num_sources, max_num_paths, 3]
   4957             # = [num_rx*1/rx_array_size, num_tx*1/tx_array_size, max_num_paths, 3]
   4958             scat_theta_s, scat_phi_s = theta_phi_from_unit_vec(paths_tmp.scat_k_s)
   4959             scat_theta_hat_s = theta_hat(scat_theta_s, scat_phi_s)
   4960             scat_phi_hat_s = phi_hat(scat_phi_s)
   4961 
   4962             # [num_rx, 1/rx_array_size, num_sources, max_num_paths, 3]
   4963             scat_theta_hat_s = split_dim(scat_theta_hat_s,
   4964                                          [num_rx, rx_array_size], 0)
   4965             scat_phi_hat_s = split_dim(scat_phi_hat_s,
   4966                                        [num_rx, rx_array_size], 0)
   4967 
   4968             # [num_rx, 1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4969             scat_theta_hat_s = split_dim(scat_theta_hat_s,
   4970                                          [num_tx, tx_array_size], 2)
   4971             scat_phi_hat_s = split_dim(scat_phi_hat_s,
   4972                                        [num_tx, tx_array_size], 2)
   4973 
   4974             # [num_rx, 1,  1/rx_array_size, num_tx, 1/tx_array_size, max_num_paths, 3]
   4975             scat_theta_hat_s = tf.expand_dims(scat_theta_hat_s, 1)
   4976             scat_phi_hat_s = tf.expand_dims(scat_phi_hat_s, 1)
   4977 
   4978             # [num_rx, 1,  1/rx_array_size, num_tx, 1, 1/tx_array_size, max_num_paths, 3]
   4979             scat_theta_hat_s = tf.expand_dims(scat_theta_hat_s, 4)
   4980             scat_phi_hat_s = tf.expand_dims(scat_phi_hat_s, 4)
   4981 
   4982             # [num_rx, 1, 1/rx_array_size, num_tx, 1, 1/tx_array_size,
   4983             #   max_num_scat_paths, 3]
   4984             scat_theta_hat_r = tf.expand_dims(scat_theta_hat_r, axis=1)
   4985             scat_theta_hat_r = tf.expand_dims(scat_theta_hat_r, axis=4)
   4986             # [num_rx, 1, 1/rx_array_size, num_tx, 1, 1/tx_array_size,
   4987             #   max_num_scat_paths, 3]
   4988             scat_phi_hat_r = tf.expand_dims(scat_phi_hat_r, axis=1)
   4989             scat_phi_hat_r = tf.expand_dims(scat_phi_hat_r, axis=4)
   4990 
   4991             trans_mat2 = component_transform(scat_theta_hat_s, scat_phi_hat_s,
   4992                                              scat_theta_hat_r, scat_phi_hat_r)
   4993 
   4994             trans_mat = tf.matmul(trans_mat2, trans_mat)
   4995 
   4996             # Compute basis transform matrix for GCS
   4997             # [num_rx, 1, rx_array_size, num_tx, num_tx_patterns, tx_array_size,
   4998             #   max_num_scat_paths, 2, 2]
   4999             trans_mat = tf.complex(trans_mat, tf.zeros_like(trans_mat))
   5000 
   5001             # Multiply a_scat by sqrt of reflected energy
   5002             # The splitting along the last dim is done because
   5003             # TF cannot handle reduce_sum for such high-dimensional
   5004             # tensors
   5005             #
   5006             # [num_rx, 1, 1/rx_array_size, num_tx, num_tx_patterns,
   5007             #   1/tx_array_size, max_num_paths-n_scat, 1]
   5008             e_spec = tf.reduce_sum(tf.square(tf.abs(a_scat)), axis=-1,
   5009                                    keepdims=True)
   5010             e_spec = tf.sqrt(e_spec)
   5011 
   5012             # [num_rx, 1, 1/rx_array_size, num_tx, num_tx_patterns,
   5013             #   1/tx_array_size, max_num_paths-n_scat, 2]
   5014             e_spec = tf.complex(e_spec, tf.zeros_like(e_spec))
   5015             a_scat = field_vec*e_spec
   5016 
   5017             # Basis transform
   5018             a_scat = tf.linalg.matvec(trans_mat, a_scat)
   5019 
   5020             # Concat with other paths
   5021             a = tf.concat([a_other, a_scat], axis=-2)
   5022 
   5023         # [num_rx, num_rx_patterns, 1/rx_array_size, num_tx, num_tx_patterns,
   5024         #   1/tx_array_size, max_num_paths]
   5025         a = dot(rx_ant_fields, a)
   5026 
   5027         if not self._scene.synthetic_array:
   5028             # Reshape as expected to merge antenna and antenna patterns into one
   5029             # dimension, as expected by Sionna
   5030             # [ num_rx, num_rx_ant = num_rx_patterns*rx_array_size,
   5031             #   num_tx, num_tx_ant = num_tx_patterns*tx_array_size,
   5032             #   max_num_paths]
   5033             a = flatten_dims(flatten_dims(a, 2, 1), 2, 3)
   5034 
   5035         return a