Skip to content

Commit

Permalink
Change nerf methods mapping and add thermal predicted image only output
Browse files Browse the repository at this point in the history
- change methods mapping in `train_eval_script.py`
- add `thermal` and `thermal_combined` images output
  • Loading branch information
jiaruiyu99 committed Oct 2, 2024
1 parent c6f78c0 commit f7ac977
Show file tree
Hide file tree
Showing 18 changed files with 115 additions and 82 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@ Install the package by running `pip install -e .` and then thermoNeRF should be
To train and evaluate ThermoNeRF, first download our dataset and then use the following scripts

```bash
python scripts/train_eval_script.py --data-asset-path DATA_PATH --model-type thermal-nerf --max-num-iterations ITERATIONS
python thermo_nerf/scripts/train_eval_script.py --data-asset-path DATA_PATH --model-type thermal-nerf --max-num-iterations ITERATIONS
```

E.g.

```bash
python scripts/train_eval_script.py --data data/ThermoScenes/double_robot/ --model_type thermal-nerf --max_num_iterations 1000
python thermo_nerf/scripts/train_eval_script.py --data data/ThermoScenes/double_robot/ --model_type thermal-nerf --max_num_iterations 1000
```

## Evaluate

To evaluate a model, run the following script.

```bash
python scripts/eval_script.py --dataset_path DATA_PATH --model_uri MODEL_PATH --output_folder RESULTS_PATH
python thermo_nerf/scripts/eval_script.py --dataset_path DATA_PATH --model_uri MODEL_PATH --output_folder RESULTS_PATH
```

## Render
Expand All @@ -52,7 +52,7 @@ For more information about it, check [Nerfstudio Documentation](https://docs.ner
To render a path of a scpefic scene using a pretrained model, use the following script

```bash
python scripts/render_video_script.py --dataset_path DATA_PATH --model_uri MODEL_PATH --camera_path_filename CAMERA_PATH_JSON --output_dir RENDER_RESULTS_PATH
python thermo_nerf/scripts/render_video_script.py --dataset_path DATA_PATH --model_uri MODEL_PATH --camera_path_filename CAMERA_PATH_JSON --output_dir RENDER_RESULTS_PATH
```

## Contribute
Expand Down
8 changes: 0 additions & 8 deletions scripts/train_script.py

This file was deleted.

9 changes: 5 additions & 4 deletions tests/test_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
from nerfstudio.models.vanilla_nerf import NeRFModel

from thermo_nerf.render.renderer import RenderedImageModality, Renderer
from thermo_nerf.render.renderer import Renderer
from thermo_nerf.rendered_image_modalities import RenderedImageModality


class TestRenderer(unittest.TestCase):
Expand All @@ -24,7 +25,7 @@ def setUpClass(cls) -> None:
array = np.reshape(array, (1024, 720))

cls._renderer_vanilla._rendered_images = {
RenderedImageModality.rgb: [array, array]
RenderedImageModality.RGB: [array, array]
}

def test_load(self) -> None:
Expand All @@ -47,7 +48,7 @@ def test_image_export(self) -> None:
output_dir.mkdir(exist_ok=True)

self._renderer_vanilla.save_images(
[RenderedImageModality.rgb], output_dir=output_dir
[RenderedImageModality.RGB], output_dir=output_dir
)
files = []
for file in output_dir.iterdir():
Expand All @@ -59,7 +60,7 @@ def test_image_export(self) -> None:
def test_gif_export(self) -> None:
output_dir = Path("tests/tmp_output")
output_dir.mkdir(exist_ok=True)
self._renderer_vanilla.save_gif([RenderedImageModality.rgb], 1, output_dir)
self._renderer_vanilla.save_gif([RenderedImageModality.RGB], 1, output_dir)

def test_load_cameras_positions(self) -> None:
cameras = Renderer.load_cameras(
Expand Down
12 changes: 6 additions & 6 deletions thermo_nerf/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from nerfstudio.pipelines.base_pipeline import Pipeline
from PIL import Image

from thermo_nerf.render.renderer import RenderedImageModality
from thermo_nerf.rendered_image_modalities import RenderedImageModality


class Evaluator:
Expand All @@ -20,7 +20,7 @@ def __init__(
pipeline: Pipeline,
config: TrainerConfig,
job_param_identifier: Optional[str] = None,
modalities_to_save: list[RenderedImageModality] = [RenderedImageModality.rgb],
modalities_to_save: list[RenderedImageModality] = [RenderedImageModality.RGB],
) -> None:
"""
Initializes the parameters which are `output_file` to save the metrics, the
Expand Down Expand Up @@ -76,9 +76,6 @@ def _compute_metrics(
images_dict,
) = self._pipeline.model.get_image_metrics_and_images(outputs, batch)

# Save imgs
images_dict["rgb"] = images_dict.pop("img")

for modality in self.modalities_to_save:
self._evaluation_images[modality].append(
(images_dict[modality.value] * 255).byte().cpu().numpy()
Expand Down Expand Up @@ -166,7 +163,10 @@ def save_metrics(self, output_folder: Path) -> None:
lpips_folder_path.joinpath(self.identifier + ".txt").write_text(
json.dumps(self._metrics["lpips"], indent=2), "utf8"
)
if RenderedImageModality.thermal in self.modalities_to_save:
if (
RenderedImageModality.THERMAL
or RenderedImageModality.THERMAL_COMBINED in self.modalities_to_save
):
psnr_folder_path.joinpath(self.identifier + "_thermal.txt").write_text(
json.dumps(self._metrics["psnr_thermal"], indent=2), "utf8"
)
Expand Down
7 changes: 7 additions & 0 deletions thermo_nerf/model_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class ModelType(Enum):
NERFACTO = 1
THERMONERF = 2
CONCATNERF = 3
11 changes: 1 addition & 10 deletions thermo_nerf/render/renderer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from enum import Enum
from pathlib import Path
from typing import Optional

Expand All @@ -13,15 +12,7 @@
from nerfstudio.models.base_model import Model
from nerfstudio.pipelines.base_pipeline import Pipeline


class RenderedImageModality(Enum):
rgb = "rgb"
depth = "depth"
accumulation = "accumulation"
thermal = "thermal"

def __str__(self):
return str(self.value)
from thermo_nerf.rendered_image_modalities import RenderedImageModality


class Renderer:
Expand Down
9 changes: 9 additions & 0 deletions thermo_nerf/rendered_image_modalities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum


class RenderedImageModality(Enum):
RGB = "img"
DEPTH = "depth"
ACCUMULATION = "accumulation"
THERMAL = "thermal"
THERMAL_COMBINED = "thermal_combined"
19 changes: 11 additions & 8 deletions thermo_nerf/rgb_concat/concat_nerfacto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from thermo_nerf.rendered_image_modalities import RenderedImageModality
from thermo_nerf.rgb_concat.concat_field import ConcatNerfactoTField
from thermo_nerf.rgb_concat.rgbt_renderer import RGBTRenderer

Expand Down Expand Up @@ -187,7 +188,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None):

pred_rgb, gt_rgb = self.renderer_rgb.blend_background_for_loss_computation(
pred_image=outputs["rgb"],
pred_accumulation=outputs["accumulation"],
pred_accumulation=outputs[RenderedImageModality.ACCUMULATION.value],
gt_image=image,
)

Expand Down Expand Up @@ -240,10 +241,12 @@ def get_image_metrics_and_images(
"rgb"
] # Blended with background (black if random background)

acc = colormaps.apply_colormap(outputs["accumulation"])
acc = colormaps.apply_colormap(
outputs[RenderedImageModality.ACCUMULATION.value]
)
depth = colormaps.apply_depth_colormap(
outputs["depth"],
accumulation=outputs["accumulation"],
outputs[RenderedImageModality.DEPTH.value],
accumulation=outputs[RenderedImageModality.ACCUMULATION.value],
)

combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1)
Expand All @@ -270,16 +273,16 @@ def get_image_metrics_and_images(
metrics_dict["lpips"] = float(lpips)

images_dict = {
"img": combined_rgb,
"accumulation": combined_acc,
"depth": combined_depth,
RenderedImageModality.RGB.value: combined_rgb,
RenderedImageModality.ACCUMULATION.value: combined_acc,
RenderedImageModality.DEPTH.value: combined_depth,
}

for i in range(self.config.num_proposal_iterations):
key = f"prop_depth_{i}"
prop_depth_i = colormaps.apply_depth_colormap(
outputs[key],
accumulation=outputs["accumulation"],
accumulation=outputs[RenderedImageModality.ACCUMULATION.value],
)
images_dict[key] = prop_depth_i

Expand Down
6 changes: 4 additions & 2 deletions scripts/eval_script.py → thermo_nerf/scripts/eval_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from pathlib import Path

import tyro

from thermo_nerf.evaluator.evaluator import Evaluator
from thermo_nerf.render.renderer import RenderedImageModality, Renderer
from thermo_nerf.render.renderer import Renderer
from thermo_nerf.rendered_image_modalities import RenderedImageModality


@dataclass
Expand All @@ -19,7 +21,7 @@ class EvalCLIArgs:
"""Name of the output folder to save metrics"""
modalities_to_save: list[RenderedImageModality] = field(
default_factory=lambda: [
RenderedImageModality.rgb,
RenderedImageModality.RGB,
]
)
"""Name of the renderer outputs to use: rgb, depth, accumulation."""
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import Optional

import tyro
from thermo_nerf.render.renderer import RenderedImageModality, Renderer

from thermo_nerf.render.renderer import Renderer
from thermo_nerf.rendered_image_modalities import RenderedImageModality


@dataclass
Expand All @@ -27,8 +29,8 @@ class RenderTrajectoryCLIArgs:
"""Save images to a folder in `output_dir`"""
rendered_image_modalities: list[RenderedImageModality] = field(
default_factory=lambda: [
RenderedImageModality.rgb,
RenderedImageModality.thermal,
RenderedImageModality.RGB,
RenderedImageModality.THERMAL,
]
)
"""Name of the renderer outputs to use: rgb, depth, accumulation."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@
from pathlib import Path

import tyro



from nerfstudio.scripts.train import main

from thermo_nerf.evaluator.evaluator import Evaluator
from thermo_nerf.model_type import ModelType
from thermo_nerf.nerfacto_config.config_nerfacto import nerfacto_config
from thermo_nerf.render.renderer import RenderedImageModality, Renderer
from thermo_nerf.render.renderer import Renderer
from thermo_nerf.rendered_image_modalities import RenderedImageModality
from thermo_nerf.rgb_concat.config_concat_nerfacto import concat_nerf_config
from thermo_nerf.thermal_nerf.config_thermal_nerf import thermal_nerftrack_config


@dataclass
class TrainingParameters:
model_type: str = "thermal-nerf"
model_type: ModelType = ModelType.THERMONERF
"""What NeRF model to train. Defaults to Nerfacto"""
experiment_name: str = "nerfacto training"
"""Name of the model to train"""
Expand All @@ -25,23 +30,28 @@ class TrainingParameters:

metrics_output_folder: Path = Path("./outputs/")

modalities_to_save: list[RenderedImageModality] = field(
default_factory=lambda: [
RenderedImageModality.rgb,
]
)
"""Name of the renderer outputs to use: rgb, depth, accumulation."""

seed: int = 0
"""Seed for the random number generator"""

def __post_init__(self) -> None:
mapping_name_to_config = {
"nerfacto": nerfacto_config,
"thermal-nerf": thermal_nerftrack_config,
"concat-nerf": concat_nerf_config,
}
self.model = mapping_name_to_config[self.model_type]
if self.model_type == ModelType.THERMONERF:
self.model = thermal_nerftrack_config
self.modalities_to_save = [
RenderedImageModality.RGB,
RenderedImageModality.THERMAL,
RenderedImageModality.THERMAL_COMBINED,
]

if self.model_type == ModelType.NERFACTO:
self.model = nerfacto_config
self.modalities_to_save = [
RenderedImageModality.RGB,
]
if self.model_type == ModelType.CONCATNERF:
self.model = concat_nerf_config
self.modalities_to_save = [
RenderedImageModality.RGB,
]


if __name__ == "__main__":
Expand All @@ -52,7 +62,6 @@ def __post_init__(self) -> None:
parameters.model.max_num_iterations = parameters.max_num_iterations
parameters.model.data = parameters.data
parameters.model.viewer.quit_on_train_completion = True

main(parameters.model)

pipeline, config = Renderer.extract_pipeline(
Expand Down
6 changes: 5 additions & 1 deletion thermo_nerf/thermal_nerf/thermal_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)
from nerfstudio.utils.io import load_from_json

from thermo_nerf.rendered_image_modalities import RenderedImageModality


@dataclass
class ThermalDataParserConfig(NerfstudioDataParserConfig):
Expand Down Expand Up @@ -333,7 +335,9 @@ def _generate_dataparser_outputs(self, split: str = "train") -> DataparserOutput
dataparser_scale=scale_factor,
dataparser_transform=transform_matrix,
metadata={
"thermal": thermal_filenames if len(thermal_filenames) > 0 else None,
RenderedImageModality.THERMAL.value: (
thermal_filenames if len(thermal_filenames) > 0 else None
),
},
)
return dataparser_outputs
10 changes: 6 additions & 4 deletions thermo_nerf/thermal_nerf/thermal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.datasets.base_dataset import InputDataset

from thermo_nerf.rendered_image_modalities import RenderedImageModality


class ThermalDataset(InputDataset):
"""
Dataset class that returns thermal images.
"""

exclude_batch_keys_from_device = InputDataset.exclude_batch_keys_from_device + [
"thermal"
RenderedImageModality.THERMAL.value
]

def __init__(
Expand All @@ -24,8 +26,8 @@ def __init__(
kernel_size: int = 3,
) -> None:
super().__init__(dataparser_outputs, scale_factor)
assert "thermal" in dataparser_outputs.metadata.keys()
self.thermal_filenames = self.metadata["thermal"]
assert RenderedImageModality.THERMAL.value in dataparser_outputs.metadata.keys()
self.thermal_filenames = self.metadata[RenderedImageModality.THERMAL.value]
self.kernel_size = kernel_size

def get_metadata(self, data: Dict) -> Dict[str, torch.Tensor]:
Expand All @@ -42,7 +44,7 @@ def get_metadata(self, data: Dict) -> Dict[str, torch.Tensor]:
scale_factor=self.scale_factor,
)

return {"thermal": thermal_data}
return {RenderedImageModality.THERMAL.value: thermal_data}

@staticmethod
def get_thermal_tensors_from_path(
Expand Down
Loading

0 comments on commit f7ac977

Please sign in to comment.