From 6e6cedd438348d03f390ba0d9b8afddaea4163c1 Mon Sep 17 00:00:00 2001 From: trinh31201 <141480383+trinh31201@users.noreply.github.com> Date: Wed, 6 Mar 2024 06:00:39 +0900 Subject: [PATCH] Handle alpha images properly for instant-ngp and tensorf models (#2979) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Đinh Trinh --- nerfstudio/models/instant_ngp.py | 2 +- nerfstudio/models/tensorf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nerfstudio/models/instant_ngp.py b/nerfstudio/models/instant_ngp.py index 1b0304d5bd9..2503e7e228c 100644 --- a/nerfstudio/models/instant_ngp.py +++ b/nerfstudio/models/instant_ngp.py @@ -223,7 +223,7 @@ def get_metrics_dict(self, outputs, batch): return metrics_dict def get_loss_dict(self, outputs, batch, metrics_dict=None): - image = batch["image"][..., :3].to(self.device) + image = batch["image"].to(self.device) pred_rgb, image = self.renderer_rgb.blend_background_for_loss_computation( pred_image=outputs["rgb"], pred_accumulation=outputs["accumulation"], diff --git a/nerfstudio/models/tensorf.py b/nerfstudio/models/tensorf.py index d19b0e2fb32..6c149b4b615 100644 --- a/nerfstudio/models/tensorf.py +++ b/nerfstudio/models/tensorf.py @@ -308,7 +308,7 @@ def get_outputs(self, ray_bundle: RayBundle): def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]: # Scaling metrics by coefficients to create the losses. device = outputs["rgb"].device - image = batch["image"][..., :3].to(device) + image = batch["image"].to(device) pred_image, image = self.renderer_rgb.blend_background_for_loss_computation( pred_image=outputs["rgb"], pred_accumulation=outputs["accumulation"],