From f748738fcd59272bb116dc89e682bb7c6d867856 Mon Sep 17 00:00:00 2001 From: Towaki Takikawa Date: Tue, 31 Jan 2023 10:12:16 -0500 Subject: [PATCH] Remove unneeded dependencies, make some others optional Signed-off-by: Towaki Takikawa --- INSTALL.md | 43 +++---- requirements.txt | 6 +- tools/linux/Dockerfile | 1 + wisp/datasets/formats/nerf_standard.py | 8 +- wisp/datasets/formats/rtmv.py | 140 ++++++++++++++++++++++- wisp/ops/image/io.py | 150 +++++-------------------- wisp/ops/image/metrics.py | 6 +- wisp/trainers/multiview_trainer.py | 49 +++++--- wisp/trainers/sdf_trainer.py | 3 + 9 files changed, 232 insertions(+), 174 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index 959e834..85d4e64 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -4,22 +4,6 @@ NVIDIA Kaolin Wisp can be installed either manually or using Docker. ## Manual Installation -### Prerequisite - -Install OpenEXR on Ubuntu: - -``` -sudo apt-get update -sudo apt-get install libopenexr-dev -``` - -Install OpenEXR on Windows: - -``` -pip install pipwin -pipwin install openexr -``` - ### Quick Start Full installation with interactive visualizer, for torch 1.12.1, cuda 11.3 and kaolin 0.12.0: ``` @@ -48,11 +32,32 @@ conda activate wisp pip install --upgrade pip ``` -#### 2. Install PyTorch +#### 2. (Optional) Install OpenEXR + +Some features of our library, like support for the [RTMV dataset](http://www.cs.umd.edu/~mmeshry/projects/rtmv/) +and logging of multi-layer EXR files (which you can visualize with +awesome tools like [tev](https://github.com/Tom94/tev)) will only work if you install OpenEXR. +These steps are optional, and these features will only be enabled if you follow these steps. + +Install OpenEXR on Ubuntu: + +``` +sudo apt-get update +sudo apt-get install libopenexr-dev +``` + +Install OpenEXR on Windows: + +``` +pip install pipwin +pipwin install openexr +``` + +#### 3. Install PyTorch You should first install PyTorch by following the [official instructions](https://pytorch.org/). The code has been tested with `1.9.1` to `1.12.0` on Ubuntu 20.04. -#### 3. Install Kaolin +#### 4. Install Kaolin kaolin can be installed with pip (use the correct torch + cuda version): ``` @@ -68,7 +73,7 @@ See the [Kaolin Installation Doc](https://kaolin.readthedocs.io/en/latest/notes/ _The minimum required version of Kaolin is `0.12.0`._ -#### 4. Installing Wisp +#### 5. Installing Wisp Install the rest of the dependencies from [requirements](requirements.txt). diff --git a/requirements.txt b/requirements.txt index d63438f..116188c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,18 @@ mkl tensorboard matplotlib -lpips git+https://github.com/tinyobjloader/tinyobjloader.git@v2.0.0rc8#subdirectory=python -pyexr pybind11 pyyaml -trimesh>=3.0 pandas tqdm Pillow +numpy~=1.23.0 scipy>=1.7.2 scikit-image -scikit-learn six>=1.12.0 moviepy opencv-python -plyfile protobuf>=3.20.0 polyscope more_itertools diff --git a/tools/linux/Dockerfile b/tools/linux/Dockerfile index b7ce46d..ffc1e5d 100644 --- a/tools/linux/Dockerfile +++ b/tools/linux/Dockerfile @@ -17,6 +17,7 @@ RUN apt-get -y update \ && rm -rf /var/lib/apt/lists/* RUN pip install -r requirements.txt +RUN pip install pyexr RUN python setup.py develop RUN if [ -z "${INSTALL_RENDERER}" ]; then \ diff --git a/wisp/datasets/formats/nerf_standard.py b/wisp/datasets/formats/nerf_standard.py index 8f7fb43..8e69034 100644 --- a/wisp/datasets/formats/nerf_standard.py +++ b/wisp/datasets/formats/nerf_standard.py @@ -10,11 +10,8 @@ import glob import time import cv2 -import skimage -import imageio import json from tqdm import tqdm -import skimage.metrics import logging as log import numpy as np import torch @@ -22,7 +19,7 @@ from kaolin.render.camera import Camera, blender_coords from wisp.core import Rays from wisp.ops.raygen import generate_pinhole_rays, generate_ortho_rays, generate_centered_pixel_coords -from wisp.ops.image import resize_mip +from wisp.ops.image import resize_mip, load_rgb """ A module for loading data files in the standard NeRF format, including extensions to the format supported by Instant Neural Graphics Primitives. @@ -51,8 +48,7 @@ def _load_standard_imgs(frame, root, mip=None): # For some reason instant-ngp allows missing images that exist in the transform but not in the data. # Handle this... also handles the above case well too. if os.path.exists(fpath): - img = imageio.imread(fpath) - img = skimage.img_as_float32(img) + img = load_rgb(fpath) if mip is not None: img = resize_mip(img, mip, interpolation=cv2.INTER_AREA) return dict(basename=basename, diff --git a/wisp/datasets/formats/rtmv.py b/wisp/datasets/formats/rtmv.py index 2a96613..45ddbcf 100644 --- a/wisp/datasets/formats/rtmv.py +++ b/wisp/datasets/formats/rtmv.py @@ -17,15 +17,143 @@ from torch.multiprocessing import Pool, cpu_count from kaolin.render.camera import Camera, blender_coords from wisp.core import Rays -from wisp.ops.image import load_exr +import wisp.ops.image as img_ops from wisp.ops.raygen import generate_pinhole_rays, generate_ortho_rays, generate_centered_pixel_coords from wisp.ops.pointcloud import create_pointcloud_from_images, normalize_pointcloud - +import cv2 """ A module for loading data files in the RTMV format. See: http://www.cs.umd.edu/~mmeshry/projects/rtmv/ """ +def load_rtmv_images(root, basename, use_depth=False, mip=None, srgb=False, bg_color='white'): + """Loads a set of RTMV images by path and basename. + + Args: + root (str): Path to the root of the dataset. + basename (str): Basename of the RTMV image set to load. + use_depth (bool): if True, loads the depth data + by default, this assumes the depth is stored in the "depth" buffer + mip (int): if not None, then each image will be resized by 2^mip + srgb (bool): if True, convert to SRGB + + Returns: + (dictionary of torch.Tensors) + + Keys: + image : torch.FloatTensor of size [H,W,3] + alpha : torch.FloatTensor of size [H,W,1] + depth : torch.FloatTensor of size [H,W,1] + """ + # TODO(ttakikawa): There is a lot that this function does... break this up + from wisp.ops.image import resize_mip, linear_to_srgb + + image_exts = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'] + exr_exts = ['.exr', '.EXR'] + npz_exts = ['.npz', '.NPZ'] + + img = None + depth = None + + # Try to load RGB first + found_image_path = None + for ext in image_exts + exr_exts: + image_path = os.path.join(root, basename + ext) + if os.path.exists(image_path): + found_image_path = image_path + break + if found_image_path is None: + raise Exception("No images found! Check if your dataset path contains an actual RTMV dataset.") + + img_ext = os.path.splitext(found_image_path)[1] + if img_ext in image_exts: + img = img_ops.load_rgb(found_image_path) + elif img_ext in exr_exts: + try: + import pyexr + except: + raise Exception( + "The RTMV dataset provided uses EXR, but module pyexr is not available. " + "To install, run `pip install pyexr`. " + "You will likely also need `libopenexr`, which through apt you can install with " + "`apt-get install libopenexr-dev` and on Windows you can install with " + "`pipwin install openexr`.") + f = pyexr.open(found_image_path) + img = f.get("default") + else: + raise Exception(f"Invalid image extension for the image path {found_image_path}") + + found_depth_path = None + for ext in [".depth.npz", ".depth.exr"] + exr_exts: + depth_path = os.path.join(root, basename + ext) + if os.path.exists(depth_path): + found_depth_path = depth_path + break + if found_depth_path is None: + raise Exception("No depth found! Check if your dataset path contains an actual RTMV dataset.") + + depth_ext = os.path.splitext(found_depth_path)[1] + # Load depth + if depth_ext == ".npz": + depth = np.load(found_depth_path)['arr_0'][..., 0] + elif depth_ext == ".exr": + try: + import pyexr + except: + raise Exception( + "The RTMV dataset provided uses EXR, but module pyexr is not available. " + "To install, run `pip install pyexr`. " + "You will likely also need `libopenexr`, which through apt you can install with " + "`apt-get install libopenexr-dev` and on Windows you can install with " + "`pipwin install openexr`.") + + f = pyexr.open(found_depth_path) + + components = os.path.basename(found_depth_path).split('.') + if len(components) > 2 and components[-1] == "exr" and components[-2] == "depth": + depth = f.get('default')[:, :, 0] + else: + if len(f.channel_map['depth']) > 0: + depth = f.get("depth") + else: + raise Exception("Depth channel not found in the EXR file provided!") + else: + raise Exception(f"Invalid depth extension for the depth path {found_depth_path}") + + alpha = img[..., 3:4] + + if bg_color == 'black': + img[..., :3] -= (1 - alpha) + img = np.clip(img, 0.0, 1.0) + else: + img[..., :3] *= alpha + img[..., :3] += (1 - alpha) + img = np.clip(img, 0.0, 1.0) + + if mip is not None: + # TODO(ttakikawa): resize_mip causes the mask to be squuezed... why? + img = resize_mip(img, mip, interpolation=cv2.INTER_AREA) + if use_depth: + depth = resize_mip(depth, mip, interpolation=cv2.INTER_NEAREST) + # mask_depth = resize_mip(mask_depth[...,None].astype(np.float), mip, interpolation=cv2.INTER_NEAREST) + + img = torch.from_numpy(img) + if use_depth: + depth = torch.from_numpy(depth) + + if use_depth: + mask_depth = torch.logical_and(depth > -1000, depth < 1000) + depth[~mask_depth] = -1.0 + depth = depth[:, :, np.newaxis] + + if srgb: + img = linear_to_srgb(img) + + alpha = mask_depth + + return img, alpha, depth + + def rescale_rtmv_intrinsics(camera, target_size, original_width, original_height): """ Rescale the intrinsics. """ @@ -87,7 +215,7 @@ def _parallel_load_rtmv_data(args): """ torch.set_num_threads(1) with torch.no_grad(): - image, alpha, depth = load_exr(**args['exr_args']) + image, alpha, depth = load_rtmv_images(**args['exr_args']) camera = load_rtmv_camera(args['camera_args']['path']) transformed_camera = transform_rtmv_camera(copy.deepcopy(camera), mip=args['camera_args']['mip']) return dict( @@ -149,7 +277,8 @@ def load_rtmv_data(root, split, mip=None, normalize=True, return_pointcloud=Fals dict( task_basename=basename, exr_args=dict( - path=os.path.join(root, basename + '.exr'), + root=root, + basename=basename, use_depth=True, mip=mip, srgb=True, @@ -170,8 +299,7 @@ def load_rtmv_data(root, split, mip=None, normalize=True, return_pointcloud=Fals for img_index, json_file in tqdm(enumerate(json_files), desc='loading data'): with torch.no_grad(): basename = os.path.splitext(os.path.basename(json_file))[0] - exr_path = os.path.join(root, basename + ".exr") - image, alpha, depth = load_exr(exr_path, use_depth=True, mip=mip, srgb=True, bg_color=bg_color) + image, alpha, depth = load_rtmv_images(root, basename, use_depth=True, mip=mip, srgb=True, bg_color=bg_color) json_path = os.path.join(root, basename + ".json") camera = load_rtmv_camera(path=json_path) transformed_camera = transform_rtmv_camera(copy.deepcopy(camera), mip=mip) diff --git a/wisp/ops/image/io.py b/wisp/ops/image/io.py index bb11391..30064f2 100644 --- a/wisp/ops/image/io.py +++ b/wisp/ops/image/io.py @@ -8,13 +8,9 @@ import os import glob -import pyexr -import cv2 -import skimage -import imageio -from PIL import Image import numpy as np import torch +import torchvision """ A module for reading / writing various image formats. """ @@ -31,13 +27,20 @@ def write_exr(path, data): Returns: (void): Writes to path. """ + try: + import pyexr + except: + raise Exception( + "Module pyexr is not available. To install, run `pip install pyexr`. " + "You will likely also need `libopenexr`, which through apt you can install with " + "`apt-get install libopenexr-dev` and on Windows you can install with " + "`pipwin install openexr`.") pyexr.write(path, data, channel_names={'normal': ['X', 'Y', 'Z'], 'x': ['X', 'Y', 'Z'], 'view': ['X', 'Y', 'Z']}, precision=pyexr.HALF) - def write_png(path, data): """Writes an PNG image to some path. @@ -48,8 +51,7 @@ def write_png(path, data): Returns: (void): Writes to path. """ - Image.fromarray(data).save(path) - + torchvision.io.write_png(hwc_to_chw(data), path) def glob_imgs(path, exts=['*.png', '*.PNG', '*.jpg', '*.jpeg', '*.JPG', '*.JPEG']): """Utility to find images in some path. @@ -66,137 +68,39 @@ def glob_imgs(path, exts=['*.png', '*.PNG', '*.jpg', '*.jpeg', '*.JPG', '*.JPEG' imgs.extend(glob.glob(os.path.join(path, ext))) return imgs - -def load_rgb(path): +def load_rgb(path, normalize=True): """Loads an image. - TODO(ttakikawa): Currently ignores the alpha channel. - Args: path (str): Path to the image. + noramlize (bool): If True, will return [0,1] floating point values. Otherwise returns [0,255] ints. Returns: - (np.array): Image as an array. + (np.array): Image as an array of shape [H,W,C] """ - img = imageio.imread(path) - img = skimage.img_as_float32(img) - img = img[:, :, :3] - return img - - -def load_mask(path): - """Loads an alpha mask. - - Args: - path (str): Path to the image. + img = torchvision.io.read_image(path) + if normalize: + img = img.float() / 255.0 + return np.array(chw_to_hwc(img)) - Returns: - (np.array): Image as an array. - """ - alpha = imageio.imread(path, as_gray=True) - alpha = skimage.img_as_float32(alpha) - object_mask = alpha > 127.5 - object_mask = object_mask.transpose(1, 0) - - return object_mask - - -def load_exr(path, use_depth=False, mip=None, srgb=False, bg_color='white', - loader_type='pyexr'): - """Loads a EXR by path. +def hwc_to_chw(img): + """Converts [H,W,C] to [C,H,W] for TensorBoard output. Args: - path (str): path to the .exr file - use_depth (bool): if True, loads the depth data - by default, this assumes the depth is stored in the "depth" buffer - mip (int): if not None, then each image will be resized by 2^mip - srgb (bool): if True, convert to SRGB - loader_type (str): options [cv2, pyexr, imageio]. - TODO(ttakikawa): Not sure quite yet what options should be supported here + img (torch.Tensor): [H,W,C] image. Returns: - (dictionary of torch.Tensors) - - Keys: - image : torch.FloatTensor of size [H,W,3] - alpha : torch.FloatTensor of size [H,W,1] - depth : torch.FloatTensor of size [H,W,1] - ray_o : torch.FloatTensor of size [H,W,3] - ray_d : torch.FloatTensor of size [H,W,3] + (torch.Tensor): [C,H,W] image. """ - # TODO(ttakikawa): There is a lot that this function does... break this up - from wisp.ops.image import resize_mip, linear_to_srgb - - # Load RGB and Depth - if loader_type == 'cv2': - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + return img.permute(2, 0, 1) - if use_depth: - depth = cv2.imread(path.replace(".exr", ".depth.exr"), cv2.IMREAD_UNCHANGED)[:, :, 0] - - elif loader_type == 'pyexr': - f = pyexr.open(path) - img = f.get("default") - - if use_depth: - if len(f.channel_map['depth']) > 0: - depth = f.get("depth") - else: - f = pyexr.open(path.replace(".exr", ".depth.exr")) - depth = f.get('default')[:, :, 0] - - elif loader_type == 'imageio': - img = imageio.imread(path) - if use_depth: - depth = imageio.imread(path.replace(".exr", ".depth.exr"))[:, :, 0] - else: - raise ValueError(f'Invalid loader_type: {loader_type}') - - alpha = img[..., 3:4] - - if bg_color == 'black': - img[..., :3] -= (1 - alpha) - img = np.clip(img, 0.0, 1.0) - else: - img[..., :3] *= alpha - img[..., :3] += (1 - alpha) - img = np.clip(img, 0.0, 1.0) - - if mip is not None: - # TODO(ttakikawa): resize_mip causes the mask to be squuezed... why? - img = resize_mip(img, mip, interpolation=cv2.INTER_AREA) - if use_depth: - depth = resize_mip(depth, mip, interpolation=cv2.INTER_NEAREST) - # mask_depth = resize_mip(mask_depth[...,None].astype(np.float), mip, interpolation=cv2.INTER_NEAREST) - - img = torch.from_numpy(img) - if use_depth: - depth = torch.from_numpy(depth) - - if use_depth: - mask_depth = torch.logical_and(depth > -1000, depth < 1000) - depth[~mask_depth] = -1.0 - depth = depth[:, :, np.newaxis] - - if loader_type == 'cv2' or loader_type == 'imageio': - # BGR to RGB - img[..., :3] = img[..., :3][..., ::-1] - - if srgb: - img = linear_to_srgb(img) - - alpha = mask_depth - - return img, alpha, depth - - -def hwc_to_chw(img): - """Converts [H,W,C] to [C,H,W] for TensorBoard output. +def chw_to_hwc(img): + """Converts [C,H,W] to [H,W,C]. Args: - img (np.array): [H,W,C] image. + img (torch.Tensor): [C,H,W] image. Returns: - (np.array): [C,H,W] image. + (torch.Tensor): [H,W,C] image. """ - return np.array(img).transpose(2, 0, 1) + return img.permute(1, 2, 0) diff --git a/wisp/ops/image/metrics.py b/wisp/ops/image/metrics.py index 879b86a..0f971c7 100644 --- a/wisp/ops/image/metrics.py +++ b/wisp/ops/image/metrics.py @@ -11,7 +11,6 @@ import skimage.metrics import numpy as np import torch -from lpips import LPIPS """ A module for image based metrics """ @@ -50,6 +49,11 @@ def lpips(rgb, gts, lpips_model=None): Returns: (float): The LPIPS score """ + try: + from lpips import LPIPS + except: + raise Exception( + "Module lpips not available. To install, run `pip install lpips`") assert (rgb.max() <= 1.05 and rgb.min() >= -0.05) assert (gts.max() <= 1.05 and gts.min() >= -0.05) assert (rgb.shape[-1] == 3) diff --git a/wisp/trainers/multiview_trainer.py b/wisp/trainers/multiview_trainer.py index 5b4d6a2..d6f9a03 100644 --- a/wisp/trainers/multiview_trainer.py +++ b/wisp/trainers/multiview_trainer.py @@ -13,7 +13,7 @@ import random import pandas as pd import torch -from lpips import LPIPS +from torch.utils.tensorboard import SummaryWriter from wisp.trainers import BaseTrainer, log_metric_to_wandb, log_images_to_wandb from wisp.ops.image import write_png, write_exr from wisp.ops.image.metrics import psnr, lpips, ssim @@ -49,7 +49,6 @@ def step(self, data): img_gts = data['imgs'].to(self.device).squeeze(0) self.optimizer.zero_grad() - loss = 0 if self.extra_args["random_lod"]: @@ -86,11 +85,10 @@ def log_cli(self): log.info(log_text) - def evaluate_metrics(self, rays, imgs, lod_idx, name=None): + def evaluate_metrics(self, rays, imgs, lod_idx, name=None, lpips_model=None): ray_os = list(rays.origins) ray_ds = list(rays.dirs) - lpips_model = LPIPS(net='vgg').cuda() psnr_total = 0.0 lpips_total = 0.0 @@ -106,7 +104,8 @@ def evaluate_metrics(self, rays, imgs, lod_idx, name=None): gts = img.cuda() psnr_total += psnr(rb.rgb[...,:3], gts[...,:3]) - lpips_total += lpips(rb.rgb[...,:3], gts[...,:3], lpips_model) + if lpips_model: + lpips_total += lpips(rb.rgb[...,:3], gts[...,:3], lpips_model) ssim_total += ssim(rb.rgb[...,:3], gts[...,:3]) out_rb = RenderBuffer(rgb=rb.rgb, depth=rb.depth, alpha=rb.alpha, @@ -117,20 +116,32 @@ def evaluate_metrics(self, rays, imgs, lod_idx, name=None): if name is not None: out_name += "-" + name - write_exr(os.path.join(self.valid_log_dir, out_name + ".exr"), exrdict) - write_png(os.path.join(self.valid_log_dir, out_name + ".png"), rb.cpu().image().byte().rgb.numpy()) + try: + write_exr(os.path.join(self.valid_log_dir, out_name + ".exr"), exrdict) + except: + if hasattr(self, "exr_exception"): + pass + else: + self.exr_exception = True + log.info("Skipping EXR logging since pyexr is not found.") + write_png(os.path.join(self.valid_log_dir, out_name + ".png"), rb.cpu().image().byte().rgb) psnr_total /= len(imgs) lpips_total /= len(imgs) ssim_total /= len(imgs) - + + metrics_dict = {"psnr": psnr_total, "ssim": ssim_total} + log_text = 'EPOCH {}/{}'.format(self.epoch, self.max_epochs) log_text += ' | {}: {:.2f}'.format(f"{name} PSNR", psnr_total) log_text += ' | {}: {:.6f}'.format(f"{name} SSIM", ssim_total) - log_text += ' | {}: {:.6f}'.format(f"{name} LPIPS", lpips_total) + + if lpips_model: + log_text += ' | {}: {:.6f}'.format(f"{name} LPIPS", lpips_total) + metrics_dict["lpips"] = lpips_total log.info(log_text) - return {"psnr" : psnr_total, "lpips": lpips_total, "ssim": ssim_total} + return {"psnr" : psnr_total, "ssim": ssim_total} def render_final_view(self, num_angles, camera_distance): angles = np.pi * 0.1 * np.array(list(range(num_angles + 1))) @@ -196,12 +207,22 @@ def validate(self): os.makedirs(self.valid_log_dir) lods = list(range(self.pipeline.nef.grid.num_lods)) - evaluation_results = self.evaluate_metrics(data["rays"], imgs, lods[-1], f"lod{lods[-1]}") + try: + from lpips import LPIPS + lpips_model = LPIPS(net='vgg').cuda() + except: + lpips_model = None + if hasattr(self, "lpips_exception"): + pass + else: + self.lpips_exception = True + log.info("Skipping LPIPS since lpips is not found.") + evaluation_results = self.evaluate_metrics(data["rays"], imgs, lods[-1], + f"lod{lods[-1]}", lpips_model=lpips_model) record_dict.update(evaluation_results) if self.using_wandb: - log_metric_to_wandb("Validation/psnr", evaluation_results['psnr'], self.epoch) - log_metric_to_wandb("Validation/lpips", evaluation_results['lpips'], self.epoch) - log_metric_to_wandb("Validation/ssim", evaluation_results['ssim'], self.epoch) + for key in evaluation_results: + log_metric_to_wandb(f"Validation/{key}", evaluation_results[key], self.epoch) df = pd.DataFrame.from_records([record_dict]) df['lod'] = lods[-1] diff --git a/wisp/trainers/sdf_trainer.py b/wisp/trainers/sdf_trainer.py index ef7c929..1b6ee2d 100644 --- a/wisp/trainers/sdf_trainer.py +++ b/wisp/trainers/sdf_trainer.py @@ -109,6 +109,9 @@ def render_tb(self): out_x = self.renderer.sdf_slice(self.pipeline.nef.get_forward_function("sdf"), dim=0) out_y = self.renderer.sdf_slice(self.pipeline.nef.get_forward_function("sdf"), dim=1) out_z = self.renderer.sdf_slice(self.pipeline.nef.get_forward_function("sdf"), dim=2) + out_x = torch.FloatTensor(out_x) + out_y = torch.FloatTensor(out_y) + out_z = torch.FloatTensor(out_z) self.writer.add_image(f'Cross-section/X/{d}', hwc_to_chw(out_x), self.epoch) self.writer.add_image(f'Cross-section/Y/{d}', hwc_to_chw(out_y), self.epoch) self.writer.add_image(f'Cross-section/Z/{d}', hwc_to_chw(out_z), self.epoch)