Skip to content

Commit

Permalink
Merge pull request #43 from kerrj/justin/updates
Browse files Browse the repository at this point in the history
Update scale, bump ns version, remove some comments
  • Loading branch information
kerrj authored Aug 16, 2023
2 parents ef934c1 + 94c8630 commit 3b2cb90
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 32 deletions.
2 changes: 2 additions & 0 deletions lerf/data/lerf_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
# assume all cameras have the same focal length and image width
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
return ray_bundle, batch
39 changes: 8 additions & 31 deletions lerf/lerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,6 @@ def populate_modules(self):
clip_n_dims=self.image_encoder.embedding_dim,
)

# populate some viewer logic
# TODO use the values from this code to select the scale
# def scale_cb(element):
# self.config.n_scales = element.value

# self.n_scale_slider = ViewerSlider("N Scales", 15, 5, 30, 1, cb_hook=scale_cb)

# def max_cb(element):
# self.config.max_scale = element.value

# self.max_scale_slider = ViewerSlider("Max Scale", 1.5, 0, 5, 0.05, cb_hook=max_cb)

# def hardcode_scale_cb(element):
# self.hardcoded_scale = element.value

# self.hardcoded_scale_slider = ViewerSlider(
# "Hardcoded Scale", 1.0, 0, 5, 0.05, cb_hook=hardcode_scale_cb, disabled=True
# )

# def single_scale_cb(element):
# self.n_scale_slider.set_disabled(element.value)
# self.max_scale_slider.set_disabled(element.value)
# self.hardcoded_scale_slider.set_disabled(not element.value)

# self.single_scale_box = ViewerCheckbox("Single Scale", False, cb_hook=single_scale_cb)

def get_max_across(self, ray_samples, weights, hashgrid_field, scales_shape, preset_scales=None):
# TODO smoothen this out
if preset_scales is not None:
Expand Down Expand Up @@ -120,13 +94,16 @@ def gather_fn(tens):
return torch.gather(tens, -2, best_ids.expand(*best_ids.shape[:-1], tens.shape[-1]))

dataclass_fn = lambda dc: dc._apply_fn_to_fields(gather_fn, dataclass_fn)
lerf_samples = ray_samples._apply_fn_to_fields(gather_fn, dataclass_fn)
lerf_samples: RaySamples = ray_samples._apply_fn_to_fields(gather_fn, dataclass_fn)

if self.training:
clip_scales = ray_bundle.metadata["clip_scales"]
clip_scales = clip_scales[..., None]
dist = lerf_samples.spacing_to_euclidean_fn(lerf_samples.spacing_starts.squeeze(-1)).unsqueeze(-1)
clip_scales = clip_scales * ray_bundle.metadata["width"] * (1 / ray_bundle.metadata["fx"]) * dist
with torch.no_grad():
clip_scales = ray_bundle.metadata["clip_scales"]
clip_scales = clip_scales[..., None]
dist = (lerf_samples.frustums.get_positions() - ray_bundle.origins[:, None, :]).norm(
dim=-1, keepdim=True
)
clip_scales = clip_scales * ray_bundle.metadata["height"] * (dist / ray_bundle.metadata["fy"])
else:
clip_scales = torch.ones_like(lerf_samples.spacing_starts, device=self.device)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies=[
"regex",
"tqdm",
"clip @ git+https://github.com/openai/CLIP.git",
"nerfstudio>=0.3.0"
"nerfstudio>=0.3.1"
]

[tool.setuptools.packages.find]
Expand Down

0 comments on commit 3b2cb90

Please sign in to comment.