anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

utils.py (25552B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 Ray tracer utilities
      7 """
      8 
      9 import tensorflow as tf
     10 import mitsuba as mi
     11 import drjit as dr
     12 
     13 from sionna import config
     14 from sionna.utils import expand_to_rank, log10
     15 from sionna import PI
     16 
     17 def rotation_matrix(angles):
     18     r"""
     19     Computes rotation matrices as defined in :eq:`rotation`
     20 
     21     The closed-form expression in (7.1-4) [TR38901]_ is used.
     22 
     23     Input
     24     ------
     25     angles : [...,3], tf.float
     26         Angles for the rotations [rad].
     27         The last dimension corresponds to the angles
     28         :math:`(\alpha,\beta,\gamma)` that define
     29         rotations about the axes :math:`(z, y, x)`,
     30         respectively.
     31 
     32     Output
     33     -------
     34     : [...,3,3], tf.float
     35         Rotation matrices
     36     """
     37 
     38     a = angles[...,0]
     39     b = angles[...,1]
     40     c = angles[...,2]
     41     cos_a = tf.cos(a)
     42     cos_b = tf.cos(b)
     43     cos_c = tf.cos(c)
     44     sin_a = tf.sin(a)
     45     sin_b = tf.sin(b)
     46     sin_c = tf.sin(c)
     47 
     48     r_11 = cos_a*cos_b
     49     r_12 = cos_a*sin_b*sin_c - sin_a*cos_c
     50     r_13 = cos_a*sin_b*cos_c + sin_a*sin_c
     51     r_1 = tf.stack([r_11, r_12, r_13], axis=-1)
     52 
     53     r_21 = sin_a*cos_b
     54     r_22 = sin_a*sin_b*sin_c + cos_a*cos_c
     55     r_23 = sin_a*sin_b*cos_c - cos_a*sin_c
     56     r_2 = tf.stack([r_21, r_22, r_23], axis=-1)
     57 
     58     r_31 = -sin_b
     59     r_32 = cos_b*sin_c
     60     r_33 = cos_b*cos_c
     61     r_3 = tf.stack([r_31, r_32, r_33], axis=-1)
     62 
     63     rot_mat = tf.stack([r_1, r_2, r_3], axis=-2)
     64     return rot_mat
     65 
     66 def rotate(p, angles, inverse=False):
     67     r"""
     68     Rotates points ``p`` by the ``angles`` according
     69     to the 3D rotation defined in :eq:`rotation`
     70 
     71     Input
     72     -----
     73     p : [...,3], tf.float
     74         Points to rotate
     75 
     76     angles : [..., 3]
     77         Angles for the rotations [rad].
     78         The last dimension corresponds to the angles
     79         :math:`(\alpha,\beta,\gamma)` that define
     80         rotations about the axes :math:`(z, y, x)`,
     81         respectively.
     82 
     83     inverse : bool
     84         If `True`, the inverse rotation is applied,
     85         i.e., the transpose of the rotation matrix is used.
     86         Defaults to `False`
     87 
     88     Output
     89     ------
     90     : [...,3]
     91         Rotated points ``p``
     92     """
     93 
     94     # Rotation matrix
     95     # [..., 3, 3]
     96     rot_mat = rotation_matrix(angles)
     97     rot_mat = expand_to_rank(rot_mat, tf.rank(p)+1, 0)
     98 
     99     # Rotation around ``center``
    100     # [..., 3]
    101     rot_p = tf.linalg.matvec(rot_mat, p, transpose_a=inverse)
    102 
    103     return rot_p
    104 
    105 def theta_phi_from_unit_vec(v):
    106     r"""
    107     Computes zenith and azimuth angles (:math:`\theta,\varphi`)
    108     from unit-norm vectors as described in :eq:`theta_phi`
    109 
    110     Input
    111     ------
    112     v : [...,3], tf.float
    113         Tensor with unit-norm vectors in the last dimension
    114 
    115     Output
    116     -------
    117     theta : [...], tf.float
    118         Zenith angles :math:`\theta`
    119 
    120     phi : [...], tf.float
    121         Azimuth angles :math:`\varphi`
    122     """
    123     x = v[...,0]
    124     y = v[...,1]
    125     z = v[...,2]
    126 
    127     # If v = z, then x = 0 and y = 0. In this case, atan2 is not differentiable,
    128     # leading to NaN when computing the gradients.
    129     # The following lines force x to one this case. Note that this does not
    130     # impact the output meaningfully, as in that case theta = 0 and phi can
    131     # take any value.
    132     zero = tf.zeros_like(x)
    133     is_unit_z = tf.logical_and(tf.equal(x, zero), tf.equal(y, zero))
    134     is_unit_z = tf.cast(is_unit_z, x.dtype)
    135     x += is_unit_z
    136 
    137     theta = acos_diff(z)
    138     phi = tf.math.atan2(y, x)
    139     return theta, phi
    140 
    141 def r_hat(theta, phi):
    142     r"""
    143     Computes the spherical unit vetor :math:`\hat{\mathbf{r}}(\theta, \phi)`
    144     as defined in :eq:`spherical_vecs`
    145 
    146     Input
    147     -------
    148     theta : arbitrary shape, tf.float
    149         Zenith angles :math:`\theta` [rad]
    150 
    151     phi : same shape as ``theta``, tf.float
    152         Azimuth angles :math:`\varphi` [rad]
    153 
    154     Output
    155     --------
    156     rho_hat : ``phi.shape`` + [3], tf.float
    157         Vector :math:`\hat{\mathbf{r}}(\theta, \phi)`  on unit sphere
    158     """
    159     rho_hat = tf.stack([tf.sin(theta)*tf.cos(phi),
    160                         tf.sin(theta)*tf.sin(phi),
    161                         tf.cos(theta)], axis=-1)
    162     return rho_hat
    163 
    164 def theta_hat(theta, phi):
    165     r"""
    166     Computes the spherical unit vector
    167     :math:`\hat{\boldsymbol{\theta}}(\theta, \varphi)`
    168     as defined in :eq:`spherical_vecs`
    169 
    170     Input
    171     -------
    172     theta : arbitrary shape, tf.float
    173         Zenith angles :math:`\theta` [rad]
    174 
    175     phi : same shape as ``theta``, tf.float
    176         Azimuth angles :math:`\varphi` [rad]
    177 
    178     Output
    179     --------
    180     theta_hat : ``phi.shape`` + [3], tf.float
    181         Vector :math:`\hat{\boldsymbol{\theta}}(\theta, \varphi)`
    182     """
    183     x = tf.cos(theta)*tf.cos(phi)
    184     y = tf.cos(theta)*tf.sin(phi)
    185     z = -tf.sin(theta)
    186     return tf.stack([x,y,z], -1)
    187 
    188 def phi_hat(phi):
    189     r"""
    190     Computes the spherical unit vector
    191     :math:`\hat{\boldsymbol{\varphi}}(\theta, \varphi)`
    192     as defined in :eq:`spherical_vecs`
    193 
    194     Input
    195     -------
    196     phi : same shape as ``theta``, tf.float
    197         Azimuth angles :math:`\varphi` [rad]
    198 
    199     Output
    200     --------
    201     theta_hat : ``phi.shape`` + [3], tf.float
    202         Vector :math:`\hat{\boldsymbol{\varphi}}(\theta, \varphi)`
    203     """
    204     x = -tf.sin(phi)
    205     y = tf.cos(phi)
    206     z = tf.zeros_like(x)
    207     return tf.stack([x,y,z], -1)
    208 
    209 def cross(u, v):
    210     r"""
    211     Computes the cross (or vector) product between u and v
    212 
    213     Input
    214     ------
    215     u : [...,3]
    216         First vector
    217 
    218     v : [...,3]
    219         Second vector
    220 
    221     Output
    222     -------
    223     : [...,3]
    224         Cross product between ``u`` and ``v``
    225     """
    226     u_x = u[...,0]
    227     u_y = u[...,1]
    228     u_z = u[...,2]
    229 
    230     v_x = v[...,0]
    231     v_y = v[...,1]
    232     v_z = v[...,2]
    233 
    234     w = tf.stack([u_y*v_z - u_z*v_y,
    235                   u_z*v_x - u_x*v_z,
    236                   u_x*v_y - u_y*v_x], axis=-1)
    237 
    238     return w
    239 
    240 def dot(u, v, keepdim=False, clip=False):
    241     r"""
    242     Computes and the dot (or scalar) product between u and v
    243 
    244     Input
    245     ------
    246     u : [...,3]
    247         First vector
    248 
    249     v : [...,3]
    250         Second vector
    251 
    252     keepdim : bool
    253         If `True`, keep the last dimension.
    254         Defaults to `False`.
    255 
    256     clip : bool
    257         If `True`, clip output to [-1,1].
    258         Defaults to `False`.
    259 
    260     Output
    261     -------
    262     : [...,1] or [...]
    263         Dot product between ``u`` and ``v``.
    264         The last dimension is removed if ``keepdim``
    265         is set to `False`.
    266     """
    267     res = tf.reduce_sum(u*v, axis=-1, keepdims=keepdim)
    268     if clip:
    269         one = tf.ones((), u.dtype)
    270         res = tf.clip_by_value(res, -one, one)
    271     return res
    272 
    273 def outer(u,v):
    274     r"""
    275     Computes the outer product between u and v
    276 
    277     Input
    278     ------
    279     u : [...,3]
    280         First vector
    281 
    282     v : [...,3]
    283         Second vector
    284 
    285     Output
    286     -------
    287     : [...,3,3]
    288         Outer product between ``u`` and ``v``
    289     """
    290     return u[...,tf.newaxis] * v[...,tf.newaxis,:]
    291 
    292 def normalize(v):
    293     r"""
    294     Normalizes ``v`` to unit norm
    295 
    296     Input
    297     ------
    298     v : [...,3], tf.float
    299         Vector
    300 
    301     Output
    302     -------
    303     : [...,3], tf.float
    304         Normalized vector
    305 
    306     : [...], tf.float
    307         Norm of the unnormalized vector
    308     """
    309     norm = tf.norm(v, axis=-1, keepdims=True)
    310     n_v = tf.math.divide_no_nan(v, norm)
    311     norm = tf.squeeze(norm, axis=-1)
    312     return n_v, norm
    313 
    314 def moller_trumbore(o, d, p0, p1, p2, epsilon):
    315     r"""
    316     Computes the intersection between a ray ``ray`` and a triangle defined
    317     by its vertices ``p0``, ``p1``, and ``p2`` using the Moller–Trumbore
    318     intersection algorithm.
    319 
    320     Input
    321     -----
    322     o, d: [..., 3], tf.float
    323         Ray origin and direction.
    324         The direction `d` must be a unit vector.
    325 
    326     p0, p1, p2: [..., 3], tf.float
    327         Vertices defining the triangle
    328 
    329     epsilon : (), tf.float
    330         Small value used to avoid errors due to numerical precision
    331 
    332     Output
    333     -------
    334     t : [...], tf.float
    335         Position along the ray from the origin at which the intersection
    336         occurs (if any)
    337 
    338     hit : [...], bool
    339         `True` if the ray intersects the triangle. `False` otherwise.
    340     """
    341 
    342     rdtype = o.dtype
    343     zero = tf.cast(0.0, rdtype)
    344     one = tf.ones((), rdtype)
    345 
    346     # [..., 3]
    347     e1 = p1 - p0
    348     e2 = p2 - p0
    349 
    350     # [...,3]
    351     pvec = cross(d, e2)
    352     # [...,1]
    353     det = dot(e1, pvec, keepdim=True)
    354 
    355     # If the ray is parallel to the triangle, then det = 0.
    356     hit = tf.greater(tf.abs(det), zero)
    357 
    358     # [...,3]
    359     tvec = o - p0
    360     # [...,1]
    361     u = tf.math.divide_no_nan(dot(tvec, pvec, keepdim=True), det)
    362     # [...,1]
    363     hit = tf.logical_and(hit,
    364         tf.logical_and(tf.greater_equal(u, -epsilon),
    365                        tf.less_equal(u, one + epsilon)))
    366 
    367     # [..., 3]
    368     qvec = cross(tvec, e1)
    369     # [...,1]
    370     v = tf.math.divide_no_nan(dot(d, qvec, keepdim=True), det)
    371     # [..., 1]
    372     hit = tf.logical_and(hit,
    373                             tf.logical_and(tf.greater_equal(v, -epsilon),
    374                                         tf.less_equal(u + v, one + epsilon)))
    375     # [..., 1]
    376     t = tf.math.divide_no_nan(dot(e2, qvec, keepdim=True), det)
    377     # [..., 1]
    378     hit = tf.logical_and(hit, tf.greater_equal(t, epsilon))
    379 
    380     # [...]
    381     t = tf.squeeze(t, axis=-1)
    382     hit = tf.squeeze(hit, axis=-1)
    383 
    384     return t, hit
    385 
    386 def component_transform(e_s, e_p, e_i_s, e_i_p):
    387     """
    388     Compute basis change matrix for reflections
    389 
    390     Input
    391     -----
    392     e_s : [..., 3], tf.float
    393         Source unit vector for S polarization
    394 
    395     e_p : [..., 3], tf.float
    396         Source unit vector for P polarization
    397 
    398     e_i_s : [..., 3], tf.float
    399         Target unit vector for S polarization
    400 
    401     e_i_p : [..., 3], tf.float
    402         Target unit vector for P polarization
    403 
    404     Output
    405     -------
    406     r : [..., 2, 2], tf.float
    407         Change of basis matrix for going from (e_s, e_p) to (e_i_s, e_i_p)
    408     """
    409     r_11 = dot(e_i_s, e_s)
    410     r_12 = dot(e_i_s, e_p)
    411     r_21 = dot(e_i_p, e_s)
    412     r_22 = dot(e_i_p, e_p)
    413     r1 = tf.stack([r_11, r_12], axis=-1)
    414     r2 = tf.stack([r_21, r_22], axis=-1)
    415     r = tf.stack([r1, r2], axis=-2)
    416     return r
    417 
    418 def mi_to_tf_tensor(mi_tensor, dtype):
    419     """
    420     Get a TensorFlow eager tensor from a Mitsuba/DrJIT tensor
    421     """
    422     dr.eval(mi_tensor)
    423     dr.sync_thread()
    424     # When there is only one input, the .tf() methods crashes.
    425     # The following hack takes care of this corner case
    426     if dr.shape(mi_tensor)[-1] == 1:
    427         mi_tensor = dr.repeat(mi_tensor, 2)
    428         tf_tensor = tf.cast(mi_tensor.tf(), dtype)[:1]
    429     else:
    430         tf_tensor = tf.cast(mi_tensor.tf(), dtype)
    431     return tf_tensor
    432 
    433 def gen_orthogonal_vector(k, epsilon):
    434     """
    435     Generate an arbitrary vector that is orthogonal to ``k``.
    436 
    437     Input
    438     ------
    439     k : [..., 3], tf.float
    440         Vector
    441 
    442     epsilon : (), tf.float
    443         Small value used to avoid errors due to numerical precision
    444 
    445     Output
    446     -------
    447     : [..., 3], tf.float
    448         Vector orthogonal to ``k``
    449     """
    450     rdtype = k.dtype
    451     ex = tf.cast([1.0, 0.0, 0.0], rdtype)
    452     ex = expand_to_rank(ex, tf.rank(k), 0)
    453 
    454     ey = tf.cast([0.0, 1.0, 0.0], rdtype)
    455     ey = expand_to_rank(ey, tf.rank(k), 0)
    456 
    457     n1 = cross(k, ex)
    458     n1_norm = tf.norm(n1, axis=-1, keepdims=True)
    459     n2 = cross(k, ey)
    460     return tf.where(tf.greater(n1_norm, epsilon), n1, n2)
    461 
    462 def compute_field_unit_vectors(k_i, k_r, n, epsilon, return_e_r=True):
    463     """
    464     Compute unit vector parallel and orthogonal to incident plane
    465 
    466     Input
    467     ------
    468     k_i : [..., 3], tf.float
    469         Direction of arrival
    470 
    471     k_r : [..., 3], tf.float
    472         Direction of reflection
    473 
    474     n : [..., 3], tf.float
    475         Surface normal
    476 
    477     epsilon : (), tf.float
    478         Small value used to avoid errors due to numerical precision
    479 
    480     return_e_r : bool
    481         If `False`, only ``e_i_s`` and ``e_i_p`` are returned.
    482 
    483     Output
    484     ------
    485     e_i_s : [..., 3], tf.float
    486         Incident unit field vector for S polarization
    487 
    488     e_i_p : [..., 3], tf.float
    489         Incident unit field vector for P polarization
    490 
    491     e_r_s : [..., 3], tf.float
    492         Reflection unit field vector for S polarization.
    493         Only returned if ``return_e_r`` is `True`.
    494 
    495     e_r_p : [..., 3], tf.float
    496         Reflection unit field vector for P polarization
    497         Only returned if ``return_e_r`` is `True`.
    498     """
    499     e_i_s = cross(k_i, n)
    500     e_i_s_norm = tf.norm(e_i_s, axis=-1, keepdims=True)
    501     # In case of normal incidence, the incidence plan is not uniquely
    502     # define and the Fresnel coefficent is the same for both polarization
    503     # (up to a sign flip for the parallel component due to the definition of
    504     # polarization).
    505     # It is required to detect such scenarios and define an arbitrary valid
    506     # e_i_s to fix an incidence plane, as the result from previous
    507     # computation leads to e_i_s = 0.
    508     e_i_s = tf.where(tf.greater(e_i_s_norm, epsilon), e_i_s,
    509                      gen_orthogonal_vector(n, epsilon))
    510 
    511     e_i_s,_ = normalize(e_i_s)
    512     e_i_p,_ = normalize(cross(e_i_s, k_i))
    513     if not return_e_r:
    514         return e_i_s, e_i_p
    515     else:
    516         e_r_s = e_i_s
    517         e_r_p,_ = normalize(cross(e_r_s, k_r))
    518         return e_i_s, e_i_p, e_r_s, e_r_p
    519 
    520 def reflection_coefficient(eta, cos_theta):
    521     """
    522     Compute simplified reflection coefficients
    523 
    524     Input
    525     ------
    526     eta : Any shape, tf.complex
    527         Complex relative permittivity
    528 
    529     cos_theta : Same as ``eta``, tf.float
    530         Cosine of the incident angle
    531 
    532     Output
    533     -------
    534     r_te : Same as input, tf.complex
    535         Fresnel reflection coefficient for S direction
    536 
    537     r_tm : Same as input, tf.complex
    538         Fresnel reflection coefficient for P direction
    539     """
    540     cos_theta = tf.complex(cos_theta, tf.zeros_like(cos_theta))
    541 
    542     # Fresnel equations
    543     a = cos_theta
    544     b = tf.sqrt(eta-1.+cos_theta**2)
    545     r_te = tf.math.divide_no_nan(a-b, a+b)
    546 
    547     c = eta*a
    548     d = b
    549     r_tm = tf.math.divide_no_nan(c-d, c+d)
    550     return r_te, r_tm
    551 
    552 def paths_to_segments(paths):
    553     """
    554     Extract the segments corresponding to a set of ``paths``
    555 
    556     Input
    557     -----
    558     paths : :class:`~sionna.rt.Paths`
    559         A set of paths
    560 
    561     Output
    562     -------
    563     starts, ends : [n,3], float
    564         Endpoints of the segments making the paths.
    565     """
    566 
    567     vertices = paths.vertices.numpy()
    568     objects = paths.objects.numpy()
    569     mask = paths.targets_sources_mask
    570     sources, targets = paths.sources.numpy(), paths.targets.numpy()
    571 
    572     # Emit directly two lists of the beginnings and endings of line segments
    573     starts = []
    574     ends = []
    575     for rx in range(vertices.shape[1]): # For each receiver
    576         for tx in range(vertices.shape[2]): # For each transmitter
    577             for p in range(vertices.shape[3]): # For each path depth
    578                 if not mask[rx, tx, p]:
    579                     continue
    580 
    581                 start = sources[tx]
    582                 i = 0
    583                 while ( (i < objects.shape[0])
    584                     and (objects[i, rx, tx, p] != -1) ):
    585                     end = vertices[i, rx, tx, p]
    586                     starts.append(start)
    587                     ends.append(end)
    588                     start = end
    589                     i += 1
    590                 # Explicitly add the path endpoint
    591                 starts.append(start)
    592                 ends.append(targets[rx])
    593     return starts, ends
    594 
    595 def scene_scale(scene):
    596     bbox = scene.mi_scene.bbox()
    597     tx_positions, rx_positions, ris_positions = {}, {}, {}
    598     devices = ((scene.transmitters, tx_positions),
    599                (scene.receivers, rx_positions),
    600                (scene.ris, ris_positions)
    601               )
    602     for source, destination in devices:
    603         for k, rd in source.items():
    604             p = rd.position.numpy()
    605             bbox.expand(p)
    606             destination[k] = p
    607 
    608     sc = 2. * bbox.bounding_sphere().radius
    609     return sc, tx_positions, rx_positions, ris_positions, bbox
    610 
    611 def fibonacci_lattice(num_points, dtype=tf.float32):
    612     """
    613     Generates a Fibonacci lattice for the unit square
    614 
    615     Input
    616     -----
    617     num_points : int
    618         Number of points
    619 
    620     type : tf.DType
    621         Datatype to use for the output
    622 
    623     Output
    624     -------
    625     points : [num_points, 2]
    626         Generated rectangular coordinates of the lattice points
    627     """
    628 
    629     golden_ratio = (1.+tf.sqrt(tf.cast(5., tf.float64)))/2.
    630     ns = tf.range(0, num_points, dtype=tf.float64)
    631 
    632     x = ns/golden_ratio
    633     x = x - tf.floor(x)
    634     y = ns/(num_points-1)
    635     points = tf.stack([x,y], axis=1)
    636 
    637     points = tf.cast(points, dtype)
    638 
    639     return points
    640 
    641 def cot(x):
    642     """
    643     Cotangens function
    644 
    645     Input
    646     ------
    647     x : [...], tf.float
    648 
    649     Output
    650     -------
    651     : [...], tf.float
    652         Cotangent of x
    653     """
    654     return tf.math.divide_no_nan(tf.ones_like(x), tf.math.tan(x))
    655 
    656 def sign(x):
    657     """
    658     Returns +1 if ``x`` is non-negative, -1 otherwise
    659 
    660     Input
    661     ------
    662     x : [...], tf.float
    663         A real-valued number
    664 
    665     Output
    666     -------
    667     : [...], tf.float
    668         +1 if ``x`` is non-negative, -1 otherwise
    669     """
    670     two = tf.cast(2, x.dtype)
    671     one = tf.cast(1, x.dtype)
    672     return two*tf.cast(tf.greater_equal(x, 0), x.dtype)-one
    673 
    674 def rot_mat_from_unit_vecs(a, b):
    675     r"""
    676     Computes Rodrigues` rotation formula :eq:`rodrigues_matrix`
    677 
    678     Input
    679     ------
    680     a : [...,3], tf.float
    681         First unit vector
    682 
    683     b : [...,3], tf.float
    684         Second unit vector
    685 
    686     Output
    687     -------
    688     : [...,3,3], tf.float
    689         Rodrigues' rotation matrix
    690     """
    691 
    692     rdtype = a.dtype
    693 
    694     # Compute rotation axis vector
    695     k, _ = normalize(cross(a, b))
    696 
    697     # Deal with special case where a and b are parallel
    698     o = gen_orthogonal_vector(a, 1e-6)
    699     k = tf.where(tf.reduce_sum(tf.abs(k), axis=-1, keepdims=True)==0, o, k)
    700 
    701     # Compute K matrix
    702     shape = tf.concat([tf.shape(k)[:-1],[1]], axis=-1)
    703     zeros = tf.zeros(shape, rdtype)
    704     kx, ky, kz = tf.split(k, 3, axis=-1)
    705     l1 = tf.concat([zeros, -kz, ky], axis=-1)
    706     l2 = tf.concat([kz, zeros, -kx], axis=-1)
    707     l3 = tf.concat([-ky, kx, zeros], axis=-1)
    708     k_mat = tf.stack([l1, l2, l3], axis=-2)
    709 
    710     # Assemble full rotation matrix
    711     eye = tf.eye(3, batch_shape=tf.shape(k)[:-1], dtype=rdtype)
    712     cos_theta = dot(a, b, clip=True)
    713     sin_theta = tf.sin(acos_diff(cos_theta))
    714     cos_theta = expand_to_rank(cos_theta, tf.rank(eye), axis=-1)
    715     sin_theta = expand_to_rank(sin_theta, tf.rank(eye), axis=-1)
    716     rot_mat = eye + k_mat*sin_theta + \
    717                       tf.linalg.matmul(k_mat, k_mat) * (1-cos_theta)
    718     return rot_mat
    719 
    720 def sample_points_on_hemisphere(normals, num_samples=1):
    721     # pylint: disable=line-too-long
    722     r"""
    723     Randomly sample points on hemispheres defined by their normal vectors
    724 
    725     Input
    726     -----
    727     normals : [batch_size, 3], tf.float
    728         Normal vectors defining hemispheres
    729 
    730     num_samples : int
    731         Number of random samples to draw for each hemisphere
    732         defined by its normal vector.
    733         Defaults to 1.
    734 
    735     Output
    736     ------
    737     points : [batch_size, num_samples, 3], tf.float or [batch_size, 3], tf.float if num_samples=1.
    738         Random points on the hemispheres
    739     """
    740     dtype = normals.dtype
    741     batch_size = tf.shape(normals)[0]
    742     shape = [batch_size, num_samples]
    743 
    744     # Sample phi uniformly distributed on [0,2*PI]
    745     phi = config.tf_rng.uniform(shape, maxval=2*PI, dtype=dtype)
    746 
    747     # Generate samples of theta for uniform distribution on the hemisphere
    748     u = config.tf_rng.uniform(shape, maxval=1, dtype=dtype)
    749     theta = tf.acos(u)
    750 
    751     # Transform spherical to Cartesian coordinates
    752     points = r_hat(theta, phi)
    753 
    754     # Compute rotation matrices
    755     z_hat = tf.constant([[0,0,1]], dtype=dtype)
    756     z_hat = tf.broadcast_to(z_hat, tf.shape(normals))
    757     rot_mat = rot_mat_from_unit_vecs(z_hat, normals)
    758     rot_mat = tf.expand_dims(rot_mat, axis=1)
    759 
    760     # Compute rotated points
    761     points = tf.linalg.matvec(rot_mat, points)
    762 
    763     # Numerical errors can cause sampling from the other hemisphere.
    764     # Correct the sampled vector to avoid sampling in the wrong hemisphere.
    765     normals = tf.expand_dims(normals, axis=1)
    766     s = dot(points, normals, keepdim=True)
    767     s = tf.where(s < 0., s, 0.)
    768     points = points - 2.*s*normals
    769 
    770     if num_samples==1:
    771         points = tf.squeeze(points, axis=1)
    772 
    773     return points
    774 
    775 def acos_diff(x, epsilon=1e-7):
    776     r"""
    777     Implementation of arccos(x) that avoids evaluating the gradient at x
    778     -1 or 1 by using straight through estimation, i.e., in the
    779     forward pass, x is clipped to (-1, 1), but in the backward pass, x is
    780     clipped to (-1 + epsilon, 1 - epsilon).
    781 
    782     Input
    783     ------
    784     x : any shape, tf.float
    785         Value at which to evaluate arccos
    786 
    787     epsilon : tf.float
    788         Small backoff to avoid evaluating the gradient at -1 or 1.
    789         Defaults to 1e-7.
    790 
    791     Output
    792     -------
    793      : same shape as x, tf.float
    794         arccos(x)
    795     """
    796 
    797     x_clip_1 = tf.clip_by_value(x, -1., 1.)
    798     x_clip_2 = tf.clip_by_value(x, -1. + epsilon, 1. - epsilon)
    799     eps = tf.stop_gradient(x - x_clip_2)
    800     x_1 =  x - eps
    801     acos_x_1 =  tf.acos(x_1)
    802     y = acos_x_1 + tf.stop_gradient(tf.acos(x_clip_1)-acos_x_1)
    803     return y
    804 
    805 def angles_to_mitsuba_rotation(angles):
    806     """
    807     Build a Mitsuba transform from angles in radian
    808 
    809     Input
    810     ------
    811     angles : [3], tf.float
    812         Angles [rad]
    813 
    814     Output
    815     -------
    816     : :class:`mitsuba.ScalarTransform4f`
    817         Mitsuba rotation
    818     """
    819 
    820     angles = 180. * angles / PI
    821 
    822     if angles.dtype == tf.float32:
    823         mi_transform_t = mi.Transform4f
    824         angles = mi.Float(angles)
    825     else:
    826         mi_transform_t = mi.Transform4d
    827         angles = mi.Float64(angles)
    828 
    829     return (
    830           mi_transform_t.rotate(axis=[0., 0., 1.], angle=angles[0])
    831         @ mi_transform_t.rotate(axis=[0., 1., 0.], angle=angles[1])
    832         @ mi_transform_t.rotate(axis=[1., 0., 0.], angle=angles[2])
    833     )
    834 
    835 def gen_basis_from_z(z, epsilon):
    836     """
    837     Generate a pair of vectors (x,y) such that (x,y,z) is an orthonormal basis.
    838 
    839     Input
    840     ------
    841     z : [..., 3], tf.float
    842         Unit vector
    843 
    844     epsilon : (), tf.float
    845         Small value used to avoid errors due to numerical precision
    846 
    847     Output
    848     -------
    849     x : [..., 3], tf.float
    850         Unit vector
    851 
    852     y : [..., 3], tf.float
    853         Unit vector
    854     """
    855     x = gen_orthogonal_vector(z, epsilon)
    856     x,_ = normalize(x)
    857     y = cross(z, x)
    858     return x,y
    859 
    860 def compute_spreading_factor(rho_1, rho_2, s):
    861     r"""
    862     Computes the spreading factor
    863     :math:`\sqrt{\frac{\rho_1 \rho_2}{(\rho_1 + s)(\rho_2 + s)}}`
    864 
    865     Input
    866     ------
    867     rho_1, rho_2 : [...], tf.float
    868         Principal radii of curvature
    869 
    870     s : [...], tf.float
    871         Position along the axial ray at which to evaluate the squared
    872         spreading factor
    873 
    874     Output
    875     -------
    876     : float
    877         Squared spreading factor
    878     """
    879 
    880     # In the case of a spherical wave, when the origin (s = 0) is set to unique
    881     # caustic point, then both principal radii of curvature are set to zero.
    882     # The spreading factor is then equal to 1/s.
    883     spherical = tf.logical_and(tf.equal(rho_1, 0.), tf.equal(rho_2, 0.))
    884     a2_spherical = tf.math.reciprocal_no_nan(s)
    885 
    886     # General formula for the spreading factor
    887     a2 = tf.sqrt(rho_1*rho_2/((rho_1+s)*(rho_2+s)))
    888 
    889     a2 = tf.where(spherical, a2_spherical, a2)
    890     return a2
    891 
    892 def mitsuba_rectangle_to_world(center, orientation, size, ris=False):
    893     """
    894     Build the `to_world` transformation that maps a default Mitsuba rectangle
    895     to the rectangle that defines the coverage map surface.
    896 
    897     Input
    898     ------
    899     center : [3], tf.float
    900         Center of the rectangle
    901 
    902     orientation : [3], tf.float
    903         Orientation of the rectangle.
    904         An orientation of `(0,0,0)` correspond to a rectangle oriented such that
    905         z+ is its normal.
    906 
    907     size : [2], tf.float
    908         Scale of the rectangle.
    909         The width of the rectangle (in the local X direction) is scale[0]
    910         and its height (in the local Y direction) scale[1].
    911 
    912     Output
    913     -------
    914     to_world : :class:`mitsuba.ScalarTransform4f`
    915         Rectangle to world transformation.
    916     """
    917     orientation = 180. * orientation / PI
    918 
    919     trans = \
    920         mi.ScalarTransform4f.translate(center.numpy())\
    921         @ mi.ScalarTransform4f.rotate(axis=[0, 0, 1], angle=orientation[0])\
    922         @ mi.ScalarTransform4f.rotate(axis=[0, 1, 0], angle=orientation[1])\
    923         @ mi.ScalarTransform4f.rotate(axis=[1, 0, 0], angle=orientation[2])
    924 
    925     if ris:
    926         # The RIS normal points at [1,0,0].
    927         # We hence rotate the normal of the rectangle which points
    928         # at [0,0,1] by 90 degrees around the [0,1,0] axis.
    929         trans = trans\
    930             @mi.ScalarTransform4f.rotate(axis=[0, 1, 0], angle=90)
    931 
    932         # size = [width (=y), height (=z)]
    933         # Since the RIS is rotated w.r.t to rectangle,
    934         # The z axis corresponds to the x axis
    935         size = [size[1], size[0]]
    936 
    937     return (trans
    938             @mi.ScalarTransform4f.scale([0.5 * size[0], 0.5 * size[1], 1])
    939     )
    940 
    941 def watt_to_dbm(power):
    942     r""" Converts :math:`P_{W}` [W] to :math:`P_{dBm}` [dBm] via the formula:
    943     :math:`P_{dBm} = 30 + 10 \log_{10}(P_W)`
    944 
    945     Input
    946     ------
    947     power : float
    948         Power [W]
    949 
    950     Output
    951     -------
    952      : float
    953         Power [dBm]
    954     """
    955     return 30 + 10 * log10(power)
    956 
    957 def dbm_to_watt(dbm):
    958     r""" Converts dBm to Watt via the formula:
    959     :math:`P_W = 10^{\frac{P_{dBm}-30}{10}}`
    960 
    961     Input
    962     ------
    963     dbm : float
    964         Power [dBm]
    965 
    966     Output
    967     -------
    968      : float
    969         Power [W]
    970     """
    971     return tf.pow(10, (dbm - 30) / 10)