From 94c863016e6e1a386ad7ebd2bd5639a88194937d Mon Sep 17 00:00:00 2001 From: Justin Kerr Date: Tue, 15 Aug 2023 16:41:11 -0700 Subject: [PATCH] update computation of scale, bump nerfstudio version, remove some dead comments --- lerf/data/lerf_datamanager.py | 2 ++ lerf/lerf.py | 39 +++++++---------------------------- pyproject.toml | 2 +- 3 files changed, 11 insertions(+), 32 deletions(-) diff --git a/lerf/data/lerf_datamanager.py b/lerf/data/lerf_datamanager.py index cb57462..ff13b5f 100644 --- a/lerf/data/lerf_datamanager.py +++ b/lerf/data/lerf_datamanager.py @@ -118,4 +118,6 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]: # assume all cameras have the same focal length and image width ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item() ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item() + ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item() + ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item() return ray_bundle, batch diff --git a/lerf/lerf.py b/lerf/lerf.py index c427748..1be90bb 100644 --- a/lerf/lerf.py +++ b/lerf/lerf.py @@ -52,32 +52,6 @@ def populate_modules(self): clip_n_dims=self.image_encoder.embedding_dim, ) - # populate some viewer logic - # TODO use the values from this code to select the scale - # def scale_cb(element): - # self.config.n_scales = element.value - - # self.n_scale_slider = ViewerSlider("N Scales", 15, 5, 30, 1, cb_hook=scale_cb) - - # def max_cb(element): - # self.config.max_scale = element.value - - # self.max_scale_slider = ViewerSlider("Max Scale", 1.5, 0, 5, 0.05, cb_hook=max_cb) - - # def hardcode_scale_cb(element): - # self.hardcoded_scale = element.value - - # self.hardcoded_scale_slider = ViewerSlider( - # "Hardcoded Scale", 1.0, 0, 5, 0.05, cb_hook=hardcode_scale_cb, disabled=True - # ) - - # def single_scale_cb(element): - # self.n_scale_slider.set_disabled(element.value) - # self.max_scale_slider.set_disabled(element.value) - # self.hardcoded_scale_slider.set_disabled(not element.value) - - # self.single_scale_box = ViewerCheckbox("Single Scale", False, cb_hook=single_scale_cb) - def get_max_across(self, ray_samples, weights, hashgrid_field, scales_shape, preset_scales=None): # TODO smoothen this out if preset_scales is not None: @@ -120,13 +94,16 @@ def gather_fn(tens): return torch.gather(tens, -2, best_ids.expand(*best_ids.shape[:-1], tens.shape[-1])) dataclass_fn = lambda dc: dc._apply_fn_to_fields(gather_fn, dataclass_fn) - lerf_samples = ray_samples._apply_fn_to_fields(gather_fn, dataclass_fn) + lerf_samples: RaySamples = ray_samples._apply_fn_to_fields(gather_fn, dataclass_fn) if self.training: - clip_scales = ray_bundle.metadata["clip_scales"] - clip_scales = clip_scales[..., None] - dist = lerf_samples.spacing_to_euclidean_fn(lerf_samples.spacing_starts.squeeze(-1)).unsqueeze(-1) - clip_scales = clip_scales * ray_bundle.metadata["width"] * (1 / ray_bundle.metadata["fx"]) * dist + with torch.no_grad(): + clip_scales = ray_bundle.metadata["clip_scales"] + clip_scales = clip_scales[..., None] + dist = (lerf_samples.frustums.get_positions() - ray_bundle.origins[:, None, :]).norm( + dim=-1, keepdim=True + ) + clip_scales = clip_scales * ray_bundle.metadata["height"] * (dist / ray_bundle.metadata["fy"]) else: clip_scales = torch.ones_like(lerf_samples.spacing_starts, device=self.device) diff --git a/pyproject.toml b/pyproject.toml index f09b49d..8f451df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies=[ "regex", "tqdm", "clip @ git+https://github.com/openai/CLIP.git", - "nerfstudio>=0.3.0" + "nerfstudio>=0.3.1" ] [tool.setuptools.packages.find]