anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

solver_base.py (38948B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 Ray tracing algorithm that uses the image method to compute all pure reflection
      7 paths.
      8 """
      9 
     10 import mitsuba as mi
     11 import drjit as dr
     12 import tensorflow as tf
     13 
     14 from .utils import normalize, dot, theta_phi_from_unit_vec, cross,\
     15     mi_to_tf_tensor, mitsuba_rectangle_to_world
     16 from sionna import PI
     17 
     18 
     19 class SolverBase:
     20     # pylint: disable=line-too-long
     21     r"""SolverBase(scene, solver=None, dtype=tf.complex64)
     22 
     23     Base class for implementing a solver. If another ``solver`` is specified at
     24     instantiation, then it re-uses the structure to avoid useless compute and
     25     memory use.
     26 
     27     Note: Only triangle mesh are supported.
     28 
     29     Parameters
     30     -----------
     31     scene : :class:`~sionna.rt.Scene`
     32         Sionna RT scene
     33 
     34     solver : :class:`~sionna.rt.SolverBase` | None
     35         Another solver from which to re-use some structures to avoid useless
     36         compute and memory use
     37 
     38     dtype : tf.complex64 | tf.complex128
     39         Datatype for all computations, inputs, and outputs.
     40         Defaults to `tf.complex64`.
     41     """
     42 
     43     # Small value used to discard intersection with edges, avoid
     44     # self-intersection, etc.
     45     # Resolution of float32 1e-6:
     46     # np.finfo(np.float32) -> reslution=1e-6
     47     # We add on top a 10x factor for caution
     48     EPSILON = 1e-5
     49 
     50     # Threshold for extracting wedges from the scene [rad]
     51     WEDGES_ANGLE_THRESHOLD = 1.*PI/180.
     52 
     53     # Small value used to avoid false positive when testing for obstruction
     54     # Resolution of float32 1e-6:
     55     # np.finfo(np.float32) -> reslution=1e-6
     56     # We add on top a 100x factor for caution
     57     EPSILON_OBSTRUCTION = 1e-4
     58 
     59     def __init__(self, scene, solver=None, dtype=tf.complex64):
     60 
     61         # Computes the quantities required for generating the paths.
     62         # More pricisely:
     63 
     64         # _primitives : [num triangles, 3, 3], float
     65         #     The triangles: x-y-z coordinates of the 3 vertices for every
     66         #     triangle
     67 
     68         # _normals : [num triangles, 3], float
     69         #     The normals of the triangles
     70 
     71         # _primitives_2_objects : [num_triangles], int
     72         #     Index of the shape containing the triangle
     73 
     74         # _prim_offsets : [num_objects], int
     75         #     Indices offsets for accessing the triangles making each shape.
     76 
     77         # _shape_indices : [num_objects], int
     78         #     Map Mitsuba shape indices to indices that can be used to access
     79         #    _prim_offsets
     80 
     81         assert dtype in (tf.complex64, tf.complex128),\
     82             "`dtype` must be tf.complex64 or tf.complex128`"
     83         self._dtype = dtype
     84         self._rdtype = dtype.real_dtype
     85 
     86         # Mitsuba types depend on the used precision
     87         if dtype == tf.complex64:
     88             self._mi_point_t = mi.Point3f
     89             self._mi_point2_t = mi.Point2f
     90             self._mi_vec_t = mi.Vector3f
     91             self._mi_scalar_t = mi.Float
     92             self._mi_tensor_t = mi.TensorXf
     93         else:
     94             self._mi_point_t = mi.Point3d
     95             self._mi_point2_t = mi.Point2d
     96             self._mi_vec_t = mi.Vector3d
     97             self._mi_scalar_t = mi.Float64
     98             self._mi_tensor_t = mi.TensorXd
     99 
    100         self._scene = scene
    101         mi_scene = scene.mi_scene
    102         self._mi_scene = mi_scene
    103 
    104         # If a solver is provided, then link to the same structures to avoid
    105         # useless compute and memory use
    106         if solver is not None:
    107             self._primitives = solver._primitives
    108             self._normals = solver._normals
    109             self._primitives_2_objects = solver._primitives_2_objects
    110             self._prim_offsets = solver._prim_offsets
    111             self._shape_indices = solver._shape_indices
    112             #
    113             self._wedges_origin = solver._wedges_origin
    114             self._wedges_e_hat = solver._wedges_e_hat
    115             self._wedges_length = solver._wedges_length
    116             self._wedges_normals = solver._wedges_normals
    117             self._primitives_2_wedges = solver._primitives_2_wedges
    118             self._wedges_objects = solver._wedges_objects
    119             self._is_edge = solver._is_edge
    120             return
    121 
    122         ###################################################
    123         # Extract triangles, their normals, and a
    124         # look-up-table to map primitives to the scene
    125         # object they belong to.
    126         ###################################################
    127 
    128         # Tensor mapping primitives to corresponding objects
    129         # [num_triangles]
    130         primitives_2_objects = []
    131 
    132         # Number of triangles
    133         n_prims = 0
    134         # Triangles of each object (shape) in the scene are stacked.
    135         # This list tracks the indices offsets for accessing the triangles
    136         # making each shape.
    137         prim_offsets = []
    138         mi_shapes = scene.mi_shapes
    139         for i,s in enumerate(mi_shapes):
    140             if not isinstance(s, mi.Mesh):
    141                 raise ValueError('Only triangle meshes are supported')
    142             prim_offsets.append(n_prims)
    143             n_prims += s.face_count()
    144             primitives_2_objects += [i]*s.face_count()
    145         # [num_objects]
    146         prim_offsets = tf.cast(prim_offsets, tf.int32)
    147 
    148         # Tensor of triangles vertices
    149         # [n_prims, number of vertices : 3, coordinates : 3]
    150         prims = tf.zeros([n_prims, 3, 3], self._rdtype)
    151         # Normals to the triangles
    152         normals = tf.zeros([n_prims, 3], self._rdtype)
    153         # Loop through the objects in the scene
    154         for prim_offset, s in zip(prim_offsets, mi_shapes):
    155             # Extract the vertices of the shape.
    156             # Dr.JIT/Mitsuba is used here.
    157             # Indices of the vertices
    158             # [n_prims, num of vertices per triangle : 3]
    159             face_indices3 = s.face_indices(dr.arange(mi.UInt32, s.face_count()))
    160             # Flatten. This is required for calling vertex_position
    161             # [n_prims*3]
    162             face_indices = dr.ravel(face_indices3)
    163             # Get vertices coordinates
    164             # [n_prims*3, 3]
    165             vertex_coords = s.vertex_position(face_indices)
    166             # Move to TensorFlow
    167             # [n_prims*3, 3]
    168             vertex_coords = mi_to_tf_tensor(vertex_coords, self._rdtype)
    169             # Unflatten
    170             # [n_prims, vertices per triangle : 3, 3]
    171             vertex_coords = tf.reshape(vertex_coords, [s.face_count(), 3, 3])
    172             # Update the `prims` tensor
    173             sl = tf.range(prim_offset, prim_offset + s.face_count(),
    174                           dtype=tf.int32)
    175             sl = tf.expand_dims(sl, axis=1)
    176             prims = tf.tensor_scatter_nd_update(prims, sl, vertex_coords)
    177             # Compute the normals to the triangles
    178             # Coordinate of the first vertices of every triangle making the
    179             # shape
    180             # [n_prims, xyz : 3]
    181             v0 = s.vertex_position(face_indices3.x)
    182             # Coordinate of the second vertices of every triangle making the
    183             # shape
    184             # [n_prims, xyz : 3]
    185             v1 = s.vertex_position(face_indices3.y)
    186             # Coordinate of the third vertices of every triangle making the
    187             # shape
    188             # [n_prims, xyz : 3]
    189             v2 = s.vertex_position(face_indices3.z)
    190             # Compute the normals
    191             # [n_prims, xyz : 3]
    192             mi_n = dr.normalize(dr.cross(
    193                 v1 - v0,
    194                 v2 - v0,
    195             ))
    196             # Move to TensorFlow
    197             # [n_prims, 3]
    198             n = mi_to_tf_tensor(mi_n, self._rdtype)
    199             # Update the 'normals' tensor
    200             normals = tf.tensor_scatter_nd_update(normals, sl, n)
    201 
    202         self._primitives = tf.Variable(prims, trainable=False)
    203         self._normals = tf.Variable(normals, trainable=False)
    204         primitives_2_objects = tf.cast(primitives_2_objects, tf.int32)
    205         self._primitives_2_objects = tf.Variable(primitives_2_objects,
    206                                                  trainable=False)
    207 
    208         ####################################################
    209         # Used by the shoot & bounce method to map from
    210         # (shape, local primitive index) to the
    211         # corresponding global primitive index.
    212         ####################################################
    213 
    214         # [num_objects]
    215         self._prim_offsets = mi.Int32(prim_offsets.numpy())
    216         mi_shapes_ptr = [mi.ShapePtr(s) for s in mi_shapes]
    217         ptr_as_int = [dr.reinterpret_array_v(mi.UInt32, p)[0] for p in mi_shapes_ptr]
    218         ptr_as_int = mi.UInt(ptr_as_int)
    219         if dr.width(ptr_as_int) == 0:
    220             self._shape_indices = mi.Int32([])
    221         else:
    222             # [num_objects]
    223             shape_indices = dr.full(mi.Int32, -1, dr.max(ptr_as_int)[0] + 1)
    224             dr.scatter(shape_indices, dr.arange(mi.Int32, 0,
    225                        dr.width(ptr_as_int)), ptr_as_int)
    226             dr.eval(shape_indices)
    227             # [num_objects]
    228             self._shape_indices = shape_indices
    229 
    230         #################################################
    231         # Extract the wedges
    232         #################################################
    233         # _wedges_origin : [num_wedges, 3], float
    234         #   Starting point of the wedges
    235 
    236         # _wedges_e_hat : [num_wedges, 3], float
    237         #   Normalized edge vector
    238 
    239         # _wedges_length : [num_wedges], float
    240         #   Length of the wedges
    241 
    242         # _wedges_normals : [num_wedges, 2, 3], float
    243         #   Normals to the wedges sides
    244 
    245         # _primitives_2_wedges : [num_primitives, 3], int
    246         #   Maps primitives to their wedges
    247 
    248         # _wedges_objects : [num_wedges, 2], int
    249         #   Indices of the two objects making the wedge (the two sides of the
    250         #   wedge could belong to different objects)
    251 
    252         # is_edge : [num_wedges], bool
    253         #     Set to `True` if a wedge is an edge, i.e., the edge of a single
    254         #     primitive.
    255 
    256         edges = self._extract_wedges()
    257         self._wedges_origin = tf.Variable(edges[0], trainable=False)
    258         self._wedges_e_hat = tf.Variable(edges[1], trainable=False)
    259         self._wedges_length = tf.Variable(edges[2], trainable=False)
    260         self._wedges_normals = tf.Variable(edges[3], trainable=False)
    261         self._primitives_2_wedges = tf.Variable(edges[4], trainable=False)
    262         self._wedges_objects = tf.Variable(edges[5], trainable=False)
    263         self._is_edge = tf.Variable(edges[6], trainable=False)
    264 
    265     ##################################################################
    266     # Internal utility methods
    267     ##################################################################
    268 
    269     @property
    270     def primitives(self):
    271         """
    272         [num triangles, 3, 3], tf.float : The triangles: [x,y,z] coordinates of
    273         the 3 vertices of every triangle
    274         """
    275         return self._primitives
    276 
    277     @property
    278     def normals(self):
    279         """
    280         [num triangles, 3], tf.float : The normals of the triangles
    281         """
    282         return self._normals
    283 
    284     @property
    285     def prim_offsets(self):
    286         """
    287         [num_objects], tf.int : Indices offsets for accessing the triangles
    288         each shape in `primitives`
    289         """
    290         return self._prim_offsets
    291 
    292     @property
    293     def shape_indices(self):
    294         """
    295         [num_objects], tf.int :  Map object ids to indices that can be used to
    296         access `prim_offsets`
    297         """
    298         return self._shape_indices
    299 
    300     @property
    301     def wedges_origin(self):
    302         """
    303         [num_wedges, 3], tf.float : Origin of the wedges
    304         """
    305         return self._wedges_origin
    306 
    307     @property
    308     def wedges_e_hat(self):
    309         """
    310         [num_wedges, 3], tf.float : Normalized edge vector
    311         """
    312         return self._wedges_e_hat
    313 
    314     @property
    315     def wedges_length(self):
    316         """
    317         [num_wedges], tf.float : Length of the wedges
    318         """
    319         return self._wedges_length
    320 
    321     @property
    322     def wedges_normals(self):
    323         """
    324         [num_wedges, 2, 3], tf.float : Normals to the wedges sides
    325         """
    326         return self._wedges_normals
    327 
    328     @property
    329     def primitives_2_wedges(self):
    330         """
    331         [num_primitives, 3], tf.int : Maps primitives to their wedges
    332         """
    333         return self._primitives_2_wedges
    334 
    335     @property
    336     def wedges_objects(self):
    337         """
    338         [num_wedges, 2], tf.int : Indices of the two objects making the
    339         wedge (the two sides of the wedge could belong to different objects)
    340         """
    341         return self._wedges_objects
    342 
    343     @property
    344     def is_edge(self):
    345         """
    346         [num_wedges], tf.bool : Set to `True` if a wedge is an edge, i.e.,
    347         the edge of a single primitive.
    348         """
    349         return self._is_edge
    350 
    351     def _build_scene_object_properties_tensors(self):
    352         r"""
    353         Build tensor containing the shape properties
    354 
    355         Input
    356         ------
    357         None
    358 
    359         Output
    360         -------
    361         relative_permittivity : [num_shape], tf.complex
    362             Tensor containing the complex relative permittivities of all shapes
    363 
    364         scattering_coefficient : [num_shape], tf.float
    365             Tensor containing the scattering coefficients of all shapes
    366 
    367         xpd_coefficient : [num_shape], tf.float
    368             Tensor containing the cross-polarization discrimination
    369             coefficients of all shapes
    370 
    371         alpha_r : [num_shape], tf.float
    372             Tensor containing the alpha_r scattering parameters of all shapes
    373 
    374         alpha_i : [num_shape], tf.float
    375             Tensor containing the alpha_i scattering parameters of all shapes
    376 
    377         lambda_ : [num_shape], tf.float
    378             Tensor containing the lambda_ scattering parameters of all shapes
    379 
    380         velocity : [num_shape], tf.float
    381             Tensor containing the velocity vectors of all shapes
    382         """
    383 
    384         # Compute the size of the tensors that store the properties of all
    385         # objects and RIS
    386         objects_id = [obj.object_id for obj in self._scene.objects.values()]
    387         max_id = 0
    388         if len(objects_id) > 0:
    389             max_id = tf.reduce_max(objects_id)
    390         array_size = max_id + 1 + len(self._scene.ris)
    391 
    392         # If a callable is set to obtain radio material properties, then there
    393         # is no need to build the tensors of material properties
    394         rm_callable_set = self._scene.radio_material_callable is not None
    395 
    396         # If a callable is set to obtain scattering patterns, then there
    397         # is no need to build the tensors for scattering properties
    398         sp_callable_set = self._scene.scattering_pattern_callable is not None
    399 
    400         if rm_callable_set:
    401             relative_permittivity = tf.zeros([0], self._dtype)
    402             scattering_coefficient = tf.zeros([0], self._rdtype)
    403             xpd_coefficient = tf.zeros([0], self._rdtype)
    404         else:
    405             relative_permittivity = tf.zeros([array_size], self._dtype)
    406             scattering_coefficient = tf.zeros([array_size], self._rdtype)
    407             xpd_coefficient = tf.zeros([array_size], self._rdtype)
    408 
    409         if sp_callable_set:
    410             alpha_r = tf.zeros([0], tf.int32)
    411             alpha_i = tf.zeros([0], tf.int32)
    412             lambda_ = tf.zeros([0], self._rdtype)
    413         else:
    414             alpha_r = tf.zeros([array_size], tf.int32)
    415             alpha_i = tf.zeros([array_size], tf.int32)
    416             lambda_ = tf.zeros([array_size], self._rdtype)
    417 
    418         if (not sp_callable_set) or (not rm_callable_set):
    419 
    420             for rm in self._scene.radio_materials.values():
    421                 using_objects = rm.using_objects
    422                 num_using_objects = tf.shape(using_objects)[0]
    423                 if num_using_objects == 0:
    424                     continue
    425 
    426                 if not rm_callable_set:
    427                     relative_permittivity = tf.tensor_scatter_nd_update(
    428                         relative_permittivity,
    429                         tf.reshape(using_objects, [-1,1]),
    430                         tf.fill([num_using_objects],
    431                                 rm.complex_relative_permittivity))
    432 
    433                     scattering_coefficient = tf.tensor_scatter_nd_update(
    434                         scattering_coefficient,
    435                         tf.reshape(using_objects, [-1,1]),
    436                         tf.fill([num_using_objects], rm.scattering_coefficient))
    437 
    438                     xpd_coefficient = tf.tensor_scatter_nd_update(
    439                         xpd_coefficient,
    440                         tf.reshape(using_objects, [-1,1]),
    441                         tf.fill([num_using_objects], rm.xpd_coefficient))
    442 
    443                 if not sp_callable_set:
    444                     alpha_r = tf.tensor_scatter_nd_update(
    445                         alpha_r,
    446                         tf.reshape(using_objects, [-1,1]),
    447                         tf.fill([num_using_objects],
    448                                 rm.scattering_pattern.alpha_r))
    449 
    450                     alpha_i = tf.tensor_scatter_nd_update(
    451                         alpha_i,
    452                         tf.reshape(using_objects, [-1,1]),
    453                         tf.fill([num_using_objects],
    454                                 rm.scattering_pattern.alpha_i))
    455 
    456                     lambda_ = tf.tensor_scatter_nd_update(
    457                         lambda_,
    458                         tf.reshape(using_objects, [-1,1]),
    459                         tf.fill([num_using_objects],
    460                                 rm.scattering_pattern.lambda_))
    461 
    462         # Velocity of all objects
    463         # [array_size, 3]
    464         velocity = tf.zeros([array_size, 3], self._rdtype)
    465         for obj in self._scene.objects.values():
    466             velocity = tf.tensor_scatter_nd_update(velocity,
    467                                                     [[obj.object_id]],
    468                                                      [obj.velocity])
    469         for obj in self._scene.ris.values():
    470             velocity = tf.tensor_scatter_nd_update(velocity,
    471                                                     [[obj.object_id]],
    472                                                      [obj.velocity])
    473 
    474         return (relative_permittivity,
    475                scattering_coefficient,
    476                xpd_coefficient,
    477                alpha_r,
    478                alpha_i,
    479                lambda_,
    480                velocity)
    481 
    482     def _test_obstruction(self, o, d, maxt, additional_blockers=None):
    483         r"""
    484         Test obstruction of a batch of rays using Mitsuba.
    485 
    486         Input
    487         -----
    488         o: [batch_size, 3], tf.float
    489             Origin of the rays
    490 
    491         d: [batch_size, 3], tf.float
    492             Direction of the rays.
    493             Must be unit vectors.
    494 
    495         maxt: [batch_size], tf.float
    496             Length of the ray
    497 
    498         additional_blockers : list(mi.Shape) | None
    499             Optional list of Mitsuba shapes containing additional blockers.
    500             Defaults to `None`.
    501 
    502         Output
    503         -------
    504         val: [batch_size], tf.bool
    505             `True` if the ray is obstructed, i.e., hits a primitive.
    506             `False` otherwise.
    507         """
    508         # Translate the origin a bit along the ray direction to avoid
    509         # consecutive intersection with the same primitive
    510         o = o + SolverBase.EPSILON_OBSTRUCTION*d
    511         # [batch_size, 3]
    512         mi_o = self._mi_point_t(o)
    513         # Ray direction
    514         # [batch_size, 3]
    515         mi_d = self._mi_vec_t(d)
    516         # [batch_size]
    517         # Reduce the ray length by a small value to avoid false positive when
    518         # testing for LoS to a primitive due to hitting the primitive we are
    519         # testing visibility to.
    520         maxt = maxt - 2.*SolverBase.EPSILON_OBSTRUCTION
    521         mi_maxt = self._mi_scalar_t(maxt)
    522         # Mitsuba ray
    523         mi_ray = mi.Ray3f(o=mi_o, d=mi_d, maxt=mi_maxt, time=0.,
    524                           wavelengths=mi.Color0f(0.))
    525         # Test for obstruction using Mitsuba
    526         # With the scene
    527         # [batch_size]
    528         mi_val = self._mi_scene.ray_test(mi_ray)
    529         val = mi_to_tf_tensor(mi_val, tf.bool)
    530         # With additional blockers
    531         if additional_blockers:
    532             for shape in additional_blockers:
    533                 mi_val = shape.ray_test(mi_ray)
    534                 val = tf.logical_or(val, mi_to_tf_tensor(mi_val, tf.bool))
    535         return val
    536 
    537     def _extract_wedges(self):
    538         r"""
    539         Extract the wedges and, optionally, the edges, from the scene geometry
    540 
    541         Output
    542         ------
    543         # _wedges_origin : [num_wedges, 3], float
    544         #   Starting point of the wedges
    545 
    546         # _wedges_e_hat : [num_wedges, 3], float
    547         #   Normalized edge vector
    548 
    549         # _wedges_length : [num_wedges], float
    550         #   Length of the wedges
    551 
    552         # _wedges_normals : [num_wedges, 2, 3], float
    553         #   Normals to the wedges sides
    554 
    555         # _primitives_2_wedges : [num_primitives, 3], int
    556         #   Maps primitives to their wedges
    557 
    558         # _wedges_objects : [num_wedges, 2], int
    559         #   Indices of the two objects making the wedge (the two sides of the
    560         #   wedge could belong to different objects)
    561 
    562         # is_edge : [num_wedges], bool
    563         #     Set to `True` if a wedge is an edge, i.e., the edge of a single
    564         #     primitive.
    565         """
    566 
    567         angle_threshold = SolverBase.WEDGES_ANGLE_THRESHOLD
    568 
    569         # Extract vertices of every triangle
    570         # [num_prim, 3]
    571         v0 = self._primitives[:,0,:]
    572         v1 = self._primitives[:,1,:]
    573         v2 = self._primitives[:,2,:]
    574 
    575         # List all edges
    576         # [num_prim, 2 + 2 + 2, 3]
    577         all_edges_undirected = tf.concat([
    578             v0, v1,
    579             v1, v2,
    580             v2, v0
    581         ], axis=1)
    582         # [num_edges = num_prim*3, 2, 3]
    583         all_edges_undirected = tf.reshape(all_edges_undirected,
    584                                           shape=(3 * v0.shape[0], 2, 3))
    585         # Edges are oriented such that identical edges have same orientation
    586         # [num_edges, 2, 3]
    587         all_edges = self._swap_edges(all_edges_undirected)
    588 
    589         # Remaining point in the triangle for each edge.
    590         # This will be used to compute the normals later
    591         # [num_prim, 3, 3]
    592         remaining_vertex = tf.concat([v2,
    593                                       v0,
    594                                       v1], axis=1)
    595         # [num_edges, 3]
    596         remaining_vertex = tf.reshape(remaining_vertex, [-1, 3])
    597 
    598         # Get unique edges, i.e., wihout duplicates
    599         # unique_edges : [num_unique_edges, 2, 3]
    600         # indices_of_unique : [num_edges], index of the edge in ``unique_edges``
    601         unique_edges, indices_of_unique = tf.raw_ops.UniqueV2(x=all_edges,
    602                                                               axis=[0])
    603 
    604         # Number of occurences of every unique edge
    605         # [num_unique_edges]
    606         _, _, unique_indices_count = tf.unique_with_counts(indices_of_unique)
    607 
    608         # Flag indicating which edges shared by exactly one or two primitives,
    609         # i.e., edges or wedges
    610         # [num_unique_edges]
    611         is_selected = tf.logical_or(tf.equal(unique_indices_count, 1),
    612                                     tf.equal(unique_indices_count, 2))
    613 
    614         # The following tensor lists the index of the first
    615         # edge in ``all_edges`` that makes the wedge.
    616         # Note: In the presence of duplicate values in `indices_of_unique`
    617         # (i.e. our case), it is not deterministic which of them
    618         # `tensor_scatter_nd_update` will store into the result. That's okay.
    619         # [num_edges]
    620         seq = tf.cast(tf.range(indices_of_unique.shape[0]), dtype=tf.int32)
    621         # [num_unique_edges]
    622         default = tf.cast(tf.fill(dims=unique_edges.shape[0], value=-1),
    623                           dtype=tf.int32)
    624         # [num_unique_edges]
    625         all_edges_index_1 = tf.tensor_scatter_nd_update(default,
    626                                             indices_of_unique[:, None], seq)
    627         # Next, list the second primitive that the wedge is connected to.
    628         # -1 is used for edges defined by a single primitive (screens)
    629         # [num_unique_edges]
    630         false_value = tf.fill(dims=all_edges_index_1.shape[0], value=False)
    631         # [num_edges]
    632         missing = tf.fill(dims=indices_of_unique.shape[0], value=True)
    633         missing = tf.tensor_scatter_nd_update(missing,
    634                                               all_edges_index_1[:, None],
    635                                               false_value)
    636         # [num_unique_edges]
    637         all_edges_index_2 = tf.tensor_scatter_nd_update(default,
    638                             indices_of_unique[missing][:, None], seq[missing])
    639         # Flag set to True if an edge is not a wedge, i.e., is attached to only
    640         # one primitive. This is the case for unique edges for which
    641         # ``all_edges_index_2`` is not set to -1
    642         # [num_unique_edges]
    643         is_edge = tf.equal(all_edges_index_2, -1)
    644         # For edges, the unique edge primitive is set to the same value for both
    645         # ``all_edges_index_1`` and ``all_edges_index_2``. This will lead to
    646         # mapping to the same primitive
    647         # [num_unique_edges]
    648         all_edges_index_2 = tf.where(is_edge, all_edges_index_1,
    649                                      all_edges_index_2)
    650 
    651         # Normals to the faces
    652         # We compute the normals such that they point in the same direction
    653         # of the space for both primitives making a wedge.
    654         # To that aim, the normal for the face 0(n) is defined as the
    655         # cross-product between:
    656         # - The edge vector
    657         # - The vector connecting a vertex of the edge to the third
    658         #   point of the triangle corresponding to the 0(n) face to whom this
    659         #   edge belongs.
    660         # Edge vertices
    661         # [num_unique_edges, 2, 3]
    662         vs = tf.gather(all_edges, all_edges_index_1)
    663         # [num_unique_edges, 3]
    664         v1 = vs[:,0]
    665         v2 = vs[:,1]
    666         # Edge vector
    667         # [num_unique_edges, 3]
    668         e = v2 - v1
    669         # Vertex on the 0 and n faces
    670         # [num_unique_edges, 3]
    671         vf1 = tf.gather(remaining_vertex, all_edges_index_1)
    672         vf2 = tf.gather(remaining_vertex, all_edges_index_2)
    673         # [num_unique_edges, 3]
    674         u_1,_ = normalize(vf1 - v1)
    675         u_2,_ = normalize(vf2 - v1)
    676         # [num_unique_edges, 3]
    677         n1,_ = normalize(cross(e, u_1))
    678         n2,_ = normalize(cross(u_2, e))
    679 
    680         # We flip the normals if necessary to ensure that they point towards
    681         # the "exterior" of the wedge, i.e., that the exterior angle is at least
    682         # pi.
    683         # To ensure that, we orient the normals such that u2 does not point
    684         # towards the half-space defined by n1.
    685         # [num_unique_edges]
    686         cos_angle = dot(u_2, n1)
    687         # Three cases:
    688         # * cos_angle > 0: u2 points towards the same half space as n1.
    689         #   We must flip n1 and n2.
    690         # * cos_angle < 0: u2 points towards the same half space as n1. Do nothing
    691         # * cos_angle = 0: u2 is orthogonal to the n1, i.e., the two primitives are
    692         # parallel. In that cases, either the wedge is an edge, or it is a flat
    693         # surface that must be discarded
    694         # [num_unique_edges]
    695         flip = tf.where(tf.greater(cos_angle, tf.zeros_like(cos_angle)),
    696                         -tf.ones_like(cos_angle),
    697                         tf.ones_like(cos_angle))
    698         # [num_unique_edges, 1]
    699         flip = tf.expand_dims(flip, axis=1)
    700         # [num_unique_edges, 3]
    701         n1 = n1*flip
    702         n2 = n2*flip
    703         # Discard the wedges considered as flat, i.e., with an opening angle
    704         # close to PI up to `angle_threshold`
    705         # We use this observation to discard close-to-flat wedges.
    706         # cos_angle = dot(u_2, n1) = cos(theta)
    707         # where theta= angle(u_2, n1).
    708         # Then, we want:
    709         # pi/2 - angle_threshold < theta < pi/2 + angle_threshold
    710         # => cos(pi/2 + angle_threshold) < cos(theta)
    711         #                                       < cos(pi/2 - angle_threshold)
    712         # => -sin(angle_threshold) < cos_angle < sin(angle_threshold)
    713         # ()
    714         theshold = tf.abs(tf.math.sin(tf.cast(angle_threshold, self._rdtype)))
    715         # [num_unique_edges]
    716         is_selected_ = tf.greater(tf.abs(cos_angle),theshold)
    717         # Don't discard edges
    718         # [num_unique_edges]
    719         is_selected_ = tf.logical_or(is_edge, is_selected_)
    720         # [num_unique_edges]
    721         is_selected = tf.logical_and(is_selected, is_selected_)
    722 
    723         # Extract only the selected lanes
    724         # [num_selected_edges]
    725         selected_indices = tf.where(is_selected)[:, 0]
    726         # [num_selected_edges, 2, 3]
    727         selected_edges = unique_edges[is_selected]
    728         # [num_selected_edges, 3]
    729         selected_wedges_start = selected_edges[:,0]
    730         selected_wedges_end = selected_edges[:,1]
    731         # [num_selected_edges, 3]
    732         n1 = n1[is_selected]
    733         n2 = n2[is_selected]
    734         # [num_selected_edges, 2, 3]
    735         # n1: 0-face
    736         # n2: n-face
    737         normals = tf.stack([n1, n2], axis=1)
    738 
    739         # Pre-compute a mapping from primitive index to (up to) three wedges.
    740         # Recall that by construction, `all_edges` is ordered by
    741         # primitive (3 rows per primitive).
    742         # Then, `indices_of_unique` gives the mapping from the row indices
    743         # of `all_edges` to the row indices of `unique_edges`.
    744         # Finally, a subset of `unique_edges` via the `is_double` mask.
    745         #
    746         # In the end, all we need to do is renumber the values of
    747         # `indices_of_unique` to refer to rows of `selected_edges` instead of
    748         # rows of `unique_edges`.
    749         seq = tf.cast(tf.range(selected_edges.shape[0]), dtype=tf.int32)
    750         unique_edge_index_to_double_edge_index = \
    751             tf.tensor_scatter_nd_update(default, selected_indices[:, None], seq)
    752         # [num_prim, 3]
    753         prim_to_wedges = tf.reshape(
    754             tf.gather(unique_edge_index_to_double_edge_index,indices_of_unique),
    755             (-1, 3)
    756         )
    757 
    758         # Indices of the objects to which each edge belongs
    759         # First, indices (in all_edges) of edges
    760         # [num_unique_edges, 2]
    761         wedges_indices = tf.stack([all_edges_index_1,
    762                                    all_edges_index_2], axis=1)
    763         # Keep only the selected wedges
    764         # [num_selected_edges, 2]
    765         wedges_indices = wedges_indices[is_selected]
    766         # [num_selected_edges]
    767         is_edge = is_edge[is_selected]
    768         # Wedges index 2 primitive index
    769         # [num_selected_edges, 2]
    770         wedges_2_prim = wedges_indices//3
    771         # Primitive index 2 object index
    772         # [num_selected_edges, 2]
    773         wedges_2_object = tf.gather(self._primitives_2_objects, wedges_2_prim)
    774 
    775         # Edges length and edge vector
    776         # The edge vector e_hat must be such that:
    777         #   normalize(n_0 x n_n) = e_hat,
    778         # where n_0 is the normal to the 0-face and n_n the normal to the
    779         # n-face
    780         # [num_selected_edges, 3]
    781         e_hat,_ = normalize(cross(normals[...,0,:],normals[...,1,:]))
    782         # Select the wedges' origin according to the normals
    783         # e_hat_ind: [num_selected_edges, 3]
    784         # length : [num_selected_edges]
    785         e_hat_ind,length = normalize(selected_wedges_end-selected_wedges_start)
    786         # [num_selected_edges]
    787         origin_indicator = dot(e_hat, e_hat_ind)
    788         # [num_selected_edges, 3]
    789         origin_indicator = tf.expand_dims(origin_indicator, axis=1)
    790         origin = tf.where(origin_indicator < 0,
    791                           selected_wedges_end,
    792                           selected_wedges_start)
    793         # Set arbitrarely the vector for the edges
    794         # [num_selected_edges, 3]
    795         e_hat = tf.where(tf.expand_dims(is_edge, axis=1), e_hat_ind, e_hat)
    796 
    797         # Output
    798         output = (
    799                     origin,                 # wedges origins
    800                     e_hat,                  # Wedge vector
    801                     length,                 # Wedge length
    802                     normals,                # wedges_normals
    803                     prim_to_wedges,         # primitives_2_wedges
    804                     wedges_2_object,        # wedges_objects
    805                     is_edge                 # is_edge
    806                  )
    807 
    808         return output
    809 
    810     def _swap_edges(self, edges):
    811         """Swap edges extremities such that identical edges are oriented in
    812         the same way.
    813 
    814         Parameters
    815         ----------
    816         edges : [...,2,3], float
    817             Batch of edges extremities
    818 
    819         Returns
    820         -------
    821         [..., 2, 3], float
    822             Reoriented edges
    823         """
    824         p0 = edges[:,0,:]
    825         p1 = edges[:,1,:]
    826         p0_hat, r0 = normalize(p0)
    827         p1_hat, r1 = normalize(p1)
    828         theta0, phi0 = theta_phi_from_unit_vec(p0_hat)
    829         theta1, phi1 = theta_phi_from_unit_vec(p1_hat)
    830 
    831         # Three considtions are used to orientate the edges
    832         # by swapping the extremities (p0,p1).
    833         # Condition n+1 is used only if none of the previous n conditions enabled
    834         # to separate the edges.
    835         # 1. norm(p1) >= norm(p0)
    836         # 2. azimuth(p1) >= azimuth(p0)
    837         # 3. elevation (p1) >= elevation(p0)
    838 
    839         # More details of the algorithm:
    840         # needs_swap 1: !r_equal and r0 > r1
    841         # needs_swap 2: r_equal and !phi_equal and phi0 > phi1
    842         # needs_swap 3: r_equal and phi_equal and theta0 > theta1
    843         # Note: case when all three coordinates are equal is not considered
    844 
    845         r_equal = tf.experimental.numpy.isclose(r0, r1)
    846         phi_equal = tf.experimental.numpy.isclose(phi0, phi1)
    847         case_2 = tf.logical_and(r_equal, tf.logical_not(phi_equal))
    848         case_3 = tf.logical_and(r_equal, phi_equal)
    849 
    850         needs_swap_1 = tf.logical_and(tf.logical_not(r_equal), r0 > r1)
    851         needs_swap_2 = tf.logical_and(case_2, phi0 > phi1)
    852         needs_swap_3 = tf.logical_and(case_3, theta0 > theta1)
    853 
    854         needs_swap = tf.reduce_any(tf.stack([needs_swap_1,
    855                                              needs_swap_2,
    856                                              needs_swap_3], axis=1),
    857                                    keepdims=True, axis=1)
    858 
    859         result = tf.concat([
    860             tf.expand_dims(tf.where(needs_swap, p1, p0), axis=1),
    861             tf.expand_dims(tf.where(needs_swap, p0, p1), axis=1),
    862         ], axis=1)
    863         return result
    864 
    865     def _wedges_from_primitives(self, candidates, edge_diffraction):
    866         r"""
    867         Returns the candidate wedges from the candidate primitives.
    868 
    869         As only first-order diffraction is considered, only the wedges of the
    870         primitives in line-of-sight of the transmitter are considered.
    871 
    872         Input
    873         ------
    874         candidates: [max_depth, num_samples], int
    875             Candidate paths with depth up to ``max_depth``.
    876             Entries correspond to primitives indices.
    877             For paths with depth lower than ``max_depth``, -1 is used as
    878             padding value.
    879             The first path is the LoS one.
    880 
    881         edge_diffraction : bool
    882             If set to `False`, only diffraction on wedges, i.e., edges that
    883             connect two primitives, is considered.
    884 
    885         Output
    886         -------
    887         candidate_wedges : [num_candidate_wedges], int
    888             Candidate wedges.
    889             Entries correspond to wedges indices.
    890         """
    891 
    892         # If no candidates, return an empty list
    893         # Useful to manage empty scenes
    894         if candidates.shape[0] == 0:
    895             return tf.constant([], tf.int32)
    896 
    897         # Remove -1
    898         candidates = tf.gather(candidates,
    899                                tf.where(tf.not_equal(candidates, -1))[:,0])
    900 
    901         # Remove duplicates
    902         candidates,_ = tf.unique(candidates)
    903 
    904         # [num_samples, 3]
    905         candidate_wedges = tf.gather(self._primitives_2_wedges, candidates,
    906                                      axis=0)
    907         # [num_samples*3]
    908         candidate_wedges = tf.reshape(candidate_wedges, [-1])
    909 
    910         # Remove -1
    911         # [<= num_samples*3]
    912         candidate_wedges = tf.gather(candidate_wedges,
    913                             tf.where(tf.not_equal(candidate_wedges, -1))[:,0])
    914 
    915         # Remove duplicates
    916         # [num_candidate_wedges]
    917         candidate_wedges,_ = tf.unique(candidate_wedges)
    918 
    919         # Remove edges if required
    920         if not edge_diffraction:
    921             # [num_candidate_wedges]
    922             is_wedge = ~tf.gather(self._is_edge, candidate_wedges)
    923             wedge_indices = tf.where(is_wedge)[:,0]
    924             # [num_candidate_wedges]
    925             candidate_wedges = tf.gather(candidate_wedges, wedge_indices)
    926 
    927         return candidate_wedges
    928 
    929     def _build_mi_ris_objects(self):
    930         r"""
    931         Builds a Mitsuba scene containing all RIS as rectangles with position,
    932         orientation, and size matching the RIS properties.∂
    933 
    934         Output
    935         ------
    936         : list(mi.Rectangle)
    937             List of Mitsuba rectangles implementing the RIS
    938 
    939         : mi.UInt
    940             RIS indices
    941         """
    942         # List of all the RIS objects in the scene
    943         all_ris = list(self._scene.ris.values())
    944         num_ris = len(all_ris)
    945 
    946         # Creates a scene containing RIS as rectangles
    947         mi_ris_objects = []
    948         ris_indices = dr.zeros(mi.ShapePtr, num_ris)
    949         for i, ris in enumerate(all_ris):
    950             center = ris.position
    951             orientation = ris.orientation
    952             size = ris.size
    953             mi_to_world = mitsuba_rectangle_to_world(center, orientation, size,
    954                                                      ris=True)
    955             ris_rect = mi.load_dict({   "type"     : "rectangle",
    956                                         "to_world" : mi_to_world
    957                                     })
    958             mi_ris_objects.append(ris_rect)
    959             dr.scatter(ris_indices, mi.ShapePtr(ris_rect), i)
    960         ris_indices = dr.reinterpret_array_v(mi.UInt32, ris_indices)
    961 
    962         return mi_ris_objects, ris_indices
    963 
    964     def _ris_intersect(self, ris_objects, ray, active):
    965         r"""
    966         Test the intersection with the RIS
    967 
    968         Input
    969         ------
    970         ris_objects : list(mi.Rectangle)
    971             List of Mitsuba rectangles implementing the RIS
    972 
    973         Output
    974         -------
    975         valid : mi.Bool
    976             Mask indicating if the intersection is valid
    977 
    978         t : mi.Float
    979             Position of the intersection on the ray
    980 
    981         indices : mi.UInt32
    982             Indices of the intersected RIS
    983         """
    984 
    985         num_rays = dr.shape(ray.d)[1]
    986         t = dr.full(mi.Float, float('inf'), num_rays)
    987         valid = dr.full(mi.Bool, False, num_rays)
    988         indices = dr.full(mi.UInt, 0, num_rays)
    989         for ris in ris_objects:
    990             si_ris = ris.ray_intersect(ray, active=active)
    991             v_ = si_ris.is_valid()
    992             t_ = si_ris.t
    993             indices_ = dr.reinterpret_array_v(mi.UInt32, si_ris.shape)
    994 
    995             valid |= v_
    996             new_closest = v_ & (t_ < t)
    997 
    998             t = dr.select(new_closest, t_, t)
    999             indices = dr.select(new_closest, indices_, indices)
   1000 
   1001         return valid, t, indices