solver_base.py (38948B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """ 6 Ray tracing algorithm that uses the image method to compute all pure reflection 7 paths. 8 """ 9 10 import mitsuba as mi 11 import drjit as dr 12 import tensorflow as tf 13 14 from .utils import normalize, dot, theta_phi_from_unit_vec, cross,\ 15 mi_to_tf_tensor, mitsuba_rectangle_to_world 16 from sionna import PI 17 18 19 class SolverBase: 20 # pylint: disable=line-too-long 21 r"""SolverBase(scene, solver=None, dtype=tf.complex64) 22 23 Base class for implementing a solver. If another ``solver`` is specified at 24 instantiation, then it re-uses the structure to avoid useless compute and 25 memory use. 26 27 Note: Only triangle mesh are supported. 28 29 Parameters 30 ----------- 31 scene : :class:`~sionna.rt.Scene` 32 Sionna RT scene 33 34 solver : :class:`~sionna.rt.SolverBase` | None 35 Another solver from which to re-use some structures to avoid useless 36 compute and memory use 37 38 dtype : tf.complex64 | tf.complex128 39 Datatype for all computations, inputs, and outputs. 40 Defaults to `tf.complex64`. 41 """ 42 43 # Small value used to discard intersection with edges, avoid 44 # self-intersection, etc. 45 # Resolution of float32 1e-6: 46 # np.finfo(np.float32) -> reslution=1e-6 47 # We add on top a 10x factor for caution 48 EPSILON = 1e-5 49 50 # Threshold for extracting wedges from the scene [rad] 51 WEDGES_ANGLE_THRESHOLD = 1.*PI/180. 52 53 # Small value used to avoid false positive when testing for obstruction 54 # Resolution of float32 1e-6: 55 # np.finfo(np.float32) -> reslution=1e-6 56 # We add on top a 100x factor for caution 57 EPSILON_OBSTRUCTION = 1e-4 58 59 def __init__(self, scene, solver=None, dtype=tf.complex64): 60 61 # Computes the quantities required for generating the paths. 62 # More pricisely: 63 64 # _primitives : [num triangles, 3, 3], float 65 # The triangles: x-y-z coordinates of the 3 vertices for every 66 # triangle 67 68 # _normals : [num triangles, 3], float 69 # The normals of the triangles 70 71 # _primitives_2_objects : [num_triangles], int 72 # Index of the shape containing the triangle 73 74 # _prim_offsets : [num_objects], int 75 # Indices offsets for accessing the triangles making each shape. 76 77 # _shape_indices : [num_objects], int 78 # Map Mitsuba shape indices to indices that can be used to access 79 # _prim_offsets 80 81 assert dtype in (tf.complex64, tf.complex128),\ 82 "`dtype` must be tf.complex64 or tf.complex128`" 83 self._dtype = dtype 84 self._rdtype = dtype.real_dtype 85 86 # Mitsuba types depend on the used precision 87 if dtype == tf.complex64: 88 self._mi_point_t = mi.Point3f 89 self._mi_point2_t = mi.Point2f 90 self._mi_vec_t = mi.Vector3f 91 self._mi_scalar_t = mi.Float 92 self._mi_tensor_t = mi.TensorXf 93 else: 94 self._mi_point_t = mi.Point3d 95 self._mi_point2_t = mi.Point2d 96 self._mi_vec_t = mi.Vector3d 97 self._mi_scalar_t = mi.Float64 98 self._mi_tensor_t = mi.TensorXd 99 100 self._scene = scene 101 mi_scene = scene.mi_scene 102 self._mi_scene = mi_scene 103 104 # If a solver is provided, then link to the same structures to avoid 105 # useless compute and memory use 106 if solver is not None: 107 self._primitives = solver._primitives 108 self._normals = solver._normals 109 self._primitives_2_objects = solver._primitives_2_objects 110 self._prim_offsets = solver._prim_offsets 111 self._shape_indices = solver._shape_indices 112 # 113 self._wedges_origin = solver._wedges_origin 114 self._wedges_e_hat = solver._wedges_e_hat 115 self._wedges_length = solver._wedges_length 116 self._wedges_normals = solver._wedges_normals 117 self._primitives_2_wedges = solver._primitives_2_wedges 118 self._wedges_objects = solver._wedges_objects 119 self._is_edge = solver._is_edge 120 return 121 122 ################################################### 123 # Extract triangles, their normals, and a 124 # look-up-table to map primitives to the scene 125 # object they belong to. 126 ################################################### 127 128 # Tensor mapping primitives to corresponding objects 129 # [num_triangles] 130 primitives_2_objects = [] 131 132 # Number of triangles 133 n_prims = 0 134 # Triangles of each object (shape) in the scene are stacked. 135 # This list tracks the indices offsets for accessing the triangles 136 # making each shape. 137 prim_offsets = [] 138 mi_shapes = scene.mi_shapes 139 for i,s in enumerate(mi_shapes): 140 if not isinstance(s, mi.Mesh): 141 raise ValueError('Only triangle meshes are supported') 142 prim_offsets.append(n_prims) 143 n_prims += s.face_count() 144 primitives_2_objects += [i]*s.face_count() 145 # [num_objects] 146 prim_offsets = tf.cast(prim_offsets, tf.int32) 147 148 # Tensor of triangles vertices 149 # [n_prims, number of vertices : 3, coordinates : 3] 150 prims = tf.zeros([n_prims, 3, 3], self._rdtype) 151 # Normals to the triangles 152 normals = tf.zeros([n_prims, 3], self._rdtype) 153 # Loop through the objects in the scene 154 for prim_offset, s in zip(prim_offsets, mi_shapes): 155 # Extract the vertices of the shape. 156 # Dr.JIT/Mitsuba is used here. 157 # Indices of the vertices 158 # [n_prims, num of vertices per triangle : 3] 159 face_indices3 = s.face_indices(dr.arange(mi.UInt32, s.face_count())) 160 # Flatten. This is required for calling vertex_position 161 # [n_prims*3] 162 face_indices = dr.ravel(face_indices3) 163 # Get vertices coordinates 164 # [n_prims*3, 3] 165 vertex_coords = s.vertex_position(face_indices) 166 # Move to TensorFlow 167 # [n_prims*3, 3] 168 vertex_coords = mi_to_tf_tensor(vertex_coords, self._rdtype) 169 # Unflatten 170 # [n_prims, vertices per triangle : 3, 3] 171 vertex_coords = tf.reshape(vertex_coords, [s.face_count(), 3, 3]) 172 # Update the `prims` tensor 173 sl = tf.range(prim_offset, prim_offset + s.face_count(), 174 dtype=tf.int32) 175 sl = tf.expand_dims(sl, axis=1) 176 prims = tf.tensor_scatter_nd_update(prims, sl, vertex_coords) 177 # Compute the normals to the triangles 178 # Coordinate of the first vertices of every triangle making the 179 # shape 180 # [n_prims, xyz : 3] 181 v0 = s.vertex_position(face_indices3.x) 182 # Coordinate of the second vertices of every triangle making the 183 # shape 184 # [n_prims, xyz : 3] 185 v1 = s.vertex_position(face_indices3.y) 186 # Coordinate of the third vertices of every triangle making the 187 # shape 188 # [n_prims, xyz : 3] 189 v2 = s.vertex_position(face_indices3.z) 190 # Compute the normals 191 # [n_prims, xyz : 3] 192 mi_n = dr.normalize(dr.cross( 193 v1 - v0, 194 v2 - v0, 195 )) 196 # Move to TensorFlow 197 # [n_prims, 3] 198 n = mi_to_tf_tensor(mi_n, self._rdtype) 199 # Update the 'normals' tensor 200 normals = tf.tensor_scatter_nd_update(normals, sl, n) 201 202 self._primitives = tf.Variable(prims, trainable=False) 203 self._normals = tf.Variable(normals, trainable=False) 204 primitives_2_objects = tf.cast(primitives_2_objects, tf.int32) 205 self._primitives_2_objects = tf.Variable(primitives_2_objects, 206 trainable=False) 207 208 #################################################### 209 # Used by the shoot & bounce method to map from 210 # (shape, local primitive index) to the 211 # corresponding global primitive index. 212 #################################################### 213 214 # [num_objects] 215 self._prim_offsets = mi.Int32(prim_offsets.numpy()) 216 mi_shapes_ptr = [mi.ShapePtr(s) for s in mi_shapes] 217 ptr_as_int = [dr.reinterpret_array_v(mi.UInt32, p)[0] for p in mi_shapes_ptr] 218 ptr_as_int = mi.UInt(ptr_as_int) 219 if dr.width(ptr_as_int) == 0: 220 self._shape_indices = mi.Int32([]) 221 else: 222 # [num_objects] 223 shape_indices = dr.full(mi.Int32, -1, dr.max(ptr_as_int)[0] + 1) 224 dr.scatter(shape_indices, dr.arange(mi.Int32, 0, 225 dr.width(ptr_as_int)), ptr_as_int) 226 dr.eval(shape_indices) 227 # [num_objects] 228 self._shape_indices = shape_indices 229 230 ################################################# 231 # Extract the wedges 232 ################################################# 233 # _wedges_origin : [num_wedges, 3], float 234 # Starting point of the wedges 235 236 # _wedges_e_hat : [num_wedges, 3], float 237 # Normalized edge vector 238 239 # _wedges_length : [num_wedges], float 240 # Length of the wedges 241 242 # _wedges_normals : [num_wedges, 2, 3], float 243 # Normals to the wedges sides 244 245 # _primitives_2_wedges : [num_primitives, 3], int 246 # Maps primitives to their wedges 247 248 # _wedges_objects : [num_wedges, 2], int 249 # Indices of the two objects making the wedge (the two sides of the 250 # wedge could belong to different objects) 251 252 # is_edge : [num_wedges], bool 253 # Set to `True` if a wedge is an edge, i.e., the edge of a single 254 # primitive. 255 256 edges = self._extract_wedges() 257 self._wedges_origin = tf.Variable(edges[0], trainable=False) 258 self._wedges_e_hat = tf.Variable(edges[1], trainable=False) 259 self._wedges_length = tf.Variable(edges[2], trainable=False) 260 self._wedges_normals = tf.Variable(edges[3], trainable=False) 261 self._primitives_2_wedges = tf.Variable(edges[4], trainable=False) 262 self._wedges_objects = tf.Variable(edges[5], trainable=False) 263 self._is_edge = tf.Variable(edges[6], trainable=False) 264 265 ################################################################## 266 # Internal utility methods 267 ################################################################## 268 269 @property 270 def primitives(self): 271 """ 272 [num triangles, 3, 3], tf.float : The triangles: [x,y,z] coordinates of 273 the 3 vertices of every triangle 274 """ 275 return self._primitives 276 277 @property 278 def normals(self): 279 """ 280 [num triangles, 3], tf.float : The normals of the triangles 281 """ 282 return self._normals 283 284 @property 285 def prim_offsets(self): 286 """ 287 [num_objects], tf.int : Indices offsets for accessing the triangles 288 each shape in `primitives` 289 """ 290 return self._prim_offsets 291 292 @property 293 def shape_indices(self): 294 """ 295 [num_objects], tf.int : Map object ids to indices that can be used to 296 access `prim_offsets` 297 """ 298 return self._shape_indices 299 300 @property 301 def wedges_origin(self): 302 """ 303 [num_wedges, 3], tf.float : Origin of the wedges 304 """ 305 return self._wedges_origin 306 307 @property 308 def wedges_e_hat(self): 309 """ 310 [num_wedges, 3], tf.float : Normalized edge vector 311 """ 312 return self._wedges_e_hat 313 314 @property 315 def wedges_length(self): 316 """ 317 [num_wedges], tf.float : Length of the wedges 318 """ 319 return self._wedges_length 320 321 @property 322 def wedges_normals(self): 323 """ 324 [num_wedges, 2, 3], tf.float : Normals to the wedges sides 325 """ 326 return self._wedges_normals 327 328 @property 329 def primitives_2_wedges(self): 330 """ 331 [num_primitives, 3], tf.int : Maps primitives to their wedges 332 """ 333 return self._primitives_2_wedges 334 335 @property 336 def wedges_objects(self): 337 """ 338 [num_wedges, 2], tf.int : Indices of the two objects making the 339 wedge (the two sides of the wedge could belong to different objects) 340 """ 341 return self._wedges_objects 342 343 @property 344 def is_edge(self): 345 """ 346 [num_wedges], tf.bool : Set to `True` if a wedge is an edge, i.e., 347 the edge of a single primitive. 348 """ 349 return self._is_edge 350 351 def _build_scene_object_properties_tensors(self): 352 r""" 353 Build tensor containing the shape properties 354 355 Input 356 ------ 357 None 358 359 Output 360 ------- 361 relative_permittivity : [num_shape], tf.complex 362 Tensor containing the complex relative permittivities of all shapes 363 364 scattering_coefficient : [num_shape], tf.float 365 Tensor containing the scattering coefficients of all shapes 366 367 xpd_coefficient : [num_shape], tf.float 368 Tensor containing the cross-polarization discrimination 369 coefficients of all shapes 370 371 alpha_r : [num_shape], tf.float 372 Tensor containing the alpha_r scattering parameters of all shapes 373 374 alpha_i : [num_shape], tf.float 375 Tensor containing the alpha_i scattering parameters of all shapes 376 377 lambda_ : [num_shape], tf.float 378 Tensor containing the lambda_ scattering parameters of all shapes 379 380 velocity : [num_shape], tf.float 381 Tensor containing the velocity vectors of all shapes 382 """ 383 384 # Compute the size of the tensors that store the properties of all 385 # objects and RIS 386 objects_id = [obj.object_id for obj in self._scene.objects.values()] 387 max_id = 0 388 if len(objects_id) > 0: 389 max_id = tf.reduce_max(objects_id) 390 array_size = max_id + 1 + len(self._scene.ris) 391 392 # If a callable is set to obtain radio material properties, then there 393 # is no need to build the tensors of material properties 394 rm_callable_set = self._scene.radio_material_callable is not None 395 396 # If a callable is set to obtain scattering patterns, then there 397 # is no need to build the tensors for scattering properties 398 sp_callable_set = self._scene.scattering_pattern_callable is not None 399 400 if rm_callable_set: 401 relative_permittivity = tf.zeros([0], self._dtype) 402 scattering_coefficient = tf.zeros([0], self._rdtype) 403 xpd_coefficient = tf.zeros([0], self._rdtype) 404 else: 405 relative_permittivity = tf.zeros([array_size], self._dtype) 406 scattering_coefficient = tf.zeros([array_size], self._rdtype) 407 xpd_coefficient = tf.zeros([array_size], self._rdtype) 408 409 if sp_callable_set: 410 alpha_r = tf.zeros([0], tf.int32) 411 alpha_i = tf.zeros([0], tf.int32) 412 lambda_ = tf.zeros([0], self._rdtype) 413 else: 414 alpha_r = tf.zeros([array_size], tf.int32) 415 alpha_i = tf.zeros([array_size], tf.int32) 416 lambda_ = tf.zeros([array_size], self._rdtype) 417 418 if (not sp_callable_set) or (not rm_callable_set): 419 420 for rm in self._scene.radio_materials.values(): 421 using_objects = rm.using_objects 422 num_using_objects = tf.shape(using_objects)[0] 423 if num_using_objects == 0: 424 continue 425 426 if not rm_callable_set: 427 relative_permittivity = tf.tensor_scatter_nd_update( 428 relative_permittivity, 429 tf.reshape(using_objects, [-1,1]), 430 tf.fill([num_using_objects], 431 rm.complex_relative_permittivity)) 432 433 scattering_coefficient = tf.tensor_scatter_nd_update( 434 scattering_coefficient, 435 tf.reshape(using_objects, [-1,1]), 436 tf.fill([num_using_objects], rm.scattering_coefficient)) 437 438 xpd_coefficient = tf.tensor_scatter_nd_update( 439 xpd_coefficient, 440 tf.reshape(using_objects, [-1,1]), 441 tf.fill([num_using_objects], rm.xpd_coefficient)) 442 443 if not sp_callable_set: 444 alpha_r = tf.tensor_scatter_nd_update( 445 alpha_r, 446 tf.reshape(using_objects, [-1,1]), 447 tf.fill([num_using_objects], 448 rm.scattering_pattern.alpha_r)) 449 450 alpha_i = tf.tensor_scatter_nd_update( 451 alpha_i, 452 tf.reshape(using_objects, [-1,1]), 453 tf.fill([num_using_objects], 454 rm.scattering_pattern.alpha_i)) 455 456 lambda_ = tf.tensor_scatter_nd_update( 457 lambda_, 458 tf.reshape(using_objects, [-1,1]), 459 tf.fill([num_using_objects], 460 rm.scattering_pattern.lambda_)) 461 462 # Velocity of all objects 463 # [array_size, 3] 464 velocity = tf.zeros([array_size, 3], self._rdtype) 465 for obj in self._scene.objects.values(): 466 velocity = tf.tensor_scatter_nd_update(velocity, 467 [[obj.object_id]], 468 [obj.velocity]) 469 for obj in self._scene.ris.values(): 470 velocity = tf.tensor_scatter_nd_update(velocity, 471 [[obj.object_id]], 472 [obj.velocity]) 473 474 return (relative_permittivity, 475 scattering_coefficient, 476 xpd_coefficient, 477 alpha_r, 478 alpha_i, 479 lambda_, 480 velocity) 481 482 def _test_obstruction(self, o, d, maxt, additional_blockers=None): 483 r""" 484 Test obstruction of a batch of rays using Mitsuba. 485 486 Input 487 ----- 488 o: [batch_size, 3], tf.float 489 Origin of the rays 490 491 d: [batch_size, 3], tf.float 492 Direction of the rays. 493 Must be unit vectors. 494 495 maxt: [batch_size], tf.float 496 Length of the ray 497 498 additional_blockers : list(mi.Shape) | None 499 Optional list of Mitsuba shapes containing additional blockers. 500 Defaults to `None`. 501 502 Output 503 ------- 504 val: [batch_size], tf.bool 505 `True` if the ray is obstructed, i.e., hits a primitive. 506 `False` otherwise. 507 """ 508 # Translate the origin a bit along the ray direction to avoid 509 # consecutive intersection with the same primitive 510 o = o + SolverBase.EPSILON_OBSTRUCTION*d 511 # [batch_size, 3] 512 mi_o = self._mi_point_t(o) 513 # Ray direction 514 # [batch_size, 3] 515 mi_d = self._mi_vec_t(d) 516 # [batch_size] 517 # Reduce the ray length by a small value to avoid false positive when 518 # testing for LoS to a primitive due to hitting the primitive we are 519 # testing visibility to. 520 maxt = maxt - 2.*SolverBase.EPSILON_OBSTRUCTION 521 mi_maxt = self._mi_scalar_t(maxt) 522 # Mitsuba ray 523 mi_ray = mi.Ray3f(o=mi_o, d=mi_d, maxt=mi_maxt, time=0., 524 wavelengths=mi.Color0f(0.)) 525 # Test for obstruction using Mitsuba 526 # With the scene 527 # [batch_size] 528 mi_val = self._mi_scene.ray_test(mi_ray) 529 val = mi_to_tf_tensor(mi_val, tf.bool) 530 # With additional blockers 531 if additional_blockers: 532 for shape in additional_blockers: 533 mi_val = shape.ray_test(mi_ray) 534 val = tf.logical_or(val, mi_to_tf_tensor(mi_val, tf.bool)) 535 return val 536 537 def _extract_wedges(self): 538 r""" 539 Extract the wedges and, optionally, the edges, from the scene geometry 540 541 Output 542 ------ 543 # _wedges_origin : [num_wedges, 3], float 544 # Starting point of the wedges 545 546 # _wedges_e_hat : [num_wedges, 3], float 547 # Normalized edge vector 548 549 # _wedges_length : [num_wedges], float 550 # Length of the wedges 551 552 # _wedges_normals : [num_wedges, 2, 3], float 553 # Normals to the wedges sides 554 555 # _primitives_2_wedges : [num_primitives, 3], int 556 # Maps primitives to their wedges 557 558 # _wedges_objects : [num_wedges, 2], int 559 # Indices of the two objects making the wedge (the two sides of the 560 # wedge could belong to different objects) 561 562 # is_edge : [num_wedges], bool 563 # Set to `True` if a wedge is an edge, i.e., the edge of a single 564 # primitive. 565 """ 566 567 angle_threshold = SolverBase.WEDGES_ANGLE_THRESHOLD 568 569 # Extract vertices of every triangle 570 # [num_prim, 3] 571 v0 = self._primitives[:,0,:] 572 v1 = self._primitives[:,1,:] 573 v2 = self._primitives[:,2,:] 574 575 # List all edges 576 # [num_prim, 2 + 2 + 2, 3] 577 all_edges_undirected = tf.concat([ 578 v0, v1, 579 v1, v2, 580 v2, v0 581 ], axis=1) 582 # [num_edges = num_prim*3, 2, 3] 583 all_edges_undirected = tf.reshape(all_edges_undirected, 584 shape=(3 * v0.shape[0], 2, 3)) 585 # Edges are oriented such that identical edges have same orientation 586 # [num_edges, 2, 3] 587 all_edges = self._swap_edges(all_edges_undirected) 588 589 # Remaining point in the triangle for each edge. 590 # This will be used to compute the normals later 591 # [num_prim, 3, 3] 592 remaining_vertex = tf.concat([v2, 593 v0, 594 v1], axis=1) 595 # [num_edges, 3] 596 remaining_vertex = tf.reshape(remaining_vertex, [-1, 3]) 597 598 # Get unique edges, i.e., wihout duplicates 599 # unique_edges : [num_unique_edges, 2, 3] 600 # indices_of_unique : [num_edges], index of the edge in ``unique_edges`` 601 unique_edges, indices_of_unique = tf.raw_ops.UniqueV2(x=all_edges, 602 axis=[0]) 603 604 # Number of occurences of every unique edge 605 # [num_unique_edges] 606 _, _, unique_indices_count = tf.unique_with_counts(indices_of_unique) 607 608 # Flag indicating which edges shared by exactly one or two primitives, 609 # i.e., edges or wedges 610 # [num_unique_edges] 611 is_selected = tf.logical_or(tf.equal(unique_indices_count, 1), 612 tf.equal(unique_indices_count, 2)) 613 614 # The following tensor lists the index of the first 615 # edge in ``all_edges`` that makes the wedge. 616 # Note: In the presence of duplicate values in `indices_of_unique` 617 # (i.e. our case), it is not deterministic which of them 618 # `tensor_scatter_nd_update` will store into the result. That's okay. 619 # [num_edges] 620 seq = tf.cast(tf.range(indices_of_unique.shape[0]), dtype=tf.int32) 621 # [num_unique_edges] 622 default = tf.cast(tf.fill(dims=unique_edges.shape[0], value=-1), 623 dtype=tf.int32) 624 # [num_unique_edges] 625 all_edges_index_1 = tf.tensor_scatter_nd_update(default, 626 indices_of_unique[:, None], seq) 627 # Next, list the second primitive that the wedge is connected to. 628 # -1 is used for edges defined by a single primitive (screens) 629 # [num_unique_edges] 630 false_value = tf.fill(dims=all_edges_index_1.shape[0], value=False) 631 # [num_edges] 632 missing = tf.fill(dims=indices_of_unique.shape[0], value=True) 633 missing = tf.tensor_scatter_nd_update(missing, 634 all_edges_index_1[:, None], 635 false_value) 636 # [num_unique_edges] 637 all_edges_index_2 = tf.tensor_scatter_nd_update(default, 638 indices_of_unique[missing][:, None], seq[missing]) 639 # Flag set to True if an edge is not a wedge, i.e., is attached to only 640 # one primitive. This is the case for unique edges for which 641 # ``all_edges_index_2`` is not set to -1 642 # [num_unique_edges] 643 is_edge = tf.equal(all_edges_index_2, -1) 644 # For edges, the unique edge primitive is set to the same value for both 645 # ``all_edges_index_1`` and ``all_edges_index_2``. This will lead to 646 # mapping to the same primitive 647 # [num_unique_edges] 648 all_edges_index_2 = tf.where(is_edge, all_edges_index_1, 649 all_edges_index_2) 650 651 # Normals to the faces 652 # We compute the normals such that they point in the same direction 653 # of the space for both primitives making a wedge. 654 # To that aim, the normal for the face 0(n) is defined as the 655 # cross-product between: 656 # - The edge vector 657 # - The vector connecting a vertex of the edge to the third 658 # point of the triangle corresponding to the 0(n) face to whom this 659 # edge belongs. 660 # Edge vertices 661 # [num_unique_edges, 2, 3] 662 vs = tf.gather(all_edges, all_edges_index_1) 663 # [num_unique_edges, 3] 664 v1 = vs[:,0] 665 v2 = vs[:,1] 666 # Edge vector 667 # [num_unique_edges, 3] 668 e = v2 - v1 669 # Vertex on the 0 and n faces 670 # [num_unique_edges, 3] 671 vf1 = tf.gather(remaining_vertex, all_edges_index_1) 672 vf2 = tf.gather(remaining_vertex, all_edges_index_2) 673 # [num_unique_edges, 3] 674 u_1,_ = normalize(vf1 - v1) 675 u_2,_ = normalize(vf2 - v1) 676 # [num_unique_edges, 3] 677 n1,_ = normalize(cross(e, u_1)) 678 n2,_ = normalize(cross(u_2, e)) 679 680 # We flip the normals if necessary to ensure that they point towards 681 # the "exterior" of the wedge, i.e., that the exterior angle is at least 682 # pi. 683 # To ensure that, we orient the normals such that u2 does not point 684 # towards the half-space defined by n1. 685 # [num_unique_edges] 686 cos_angle = dot(u_2, n1) 687 # Three cases: 688 # * cos_angle > 0: u2 points towards the same half space as n1. 689 # We must flip n1 and n2. 690 # * cos_angle < 0: u2 points towards the same half space as n1. Do nothing 691 # * cos_angle = 0: u2 is orthogonal to the n1, i.e., the two primitives are 692 # parallel. In that cases, either the wedge is an edge, or it is a flat 693 # surface that must be discarded 694 # [num_unique_edges] 695 flip = tf.where(tf.greater(cos_angle, tf.zeros_like(cos_angle)), 696 -tf.ones_like(cos_angle), 697 tf.ones_like(cos_angle)) 698 # [num_unique_edges, 1] 699 flip = tf.expand_dims(flip, axis=1) 700 # [num_unique_edges, 3] 701 n1 = n1*flip 702 n2 = n2*flip 703 # Discard the wedges considered as flat, i.e., with an opening angle 704 # close to PI up to `angle_threshold` 705 # We use this observation to discard close-to-flat wedges. 706 # cos_angle = dot(u_2, n1) = cos(theta) 707 # where theta= angle(u_2, n1). 708 # Then, we want: 709 # pi/2 - angle_threshold < theta < pi/2 + angle_threshold 710 # => cos(pi/2 + angle_threshold) < cos(theta) 711 # < cos(pi/2 - angle_threshold) 712 # => -sin(angle_threshold) < cos_angle < sin(angle_threshold) 713 # () 714 theshold = tf.abs(tf.math.sin(tf.cast(angle_threshold, self._rdtype))) 715 # [num_unique_edges] 716 is_selected_ = tf.greater(tf.abs(cos_angle),theshold) 717 # Don't discard edges 718 # [num_unique_edges] 719 is_selected_ = tf.logical_or(is_edge, is_selected_) 720 # [num_unique_edges] 721 is_selected = tf.logical_and(is_selected, is_selected_) 722 723 # Extract only the selected lanes 724 # [num_selected_edges] 725 selected_indices = tf.where(is_selected)[:, 0] 726 # [num_selected_edges, 2, 3] 727 selected_edges = unique_edges[is_selected] 728 # [num_selected_edges, 3] 729 selected_wedges_start = selected_edges[:,0] 730 selected_wedges_end = selected_edges[:,1] 731 # [num_selected_edges, 3] 732 n1 = n1[is_selected] 733 n2 = n2[is_selected] 734 # [num_selected_edges, 2, 3] 735 # n1: 0-face 736 # n2: n-face 737 normals = tf.stack([n1, n2], axis=1) 738 739 # Pre-compute a mapping from primitive index to (up to) three wedges. 740 # Recall that by construction, `all_edges` is ordered by 741 # primitive (3 rows per primitive). 742 # Then, `indices_of_unique` gives the mapping from the row indices 743 # of `all_edges` to the row indices of `unique_edges`. 744 # Finally, a subset of `unique_edges` via the `is_double` mask. 745 # 746 # In the end, all we need to do is renumber the values of 747 # `indices_of_unique` to refer to rows of `selected_edges` instead of 748 # rows of `unique_edges`. 749 seq = tf.cast(tf.range(selected_edges.shape[0]), dtype=tf.int32) 750 unique_edge_index_to_double_edge_index = \ 751 tf.tensor_scatter_nd_update(default, selected_indices[:, None], seq) 752 # [num_prim, 3] 753 prim_to_wedges = tf.reshape( 754 tf.gather(unique_edge_index_to_double_edge_index,indices_of_unique), 755 (-1, 3) 756 ) 757 758 # Indices of the objects to which each edge belongs 759 # First, indices (in all_edges) of edges 760 # [num_unique_edges, 2] 761 wedges_indices = tf.stack([all_edges_index_1, 762 all_edges_index_2], axis=1) 763 # Keep only the selected wedges 764 # [num_selected_edges, 2] 765 wedges_indices = wedges_indices[is_selected] 766 # [num_selected_edges] 767 is_edge = is_edge[is_selected] 768 # Wedges index 2 primitive index 769 # [num_selected_edges, 2] 770 wedges_2_prim = wedges_indices//3 771 # Primitive index 2 object index 772 # [num_selected_edges, 2] 773 wedges_2_object = tf.gather(self._primitives_2_objects, wedges_2_prim) 774 775 # Edges length and edge vector 776 # The edge vector e_hat must be such that: 777 # normalize(n_0 x n_n) = e_hat, 778 # where n_0 is the normal to the 0-face and n_n the normal to the 779 # n-face 780 # [num_selected_edges, 3] 781 e_hat,_ = normalize(cross(normals[...,0,:],normals[...,1,:])) 782 # Select the wedges' origin according to the normals 783 # e_hat_ind: [num_selected_edges, 3] 784 # length : [num_selected_edges] 785 e_hat_ind,length = normalize(selected_wedges_end-selected_wedges_start) 786 # [num_selected_edges] 787 origin_indicator = dot(e_hat, e_hat_ind) 788 # [num_selected_edges, 3] 789 origin_indicator = tf.expand_dims(origin_indicator, axis=1) 790 origin = tf.where(origin_indicator < 0, 791 selected_wedges_end, 792 selected_wedges_start) 793 # Set arbitrarely the vector for the edges 794 # [num_selected_edges, 3] 795 e_hat = tf.where(tf.expand_dims(is_edge, axis=1), e_hat_ind, e_hat) 796 797 # Output 798 output = ( 799 origin, # wedges origins 800 e_hat, # Wedge vector 801 length, # Wedge length 802 normals, # wedges_normals 803 prim_to_wedges, # primitives_2_wedges 804 wedges_2_object, # wedges_objects 805 is_edge # is_edge 806 ) 807 808 return output 809 810 def _swap_edges(self, edges): 811 """Swap edges extremities such that identical edges are oriented in 812 the same way. 813 814 Parameters 815 ---------- 816 edges : [...,2,3], float 817 Batch of edges extremities 818 819 Returns 820 ------- 821 [..., 2, 3], float 822 Reoriented edges 823 """ 824 p0 = edges[:,0,:] 825 p1 = edges[:,1,:] 826 p0_hat, r0 = normalize(p0) 827 p1_hat, r1 = normalize(p1) 828 theta0, phi0 = theta_phi_from_unit_vec(p0_hat) 829 theta1, phi1 = theta_phi_from_unit_vec(p1_hat) 830 831 # Three considtions are used to orientate the edges 832 # by swapping the extremities (p0,p1). 833 # Condition n+1 is used only if none of the previous n conditions enabled 834 # to separate the edges. 835 # 1. norm(p1) >= norm(p0) 836 # 2. azimuth(p1) >= azimuth(p0) 837 # 3. elevation (p1) >= elevation(p0) 838 839 # More details of the algorithm: 840 # needs_swap 1: !r_equal and r0 > r1 841 # needs_swap 2: r_equal and !phi_equal and phi0 > phi1 842 # needs_swap 3: r_equal and phi_equal and theta0 > theta1 843 # Note: case when all three coordinates are equal is not considered 844 845 r_equal = tf.experimental.numpy.isclose(r0, r1) 846 phi_equal = tf.experimental.numpy.isclose(phi0, phi1) 847 case_2 = tf.logical_and(r_equal, tf.logical_not(phi_equal)) 848 case_3 = tf.logical_and(r_equal, phi_equal) 849 850 needs_swap_1 = tf.logical_and(tf.logical_not(r_equal), r0 > r1) 851 needs_swap_2 = tf.logical_and(case_2, phi0 > phi1) 852 needs_swap_3 = tf.logical_and(case_3, theta0 > theta1) 853 854 needs_swap = tf.reduce_any(tf.stack([needs_swap_1, 855 needs_swap_2, 856 needs_swap_3], axis=1), 857 keepdims=True, axis=1) 858 859 result = tf.concat([ 860 tf.expand_dims(tf.where(needs_swap, p1, p0), axis=1), 861 tf.expand_dims(tf.where(needs_swap, p0, p1), axis=1), 862 ], axis=1) 863 return result 864 865 def _wedges_from_primitives(self, candidates, edge_diffraction): 866 r""" 867 Returns the candidate wedges from the candidate primitives. 868 869 As only first-order diffraction is considered, only the wedges of the 870 primitives in line-of-sight of the transmitter are considered. 871 872 Input 873 ------ 874 candidates: [max_depth, num_samples], int 875 Candidate paths with depth up to ``max_depth``. 876 Entries correspond to primitives indices. 877 For paths with depth lower than ``max_depth``, -1 is used as 878 padding value. 879 The first path is the LoS one. 880 881 edge_diffraction : bool 882 If set to `False`, only diffraction on wedges, i.e., edges that 883 connect two primitives, is considered. 884 885 Output 886 ------- 887 candidate_wedges : [num_candidate_wedges], int 888 Candidate wedges. 889 Entries correspond to wedges indices. 890 """ 891 892 # If no candidates, return an empty list 893 # Useful to manage empty scenes 894 if candidates.shape[0] == 0: 895 return tf.constant([], tf.int32) 896 897 # Remove -1 898 candidates = tf.gather(candidates, 899 tf.where(tf.not_equal(candidates, -1))[:,0]) 900 901 # Remove duplicates 902 candidates,_ = tf.unique(candidates) 903 904 # [num_samples, 3] 905 candidate_wedges = tf.gather(self._primitives_2_wedges, candidates, 906 axis=0) 907 # [num_samples*3] 908 candidate_wedges = tf.reshape(candidate_wedges, [-1]) 909 910 # Remove -1 911 # [<= num_samples*3] 912 candidate_wedges = tf.gather(candidate_wedges, 913 tf.where(tf.not_equal(candidate_wedges, -1))[:,0]) 914 915 # Remove duplicates 916 # [num_candidate_wedges] 917 candidate_wedges,_ = tf.unique(candidate_wedges) 918 919 # Remove edges if required 920 if not edge_diffraction: 921 # [num_candidate_wedges] 922 is_wedge = ~tf.gather(self._is_edge, candidate_wedges) 923 wedge_indices = tf.where(is_wedge)[:,0] 924 # [num_candidate_wedges] 925 candidate_wedges = tf.gather(candidate_wedges, wedge_indices) 926 927 return candidate_wedges 928 929 def _build_mi_ris_objects(self): 930 r""" 931 Builds a Mitsuba scene containing all RIS as rectangles with position, 932 orientation, and size matching the RIS properties.∂ 933 934 Output 935 ------ 936 : list(mi.Rectangle) 937 List of Mitsuba rectangles implementing the RIS 938 939 : mi.UInt 940 RIS indices 941 """ 942 # List of all the RIS objects in the scene 943 all_ris = list(self._scene.ris.values()) 944 num_ris = len(all_ris) 945 946 # Creates a scene containing RIS as rectangles 947 mi_ris_objects = [] 948 ris_indices = dr.zeros(mi.ShapePtr, num_ris) 949 for i, ris in enumerate(all_ris): 950 center = ris.position 951 orientation = ris.orientation 952 size = ris.size 953 mi_to_world = mitsuba_rectangle_to_world(center, orientation, size, 954 ris=True) 955 ris_rect = mi.load_dict({ "type" : "rectangle", 956 "to_world" : mi_to_world 957 }) 958 mi_ris_objects.append(ris_rect) 959 dr.scatter(ris_indices, mi.ShapePtr(ris_rect), i) 960 ris_indices = dr.reinterpret_array_v(mi.UInt32, ris_indices) 961 962 return mi_ris_objects, ris_indices 963 964 def _ris_intersect(self, ris_objects, ray, active): 965 r""" 966 Test the intersection with the RIS 967 968 Input 969 ------ 970 ris_objects : list(mi.Rectangle) 971 List of Mitsuba rectangles implementing the RIS 972 973 Output 974 ------- 975 valid : mi.Bool 976 Mask indicating if the intersection is valid 977 978 t : mi.Float 979 Position of the intersection on the ray 980 981 indices : mi.UInt32 982 Indices of the intersected RIS 983 """ 984 985 num_rays = dr.shape(ray.d)[1] 986 t = dr.full(mi.Float, float('inf'), num_rays) 987 valid = dr.full(mi.Bool, False, num_rays) 988 indices = dr.full(mi.UInt, 0, num_rays) 989 for ris in ris_objects: 990 si_ris = ris.ray_intersect(ray, active=active) 991 v_ = si_ris.is_valid() 992 t_ = si_ris.t 993 indices_ = dr.reinterpret_array_v(mi.UInt32, si_ris.shape) 994 995 valid |= v_ 996 new_closest = v_ & (t_ < t) 997 998 t = dr.select(new_closest, t_, t) 999 indices = dr.select(new_closest, indices_, indices) 1000 1001 return valid, t, indices