anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

scene_object.py (16110B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 Class representing objects in the scene
      7 """
      8 import tensorflow as tf
      9 
     10 from .object import Object
     11 from .radio_material import RadioMaterial
     12 import drjit as dr
     13 import mitsuba as mi
     14 from .utils import mi_to_tf_tensor, angles_to_mitsuba_rotation, normalize,\
     15     theta_phi_from_unit_vec
     16 from sionna.constants import PI
     17 
     18 
     19 class SceneObject(Object):
     20     # pylint: disable=line-too-long
     21     r"""
     22     SceneObject()
     23 
     24     Every object in the scene is implemented by an instance of this class
     25     """
     26 
     27     def __init__(self,
     28                  name,
     29                  object_id=None,
     30                  mi_shape=None,
     31                  radio_material=None,
     32                  orientation=(0,0,0),
     33                  dtype=tf.complex64,
     34                  **kwargs):
     35 
     36         if dtype not in (tf.complex64, tf.complex128):
     37             raise ValueError("`dtype` must be tf.complex64 or tf.complex128`")
     38         self._dtype = dtype
     39         self._rdtype = dtype.real_dtype
     40 
     41         # Set initial orientation of the object
     42         self._orientation = tf.cast(orientation, dtype=self._rdtype)
     43         # Set velocity vector
     44         self._velocity = tf.zeros((3,), dtype=self._rdtype)
     45 
     46         # Initialize the base class Object
     47         super().__init__(name, **kwargs)
     48 
     49         # Set the radio material
     50         self.radio_material = radio_material
     51 
     52         # Set the object id
     53         self.object_id = object_id
     54 
     55         # Set the Mitsuba shape
     56         self._mi_shape = mi_shape
     57 
     58 
     59         if self._dtype == tf.complex64:
     60             self._mi_point_t = mi.Point3f
     61             self._mi_vec_t = mi.Vector3f
     62             self._mi_scalar_t = mi.Float
     63             self._mi_transform_t = mi.Transform4f
     64         else:
     65             self._mi_point_t = mi.Point3d
     66             self._mi_vec_t = mi.Vector3d
     67             self._mi_scalar_t = mi.Float64
     68             self._mi_transform_t = mi.Transform4d
     69 
     70     @property
     71     def object_id(self):
     72         r"""
     73         int : Get/set the identifier of this object
     74         """
     75         return self._object_id
     76 
     77     @object_id.setter
     78     def object_id(self, v):
     79         self._object_id = v
     80 
     81     @property
     82     def radio_material(self):
     83         r"""
     84         :class:`~sionna.rt.RadioMaterial` : Get/set the radio material of the
     85         object. Setting can be done by using either an instance of
     86         :class:`~sionna.rt.RadioMaterial` or the material name (`str`).
     87         If the radio material is not part of the scene, it will be added. This
     88         can raise an error if a different radio material with the same name was
     89         already added to the scene.
     90         """
     91         return self._radio_material
     92 
     93     @radio_material.setter
     94     def radio_material(self, mat):
     95         # Note: _radio_material is set at __init__, but pylint doesn't see it.
     96         if mat is None:
     97             mat_obj = None
     98 
     99         elif isinstance(mat, str):
    100             mat_obj = self.scene.get(mat)
    101             if (mat_obj is None) or (not isinstance(mat_obj, RadioMaterial)):
    102                 err_msg = f"Unknown radio material '{mat}'"
    103                 raise TypeError(err_msg)
    104 
    105         elif not isinstance(mat, RadioMaterial):
    106             err_msg = ("The material must be a material name (str) or an "
    107                         "instance of RadioMaterial")
    108             raise TypeError(err_msg)
    109 
    110         else:
    111             mat_obj = mat
    112 
    113         # Remove the object from the set of the currently used material, if any
    114         # pylint: disable=access-member-before-definition
    115         if hasattr(self, '_radio_material') and self._radio_material:
    116             self._radio_material.discard_object_using(self.object_id)
    117         # Assign the new material
    118         # pylint: disable=access-member-before-definition
    119         self._radio_material = mat_obj
    120 
    121         # If the radio material is set to None, we can stop here
    122         # pylint: disable=access-member-before-definition
    123         if not self._radio_material:
    124             return
    125 
    126         # Add the object to the set of the newly used material
    127         # pylint: disable=access-member-before-definition
    128         self._radio_material.add_object_using(self.object_id)
    129 
    130         # Add the RadioMaterial to the scene if not already done
    131         self.scene.add(self._radio_material)
    132 
    133     @property
    134     def velocity(self):
    135         """
    136         [3], tf.float : Get/set the velocity vector [m/s]
    137         """
    138         return self._velocity
    139 
    140     @velocity.setter
    141     def velocity(self, v):
    142         if not tf.shape(v)==3:
    143             raise ValueError("`velocity` must have shape [3]")
    144         self._velocity = tf.cast(v, self._rdtype)
    145 
    146     @property
    147     def position(self):
    148         """
    149         [3], tf.float : Get/set the position vector [m] of the center
    150             of the object. The center is defined as the object's axis-aligned
    151             bounding box (AABB).
    152         """
    153         dr.sync_thread()
    154         rdtype = self._scene.dtype.real_dtype
    155         # Bounding box
    156         # [3]
    157         bbox_min = tf.cast(self._mi_shape.bbox().min, rdtype)
    158         # [3]
    159         bbox_max = tf.cast(self._mi_shape.bbox().max, rdtype)
    160         # [3]
    161         half = tf.cast(0.5, self._rdtype)
    162         position = half*(bbox_min + bbox_max)
    163         return position
    164 
    165     @position.setter
    166     def position(self, new_position):
    167 
    168         ## Update Mitsuba vertices
    169 
    170         # Scene parameters
    171         scene_params = self._scene.mi_scene_params
    172         vp_key = self._vertex_params_name(self.name)
    173 
    174         # Real dtype
    175         rdtype = self._scene.dtype.real_dtype
    176         new_position = tf.cast(new_position, rdtype)
    177         # [num_vertices*3]
    178         vertices = scene_params[vp_key]
    179         # [num_vertices,3]
    180         vertices = mi_to_tf_tensor(vertices, rdtype)
    181         vertices = tf.reshape(vertices, [-1, 3])
    182         # [3]
    183         position = self.position
    184         # [3]
    185         translation_vector = new_position - position
    186         # [1,3]
    187         translation_vector = tf.expand_dims(translation_vector, axis=0)
    188         # [num_vertices,3]
    189         translated_vertices = vertices + translation_vector
    190         # Cast to Mitsuba type to object the Mitsuba scene
    191         fltn_translated_vertices = tf.reshape(translated_vertices, [-1])
    192         fltn_translated_vertices = self._mi_scalar_t(fltn_translated_vertices)
    193         #
    194         scene_params[vp_key] =\
    195             fltn_translated_vertices
    196         scene_params.update()
    197 
    198         ## Update Sionna vertices
    199 
    200         obj_id = self.object_id
    201         mi_shape = self._mi_shape
    202         solver_paths = self._scene.solver_paths
    203 
    204         prim_offset = solver_paths.prim_offsets[obj_id]
    205 
    206         face_indices3 = mi_shape.face_indices(dr.arange(mi.UInt32,
    207                                                         mi_shape.face_count()))
    208         # Flatten. This is required for calling vertex_position
    209         # [n_prims*3]
    210         face_indices = dr.ravel(face_indices3)
    211         # Get vertices coordinates
    212         # [n_prims*3, 3]
    213         vertex_coords = mi_shape.vertex_position(face_indices)
    214         # Cast to TensorFlow type
    215         # [n_prims*3, 3]
    216         vertex_coords = mi_to_tf_tensor(vertex_coords, rdtype)
    217         # Unflatten
    218         # [n_prims, vertices per triangle : 3, 3]
    219         vertex_coords = tf.reshape(vertex_coords, [mi_shape.face_count(), 3, 3])
    220         # Update the tensor storing the primitive vertices
    221         sl = tf.range(prim_offset, prim_offset + mi_shape.face_count(),
    222                     dtype=tf.int32)
    223         sl = tf.expand_dims(sl, axis=1)
    224         solver_paths.primitives.scatter_nd_update(sl, vertex_coords)
    225 
    226         ## Update Sionna wedges
    227 
    228         wedges_objects = solver_paths.wedges_objects
    229         wedges_origin = solver_paths.wedges_origin
    230 
    231         # Indices of the wedges corresponding to this object
    232         # [num_wedges]
    233         wedges_ind, _ = tf.unique(tf.where(wedges_objects == obj_id)[:,0])
    234 
    235         # Corresponding origins
    236         # [num_wedges, 3]
    237         wedges_origin = tf.gather(wedges_origin, wedges_ind, axis=0)
    238 
    239         # Translates the wedges
    240         # [num_wedges, 3]
    241         wedges_origin += translation_vector
    242 
    243         # Updates the wedges
    244         wedges_ind = tf.expand_dims(wedges_ind, axis=1)
    245         solver_paths.wedges_origin.scatter_nd_update(wedges_ind, wedges_origin)
    246 
    247         # Trigger scene callback
    248         self._scene.scene_geometry_updated()
    249 
    250     @property
    251     def orientation(self):
    252         r"""
    253         [3], tf.float : Get/set the orientation :math:`(\alpha, \beta, \gamma)`
    254             [rad] specified through three angles corresponding to a
    255             3D rotation as defined in :eq:`rotation`.
    256         """
    257         return self._orientation
    258 
    259     @orientation.setter
    260     def orientation(self, new_orient):
    261 
    262         # Real dtype
    263         new_orient = tf.cast(new_orient, self._rdtype)
    264 
    265         # Build the transformtation corresponding to the new rotation
    266         new_rotation = angles_to_mitsuba_rotation(new_orient)
    267 
    268         # Invert the current orientation
    269         cur_rotation = angles_to_mitsuba_rotation(self._orientation.numpy())
    270         inv_cur_rotation = cur_rotation.inverse()
    271 
    272         # Build the transform.
    273         # The object is first translated to the origin, then rotated, then
    274         # translated back to its current position
    275         transform =  (  self._mi_transform_t.translate(self.position.numpy())
    276                       @ new_rotation
    277                       @ inv_cur_rotation
    278                       @ self._mi_transform_t.translate(-self.position.numpy()) )
    279 
    280         ## Update Mitsuba vertices
    281 
    282         # Scene parameters
    283         scene_params = self._scene.mi_scene_params
    284         vp_key = self._vertex_params_name(self.name)
    285 
    286         # [num_vertices*3]
    287         vertices = scene_params[vp_key]
    288         # [num_vertices,3]
    289         vertices = dr.unravel(self._mi_point_t, vertices)
    290         # Apply the transform
    291         vertices = transform.transform_affine(vertices)
    292         # Cast to Mitsuba type to object the Mitsuba scene
    293         fltn_vertices = tf.reshape(vertices, [-1])
    294         fltn_vertices = tf.cast(fltn_vertices, tf.float32)
    295         scene_params[vp_key] = fltn_vertices
    296         scene_params.update()
    297 
    298         ## Update Sionna vertices
    299 
    300         obj_id = self.object_id
    301         mi_shape = self._mi_shape
    302         solver_paths = self._scene.solver_paths
    303 
    304         prim_offset = solver_paths.prim_offsets[obj_id]
    305 
    306         face_indices3 = mi_shape.face_indices(dr.arange(mi.UInt32,
    307                                                         mi_shape.face_count()))
    308         # Flatten. This is required for calling vertex_position
    309         # [n_prims*3]
    310         face_indices = dr.ravel(face_indices3)
    311         # Get vertices coordinates
    312         # [n_prims*3, 3]
    313         vertex_coords = mi_shape.vertex_position(face_indices)
    314         # Cast to TensorFlow type
    315         # [n_prims*3, 3]
    316         vertex_coords = mi_to_tf_tensor(vertex_coords, self._rdtype)
    317         # Unflatten
    318         # [n_prims, vertices per triangle : 3, 3]
    319         vertex_coords = tf.reshape(vertex_coords, [mi_shape.face_count(), 3, 3])
    320         # Update the tensor storing the primitive vertices
    321         sl = tf.range(prim_offset, prim_offset + mi_shape.face_count(),
    322                     dtype=tf.int32)
    323         sl = tf.expand_dims(sl, axis=1)
    324         solver_paths.primitives.scatter_nd_update(sl, vertex_coords)
    325 
    326         ## Update Sionna normals
    327 
    328         # Get vertices coordinates
    329         # [n_prims, 3]
    330         normals = solver_paths.normals.gather_nd(sl)
    331         # Cast to Mitsuba Vector
    332         # [n_prims, 3]
    333         normals = self._mi_vec_t(normals)
    334         # Rotate the normals
    335         normals = transform.transform_affine(normals)
    336         # Cast to Tensorflow type
    337         # [n_prims, 3]
    338         normals = mi_to_tf_tensor(normals, self._rdtype)
    339         # Update the tensor storing the primitive vertices
    340         solver_paths.normals.scatter_nd_update(sl, normals)
    341 
    342         ## Update Sionna wedges
    343 
    344         wedges_objects = solver_paths.wedges_objects
    345         wedges_origin = solver_paths.wedges_origin
    346         wedges_e_hat = solver_paths.wedges_e_hat
    347         wedges_normals = solver_paths.wedges_normals
    348 
    349         # Indices of the wedges corresponding to this object
    350         # [num_wedges]
    351         wedges_ind, _ = tf.unique(tf.where(wedges_objects == obj_id)[:,0])
    352 
    353         # Corresponding origins, e_hat, and normals
    354         # [num_wedges, 3]
    355         wedges_origin = tf.gather(wedges_origin, wedges_ind, axis=0)
    356         # [num_wedges, 3]
    357         wedges_e_hat = tf.gather(wedges_e_hat, wedges_ind, axis=0)
    358         # [num_wedges, 3]
    359         wedges_normals = tf.gather(wedges_normals, wedges_ind, axis=0)
    360         # [num_wedges*2, 3]
    361         wedges_normals = tf.reshape(wedges_normals, [-1, 3])
    362 
    363         # Cast to Mitsuba types
    364         # [num_wedges, 3]
    365         wedges_origin = self._mi_point_t(wedges_origin)
    366         # [num_wedges, 3]
    367         wedges_e_hat = self._mi_vec_t(wedges_e_hat)
    368         # [num_wedges*2, 3]
    369         wedges_normals = self._mi_vec_t(wedges_normals)
    370 
    371         # Rotate all quantities
    372         # [num_wedges, 3]
    373         wedges_origin = transform.transform_affine(wedges_origin)
    374          # [num_wedges, 3]
    375         wedges_e_hat = transform.transform_affine(wedges_e_hat)
    376          # [num_wedges*2, 3]
    377         wedges_normals = transform.transform_affine(wedges_normals)
    378 
    379         # Cast to Tensorflow type
    380         # [num_wedges, 3]
    381         wedges_origin = mi_to_tf_tensor(wedges_origin, self._rdtype)
    382         # [num_wedges, 3]
    383         wedges_e_hat = mi_to_tf_tensor(wedges_e_hat, self._rdtype)
    384         # [num_wedges*2, 3]
    385         wedges_normals = mi_to_tf_tensor(wedges_normals, self._rdtype)
    386         # [num_wedges, 2, 3]
    387         wedges_normals = tf.reshape(wedges_normals, [-1, 2, 3])
    388 
    389         # Updates the wedges
    390         wedges_ind = tf.expand_dims(wedges_ind, axis=1)
    391         solver_paths.wedges_origin.scatter_nd_update(wedges_ind, wedges_origin)
    392         solver_paths.wedges_e_hat.scatter_nd_update(wedges_ind, wedges_e_hat)
    393         solver_paths.wedges_normals.scatter_nd_update(wedges_ind,
    394                                                       wedges_normals)
    395 
    396         self._orientation = new_orient
    397 
    398         # Trigger scene callback
    399         self._scene.scene_geometry_updated()
    400 
    401     def look_at(self, target):
    402         # pylint: disable=line-too-long
    403         r"""
    404         Sets the orientation so that the x-axis points toward an
    405         ``Object``.
    406 
    407         Input
    408         -----
    409         target : [3], float | :class:`sionna.rt.Object` | str
    410             A position or the name or instance of an
    411             :class:`sionna.rt.Object` in the scene to point toward to
    412         """
    413         # Get position to look at
    414         if isinstance(target, str):
    415             obj = self.scene.get(target)
    416             if not isinstance(obj, Object):
    417                 msg = f"No camera, device, or object named '{target}' found."
    418                 raise ValueError(msg)
    419             else:
    420                 target = obj.position
    421         elif isinstance(target, Object):
    422             target = target.position
    423         else:
    424             target = tf.cast(target, dtype=self._rdtype)
    425             if not target.shape[0]==3:
    426                 raise ValueError("`target` must be a three-element vector)")
    427 
    428         # Compute angles relative to LCS
    429         x = target - self.position
    430         x, _ = normalize(x)
    431         theta, phi = theta_phi_from_unit_vec(x)
    432         alpha = phi # Rotation around z-axis
    433         beta = theta-PI/2 # Rotation around y-axis
    434         gamma = 0.0 # Rotation around x-axis
    435         self.orientation = (alpha, beta, gamma)
    436 
    437     def _vertex_params_name(self, mesh_id, scene_params=None):
    438         """
    439         Since any `mesh-` prefix was removed from `self.name`, we may need
    440         to add it back here before trying to access the vertex positions
    441         variable in the scene parameters object.
    442         """
    443         if scene_params is None:
    444             scene_params = self._scene.mi_scene_params
    445 
    446         key = mesh_id + ".vertex_positions"
    447         with_prefix = 'mesh-' + key
    448         if with_prefix in scene_params:
    449             return with_prefix
    450         return key