scene_object.py (16110B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """ 6 Class representing objects in the scene 7 """ 8 import tensorflow as tf 9 10 from .object import Object 11 from .radio_material import RadioMaterial 12 import drjit as dr 13 import mitsuba as mi 14 from .utils import mi_to_tf_tensor, angles_to_mitsuba_rotation, normalize,\ 15 theta_phi_from_unit_vec 16 from sionna.constants import PI 17 18 19 class SceneObject(Object): 20 # pylint: disable=line-too-long 21 r""" 22 SceneObject() 23 24 Every object in the scene is implemented by an instance of this class 25 """ 26 27 def __init__(self, 28 name, 29 object_id=None, 30 mi_shape=None, 31 radio_material=None, 32 orientation=(0,0,0), 33 dtype=tf.complex64, 34 **kwargs): 35 36 if dtype not in (tf.complex64, tf.complex128): 37 raise ValueError("`dtype` must be tf.complex64 or tf.complex128`") 38 self._dtype = dtype 39 self._rdtype = dtype.real_dtype 40 41 # Set initial orientation of the object 42 self._orientation = tf.cast(orientation, dtype=self._rdtype) 43 # Set velocity vector 44 self._velocity = tf.zeros((3,), dtype=self._rdtype) 45 46 # Initialize the base class Object 47 super().__init__(name, **kwargs) 48 49 # Set the radio material 50 self.radio_material = radio_material 51 52 # Set the object id 53 self.object_id = object_id 54 55 # Set the Mitsuba shape 56 self._mi_shape = mi_shape 57 58 59 if self._dtype == tf.complex64: 60 self._mi_point_t = mi.Point3f 61 self._mi_vec_t = mi.Vector3f 62 self._mi_scalar_t = mi.Float 63 self._mi_transform_t = mi.Transform4f 64 else: 65 self._mi_point_t = mi.Point3d 66 self._mi_vec_t = mi.Vector3d 67 self._mi_scalar_t = mi.Float64 68 self._mi_transform_t = mi.Transform4d 69 70 @property 71 def object_id(self): 72 r""" 73 int : Get/set the identifier of this object 74 """ 75 return self._object_id 76 77 @object_id.setter 78 def object_id(self, v): 79 self._object_id = v 80 81 @property 82 def radio_material(self): 83 r""" 84 :class:`~sionna.rt.RadioMaterial` : Get/set the radio material of the 85 object. Setting can be done by using either an instance of 86 :class:`~sionna.rt.RadioMaterial` or the material name (`str`). 87 If the radio material is not part of the scene, it will be added. This 88 can raise an error if a different radio material with the same name was 89 already added to the scene. 90 """ 91 return self._radio_material 92 93 @radio_material.setter 94 def radio_material(self, mat): 95 # Note: _radio_material is set at __init__, but pylint doesn't see it. 96 if mat is None: 97 mat_obj = None 98 99 elif isinstance(mat, str): 100 mat_obj = self.scene.get(mat) 101 if (mat_obj is None) or (not isinstance(mat_obj, RadioMaterial)): 102 err_msg = f"Unknown radio material '{mat}'" 103 raise TypeError(err_msg) 104 105 elif not isinstance(mat, RadioMaterial): 106 err_msg = ("The material must be a material name (str) or an " 107 "instance of RadioMaterial") 108 raise TypeError(err_msg) 109 110 else: 111 mat_obj = mat 112 113 # Remove the object from the set of the currently used material, if any 114 # pylint: disable=access-member-before-definition 115 if hasattr(self, '_radio_material') and self._radio_material: 116 self._radio_material.discard_object_using(self.object_id) 117 # Assign the new material 118 # pylint: disable=access-member-before-definition 119 self._radio_material = mat_obj 120 121 # If the radio material is set to None, we can stop here 122 # pylint: disable=access-member-before-definition 123 if not self._radio_material: 124 return 125 126 # Add the object to the set of the newly used material 127 # pylint: disable=access-member-before-definition 128 self._radio_material.add_object_using(self.object_id) 129 130 # Add the RadioMaterial to the scene if not already done 131 self.scene.add(self._radio_material) 132 133 @property 134 def velocity(self): 135 """ 136 [3], tf.float : Get/set the velocity vector [m/s] 137 """ 138 return self._velocity 139 140 @velocity.setter 141 def velocity(self, v): 142 if not tf.shape(v)==3: 143 raise ValueError("`velocity` must have shape [3]") 144 self._velocity = tf.cast(v, self._rdtype) 145 146 @property 147 def position(self): 148 """ 149 [3], tf.float : Get/set the position vector [m] of the center 150 of the object. The center is defined as the object's axis-aligned 151 bounding box (AABB). 152 """ 153 dr.sync_thread() 154 rdtype = self._scene.dtype.real_dtype 155 # Bounding box 156 # [3] 157 bbox_min = tf.cast(self._mi_shape.bbox().min, rdtype) 158 # [3] 159 bbox_max = tf.cast(self._mi_shape.bbox().max, rdtype) 160 # [3] 161 half = tf.cast(0.5, self._rdtype) 162 position = half*(bbox_min + bbox_max) 163 return position 164 165 @position.setter 166 def position(self, new_position): 167 168 ## Update Mitsuba vertices 169 170 # Scene parameters 171 scene_params = self._scene.mi_scene_params 172 vp_key = self._vertex_params_name(self.name) 173 174 # Real dtype 175 rdtype = self._scene.dtype.real_dtype 176 new_position = tf.cast(new_position, rdtype) 177 # [num_vertices*3] 178 vertices = scene_params[vp_key] 179 # [num_vertices,3] 180 vertices = mi_to_tf_tensor(vertices, rdtype) 181 vertices = tf.reshape(vertices, [-1, 3]) 182 # [3] 183 position = self.position 184 # [3] 185 translation_vector = new_position - position 186 # [1,3] 187 translation_vector = tf.expand_dims(translation_vector, axis=0) 188 # [num_vertices,3] 189 translated_vertices = vertices + translation_vector 190 # Cast to Mitsuba type to object the Mitsuba scene 191 fltn_translated_vertices = tf.reshape(translated_vertices, [-1]) 192 fltn_translated_vertices = self._mi_scalar_t(fltn_translated_vertices) 193 # 194 scene_params[vp_key] =\ 195 fltn_translated_vertices 196 scene_params.update() 197 198 ## Update Sionna vertices 199 200 obj_id = self.object_id 201 mi_shape = self._mi_shape 202 solver_paths = self._scene.solver_paths 203 204 prim_offset = solver_paths.prim_offsets[obj_id] 205 206 face_indices3 = mi_shape.face_indices(dr.arange(mi.UInt32, 207 mi_shape.face_count())) 208 # Flatten. This is required for calling vertex_position 209 # [n_prims*3] 210 face_indices = dr.ravel(face_indices3) 211 # Get vertices coordinates 212 # [n_prims*3, 3] 213 vertex_coords = mi_shape.vertex_position(face_indices) 214 # Cast to TensorFlow type 215 # [n_prims*3, 3] 216 vertex_coords = mi_to_tf_tensor(vertex_coords, rdtype) 217 # Unflatten 218 # [n_prims, vertices per triangle : 3, 3] 219 vertex_coords = tf.reshape(vertex_coords, [mi_shape.face_count(), 3, 3]) 220 # Update the tensor storing the primitive vertices 221 sl = tf.range(prim_offset, prim_offset + mi_shape.face_count(), 222 dtype=tf.int32) 223 sl = tf.expand_dims(sl, axis=1) 224 solver_paths.primitives.scatter_nd_update(sl, vertex_coords) 225 226 ## Update Sionna wedges 227 228 wedges_objects = solver_paths.wedges_objects 229 wedges_origin = solver_paths.wedges_origin 230 231 # Indices of the wedges corresponding to this object 232 # [num_wedges] 233 wedges_ind, _ = tf.unique(tf.where(wedges_objects == obj_id)[:,0]) 234 235 # Corresponding origins 236 # [num_wedges, 3] 237 wedges_origin = tf.gather(wedges_origin, wedges_ind, axis=0) 238 239 # Translates the wedges 240 # [num_wedges, 3] 241 wedges_origin += translation_vector 242 243 # Updates the wedges 244 wedges_ind = tf.expand_dims(wedges_ind, axis=1) 245 solver_paths.wedges_origin.scatter_nd_update(wedges_ind, wedges_origin) 246 247 # Trigger scene callback 248 self._scene.scene_geometry_updated() 249 250 @property 251 def orientation(self): 252 r""" 253 [3], tf.float : Get/set the orientation :math:`(\alpha, \beta, \gamma)` 254 [rad] specified through three angles corresponding to a 255 3D rotation as defined in :eq:`rotation`. 256 """ 257 return self._orientation 258 259 @orientation.setter 260 def orientation(self, new_orient): 261 262 # Real dtype 263 new_orient = tf.cast(new_orient, self._rdtype) 264 265 # Build the transformtation corresponding to the new rotation 266 new_rotation = angles_to_mitsuba_rotation(new_orient) 267 268 # Invert the current orientation 269 cur_rotation = angles_to_mitsuba_rotation(self._orientation.numpy()) 270 inv_cur_rotation = cur_rotation.inverse() 271 272 # Build the transform. 273 # The object is first translated to the origin, then rotated, then 274 # translated back to its current position 275 transform = ( self._mi_transform_t.translate(self.position.numpy()) 276 @ new_rotation 277 @ inv_cur_rotation 278 @ self._mi_transform_t.translate(-self.position.numpy()) ) 279 280 ## Update Mitsuba vertices 281 282 # Scene parameters 283 scene_params = self._scene.mi_scene_params 284 vp_key = self._vertex_params_name(self.name) 285 286 # [num_vertices*3] 287 vertices = scene_params[vp_key] 288 # [num_vertices,3] 289 vertices = dr.unravel(self._mi_point_t, vertices) 290 # Apply the transform 291 vertices = transform.transform_affine(vertices) 292 # Cast to Mitsuba type to object the Mitsuba scene 293 fltn_vertices = tf.reshape(vertices, [-1]) 294 fltn_vertices = tf.cast(fltn_vertices, tf.float32) 295 scene_params[vp_key] = fltn_vertices 296 scene_params.update() 297 298 ## Update Sionna vertices 299 300 obj_id = self.object_id 301 mi_shape = self._mi_shape 302 solver_paths = self._scene.solver_paths 303 304 prim_offset = solver_paths.prim_offsets[obj_id] 305 306 face_indices3 = mi_shape.face_indices(dr.arange(mi.UInt32, 307 mi_shape.face_count())) 308 # Flatten. This is required for calling vertex_position 309 # [n_prims*3] 310 face_indices = dr.ravel(face_indices3) 311 # Get vertices coordinates 312 # [n_prims*3, 3] 313 vertex_coords = mi_shape.vertex_position(face_indices) 314 # Cast to TensorFlow type 315 # [n_prims*3, 3] 316 vertex_coords = mi_to_tf_tensor(vertex_coords, self._rdtype) 317 # Unflatten 318 # [n_prims, vertices per triangle : 3, 3] 319 vertex_coords = tf.reshape(vertex_coords, [mi_shape.face_count(), 3, 3]) 320 # Update the tensor storing the primitive vertices 321 sl = tf.range(prim_offset, prim_offset + mi_shape.face_count(), 322 dtype=tf.int32) 323 sl = tf.expand_dims(sl, axis=1) 324 solver_paths.primitives.scatter_nd_update(sl, vertex_coords) 325 326 ## Update Sionna normals 327 328 # Get vertices coordinates 329 # [n_prims, 3] 330 normals = solver_paths.normals.gather_nd(sl) 331 # Cast to Mitsuba Vector 332 # [n_prims, 3] 333 normals = self._mi_vec_t(normals) 334 # Rotate the normals 335 normals = transform.transform_affine(normals) 336 # Cast to Tensorflow type 337 # [n_prims, 3] 338 normals = mi_to_tf_tensor(normals, self._rdtype) 339 # Update the tensor storing the primitive vertices 340 solver_paths.normals.scatter_nd_update(sl, normals) 341 342 ## Update Sionna wedges 343 344 wedges_objects = solver_paths.wedges_objects 345 wedges_origin = solver_paths.wedges_origin 346 wedges_e_hat = solver_paths.wedges_e_hat 347 wedges_normals = solver_paths.wedges_normals 348 349 # Indices of the wedges corresponding to this object 350 # [num_wedges] 351 wedges_ind, _ = tf.unique(tf.where(wedges_objects == obj_id)[:,0]) 352 353 # Corresponding origins, e_hat, and normals 354 # [num_wedges, 3] 355 wedges_origin = tf.gather(wedges_origin, wedges_ind, axis=0) 356 # [num_wedges, 3] 357 wedges_e_hat = tf.gather(wedges_e_hat, wedges_ind, axis=0) 358 # [num_wedges, 3] 359 wedges_normals = tf.gather(wedges_normals, wedges_ind, axis=0) 360 # [num_wedges*2, 3] 361 wedges_normals = tf.reshape(wedges_normals, [-1, 3]) 362 363 # Cast to Mitsuba types 364 # [num_wedges, 3] 365 wedges_origin = self._mi_point_t(wedges_origin) 366 # [num_wedges, 3] 367 wedges_e_hat = self._mi_vec_t(wedges_e_hat) 368 # [num_wedges*2, 3] 369 wedges_normals = self._mi_vec_t(wedges_normals) 370 371 # Rotate all quantities 372 # [num_wedges, 3] 373 wedges_origin = transform.transform_affine(wedges_origin) 374 # [num_wedges, 3] 375 wedges_e_hat = transform.transform_affine(wedges_e_hat) 376 # [num_wedges*2, 3] 377 wedges_normals = transform.transform_affine(wedges_normals) 378 379 # Cast to Tensorflow type 380 # [num_wedges, 3] 381 wedges_origin = mi_to_tf_tensor(wedges_origin, self._rdtype) 382 # [num_wedges, 3] 383 wedges_e_hat = mi_to_tf_tensor(wedges_e_hat, self._rdtype) 384 # [num_wedges*2, 3] 385 wedges_normals = mi_to_tf_tensor(wedges_normals, self._rdtype) 386 # [num_wedges, 2, 3] 387 wedges_normals = tf.reshape(wedges_normals, [-1, 2, 3]) 388 389 # Updates the wedges 390 wedges_ind = tf.expand_dims(wedges_ind, axis=1) 391 solver_paths.wedges_origin.scatter_nd_update(wedges_ind, wedges_origin) 392 solver_paths.wedges_e_hat.scatter_nd_update(wedges_ind, wedges_e_hat) 393 solver_paths.wedges_normals.scatter_nd_update(wedges_ind, 394 wedges_normals) 395 396 self._orientation = new_orient 397 398 # Trigger scene callback 399 self._scene.scene_geometry_updated() 400 401 def look_at(self, target): 402 # pylint: disable=line-too-long 403 r""" 404 Sets the orientation so that the x-axis points toward an 405 ``Object``. 406 407 Input 408 ----- 409 target : [3], float | :class:`sionna.rt.Object` | str 410 A position or the name or instance of an 411 :class:`sionna.rt.Object` in the scene to point toward to 412 """ 413 # Get position to look at 414 if isinstance(target, str): 415 obj = self.scene.get(target) 416 if not isinstance(obj, Object): 417 msg = f"No camera, device, or object named '{target}' found." 418 raise ValueError(msg) 419 else: 420 target = obj.position 421 elif isinstance(target, Object): 422 target = target.position 423 else: 424 target = tf.cast(target, dtype=self._rdtype) 425 if not target.shape[0]==3: 426 raise ValueError("`target` must be a three-element vector)") 427 428 # Compute angles relative to LCS 429 x = target - self.position 430 x, _ = normalize(x) 431 theta, phi = theta_phi_from_unit_vec(x) 432 alpha = phi # Rotation around z-axis 433 beta = theta-PI/2 # Rotation around y-axis 434 gamma = 0.0 # Rotation around x-axis 435 self.orientation = (alpha, beta, gamma) 436 437 def _vertex_params_name(self, mesh_id, scene_params=None): 438 """ 439 Since any `mesh-` prefix was removed from `self.name`, we may need 440 to add it back here before trying to access the vertex positions 441 variable in the scene parameters object. 442 """ 443 if scene_params is None: 444 scene_params = self._scene.mi_scene_params 445 446 key = mesh_id + ".vertex_positions" 447 with_prefix = 'mesh-' + key 448 if with_prefix in scene_params: 449 return with_prefix 450 return key