previewer.py (20857B)
1 # 2 # SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 # SPDX-License-Identifier: Apache-2.0 4 # 5 """ 6 3D scene and paths viewer 7 """ 8 9 import drjit as dr 10 import mitsuba as mi 11 import numpy as np 12 from ipywidgets.embed import embed_snippet 13 import pythreejs as p3s 14 import matplotlib 15 16 from .utils import paths_to_segments, scene_scale, rotate,\ 17 mitsuba_rectangle_to_world 18 from .renderer import coverage_map_color_mapping 19 20 21 class InteractiveDisplay: 22 """ 23 Lightweight wrapper around the `pythreejs` library. 24 25 Input 26 ----- 27 resolution : [2], int 28 Size of the viewer figure. 29 30 fov : float 31 Field of view, in degrees. 32 33 background : str 34 Background color in hex format prefixed by '#'. 35 """ 36 37 def __init__(self, scene, resolution, fov, background): 38 39 self._scene = scene 40 self._disk_sprite = None 41 42 # List of objects in the scene 43 self._objects = [] 44 # Bounding box of the scene 45 self._bbox = mi.ScalarBoundingBox3f() 46 47 #################################################### 48 # Setup the viewer 49 #################################################### 50 51 # Lighting 52 ambient_light = p3s.AmbientLight(intensity=0.80) 53 camera_light = p3s.DirectionalLight( 54 position=[0, 0, 0], intensity=0.25 55 ) 56 57 # Camera & controls 58 self._camera = p3s.PerspectiveCamera( 59 fov=fov, aspect=resolution[0]/resolution[1], 60 up=[0, 0, 1], far=10000, 61 children=[camera_light], 62 ) 63 self._orbit = p3s.OrbitControls( 64 controlling = self._camera 65 ) 66 67 # Scene & renderer 68 self._p3s_scene = p3s.Scene( 69 background=background, children=[self._camera, ambient_light] 70 ) 71 self._renderer = p3s.Renderer( 72 scene=self._p3s_scene, camera=self._camera, controls=[self._orbit], 73 width=resolution[0], height=resolution[1], antialias=True 74 ) 75 76 #################################################### 77 # Plot the scene geometry 78 #################################################### 79 self.plot_scene() 80 81 # Finally, ensure the camera is looking at the scene 82 self.center_view() 83 84 def reset(self): 85 """ 86 Removes objects that are not flagged as persistent, i.e., the paths. 87 """ 88 remaining = [] 89 for obj, persist in self._objects: 90 if persist: 91 remaining.append((obj, persist)) 92 else: 93 self._p3s_scene.remove(obj) 94 self._objects = remaining 95 96 def redraw_scene_geometry(self): 97 """ 98 Redraw the scene geometry. 99 """ 100 remaining = [] 101 for obj, persist in self._objects: 102 if not persist: # Only scene objects are flagged as persistent 103 remaining.append((obj, persist)) 104 else: 105 self._p3s_scene.remove(obj) 106 self._objects = remaining 107 108 # Plot the scene geometry 109 self.plot_scene() 110 111 def center_view(self): 112 """ 113 Automatically place the camera based on the scene's bounding box such 114 that it is located at (-1, -1, 1) on the normalized bounding box, and 115 oriented toward the center of the scene. 116 """ 117 bbox = self._bbox if self._bbox.valid() else mi.ScalarBoundingBox3f(0.) 118 center = bbox.center() 119 120 corner = [bbox.min.x, center.y, 1.5 * bbox.max.z] 121 if np.allclose(corner, 0): 122 corner = (-1, -1, 1) 123 self._camera.position = tuple(corner) 124 125 self._camera.lookAt(center) 126 self._orbit.exec_three_obj_method('update') 127 self._camera.exec_three_obj_method('updateProjectionMatrix') 128 129 def plot_radio_devices(self, show_orientations=False): 130 """ 131 Plots the radio devices. 132 133 Input 134 ----- 135 show_orientations : bool 136 Shows the radio devices' orientations. 137 Defaults to `False`. 138 """ 139 scene = self._scene 140 sc, tx_positions, rx_positions, _, _ = scene_scale(scene) 141 transmitter_colors = [transmitter.color.numpy() for 142 transmitter in scene.transmitters.values()] 143 receiver_colors = [receiver.color.numpy() for 144 receiver in scene.receivers.values()] 145 146 # Radio emitters, shown as points 147 p = np.array(list(tx_positions.values()) + 148 list(rx_positions.values()) 149 ) 150 albedo = np.array(transmitter_colors + 151 receiver_colors) 152 153 if p.shape[0] > 0: 154 # Radio devices are not persistent 155 radius = max(0.005 * sc, 1) 156 self._plot_points(p, persist=False, colors=albedo, 157 radius=radius) 158 if show_orientations: 159 line_length = 0.0075 * sc 160 head_length = 0.15 * line_length 161 zeros = np.zeros((1, 3)) 162 163 for devices in [scene.transmitters.values(), 164 scene.receivers.values(), 165 scene.ris.values()]: 166 if len(devices) == 0: 167 continue 168 starts, ends = [], [] 169 for rd in devices: 170 # Arrow line 171 color = f'rgb({", ".join([str(int(v)) for v in rd.color])})' 172 starts.append(rd.position) 173 endpoint = rd.position + rotate([line_length, 0., 0.], 174 rd.orientation) 175 ends.append(endpoint) 176 177 geo = p3s.CylinderGeometry( 178 radiusTop=0, radiusBottom=0.3 * head_length, 179 height=head_length, radialSegments=8, 180 heightSegments=0, openEnded=False) 181 mat = p3s.MeshLambertMaterial(color=color) 182 mesh = p3s.Mesh(geo, mat) 183 mesh.position = tuple(endpoint) 184 angles = rd.orientation.numpy() 185 mesh.rotateZ(angles[0] - np.pi/2) 186 mesh.rotateY(angles[2]) 187 mesh.rotateX(-angles[1]) 188 self._add_child(mesh, zeros, zeros, persist=False) 189 190 self._plot_lines(np.array(starts), np.array(ends), 191 width=2, color=color) 192 193 def plot_paths(self, paths): 194 """ 195 Plot the ``paths``. 196 197 Input 198 ----- 199 paths : :class:`~sionna.rt.Paths` 200 Paths to plot 201 """ 202 starts, ends = paths_to_segments(paths) 203 if starts and ends: 204 self._plot_lines(np.vstack(starts), np.vstack(ends)) 205 206 def plot_scene(self): 207 """ 208 Plots the meshes that make the scene. 209 """ 210 shapes = self._scene.mi_scene.shapes() 211 n = len(shapes) 212 if n <= 0: 213 return 214 215 palette = None 216 si = dr.zeros(mi.SurfaceInteraction3f) 217 si.wi = mi.Vector3f(0, 0, 1) 218 219 # Shapes (e.g. buildings) 220 vertices, faces, albedos = [], [], [] 221 f_offset = 0 222 for i, s in enumerate(shapes): 223 null_transmission = s.bsdf().eval_null_transmission(si).numpy() 224 if np.min(null_transmission) > 0.99: 225 # The BSDF for this shape was probably set to `null`, do not 226 # include it in the scene preview. 227 continue 228 229 n_vertices = s.vertex_count() 230 v = s.vertex_position(dr.arange(mi.UInt32, n_vertices)) 231 vertices.append(v.numpy()) 232 f = s.face_indices(dr.arange(mi.UInt32, s.face_count())) 233 faces.append(f.numpy() + f_offset) 234 f_offset += n_vertices 235 236 albedo = s.bsdf().eval_diffuse_reflectance(si).numpy() 237 if not np.any(albedo > 0.): 238 if palette is None: 239 palette = matplotlib.colormaps.get_cmap('Pastel1_r') 240 albedo[:] = palette((i % palette.N + 0.5) / palette.N)[:3] 241 242 albedos.append(np.tile(albedo, (n_vertices, 1))) 243 244 # Plot all objects as a single PyThreeJS mesh, which is must faster 245 # than creating individual mesh objects in large scenes. 246 self._plot_mesh(np.concatenate(vertices, axis=0), 247 np.concatenate(faces, axis=0), 248 persist=True, # The scene geometry is persistent 249 colors=np.concatenate(albedos, axis=0)) 250 251 def plot_coverage_map(self, coverage_map, tx=0, db_scale=True, 252 vmin=None, vmax=None, metric="path_gain"): 253 """ 254 Plot the coverage map as a textured rectangle in the scene. Regions 255 where the coverage map is zero-valued are made transparent. 256 """ 257 to_world = coverage_map.to_world() 258 # coverage_map = resample_to_corners( 259 # coverage_map[tx, :, :].numpy() 260 # ) 261 cm = getattr(coverage_map, metric).numpy() 262 if tx is None: 263 coverage_map = np.max(cm, axis=0) 264 else: 265 coverage_map = cm[tx] 266 267 # Create a rectangle from two triangles 268 p00 = to_world.transform_affine([-1, -1, 0]) 269 p01 = to_world.transform_affine([1, -1, 0]) 270 p10 = to_world.transform_affine([-1, 1, 0]) 271 p11 = to_world.transform_affine([1, 1, 0]) 272 273 vertices = np.array([p00, p01, p10, p11]) 274 pmin = np.min(vertices, axis=0) 275 pmax = np.max(vertices, axis=0) 276 277 faces = np.array([ 278 [0, 1, 2], 279 [2, 1, 3], 280 ], dtype=np.uint32) 281 282 vertex_uvs = np.array([ 283 [0, 0], [1, 0], [0, 1], [1, 1] 284 ], dtype=np.float32) 285 286 geo = p3s.BufferGeometry( 287 attributes={ 288 'position': p3s.BufferAttribute(vertices, normalized=False), 289 'index': p3s.BufferAttribute(faces.ravel(), normalized=False), 290 'uv': p3s.BufferAttribute(vertex_uvs, normalized=False), 291 } 292 ) 293 294 to_map, normalizer, color_map = coverage_map_color_mapping( 295 coverage_map, db_scale=db_scale, vmin=vmin, vmax=vmax) 296 texture = color_map(normalizer(to_map)).astype(np.float32) 297 texture[:, :, 3] = (coverage_map > 0.).astype(np.float32) 298 # Pre-multiply alpha 299 texture[:, :, :3] *= texture[:, :, 3, None] 300 301 texture = p3s.DataTexture( 302 data=(255. * texture).astype(np.uint8), 303 format='RGBAFormat', 304 type='UnsignedByteType', 305 magFilter='NearestFilter', 306 minFilter='NearestFilter', 307 ) 308 309 mat = p3s.MeshLambertMaterial( 310 side='DoubleSide', 311 map=texture, transparent=True, 312 ) 313 mesh = p3s.Mesh(geo, mat) 314 315 self._add_child(mesh, pmin, pmax, persist=False) 316 317 def plot_ris(self): 318 """ 319 Plot all RIS as a monochromatic rectangle in the scene 320 """ 321 all_ris = list(self._scene.ris.values()) 322 323 for ris in all_ris: 324 orientation = ris.orientation 325 to_world =\ 326 mitsuba_rectangle_to_world(ris.position, orientation, ris.size, 327 ris=True) 328 color = ris.color.numpy() 329 330 # Create a rectangle from two triangles 331 p00 = to_world.transform_affine([-1, -1, 0]) 332 p01 = to_world.transform_affine([1, -1, 0]) 333 p10 = to_world.transform_affine([-1, 1, 0]) 334 p11 = to_world.transform_affine([1, 1, 0]) 335 336 vertices = np.array([p00, p01, p10, p11]) 337 pmin = np.min(vertices, axis=0) 338 pmax = np.max(vertices, axis=0) 339 340 faces = np.array([ 341 [0, 1, 2], 342 [2, 1, 3], 343 ], dtype=np.uint32) 344 345 geo = p3s.BufferGeometry( 346 attributes={ 347 'position': p3s.BufferAttribute(vertices, 348 normalized=False), 349 'index': p3s.BufferAttribute(faces.ravel(), 350 normalized=False), 351 } 352 ) 353 354 color = f'rgb({", ".join([str(int(v*255)) for v in color])})' 355 mat = p3s.MeshLambertMaterial(color=color, side='DoubleSide') 356 mesh = p3s.Mesh(geo, mat) 357 358 self._add_child(mesh, pmin, pmax, persist=False) 359 360 def set_clipping_plane(self, offset, orientation): 361 """ 362 Input 363 ----- 364 clip_at : float 365 If not `None`, the scene preview will be clipped (cut) by a plane 366 with normal orientation ``clip_plane_orientation`` and offset 367 ``clip_at``. This allows visualizing the interior of meshes such 368 as buildings. 369 370 clip_plane_orientation : tuple[float, float, float] 371 Normal vector of the clipping plane. 372 """ 373 if offset is None: 374 self._renderer.localClippingEnabled = False 375 self._renderer.clippingPlanes = [] 376 else: 377 self._renderer.localClippingEnabled = True 378 self._renderer.clippingPlanes = [p3s.Plane(orientation, offset)] 379 380 @property 381 def camera(self): 382 """ 383 pthreejs.PerspectiveCamera : Get the camera 384 """ 385 return self._camera 386 387 @property 388 def orbit(self): 389 """ 390 pthreejs.OrbitControls : Get the orbit 391 """ 392 return self._orbit 393 394 def resolution(self): 395 """ 396 Returns a tuple (width, height) with the rendering resolution. 397 """ 398 return (self._renderer.width, self._renderer.height) 399 400 ################################################## 401 # Internal methods 402 ################################################## 403 404 def _plot_mesh(self, vertices, faces, persist, colors=None): 405 """ 406 Plots a mesh. 407 408 Input 409 ------ 410 vertices : [n,3], float 411 Position of the vertices 412 413 faces : [n,3], int 414 Indices of the triangles associated with ``vertices`` 415 416 persist : bool 417 Flag indicating if the mesh is persistent, i.e., should not be 418 erased when ``reset()`` is called. 419 420 colors : [n,3] | [3] | None 421 Colors of the vertices. If `None`, black is used. 422 Defaults to `None`. 423 """ 424 assert vertices.ndim == 2 and vertices.shape[1] == 3 425 assert faces.ndim == 2 and faces.shape[1] == 3 426 n_v = vertices.shape[0] 427 pmin, pmax = np.min(vertices, axis=0), np.max(vertices, axis=0) 428 429 # Assuming per-vertex colors 430 if colors is None: 431 # Black is default 432 colors = np.zeros((n_v, 3), dtype=np.float32) 433 elif colors.ndim == 1: 434 colors = np.tile(colors[None, :], (n_v, 1)) 435 colors = colors.astype(np.float32) 436 assert ( (colors.ndim == 2) 437 and (colors.shape[1] == 3) 438 and (colors.shape[0] == n_v) ) 439 440 # Closer match to Mitsuba and Blender 441 colors = np.power(colors, 1/1.8) 442 443 geo = p3s.BufferGeometry( 444 attributes={ 445 'index': p3s.BufferAttribute(faces.ravel(), normalized=False), 446 'position': p3s.BufferAttribute(vertices, normalized=False), 447 'color': p3s.BufferAttribute(colors, normalized=False) 448 } 449 ) 450 451 mat = p3s.MeshStandardMaterial( 452 side='DoubleSide', metalness=0., roughness=1.0, 453 vertexColors='VertexColors', flatShading=True, 454 ) 455 mesh = p3s.Mesh(geo, mat) 456 self._add_child(mesh, pmin, pmax, persist=persist) 457 458 def _plot_points(self, points, persist, colors=None, radius=0.05): 459 """ 460 Plots a set of `n` points. 461 462 Input 463 ------- 464 points : [n, 3], float 465 Coordinates of the `n` points. 466 467 persist : bool 468 Indicates if the points are persistent, i.e., should not be erased 469 when ``reset()`` is called. 470 471 colors : [n, 3], float | [3], float | None 472 Colors of the points. 473 474 radius : float 475 Radius of the points. 476 """ 477 assert points.ndim == 2 and points.shape[1] == 3 478 n = points.shape[0] 479 pmin, pmax = np.min(points, axis=0), np.max(points, axis=0) 480 481 # Assuming per-vertex colors 482 if colors is None: 483 colors = np.zeros((n, 3), dtype=np.float32) 484 elif colors.ndim == 1: 485 colors = np.tile(colors[None, :], (n, 1)) 486 colors = colors.astype(np.float32) 487 assert ( (colors.ndim == 2) 488 and (colors.shape[1] == 3) 489 and (colors.shape[0] == n) ) 490 491 tex = p3s.DataTexture(data=self._get_disk_sprite(), format="RGBAFormat", 492 type="FloatType") 493 494 points = points.astype(np.float32) 495 geo = p3s.BufferGeometry(attributes={ 496 'position': p3s.BufferAttribute(points, normalized=False), 497 'color': p3s.BufferAttribute(colors, normalized=False), 498 }) 499 mat = p3s.PointsMaterial( 500 size=2*radius, sizeAttenuation=True, vertexColors='VertexColors', 501 map=tex, alphaTest=0.5, transparent=True, 502 ) 503 mesh = p3s.Points(geo, mat) 504 self._add_child(mesh, pmin, pmax, persist=persist) 505 506 def _add_child(self, obj, pmin, pmax, persist): 507 """ 508 Adds an object for display 509 510 Input 511 ------ 512 obj : :class:`~pythreejs.Mesh` 513 Mesh to display 514 515 pmin : [3], float 516 Lowest position for the bounding box 517 518 pmax : [3], float 519 Highest position for the bounding box 520 521 persist : bool 522 Flag that indicates if the object is persistent, i.e., if it should 523 be removed from the display when `reset()` is called. 524 """ 525 self._objects.append((obj, persist)) 526 self._p3s_scene.add(obj) 527 528 self._bbox.expand(pmin) 529 self._bbox.expand(pmax) 530 531 def _plot_lines(self, starts, ends, width=0.5, color='black'): 532 """ 533 Plots a set of `n` lines. This is used to plot the paths. 534 535 Input 536 ------ 537 starts : [n, 3], float 538 Coordinates of the lines starting points 539 540 ends : [n, 3], float 541 Coordinates of the lines ending points 542 543 width : float 544 Width of the lines. 545 Defaults to 0.5. 546 547 color : str 548 Color of the lines. 549 Defaults to 'black'. 550 """ 551 552 assert starts.ndim == 2 and starts.shape[1] == 3 553 assert ends.ndim == 2 and ends.shape[1] == 3 554 assert starts.shape[0] == ends.shape[0] 555 556 segments = np.hstack((starts, ends)).astype(np.float32).reshape(-1,2,3) 557 pmin = np.min(segments, axis=(0, 1)) 558 pmax = np.max(segments, axis=(0, 1)) 559 560 geo = p3s.LineSegmentsGeometry(positions=segments) 561 mat = p3s.LineMaterial(linewidth=width, color=color) 562 mesh = p3s.LineSegments2(geo, mat) 563 564 # Lines are not flagged as persistent as they correspond to paths, which 565 # can changes from one display to the next. 566 self._add_child(mesh, pmin, pmax, persist=False) 567 568 def _get_disk_sprite(self): 569 """ 570 Returns the sprite used to represent transmitters and receivers though 571 ``_plot_points()``. 572 573 Output 574 ------ 575 : [n,n,4], float 576 Sprite 577 """ 578 if self._disk_sprite is not None: 579 return self._disk_sprite 580 581 n = 128 582 sprite = np.ones((n, n, 4)) 583 sprite[:, :, 3] = 0. 584 # Draw a disk with an empty circle close to the edge 585 ij = np.mgrid[:n, :n] 586 ij = ij.reshape(2, -1) 587 588 p = (ij + 0.5) / n - 0.5 589 t = np.linalg.norm(p, axis=0).reshape((n, n)) 590 inside = t < 0.48 591 in_band = (t < 0.45) & (t > 0.42) 592 sprite[inside & (~in_band), 3] = 1.0 593 594 sprite = sprite.astype(np.float32) 595 self._disk_sprite = sprite 596 return sprite 597 598 ################################################ 599 # The following methods are required for 600 # integration in Jupyter notebooks 601 ################################################ 602 603 # pylint: disable=unused-argument 604 def _repr_mimebundle_(self, **kwargs): 605 # pylint: disable=protected-access,not-callable 606 bundle = self._renderer._repr_mimebundle_() 607 assert 'text/html' not in bundle 608 bundle['text/html'] = self._repr_html_() 609 return bundle 610 611 def _repr_html_(self): 612 """ 613 Standalone HTML display, i.e. outside of an interactive Jupyter 614 notebook environment. 615 """ 616 617 html = embed_snippet(self._renderer, requirejs=True) 618 return html