anomaly-detection-material-parameters-calibration

Sionna param calibration (research proj)
git clone https://git.ea.contact/anomaly-detection-material-parameters-calibration
Log | Files | Refs | README

previewer.py (20857B)


      1 #
      2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
      3 # SPDX-License-Identifier: Apache-2.0
      4 #
      5 """
      6 3D scene and paths viewer
      7 """
      8 
      9 import drjit as dr
     10 import mitsuba as mi
     11 import numpy as np
     12 from ipywidgets.embed import embed_snippet
     13 import pythreejs as p3s
     14 import matplotlib
     15 
     16 from .utils import paths_to_segments, scene_scale, rotate,\
     17     mitsuba_rectangle_to_world
     18 from .renderer import coverage_map_color_mapping
     19 
     20 
     21 class InteractiveDisplay:
     22     """
     23     Lightweight wrapper around the `pythreejs` library.
     24 
     25     Input
     26     -----
     27     resolution : [2], int
     28         Size of the viewer figure.
     29 
     30     fov : float
     31         Field of view, in degrees.
     32 
     33     background : str
     34         Background color in hex format prefixed by '#'.
     35     """
     36 
     37     def __init__(self, scene, resolution, fov, background):
     38 
     39         self._scene = scene
     40         self._disk_sprite = None
     41 
     42         # List of objects in the scene
     43         self._objects = []
     44         # Bounding box of the scene
     45         self._bbox = mi.ScalarBoundingBox3f()
     46 
     47         ####################################################
     48         # Setup the viewer
     49         ####################################################
     50 
     51         # Lighting
     52         ambient_light = p3s.AmbientLight(intensity=0.80)
     53         camera_light = p3s.DirectionalLight(
     54             position=[0, 0, 0], intensity=0.25
     55         )
     56 
     57         # Camera & controls
     58         self._camera = p3s.PerspectiveCamera(
     59             fov=fov, aspect=resolution[0]/resolution[1],
     60             up=[0, 0, 1], far=10000,
     61             children=[camera_light],
     62         )
     63         self._orbit = p3s.OrbitControls(
     64             controlling = self._camera
     65         )
     66 
     67         # Scene & renderer
     68         self._p3s_scene = p3s.Scene(
     69             background=background, children=[self._camera, ambient_light]
     70         )
     71         self._renderer = p3s.Renderer(
     72             scene=self._p3s_scene, camera=self._camera, controls=[self._orbit],
     73             width=resolution[0], height=resolution[1], antialias=True
     74         )
     75 
     76         ####################################################
     77         # Plot the scene geometry
     78         ####################################################
     79         self.plot_scene()
     80 
     81         # Finally, ensure the camera is looking at the scene
     82         self.center_view()
     83 
     84     def reset(self):
     85         """
     86         Removes objects that are not flagged as persistent, i.e., the paths.
     87         """
     88         remaining = []
     89         for obj, persist in self._objects:
     90             if persist:
     91                 remaining.append((obj, persist))
     92             else:
     93                 self._p3s_scene.remove(obj)
     94         self._objects = remaining
     95 
     96     def redraw_scene_geometry(self):
     97         """
     98         Redraw the scene geometry.
     99         """
    100         remaining = []
    101         for obj, persist in self._objects:
    102             if not persist: # Only scene objects are flagged as persistent
    103                 remaining.append((obj, persist))
    104             else:
    105                 self._p3s_scene.remove(obj)
    106         self._objects = remaining
    107 
    108         # Plot the scene geometry
    109         self.plot_scene()
    110 
    111     def center_view(self):
    112         """
    113         Automatically place the camera based on the scene's bounding box such
    114         that it is located at (-1, -1, 1) on the normalized bounding box, and
    115         oriented toward the center of the scene.
    116         """
    117         bbox = self._bbox if self._bbox.valid() else mi.ScalarBoundingBox3f(0.)
    118         center = bbox.center()
    119 
    120         corner = [bbox.min.x, center.y, 1.5 * bbox.max.z]
    121         if np.allclose(corner, 0):
    122             corner = (-1, -1, 1)
    123         self._camera.position = tuple(corner)
    124 
    125         self._camera.lookAt(center)
    126         self._orbit.exec_three_obj_method('update')
    127         self._camera.exec_three_obj_method('updateProjectionMatrix')
    128 
    129     def plot_radio_devices(self, show_orientations=False):
    130         """
    131         Plots the radio devices.
    132 
    133         Input
    134         -----
    135         show_orientations : bool
    136             Shows the radio devices' orientations.
    137             Defaults to `False`.
    138         """
    139         scene = self._scene
    140         sc, tx_positions, rx_positions, _, _ = scene_scale(scene)
    141         transmitter_colors = [transmitter.color.numpy() for
    142                               transmitter in scene.transmitters.values()]
    143         receiver_colors = [receiver.color.numpy() for
    144                            receiver in scene.receivers.values()]
    145 
    146         # Radio emitters, shown as points
    147         p = np.array(list(tx_positions.values()) +
    148                      list(rx_positions.values())
    149                       )
    150         albedo = np.array(transmitter_colors +
    151                           receiver_colors)
    152 
    153         if p.shape[0] > 0:
    154             # Radio devices are not persistent
    155             radius = max(0.005 * sc, 1)
    156             self._plot_points(p, persist=False, colors=albedo,
    157                               radius=radius)
    158         if show_orientations:
    159             line_length = 0.0075 * sc
    160             head_length = 0.15 * line_length
    161             zeros = np.zeros((1, 3))
    162 
    163             for devices in [scene.transmitters.values(),
    164                             scene.receivers.values(),
    165                             scene.ris.values()]:
    166                 if len(devices) == 0:
    167                     continue
    168                 starts, ends = [], []
    169                 for rd in devices:
    170                     # Arrow line
    171                     color = f'rgb({", ".join([str(int(v)) for v in rd.color])})'
    172                     starts.append(rd.position)
    173                     endpoint = rd.position + rotate([line_length, 0., 0.],
    174                                                     rd.orientation)
    175                     ends.append(endpoint)
    176 
    177                     geo = p3s.CylinderGeometry(
    178                         radiusTop=0, radiusBottom=0.3 * head_length,
    179                         height=head_length, radialSegments=8,
    180                         heightSegments=0, openEnded=False)
    181                     mat = p3s.MeshLambertMaterial(color=color)
    182                     mesh = p3s.Mesh(geo, mat)
    183                     mesh.position = tuple(endpoint)
    184                     angles = rd.orientation.numpy()
    185                     mesh.rotateZ(angles[0] - np.pi/2)
    186                     mesh.rotateY(angles[2])
    187                     mesh.rotateX(-angles[1])
    188                     self._add_child(mesh, zeros, zeros, persist=False)
    189 
    190                 self._plot_lines(np.array(starts), np.array(ends),
    191                                  width=2, color=color)
    192 
    193     def plot_paths(self, paths):
    194         """
    195         Plot the ``paths``.
    196 
    197         Input
    198         -----
    199         paths : :class:`~sionna.rt.Paths`
    200             Paths to plot
    201         """
    202         starts, ends = paths_to_segments(paths)
    203         if starts and ends:
    204             self._plot_lines(np.vstack(starts), np.vstack(ends))
    205 
    206     def plot_scene(self):
    207         """
    208         Plots the meshes that make the scene.
    209         """
    210         shapes = self._scene.mi_scene.shapes()
    211         n = len(shapes)
    212         if n <= 0:
    213             return
    214 
    215         palette = None
    216         si = dr.zeros(mi.SurfaceInteraction3f)
    217         si.wi = mi.Vector3f(0, 0, 1)
    218 
    219         # Shapes (e.g. buildings)
    220         vertices, faces, albedos = [], [], []
    221         f_offset = 0
    222         for i, s in enumerate(shapes):
    223             null_transmission = s.bsdf().eval_null_transmission(si).numpy()
    224             if np.min(null_transmission) > 0.99:
    225                 # The BSDF for this shape was probably set to `null`, do not
    226                 # include it in the scene preview.
    227                 continue
    228 
    229             n_vertices = s.vertex_count()
    230             v = s.vertex_position(dr.arange(mi.UInt32, n_vertices))
    231             vertices.append(v.numpy())
    232             f = s.face_indices(dr.arange(mi.UInt32, s.face_count()))
    233             faces.append(f.numpy() + f_offset)
    234             f_offset += n_vertices
    235 
    236             albedo = s.bsdf().eval_diffuse_reflectance(si).numpy()
    237             if not np.any(albedo > 0.):
    238                 if palette is None:
    239                     palette = matplotlib.colormaps.get_cmap('Pastel1_r')
    240                 albedo[:] = palette((i % palette.N + 0.5) / palette.N)[:3]
    241 
    242             albedos.append(np.tile(albedo, (n_vertices, 1)))
    243 
    244         # Plot all objects as a single PyThreeJS mesh, which is must faster
    245         # than creating individual mesh objects in large scenes.
    246         self._plot_mesh(np.concatenate(vertices, axis=0),
    247                         np.concatenate(faces, axis=0),
    248                         persist=True, # The scene geometry is persistent
    249                         colors=np.concatenate(albedos, axis=0))
    250 
    251     def plot_coverage_map(self, coverage_map, tx=0, db_scale=True,
    252                           vmin=None, vmax=None, metric="path_gain"):
    253         """
    254         Plot the coverage map as a textured rectangle in the scene. Regions
    255         where the coverage map is zero-valued are made transparent.
    256         """
    257         to_world = coverage_map.to_world()
    258         # coverage_map = resample_to_corners(
    259         #     coverage_map[tx, :, :].numpy()
    260         # )
    261         cm = getattr(coverage_map, metric).numpy()
    262         if tx is None:
    263             coverage_map = np.max(cm, axis=0)
    264         else:
    265             coverage_map = cm[tx]
    266 
    267         # Create a rectangle from two triangles
    268         p00 = to_world.transform_affine([-1, -1, 0])
    269         p01 = to_world.transform_affine([1, -1, 0])
    270         p10 = to_world.transform_affine([-1, 1, 0])
    271         p11 = to_world.transform_affine([1, 1, 0])
    272 
    273         vertices = np.array([p00, p01, p10, p11])
    274         pmin = np.min(vertices, axis=0)
    275         pmax = np.max(vertices, axis=0)
    276 
    277         faces = np.array([
    278             [0, 1, 2],
    279             [2, 1, 3],
    280         ], dtype=np.uint32)
    281 
    282         vertex_uvs = np.array([
    283             [0, 0], [1, 0], [0, 1], [1, 1]
    284         ], dtype=np.float32)
    285 
    286         geo = p3s.BufferGeometry(
    287             attributes={
    288                 'position': p3s.BufferAttribute(vertices, normalized=False),
    289                 'index': p3s.BufferAttribute(faces.ravel(), normalized=False),
    290                 'uv': p3s.BufferAttribute(vertex_uvs, normalized=False),
    291             }
    292         )
    293 
    294         to_map, normalizer, color_map = coverage_map_color_mapping(
    295             coverage_map, db_scale=db_scale, vmin=vmin, vmax=vmax)
    296         texture = color_map(normalizer(to_map)).astype(np.float32)
    297         texture[:, :, 3] = (coverage_map > 0.).astype(np.float32)
    298         # Pre-multiply alpha
    299         texture[:, :, :3] *= texture[:, :, 3, None]
    300 
    301         texture = p3s.DataTexture(
    302             data=(255. * texture).astype(np.uint8),
    303             format='RGBAFormat',
    304             type='UnsignedByteType',
    305             magFilter='NearestFilter',
    306             minFilter='NearestFilter',
    307         )
    308 
    309         mat = p3s.MeshLambertMaterial(
    310             side='DoubleSide',
    311             map=texture, transparent=True,
    312         )
    313         mesh = p3s.Mesh(geo, mat)
    314 
    315         self._add_child(mesh, pmin, pmax, persist=False)
    316 
    317     def plot_ris(self):
    318         """
    319         Plot all RIS as a monochromatic rectangle in the scene
    320         """
    321         all_ris = list(self._scene.ris.values())
    322 
    323         for ris in all_ris:
    324             orientation = ris.orientation
    325             to_world =\
    326                 mitsuba_rectangle_to_world(ris.position, orientation, ris.size,
    327                                            ris=True)
    328             color = ris.color.numpy()
    329 
    330             # Create a rectangle from two triangles
    331             p00 = to_world.transform_affine([-1, -1, 0])
    332             p01 = to_world.transform_affine([1, -1, 0])
    333             p10 = to_world.transform_affine([-1, 1, 0])
    334             p11 = to_world.transform_affine([1, 1, 0])
    335 
    336             vertices = np.array([p00, p01, p10, p11])
    337             pmin = np.min(vertices, axis=0)
    338             pmax = np.max(vertices, axis=0)
    339 
    340             faces = np.array([
    341                 [0, 1, 2],
    342                 [2, 1, 3],
    343             ], dtype=np.uint32)
    344 
    345             geo = p3s.BufferGeometry(
    346                 attributes={
    347                     'position': p3s.BufferAttribute(vertices,
    348                                                     normalized=False),
    349                     'index': p3s.BufferAttribute(faces.ravel(),
    350                                                  normalized=False),
    351                 }
    352             )
    353 
    354             color = f'rgb({", ".join([str(int(v*255)) for v in color])})'
    355             mat = p3s.MeshLambertMaterial(color=color, side='DoubleSide')
    356             mesh = p3s.Mesh(geo, mat)
    357 
    358             self._add_child(mesh, pmin, pmax, persist=False)
    359 
    360     def set_clipping_plane(self, offset, orientation):
    361         """
    362         Input
    363         -----
    364         clip_at : float
    365             If not `None`, the scene preview will be clipped (cut) by a plane
    366             with normal orientation ``clip_plane_orientation`` and offset
    367             ``clip_at``. This allows visualizing the interior of meshes such
    368             as buildings.
    369 
    370         clip_plane_orientation : tuple[float, float, float]
    371             Normal vector of the clipping plane.
    372         """
    373         if offset is None:
    374             self._renderer.localClippingEnabled = False
    375             self._renderer.clippingPlanes = []
    376         else:
    377             self._renderer.localClippingEnabled = True
    378             self._renderer.clippingPlanes = [p3s.Plane(orientation, offset)]
    379 
    380     @property
    381     def camera(self):
    382         """
    383         pthreejs.PerspectiveCamera : Get the camera
    384         """
    385         return self._camera
    386 
    387     @property
    388     def orbit(self):
    389         """
    390         pthreejs.OrbitControls : Get the orbit
    391         """
    392         return self._orbit
    393 
    394     def resolution(self):
    395         """
    396         Returns a tuple (width, height) with the rendering resolution.
    397         """
    398         return (self._renderer.width, self._renderer.height)
    399 
    400     ##################################################
    401     # Internal methods
    402     ##################################################
    403 
    404     def _plot_mesh(self, vertices, faces, persist, colors=None):
    405         """
    406         Plots a mesh.
    407 
    408         Input
    409         ------
    410         vertices : [n,3], float
    411             Position of the vertices
    412 
    413         faces : [n,3], int
    414             Indices of the triangles associated with ``vertices``
    415 
    416         persist : bool
    417             Flag indicating if the mesh is persistent, i.e., should not be
    418             erased when ``reset()`` is called.
    419 
    420         colors : [n,3] | [3] | None
    421             Colors of the vertices. If `None`, black is used.
    422             Defaults to `None`.
    423         """
    424         assert vertices.ndim == 2 and vertices.shape[1] == 3
    425         assert faces.ndim == 2 and faces.shape[1] == 3
    426         n_v = vertices.shape[0]
    427         pmin, pmax = np.min(vertices, axis=0), np.max(vertices, axis=0)
    428 
    429         # Assuming per-vertex colors
    430         if colors is None:
    431             # Black is default
    432             colors = np.zeros((n_v, 3), dtype=np.float32)
    433         elif colors.ndim == 1:
    434             colors = np.tile(colors[None, :], (n_v, 1))
    435         colors = colors.astype(np.float32)
    436         assert ( (colors.ndim == 2)
    437              and (colors.shape[1] == 3)
    438              and (colors.shape[0] == n_v) )
    439 
    440         # Closer match to Mitsuba and Blender
    441         colors = np.power(colors, 1/1.8)
    442 
    443         geo = p3s.BufferGeometry(
    444             attributes={
    445                 'index': p3s.BufferAttribute(faces.ravel(), normalized=False),
    446                 'position': p3s.BufferAttribute(vertices, normalized=False),
    447                 'color': p3s.BufferAttribute(colors, normalized=False)
    448             }
    449         )
    450 
    451         mat = p3s.MeshStandardMaterial(
    452             side='DoubleSide', metalness=0., roughness=1.0,
    453             vertexColors='VertexColors', flatShading=True,
    454         )
    455         mesh = p3s.Mesh(geo, mat)
    456         self._add_child(mesh, pmin, pmax, persist=persist)
    457 
    458     def _plot_points(self, points, persist, colors=None, radius=0.05):
    459         """
    460         Plots a set of `n` points.
    461 
    462         Input
    463         -------
    464         points : [n, 3], float
    465             Coordinates of the `n` points.
    466 
    467         persist : bool
    468             Indicates if the points are persistent, i.e., should not be erased
    469             when ``reset()`` is called.
    470 
    471         colors : [n, 3], float | [3], float | None
    472             Colors of the points.
    473 
    474         radius : float
    475             Radius of the points.
    476         """
    477         assert points.ndim == 2 and points.shape[1] == 3
    478         n = points.shape[0]
    479         pmin, pmax = np.min(points, axis=0), np.max(points, axis=0)
    480 
    481         # Assuming per-vertex colors
    482         if colors is None:
    483             colors = np.zeros((n, 3), dtype=np.float32)
    484         elif colors.ndim == 1:
    485             colors = np.tile(colors[None, :], (n, 1))
    486         colors = colors.astype(np.float32)
    487         assert ( (colors.ndim == 2)
    488              and (colors.shape[1] == 3)
    489              and (colors.shape[0] == n) )
    490 
    491         tex = p3s.DataTexture(data=self._get_disk_sprite(), format="RGBAFormat",
    492                               type="FloatType")
    493 
    494         points = points.astype(np.float32)
    495         geo = p3s.BufferGeometry(attributes={
    496             'position': p3s.BufferAttribute(points, normalized=False),
    497             'color': p3s.BufferAttribute(colors, normalized=False),
    498         })
    499         mat = p3s.PointsMaterial(
    500             size=2*radius, sizeAttenuation=True, vertexColors='VertexColors',
    501             map=tex, alphaTest=0.5, transparent=True,
    502         )
    503         mesh = p3s.Points(geo, mat)
    504         self._add_child(mesh, pmin, pmax, persist=persist)
    505 
    506     def _add_child(self, obj, pmin, pmax, persist):
    507         """
    508         Adds an object for display
    509 
    510         Input
    511         ------
    512         obj : :class:`~pythreejs.Mesh`
    513             Mesh to display
    514 
    515         pmin : [3], float
    516             Lowest position for the bounding box
    517 
    518         pmax : [3], float
    519             Highest position for the bounding box
    520 
    521         persist : bool
    522             Flag that indicates if the object is persistent, i.e., if it should
    523             be removed from the display when `reset()` is called.
    524         """
    525         self._objects.append((obj, persist))
    526         self._p3s_scene.add(obj)
    527 
    528         self._bbox.expand(pmin)
    529         self._bbox.expand(pmax)
    530 
    531     def _plot_lines(self, starts, ends, width=0.5, color='black'):
    532         """
    533         Plots a set of `n` lines. This is used to plot the paths.
    534 
    535         Input
    536         ------
    537         starts : [n, 3], float
    538             Coordinates of the lines starting points
    539 
    540         ends : [n, 3], float
    541             Coordinates of the lines ending points
    542 
    543         width : float
    544             Width of the lines.
    545             Defaults to 0.5.
    546 
    547         color : str
    548             Color of the lines.
    549             Defaults to 'black'.
    550         """
    551 
    552         assert starts.ndim == 2 and starts.shape[1] == 3
    553         assert ends.ndim == 2 and ends.shape[1] == 3
    554         assert starts.shape[0] == ends.shape[0]
    555 
    556         segments = np.hstack((starts, ends)).astype(np.float32).reshape(-1,2,3)
    557         pmin = np.min(segments, axis=(0, 1))
    558         pmax = np.max(segments, axis=(0, 1))
    559 
    560         geo = p3s.LineSegmentsGeometry(positions=segments)
    561         mat = p3s.LineMaterial(linewidth=width, color=color)
    562         mesh = p3s.LineSegments2(geo, mat)
    563 
    564         # Lines are not flagged as persistent as they correspond to paths, which
    565         # can changes from one display to the next.
    566         self._add_child(mesh, pmin, pmax, persist=False)
    567 
    568     def _get_disk_sprite(self):
    569         """
    570         Returns the sprite used to represent transmitters and receivers though
    571         ``_plot_points()``.
    572 
    573         Output
    574         ------
    575         : [n,n,4], float
    576             Sprite
    577         """
    578         if self._disk_sprite is not None:
    579             return self._disk_sprite
    580 
    581         n = 128
    582         sprite = np.ones((n, n, 4))
    583         sprite[:, :, 3] = 0.
    584         # Draw a disk with an empty circle close to the edge
    585         ij = np.mgrid[:n, :n]
    586         ij = ij.reshape(2, -1)
    587 
    588         p = (ij + 0.5) / n - 0.5
    589         t = np.linalg.norm(p, axis=0).reshape((n, n))
    590         inside = t < 0.48
    591         in_band = (t < 0.45) & (t > 0.42)
    592         sprite[inside & (~in_band), 3] = 1.0
    593 
    594         sprite = sprite.astype(np.float32)
    595         self._disk_sprite = sprite
    596         return sprite
    597 
    598     ################################################
    599     # The following methods are required for
    600     # integration in Jupyter notebooks
    601     ################################################
    602 
    603     # pylint: disable=unused-argument
    604     def _repr_mimebundle_(self, **kwargs):
    605         # pylint: disable=protected-access,not-callable
    606         bundle = self._renderer._repr_mimebundle_()
    607         assert 'text/html' not in bundle
    608         bundle['text/html'] = self._repr_html_()
    609         return bundle
    610 
    611     def _repr_html_(self):
    612         """
    613         Standalone HTML display, i.e. outside of an interactive Jupyter
    614         notebook environment.
    615         """
    616 
    617         html = embed_snippet(self._renderer, requirejs=True)
    618         return html