diff --git a/wisp/accelstructs/octree_as.py b/wisp/accelstructs/octree_as.py index fe310b8..d6051db 100644 --- a/wisp/accelstructs/octree_as.py +++ b/wisp/accelstructs/octree_as.py @@ -255,7 +255,7 @@ def raymarch(self, rays, num_samples, level=None, raymarch_type='voxel'): 'ray' - samples num_samples along each ray, and then filters out samples which falls outside of occupied cells. In this scheme, num_hit_samples <= num_rays * num_samples - + Returns: (torch.LongTensor, torch.LongTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.BoolTensor): @@ -271,16 +271,16 @@ def raymarch(self, rays, num_samples, level=None, raymarch_type='voxel'): # Samples points along the rays by first tracing it against the SPC object. # Then, given each SPC voxel hit, will sample some number of samples in each voxel. - # This setting is pretty nice for getting decent outputs from outside-looking-in scenes, + # This setting is pretty nice for getting decent outputs from outside-looking-in scenes, # but in general it's not very robust or proper since the ray samples will be weirdly distributed - # and or aliased. + # and or aliased. if raymarch_type == 'voxel': ridx, samples, depth_samples, deltas, boundary = self._raymarch_voxel(rays=rays, level=level, num_samples=num_samples) # Samples points along the rays, and then uses the SPC object the filter out samples that don't hit - # the SPC objects. This is a much more well-spaced-out sampling scheme and will work well for + # the SPC objects. This is a much more well-spaced-out sampling scheme and will work well for # inside-looking-out scenes. The camera near and far planes will have to be adjusted carefully, however. elif raymarch_type == 'ray': ridx, samples, depth_samples, deltas, boundary = self._raymarch_ray(rays=rays, diff --git a/wisp/gfx/datalayers/__init__.py b/wisp/gfx/datalayers/__init__.py index 3e042c5..660c6ad 100644 --- a/wisp/gfx/datalayers/__init__.py +++ b/wisp/gfx/datalayers/__init__.py @@ -9,3 +9,4 @@ from .datalayers import Datalayers from .camera_datalayers import CameraDatalayers from .octree_datalayers import OctreeDatalayers +from .aabb_datalayers import AABBDatalayers diff --git a/wisp/gfx/datalayers/aabb_datalayers.py b/wisp/gfx/datalayers/aabb_datalayers.py new file mode 100644 index 0000000..32a63b3 --- /dev/null +++ b/wisp/gfx/datalayers/aabb_datalayers.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from typing import Dict +import kaolin.ops.spc as spc_ops +from wisp.core import PrimitivesPack +from wisp.accelstructs import AxisAlignedBBoxAS +from wisp.gfx.datalayers import Datalayers +from wisp.core.colors import soft_blue, soft_red, lime_green, purple, gold + + +class AABBDatalayers(Datalayers): + + def __init__(self): + self._last_state = dict() + + def needs_redraw(self, blas: AxisAlignedBBoxAS) -> True: + return True + + def regenerate_data_layers(self, blas: AxisAlignedBBoxAS) -> Dict[str, PrimitivesPack]: + data_layers = dict() + color_tensor = torch.tensor((*soft_blue, 1.0)) + + cells = PrimitivesPack() + lod = 0 + level_points = spc_ops.unbatched_get_level_points(blas.points, blas.pyramid, 0) + corners = spc_ops.points_to_corners(level_points) / (2 ** lod) + corners = corners * 2.0 - 1.0 + grid_lines = corners[:, [(0, 1), (1, 3), (3, 2), (2, 0), + (4, 5), (5, 7), (7, 6), (6, 4), + (0, 4), (1, 5), (2, 6), (3, 7)]] + + grid_lines_start = grid_lines[:, :, 0].reshape(-1, 3) + grid_lines_end = grid_lines[:, :, 1].reshape(-1, 3) + grid_lines_color = color_tensor.repeat(grid_lines_start.shape[0], 1) + cells.add_lines(grid_lines_start, grid_lines_end, grid_lines_color) + + data_layers[f'AABB'] = cells + return data_layers diff --git a/wisp/gfx/datalayers/octree_datalayers.py b/wisp/gfx/datalayers/octree_datalayers.py index 993aa30..8fe1574 100644 --- a/wisp/gfx/datalayers/octree_datalayers.py +++ b/wisp/gfx/datalayers/octree_datalayers.py @@ -10,7 +10,7 @@ from typing import Dict import kaolin.ops.spc as spc_ops from wisp.core import PrimitivesPack -from wisp.models.grids.octree_grid import OctreeGrid +from wisp.accelstructs import OctreeAS from wisp.gfx.datalayers import Datalayers from wisp.core.colors import soft_blue, soft_red, lime_green, purple, gold @@ -20,13 +20,13 @@ class OctreeDatalayers(Datalayers): def __init__(self): self._last_state = dict() - def needs_redraw(self, grid: OctreeGrid) -> True: + def needs_redraw(self, blas: OctreeAS) -> True: # Pyramids contain information about the number of cells per level, # it's a plausible heuristic to determine whether the frame should be redrawn return not ('pyramids' in self._last_state and - torch.equal(self._last_state['pyramids'], grid.blas.pyramid)) + torch.equal(self._last_state['pyramids'], blas.pyramid)) - def regenerate_data_layers(self, grid: OctreeGrid) -> Dict[str, PrimitivesPack]: + def regenerate_data_layers(self, blas: OctreeAS) -> Dict[str, PrimitivesPack]: data_layers = dict() lod_colors = [ torch.tensor((*soft_blue, 1.0)), @@ -36,11 +36,10 @@ def regenerate_data_layers(self, grid: OctreeGrid) -> Dict[str, PrimitivesPack]: torch.tensor((*gold, 1.0)), ] - for lod in range(grid.blas.max_level): + for lod in range(blas.max_level): cells = PrimitivesPack() - level_points = spc_ops.unbatched_get_level_points(grid.blas.points, - grid.blas.pyramid, lod) + level_points = spc_ops.unbatched_get_level_points(blas.points, blas.pyramid, lod) corners = spc_ops.points_to_corners(level_points) / (2 ** lod) @@ -58,5 +57,5 @@ def regenerate_data_layers(self, grid: OctreeGrid) -> Dict[str, PrimitivesPack]: data_layers[f'Octree LOD{lod}'] = cells - self._last_state['pyramids'] = grid.blas.pyramid + self._last_state['pyramids'] = blas.pyramid return data_layers diff --git a/wisp/models/grids/octree_grid.py b/wisp/models/grids/octree_grid.py index 36c625d..618525f 100644 --- a/wisp/models/grids/octree_grid.py +++ b/wisp/models/grids/octree_grid.py @@ -25,7 +25,7 @@ def __init__( accelstruct, feature_dim : int, base_lod : int, - num_lods : int = 1, + num_lods : int = 1, interpolation_type : str = 'linear', multiscale_type : str = 'cat', feature_std : float = 0.0, diff --git a/wisp/models/nefs/nerf.py b/wisp/models/nefs/nerf.py index 81671e5..4c7ee87 100644 --- a/wisp/models/nefs/nerf.py +++ b/wisp/models/nefs/nerf.py @@ -76,7 +76,7 @@ def init_embedder(self, embedder_type, frequencies=None): return embedder, embed_dim def init_decoders(self, activation_type, layer_type, num_layers, hidden_dim): - """Initializes the decoder object. + """Initializes the decoder object. """ decoder_density = BasicDecoder(input_dim=self.density_net_input_dim, output_dim=16, diff --git a/wisp/renderer/core/api/base_renderer.py b/wisp/renderer/core/api/base_renderer.py index e15d8a9..afd5c94 100644 --- a/wisp/renderer/core/api/base_renderer.py +++ b/wisp/renderer/core/api/base_renderer.py @@ -123,16 +123,10 @@ def create_layers_painter(cls, nef: BaseNeuralField) -> Optional[Datalayers]: return None def needs_redraw(self) -> bool: - if self.layers_painter is not None: - return self.layers_painter.needs_redraw(self.nef.grid) - else: - return True + return True def regenerate_data_layers(self) -> Dict[str, PrimitivesPack]: - if self.layers_painter is not None: - return self.layers_painter.regenerate_data_layers(self.nef.grid) - else: - return dict() + return dict() def pre_render(self, payload: FramePayload, *args, **kwargs) -> None: """ Prepare primary rays to render """ diff --git a/wisp/renderer/core/renderers/radiance_pipeline_renderer.py b/wisp/renderer/core/renderers/radiance_pipeline_renderer.py index ec25fb4..19bbd48 100644 --- a/wisp/renderer/core/renderers/radiance_pipeline_renderer.py +++ b/wisp/renderer/core/renderers/radiance_pipeline_renderer.py @@ -7,14 +7,14 @@ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. from __future__ import annotations -from typing import Optional +from typing import Optional, Dict import torch -from wisp.core import RenderBuffer, Rays +from wisp.core import RenderBuffer, Rays, PrimitivesPack from wisp.renderer.core.api import RayTracedRenderer, FramePayload, field_renderer from wisp.models.nefs.nerf import NeuralRadianceField, BaseNeuralField from wisp.tracers import PackedRFTracer -from wisp.accelstructs import OctreeAS -from wisp.gfx.datalayers import Datalayers, OctreeDatalayers +from wisp.accelstructs import OctreeAS, AxisAlignedBBoxAS +from wisp.gfx.datalayers import Datalayers, OctreeDatalayers, AABBDatalayers @field_renderer(BaseNeuralField, PackedRFTracer) @@ -50,11 +50,30 @@ def __init__(self, nef: NeuralRadianceField, tracer_type=None, batch_size=2**14, @classmethod def create_layers_painter(cls, nef: BaseNeuralField) -> Optional[Datalayers]: - if nef.grid.__class__.__name__ in ('OctreeGrid', 'CodebookOctreeGrid', 'HashGrid'): + """ NeuralRadianceFieldPackedRenderer can draw datalayers showing the occupancy status. + These depend on the bottom level acceleration structure. + """ + if not hasattr(nef.grid, 'blas'): + return None + elif isinstance(nef.grid.blas, AxisAlignedBBoxAS): + return AABBDatalayers() + elif isinstance(nef.grid.blas, OctreeAS): return OctreeDatalayers() else: return None + def needs_redraw(self) -> bool: + if self.layers_painter is not None: + return self.layers_painter.needs_redraw(self.nef.grid.blas) + else: + return True + + def regenerate_data_layers(self) -> Dict[str, PrimitivesPack]: + if self.layers_painter is not None: + return self.layers_painter.regenerate_data_layers(self.nef.grid.blas) + else: + return dict() + def pre_render(self, payload: FramePayload, *args, **kwargs) -> None: super().pre_render(payload) self.render_res_x = payload.render_res_x @@ -104,24 +123,20 @@ def aabb(self) -> torch.Tensor: # (center_x, center_y, center_z, width, height, depth) return torch.tensor((0.0, 0.0, 0.0, 2.0, 2.0, 2.0), device=self.device) - def acceleration_structure(self): - if isinstance(self.nef.grid.blas, OctreeAS): - return "Octree" - else: + def acceleration_structure(self) -> str: + """ Returns a human readable name of the bottom level acceleration structure used by this renderer """ + if getattr(self.nef, 'grid') is None or getattr(self.nef.grid, 'blas') is None: return "None" - - def features_structure(self): - grid_type = self.nef.grid.__class__.__name__ - if grid_type == "OctreeGrid": - return "Octree Grid" - elif grid_type == "CodebookOctreeGrid": - return "Codebook Grid" - elif grid_type == "TriplanarGrid": - return "Triplanar Grid" - elif grid_type == "HashGrid": - return "Hash Grid" + elif hasattr(self.nef.grid.blas, 'name'): + return self.nef.grid.blas.name() else: return "Unknown" - - + def features_structure(self) -> str: + """ Returns a human readable name of the feature structure used by this renderer """ + if getattr(self.nef, 'grid') is None: + return "None" + elif hasattr(self.nef.grid, 'name'): + return self.nef.grid.name() + else: + return "Unknown" diff --git a/wisp/renderer/core/renderers/sdf_pipeline_renderer.py b/wisp/renderer/core/renderers/sdf_pipeline_renderer.py index a3d491c..4193534 100644 --- a/wisp/renderer/core/renderers/sdf_pipeline_renderer.py +++ b/wisp/renderer/core/renderers/sdf_pipeline_renderer.py @@ -9,13 +9,12 @@ from __future__ import annotations from typing import Optional, Dict import torch -from wisp.core import RenderBuffer +from wisp.core import RenderBuffer, Rays, PrimitivesPack from wisp.renderer.core.api import RayTracedRenderer, FramePayload, field_renderer -from wisp.core import Rays from wisp.models.nefs.neural_sdf import NeuralSDF, BaseNeuralField from wisp.tracers import PackedSDFTracer -from wisp.accelstructs import OctreeAS -from wisp.gfx.datalayers import Datalayers, OctreeDatalayers +from wisp.accelstructs import OctreeAS, AxisAlignedBBoxAS +from wisp.gfx.datalayers import Datalayers, OctreeDatalayers, AABBDatalayers @field_renderer(BaseNeuralField, PackedSDFTracer) @@ -49,11 +48,30 @@ def __init__(self, nef: NeuralSDF, tracer_type=None, @classmethod def create_layers_painter(cls, nef: BaseNeuralField) -> Optional[Datalayers]: - if nef.grid.__class__.__name__ in ('OctreeGrid', 'CodebookOctreeGrid', 'HashGrid'): + """ NeuralSDFPackedRenderer can draw datalayers showing the occupancy status. + These depend on the bottom level acceleration structure. + """ + if not hasattr(nef.grid, 'blas'): + return None + elif isinstance(nef.grid.blas, AxisAlignedBBoxAS): + return AABBDatalayers() + elif isinstance(nef.grid.blas, OctreeAS): return OctreeDatalayers() else: return None + def needs_redraw(self) -> bool: + if self.layers_painter is not None: + return self.layers_painter.needs_redraw(self.nef.grid.blas) + else: + return True + + def regenerate_data_layers(self) -> Dict[str, PrimitivesPack]: + if self.layers_painter is not None: + return self.layers_painter.regenerate_data_layers(self.nef.grid.blas) + else: + return dict() + def pre_render(self, payload: FramePayload, *args, **kwargs) -> None: super().pre_render(payload) self.render_res_x = payload.render_res_x @@ -99,24 +117,20 @@ def aabb(self) -> torch.Tensor: # (center_x, center_y, center_z, width, height, depth) return torch.tensor((0.0, 0.0, 0.0, 2.0, 2.0, 2.0), device=self.device) - def acceleration_structure(self): - if isinstance(self.nef.grid.blas, OctreeAS): - return "Octree" - else: + def acceleration_structure(self) -> str: + """ Returns a human readable name of the bottom level acceleration structure used by this renderer """ + if getattr(self.nef, 'grid') is None or getattr(self.nef.grid, 'blas') is None: return "None" - - def features_structure(self): - grid_type = self.nef.grid.__class__.__name__ - if grid_type == "OctreeGrid": - return "Octree Grid" - elif grid_type == "CodebookOctreeGrid": - return "Codebook Grid" - elif grid_type == "TriplanarGrid": - return "Triplanar Grid" - elif grid_type == "HashGrid": - return "Hash Grid" + elif hasattr(self.nef.grid.blas, 'name'): + return self.nef.grid.blas.name() else: return "Unknown" - - + def features_structure(self) -> str: + """ Returns a human readable name of the feature structure used by this renderer """ + if getattr(self.nef, 'grid') is None: + return "None" + elif hasattr(self.nef.grid, 'name'): + return self.nef.grid.name() + else: + return "Unknown" diff --git a/wisp/renderer/core/renderers/spc_pipeline_renderer.py b/wisp/renderer/core/renderers/spc_pipeline_renderer.py index 1164e4b..773195f 100644 --- a/wisp/renderer/core/renderers/spc_pipeline_renderer.py +++ b/wisp/renderer/core/renderers/spc_pipeline_renderer.py @@ -64,8 +64,10 @@ def aabb(self) -> torch.Tensor: # (center_x, center_y, center_z, width, height, depth) return torch.tensor((0.0, 0.0, 0.0, 2.0, 2.0, 2.0), device=self.device) - def acceleration_structure(self): + def acceleration_structure(self) -> str: + """ Returns a human readable name of the bottom level acceleration structure used by this renderer """ return "Octree" # Assumes to always use OctreeAS - def features_structure(self): + def features_structure(self) -> str: + """ Returns a human readable name of the feature structure used by this renderer """ return "Octree Grid" # Assumes to always use OctreeGrid for storing features