utils.py (25552B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """ 6 Ray tracer utilities 7 """ 8 9 import tensorflow as tf 10 import mitsuba as mi 11 import drjit as dr 12 13 from sionna import config 14 from sionna.utils import expand_to_rank, log10 15 from sionna import PI 16 17 def rotation_matrix(angles): 18 r""" 19 Computes rotation matrices as defined in :eq:`rotation` 20 21 The closed-form expression in (7.1-4) [TR38901]_ is used. 22 23 Input 24 ------ 25 angles : [...,3], tf.float 26 Angles for the rotations [rad]. 27 The last dimension corresponds to the angles 28 :math:`(\alpha,\beta,\gamma)` that define 29 rotations about the axes :math:`(z, y, x)`, 30 respectively. 31 32 Output 33 ------- 34 : [...,3,3], tf.float 35 Rotation matrices 36 """ 37 38 a = angles[...,0] 39 b = angles[...,1] 40 c = angles[...,2] 41 cos_a = tf.cos(a) 42 cos_b = tf.cos(b) 43 cos_c = tf.cos(c) 44 sin_a = tf.sin(a) 45 sin_b = tf.sin(b) 46 sin_c = tf.sin(c) 47 48 r_11 = cos_a*cos_b 49 r_12 = cos_a*sin_b*sin_c - sin_a*cos_c 50 r_13 = cos_a*sin_b*cos_c + sin_a*sin_c 51 r_1 = tf.stack([r_11, r_12, r_13], axis=-1) 52 53 r_21 = sin_a*cos_b 54 r_22 = sin_a*sin_b*sin_c + cos_a*cos_c 55 r_23 = sin_a*sin_b*cos_c - cos_a*sin_c 56 r_2 = tf.stack([r_21, r_22, r_23], axis=-1) 57 58 r_31 = -sin_b 59 r_32 = cos_b*sin_c 60 r_33 = cos_b*cos_c 61 r_3 = tf.stack([r_31, r_32, r_33], axis=-1) 62 63 rot_mat = tf.stack([r_1, r_2, r_3], axis=-2) 64 return rot_mat 65 66 def rotate(p, angles, inverse=False): 67 r""" 68 Rotates points ``p`` by the ``angles`` according 69 to the 3D rotation defined in :eq:`rotation` 70 71 Input 72 ----- 73 p : [...,3], tf.float 74 Points to rotate 75 76 angles : [..., 3] 77 Angles for the rotations [rad]. 78 The last dimension corresponds to the angles 79 :math:`(\alpha,\beta,\gamma)` that define 80 rotations about the axes :math:`(z, y, x)`, 81 respectively. 82 83 inverse : bool 84 If `True`, the inverse rotation is applied, 85 i.e., the transpose of the rotation matrix is used. 86 Defaults to `False` 87 88 Output 89 ------ 90 : [...,3] 91 Rotated points ``p`` 92 """ 93 94 # Rotation matrix 95 # [..., 3, 3] 96 rot_mat = rotation_matrix(angles) 97 rot_mat = expand_to_rank(rot_mat, tf.rank(p)+1, 0) 98 99 # Rotation around ``center`` 100 # [..., 3] 101 rot_p = tf.linalg.matvec(rot_mat, p, transpose_a=inverse) 102 103 return rot_p 104 105 def theta_phi_from_unit_vec(v): 106 r""" 107 Computes zenith and azimuth angles (:math:`\theta,\varphi`) 108 from unit-norm vectors as described in :eq:`theta_phi` 109 110 Input 111 ------ 112 v : [...,3], tf.float 113 Tensor with unit-norm vectors in the last dimension 114 115 Output 116 ------- 117 theta : [...], tf.float 118 Zenith angles :math:`\theta` 119 120 phi : [...], tf.float 121 Azimuth angles :math:`\varphi` 122 """ 123 x = v[...,0] 124 y = v[...,1] 125 z = v[...,2] 126 127 # If v = z, then x = 0 and y = 0. In this case, atan2 is not differentiable, 128 # leading to NaN when computing the gradients. 129 # The following lines force x to one this case. Note that this does not 130 # impact the output meaningfully, as in that case theta = 0 and phi can 131 # take any value. 132 zero = tf.zeros_like(x) 133 is_unit_z = tf.logical_and(tf.equal(x, zero), tf.equal(y, zero)) 134 is_unit_z = tf.cast(is_unit_z, x.dtype) 135 x += is_unit_z 136 137 theta = acos_diff(z) 138 phi = tf.math.atan2(y, x) 139 return theta, phi 140 141 def r_hat(theta, phi): 142 r""" 143 Computes the spherical unit vetor :math:`\hat{\mathbf{r}}(\theta, \phi)` 144 as defined in :eq:`spherical_vecs` 145 146 Input 147 ------- 148 theta : arbitrary shape, tf.float 149 Zenith angles :math:`\theta` [rad] 150 151 phi : same shape as ``theta``, tf.float 152 Azimuth angles :math:`\varphi` [rad] 153 154 Output 155 -------- 156 rho_hat : ``phi.shape`` + [3], tf.float 157 Vector :math:`\hat{\mathbf{r}}(\theta, \phi)` on unit sphere 158 """ 159 rho_hat = tf.stack([tf.sin(theta)*tf.cos(phi), 160 tf.sin(theta)*tf.sin(phi), 161 tf.cos(theta)], axis=-1) 162 return rho_hat 163 164 def theta_hat(theta, phi): 165 r""" 166 Computes the spherical unit vector 167 :math:`\hat{\boldsymbol{\theta}}(\theta, \varphi)` 168 as defined in :eq:`spherical_vecs` 169 170 Input 171 ------- 172 theta : arbitrary shape, tf.float 173 Zenith angles :math:`\theta` [rad] 174 175 phi : same shape as ``theta``, tf.float 176 Azimuth angles :math:`\varphi` [rad] 177 178 Output 179 -------- 180 theta_hat : ``phi.shape`` + [3], tf.float 181 Vector :math:`\hat{\boldsymbol{\theta}}(\theta, \varphi)` 182 """ 183 x = tf.cos(theta)*tf.cos(phi) 184 y = tf.cos(theta)*tf.sin(phi) 185 z = -tf.sin(theta) 186 return tf.stack([x,y,z], -1) 187 188 def phi_hat(phi): 189 r""" 190 Computes the spherical unit vector 191 :math:`\hat{\boldsymbol{\varphi}}(\theta, \varphi)` 192 as defined in :eq:`spherical_vecs` 193 194 Input 195 ------- 196 phi : same shape as ``theta``, tf.float 197 Azimuth angles :math:`\varphi` [rad] 198 199 Output 200 -------- 201 theta_hat : ``phi.shape`` + [3], tf.float 202 Vector :math:`\hat{\boldsymbol{\varphi}}(\theta, \varphi)` 203 """ 204 x = -tf.sin(phi) 205 y = tf.cos(phi) 206 z = tf.zeros_like(x) 207 return tf.stack([x,y,z], -1) 208 209 def cross(u, v): 210 r""" 211 Computes the cross (or vector) product between u and v 212 213 Input 214 ------ 215 u : [...,3] 216 First vector 217 218 v : [...,3] 219 Second vector 220 221 Output 222 ------- 223 : [...,3] 224 Cross product between ``u`` and ``v`` 225 """ 226 u_x = u[...,0] 227 u_y = u[...,1] 228 u_z = u[...,2] 229 230 v_x = v[...,0] 231 v_y = v[...,1] 232 v_z = v[...,2] 233 234 w = tf.stack([u_y*v_z - u_z*v_y, 235 u_z*v_x - u_x*v_z, 236 u_x*v_y - u_y*v_x], axis=-1) 237 238 return w 239 240 def dot(u, v, keepdim=False, clip=False): 241 r""" 242 Computes and the dot (or scalar) product between u and v 243 244 Input 245 ------ 246 u : [...,3] 247 First vector 248 249 v : [...,3] 250 Second vector 251 252 keepdim : bool 253 If `True`, keep the last dimension. 254 Defaults to `False`. 255 256 clip : bool 257 If `True`, clip output to [-1,1]. 258 Defaults to `False`. 259 260 Output 261 ------- 262 : [...,1] or [...] 263 Dot product between ``u`` and ``v``. 264 The last dimension is removed if ``keepdim`` 265 is set to `False`. 266 """ 267 res = tf.reduce_sum(u*v, axis=-1, keepdims=keepdim) 268 if clip: 269 one = tf.ones((), u.dtype) 270 res = tf.clip_by_value(res, -one, one) 271 return res 272 273 def outer(u,v): 274 r""" 275 Computes the outer product between u and v 276 277 Input 278 ------ 279 u : [...,3] 280 First vector 281 282 v : [...,3] 283 Second vector 284 285 Output 286 ------- 287 : [...,3,3] 288 Outer product between ``u`` and ``v`` 289 """ 290 return u[...,tf.newaxis] * v[...,tf.newaxis,:] 291 292 def normalize(v): 293 r""" 294 Normalizes ``v`` to unit norm 295 296 Input 297 ------ 298 v : [...,3], tf.float 299 Vector 300 301 Output 302 ------- 303 : [...,3], tf.float 304 Normalized vector 305 306 : [...], tf.float 307 Norm of the unnormalized vector 308 """ 309 norm = tf.norm(v, axis=-1, keepdims=True) 310 n_v = tf.math.divide_no_nan(v, norm) 311 norm = tf.squeeze(norm, axis=-1) 312 return n_v, norm 313 314 def moller_trumbore(o, d, p0, p1, p2, epsilon): 315 r""" 316 Computes the intersection between a ray ``ray`` and a triangle defined 317 by its vertices ``p0``, ``p1``, and ``p2`` using the Moller–Trumbore 318 intersection algorithm. 319 320 Input 321 ----- 322 o, d: [..., 3], tf.float 323 Ray origin and direction. 324 The direction `d` must be a unit vector. 325 326 p0, p1, p2: [..., 3], tf.float 327 Vertices defining the triangle 328 329 epsilon : (), tf.float 330 Small value used to avoid errors due to numerical precision 331 332 Output 333 ------- 334 t : [...], tf.float 335 Position along the ray from the origin at which the intersection 336 occurs (if any) 337 338 hit : [...], bool 339 `True` if the ray intersects the triangle. `False` otherwise. 340 """ 341 342 rdtype = o.dtype 343 zero = tf.cast(0.0, rdtype) 344 one = tf.ones((), rdtype) 345 346 # [..., 3] 347 e1 = p1 - p0 348 e2 = p2 - p0 349 350 # [...,3] 351 pvec = cross(d, e2) 352 # [...,1] 353 det = dot(e1, pvec, keepdim=True) 354 355 # If the ray is parallel to the triangle, then det = 0. 356 hit = tf.greater(tf.abs(det), zero) 357 358 # [...,3] 359 tvec = o - p0 360 # [...,1] 361 u = tf.math.divide_no_nan(dot(tvec, pvec, keepdim=True), det) 362 # [...,1] 363 hit = tf.logical_and(hit, 364 tf.logical_and(tf.greater_equal(u, -epsilon), 365 tf.less_equal(u, one + epsilon))) 366 367 # [..., 3] 368 qvec = cross(tvec, e1) 369 # [...,1] 370 v = tf.math.divide_no_nan(dot(d, qvec, keepdim=True), det) 371 # [..., 1] 372 hit = tf.logical_and(hit, 373 tf.logical_and(tf.greater_equal(v, -epsilon), 374 tf.less_equal(u + v, one + epsilon))) 375 # [..., 1] 376 t = tf.math.divide_no_nan(dot(e2, qvec, keepdim=True), det) 377 # [..., 1] 378 hit = tf.logical_and(hit, tf.greater_equal(t, epsilon)) 379 380 # [...] 381 t = tf.squeeze(t, axis=-1) 382 hit = tf.squeeze(hit, axis=-1) 383 384 return t, hit 385 386 def component_transform(e_s, e_p, e_i_s, e_i_p): 387 """ 388 Compute basis change matrix for reflections 389 390 Input 391 ----- 392 e_s : [..., 3], tf.float 393 Source unit vector for S polarization 394 395 e_p : [..., 3], tf.float 396 Source unit vector for P polarization 397 398 e_i_s : [..., 3], tf.float 399 Target unit vector for S polarization 400 401 e_i_p : [..., 3], tf.float 402 Target unit vector for P polarization 403 404 Output 405 ------- 406 r : [..., 2, 2], tf.float 407 Change of basis matrix for going from (e_s, e_p) to (e_i_s, e_i_p) 408 """ 409 r_11 = dot(e_i_s, e_s) 410 r_12 = dot(e_i_s, e_p) 411 r_21 = dot(e_i_p, e_s) 412 r_22 = dot(e_i_p, e_p) 413 r1 = tf.stack([r_11, r_12], axis=-1) 414 r2 = tf.stack([r_21, r_22], axis=-1) 415 r = tf.stack([r1, r2], axis=-2) 416 return r 417 418 def mi_to_tf_tensor(mi_tensor, dtype): 419 """ 420 Get a TensorFlow eager tensor from a Mitsuba/DrJIT tensor 421 """ 422 dr.eval(mi_tensor) 423 dr.sync_thread() 424 # When there is only one input, the .tf() methods crashes. 425 # The following hack takes care of this corner case 426 if dr.shape(mi_tensor)[-1] == 1: 427 mi_tensor = dr.repeat(mi_tensor, 2) 428 tf_tensor = tf.cast(mi_tensor.tf(), dtype)[:1] 429 else: 430 tf_tensor = tf.cast(mi_tensor.tf(), dtype) 431 return tf_tensor 432 433 def gen_orthogonal_vector(k, epsilon): 434 """ 435 Generate an arbitrary vector that is orthogonal to ``k``. 436 437 Input 438 ------ 439 k : [..., 3], tf.float 440 Vector 441 442 epsilon : (), tf.float 443 Small value used to avoid errors due to numerical precision 444 445 Output 446 ------- 447 : [..., 3], tf.float 448 Vector orthogonal to ``k`` 449 """ 450 rdtype = k.dtype 451 ex = tf.cast([1.0, 0.0, 0.0], rdtype) 452 ex = expand_to_rank(ex, tf.rank(k), 0) 453 454 ey = tf.cast([0.0, 1.0, 0.0], rdtype) 455 ey = expand_to_rank(ey, tf.rank(k), 0) 456 457 n1 = cross(k, ex) 458 n1_norm = tf.norm(n1, axis=-1, keepdims=True) 459 n2 = cross(k, ey) 460 return tf.where(tf.greater(n1_norm, epsilon), n1, n2) 461 462 def compute_field_unit_vectors(k_i, k_r, n, epsilon, return_e_r=True): 463 """ 464 Compute unit vector parallel and orthogonal to incident plane 465 466 Input 467 ------ 468 k_i : [..., 3], tf.float 469 Direction of arrival 470 471 k_r : [..., 3], tf.float 472 Direction of reflection 473 474 n : [..., 3], tf.float 475 Surface normal 476 477 epsilon : (), tf.float 478 Small value used to avoid errors due to numerical precision 479 480 return_e_r : bool 481 If `False`, only ``e_i_s`` and ``e_i_p`` are returned. 482 483 Output 484 ------ 485 e_i_s : [..., 3], tf.float 486 Incident unit field vector for S polarization 487 488 e_i_p : [..., 3], tf.float 489 Incident unit field vector for P polarization 490 491 e_r_s : [..., 3], tf.float 492 Reflection unit field vector for S polarization. 493 Only returned if ``return_e_r`` is `True`. 494 495 e_r_p : [..., 3], tf.float 496 Reflection unit field vector for P polarization 497 Only returned if ``return_e_r`` is `True`. 498 """ 499 e_i_s = cross(k_i, n) 500 e_i_s_norm = tf.norm(e_i_s, axis=-1, keepdims=True) 501 # In case of normal incidence, the incidence plan is not uniquely 502 # define and the Fresnel coefficent is the same for both polarization 503 # (up to a sign flip for the parallel component due to the definition of 504 # polarization). 505 # It is required to detect such scenarios and define an arbitrary valid 506 # e_i_s to fix an incidence plane, as the result from previous 507 # computation leads to e_i_s = 0. 508 e_i_s = tf.where(tf.greater(e_i_s_norm, epsilon), e_i_s, 509 gen_orthogonal_vector(n, epsilon)) 510 511 e_i_s,_ = normalize(e_i_s) 512 e_i_p,_ = normalize(cross(e_i_s, k_i)) 513 if not return_e_r: 514 return e_i_s, e_i_p 515 else: 516 e_r_s = e_i_s 517 e_r_p,_ = normalize(cross(e_r_s, k_r)) 518 return e_i_s, e_i_p, e_r_s, e_r_p 519 520 def reflection_coefficient(eta, cos_theta): 521 """ 522 Compute simplified reflection coefficients 523 524 Input 525 ------ 526 eta : Any shape, tf.complex 527 Complex relative permittivity 528 529 cos_theta : Same as ``eta``, tf.float 530 Cosine of the incident angle 531 532 Output 533 ------- 534 r_te : Same as input, tf.complex 535 Fresnel reflection coefficient for S direction 536 537 r_tm : Same as input, tf.complex 538 Fresnel reflection coefficient for P direction 539 """ 540 cos_theta = tf.complex(cos_theta, tf.zeros_like(cos_theta)) 541 542 # Fresnel equations 543 a = cos_theta 544 b = tf.sqrt(eta-1.+cos_theta**2) 545 r_te = tf.math.divide_no_nan(a-b, a+b) 546 547 c = eta*a 548 d = b 549 r_tm = tf.math.divide_no_nan(c-d, c+d) 550 return r_te, r_tm 551 552 def paths_to_segments(paths): 553 """ 554 Extract the segments corresponding to a set of ``paths`` 555 556 Input 557 ----- 558 paths : :class:`~sionna.rt.Paths` 559 A set of paths 560 561 Output 562 ------- 563 starts, ends : [n,3], float 564 Endpoints of the segments making the paths. 565 """ 566 567 vertices = paths.vertices.numpy() 568 objects = paths.objects.numpy() 569 mask = paths.targets_sources_mask 570 sources, targets = paths.sources.numpy(), paths.targets.numpy() 571 572 # Emit directly two lists of the beginnings and endings of line segments 573 starts = [] 574 ends = [] 575 for rx in range(vertices.shape[1]): # For each receiver 576 for tx in range(vertices.shape[2]): # For each transmitter 577 for p in range(vertices.shape[3]): # For each path depth 578 if not mask[rx, tx, p]: 579 continue 580 581 start = sources[tx] 582 i = 0 583 while ( (i < objects.shape[0]) 584 and (objects[i, rx, tx, p] != -1) ): 585 end = vertices[i, rx, tx, p] 586 starts.append(start) 587 ends.append(end) 588 start = end 589 i += 1 590 # Explicitly add the path endpoint 591 starts.append(start) 592 ends.append(targets[rx]) 593 return starts, ends 594 595 def scene_scale(scene): 596 bbox = scene.mi_scene.bbox() 597 tx_positions, rx_positions, ris_positions = {}, {}, {} 598 devices = ((scene.transmitters, tx_positions), 599 (scene.receivers, rx_positions), 600 (scene.ris, ris_positions) 601 ) 602 for source, destination in devices: 603 for k, rd in source.items(): 604 p = rd.position.numpy() 605 bbox.expand(p) 606 destination[k] = p 607 608 sc = 2. * bbox.bounding_sphere().radius 609 return sc, tx_positions, rx_positions, ris_positions, bbox 610 611 def fibonacci_lattice(num_points, dtype=tf.float32): 612 """ 613 Generates a Fibonacci lattice for the unit square 614 615 Input 616 ----- 617 num_points : int 618 Number of points 619 620 type : tf.DType 621 Datatype to use for the output 622 623 Output 624 ------- 625 points : [num_points, 2] 626 Generated rectangular coordinates of the lattice points 627 """ 628 629 golden_ratio = (1.+tf.sqrt(tf.cast(5., tf.float64)))/2. 630 ns = tf.range(0, num_points, dtype=tf.float64) 631 632 x = ns/golden_ratio 633 x = x - tf.floor(x) 634 y = ns/(num_points-1) 635 points = tf.stack([x,y], axis=1) 636 637 points = tf.cast(points, dtype) 638 639 return points 640 641 def cot(x): 642 """ 643 Cotangens function 644 645 Input 646 ------ 647 x : [...], tf.float 648 649 Output 650 ------- 651 : [...], tf.float 652 Cotangent of x 653 """ 654 return tf.math.divide_no_nan(tf.ones_like(x), tf.math.tan(x)) 655 656 def sign(x): 657 """ 658 Returns +1 if ``x`` is non-negative, -1 otherwise 659 660 Input 661 ------ 662 x : [...], tf.float 663 A real-valued number 664 665 Output 666 ------- 667 : [...], tf.float 668 +1 if ``x`` is non-negative, -1 otherwise 669 """ 670 two = tf.cast(2, x.dtype) 671 one = tf.cast(1, x.dtype) 672 return two*tf.cast(tf.greater_equal(x, 0), x.dtype)-one 673 674 def rot_mat_from_unit_vecs(a, b): 675 r""" 676 Computes Rodrigues` rotation formula :eq:`rodrigues_matrix` 677 678 Input 679 ------ 680 a : [...,3], tf.float 681 First unit vector 682 683 b : [...,3], tf.float 684 Second unit vector 685 686 Output 687 ------- 688 : [...,3,3], tf.float 689 Rodrigues' rotation matrix 690 """ 691 692 rdtype = a.dtype 693 694 # Compute rotation axis vector 695 k, _ = normalize(cross(a, b)) 696 697 # Deal with special case where a and b are parallel 698 o = gen_orthogonal_vector(a, 1e-6) 699 k = tf.where(tf.reduce_sum(tf.abs(k), axis=-1, keepdims=True)==0, o, k) 700 701 # Compute K matrix 702 shape = tf.concat([tf.shape(k)[:-1],[1]], axis=-1) 703 zeros = tf.zeros(shape, rdtype) 704 kx, ky, kz = tf.split(k, 3, axis=-1) 705 l1 = tf.concat([zeros, -kz, ky], axis=-1) 706 l2 = tf.concat([kz, zeros, -kx], axis=-1) 707 l3 = tf.concat([-ky, kx, zeros], axis=-1) 708 k_mat = tf.stack([l1, l2, l3], axis=-2) 709 710 # Assemble full rotation matrix 711 eye = tf.eye(3, batch_shape=tf.shape(k)[:-1], dtype=rdtype) 712 cos_theta = dot(a, b, clip=True) 713 sin_theta = tf.sin(acos_diff(cos_theta)) 714 cos_theta = expand_to_rank(cos_theta, tf.rank(eye), axis=-1) 715 sin_theta = expand_to_rank(sin_theta, tf.rank(eye), axis=-1) 716 rot_mat = eye + k_mat*sin_theta + \ 717 tf.linalg.matmul(k_mat, k_mat) * (1-cos_theta) 718 return rot_mat 719 720 def sample_points_on_hemisphere(normals, num_samples=1): 721 # pylint: disable=line-too-long 722 r""" 723 Randomly sample points on hemispheres defined by their normal vectors 724 725 Input 726 ----- 727 normals : [batch_size, 3], tf.float 728 Normal vectors defining hemispheres 729 730 num_samples : int 731 Number of random samples to draw for each hemisphere 732 defined by its normal vector. 733 Defaults to 1. 734 735 Output 736 ------ 737 points : [batch_size, num_samples, 3], tf.float or [batch_size, 3], tf.float if num_samples=1. 738 Random points on the hemispheres 739 """ 740 dtype = normals.dtype 741 batch_size = tf.shape(normals)[0] 742 shape = [batch_size, num_samples] 743 744 # Sample phi uniformly distributed on [0,2*PI] 745 phi = config.tf_rng.uniform(shape, maxval=2*PI, dtype=dtype) 746 747 # Generate samples of theta for uniform distribution on the hemisphere 748 u = config.tf_rng.uniform(shape, maxval=1, dtype=dtype) 749 theta = tf.acos(u) 750 751 # Transform spherical to Cartesian coordinates 752 points = r_hat(theta, phi) 753 754 # Compute rotation matrices 755 z_hat = tf.constant([[0,0,1]], dtype=dtype) 756 z_hat = tf.broadcast_to(z_hat, tf.shape(normals)) 757 rot_mat = rot_mat_from_unit_vecs(z_hat, normals) 758 rot_mat = tf.expand_dims(rot_mat, axis=1) 759 760 # Compute rotated points 761 points = tf.linalg.matvec(rot_mat, points) 762 763 # Numerical errors can cause sampling from the other hemisphere. 764 # Correct the sampled vector to avoid sampling in the wrong hemisphere. 765 normals = tf.expand_dims(normals, axis=1) 766 s = dot(points, normals, keepdim=True) 767 s = tf.where(s < 0., s, 0.) 768 points = points - 2.*s*normals 769 770 if num_samples==1: 771 points = tf.squeeze(points, axis=1) 772 773 return points 774 775 def acos_diff(x, epsilon=1e-7): 776 r""" 777 Implementation of arccos(x) that avoids evaluating the gradient at x 778 -1 or 1 by using straight through estimation, i.e., in the 779 forward pass, x is clipped to (-1, 1), but in the backward pass, x is 780 clipped to (-1 + epsilon, 1 - epsilon). 781 782 Input 783 ------ 784 x : any shape, tf.float 785 Value at which to evaluate arccos 786 787 epsilon : tf.float 788 Small backoff to avoid evaluating the gradient at -1 or 1. 789 Defaults to 1e-7. 790 791 Output 792 ------- 793 : same shape as x, tf.float 794 arccos(x) 795 """ 796 797 x_clip_1 = tf.clip_by_value(x, -1., 1.) 798 x_clip_2 = tf.clip_by_value(x, -1. + epsilon, 1. - epsilon) 799 eps = tf.stop_gradient(x - x_clip_2) 800 x_1 = x - eps 801 acos_x_1 = tf.acos(x_1) 802 y = acos_x_1 + tf.stop_gradient(tf.acos(x_clip_1)-acos_x_1) 803 return y 804 805 def angles_to_mitsuba_rotation(angles): 806 """ 807 Build a Mitsuba transform from angles in radian 808 809 Input 810 ------ 811 angles : [3], tf.float 812 Angles [rad] 813 814 Output 815 ------- 816 : :class:`mitsuba.ScalarTransform4f` 817 Mitsuba rotation 818 """ 819 820 angles = 180. * angles / PI 821 822 if angles.dtype == tf.float32: 823 mi_transform_t = mi.Transform4f 824 angles = mi.Float(angles) 825 else: 826 mi_transform_t = mi.Transform4d 827 angles = mi.Float64(angles) 828 829 return ( 830 mi_transform_t.rotate(axis=[0., 0., 1.], angle=angles[0]) 831 @ mi_transform_t.rotate(axis=[0., 1., 0.], angle=angles[1]) 832 @ mi_transform_t.rotate(axis=[1., 0., 0.], angle=angles[2]) 833 ) 834 835 def gen_basis_from_z(z, epsilon): 836 """ 837 Generate a pair of vectors (x,y) such that (x,y,z) is an orthonormal basis. 838 839 Input 840 ------ 841 z : [..., 3], tf.float 842 Unit vector 843 844 epsilon : (), tf.float 845 Small value used to avoid errors due to numerical precision 846 847 Output 848 ------- 849 x : [..., 3], tf.float 850 Unit vector 851 852 y : [..., 3], tf.float 853 Unit vector 854 """ 855 x = gen_orthogonal_vector(z, epsilon) 856 x,_ = normalize(x) 857 y = cross(z, x) 858 return x,y 859 860 def compute_spreading_factor(rho_1, rho_2, s): 861 r""" 862 Computes the spreading factor 863 :math:`\sqrt{\frac{\rho_1 \rho_2}{(\rho_1 + s)(\rho_2 + s)}}` 864 865 Input 866 ------ 867 rho_1, rho_2 : [...], tf.float 868 Principal radii of curvature 869 870 s : [...], tf.float 871 Position along the axial ray at which to evaluate the squared 872 spreading factor 873 874 Output 875 ------- 876 : float 877 Squared spreading factor 878 """ 879 880 # In the case of a spherical wave, when the origin (s = 0) is set to unique 881 # caustic point, then both principal radii of curvature are set to zero. 882 # The spreading factor is then equal to 1/s. 883 spherical = tf.logical_and(tf.equal(rho_1, 0.), tf.equal(rho_2, 0.)) 884 a2_spherical = tf.math.reciprocal_no_nan(s) 885 886 # General formula for the spreading factor 887 a2 = tf.sqrt(rho_1*rho_2/((rho_1+s)*(rho_2+s))) 888 889 a2 = tf.where(spherical, a2_spherical, a2) 890 return a2 891 892 def mitsuba_rectangle_to_world(center, orientation, size, ris=False): 893 """ 894 Build the `to_world` transformation that maps a default Mitsuba rectangle 895 to the rectangle that defines the coverage map surface. 896 897 Input 898 ------ 899 center : [3], tf.float 900 Center of the rectangle 901 902 orientation : [3], tf.float 903 Orientation of the rectangle. 904 An orientation of `(0,0,0)` correspond to a rectangle oriented such that 905 z+ is its normal. 906 907 size : [2], tf.float 908 Scale of the rectangle. 909 The width of the rectangle (in the local X direction) is scale[0] 910 and its height (in the local Y direction) scale[1]. 911 912 Output 913 ------- 914 to_world : :class:`mitsuba.ScalarTransform4f` 915 Rectangle to world transformation. 916 """ 917 orientation = 180. * orientation / PI 918 919 trans = \ 920 mi.ScalarTransform4f.translate(center.numpy())\ 921 @ mi.ScalarTransform4f.rotate(axis=[0, 0, 1], angle=orientation[0])\ 922 @ mi.ScalarTransform4f.rotate(axis=[0, 1, 0], angle=orientation[1])\ 923 @ mi.ScalarTransform4f.rotate(axis=[1, 0, 0], angle=orientation[2]) 924 925 if ris: 926 # The RIS normal points at [1,0,0]. 927 # We hence rotate the normal of the rectangle which points 928 # at [0,0,1] by 90 degrees around the [0,1,0] axis. 929 trans = trans\ 930 @mi.ScalarTransform4f.rotate(axis=[0, 1, 0], angle=90) 931 932 # size = [width (=y), height (=z)] 933 # Since the RIS is rotated w.r.t to rectangle, 934 # The z axis corresponds to the x axis 935 size = [size[1], size[0]] 936 937 return (trans 938 @mi.ScalarTransform4f.scale([0.5 * size[0], 0.5 * size[1], 1]) 939 ) 940 941 def watt_to_dbm(power): 942 r""" Converts :math:`P_{W}` [W] to :math:`P_{dBm}` [dBm] via the formula: 943 :math:`P_{dBm} = 30 + 10 \log_{10}(P_W)` 944 945 Input 946 ------ 947 power : float 948 Power [W] 949 950 Output 951 ------- 952 : float 953 Power [dBm] 954 """ 955 return 30 + 10 * log10(power) 956 957 def dbm_to_watt(dbm): 958 r""" Converts dBm to Watt via the formula: 959 :math:`P_W = 10^{\frac{P_{dBm}-30}{10}}` 960 961 Input 962 ------ 963 dbm : float 964 Power [dBm] 965 966 Output 967 ------- 968 : float 969 Power [W] 970 """ 971 return tf.pow(10, (dbm - 30) / 10)