Skip to content

Commit

Permalink
Add NVTX markers for profiling (#110)
Browse files Browse the repository at this point in the history
Signed-off-by: Towaki Takikawa <[email protected]>

---------

Signed-off-by: Towaki Takikawa <[email protected]>
  • Loading branch information
tovacinni authored Jan 31, 2023
1 parent 750a632 commit 170b84e
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 29 deletions.
4 changes: 3 additions & 1 deletion app/nerf/main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def parse_args():
parser = argparse.ArgumentParser(description='A script for training simple NeRF variants.')
parser.add_argument('--config', type=str,
help='Path to config file to replace defaults.')
parser.add_argument('--profile', action='store_true',
help='Enable NVTX profiling')

log_group = parser.add_argument_group('logging')
log_group.add_argument('--exp-name', type=str,
Expand Down Expand Up @@ -489,4 +491,4 @@ def is_interactive() -> bool:
if args.valid_only:
trainer.validate()
else:
trainer.train() # Run in headless mode
trainer.train()
3 changes: 2 additions & 1 deletion app/nglod/main_nglod.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def parse_args():
'Implicit 3D Shapes".')
parser.add_argument('--config', type=str,
help='Path to config file to replace defaults.')

parser.add_argument('--profile', action='store_true',
help='Enable NVTX profiling')
log_group = parser.add_argument_group('logging')
log_group.add_argument('--exp-name', type=str,
help='Experiment name, unique id for trainers, logs.')
Expand Down
3 changes: 2 additions & 1 deletion examples/latent_nerf/main_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@
valid_every=-1,
save_as_new=False,
model_format='full',
mip=0
mip=0,
profile=False
),
render_tb_every=100,
save_every=100,
Expand Down
1 change: 1 addition & 0 deletions wisp/datasets/transforms/ray_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SampleRays:
def __init__(self, num_samples):
self.num_samples = num_samples

@torch.cuda.nvtx.range("SampleRays")
def __call__(self, inputs):
ray_idx = torch.randint(0, inputs['imgs'].shape[0], [self.num_samples],
device=inputs['imgs'].device)
Expand Down
4 changes: 3 additions & 1 deletion wisp/models/nefs/base_nef.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
from abc import abstractmethod
from typing import Dict, Any
import torch
from wisp.core import WispModule


Expand Down Expand Up @@ -135,7 +136,7 @@ def forward(self, channels=None, **kwargs):

return_dict = {}
for fn in self._forward_functions:

torch.cuda.nvtx.range_push(f"{fn.__name__}")
output_channels = self._forward_functions[fn]
# Filter the set of channels supported by the current forward function
supported_channels = output_channels & requested_channels
Expand All @@ -161,6 +162,7 @@ def forward(self, channels=None, **kwargs):

for channel in supported_channels:
return_dict[channel] = output[channel]
torch.cuda.nvtx.range_pop()

if isinstance(channels, str):
if channels in return_dict:
Expand Down
2 changes: 0 additions & 2 deletions wisp/offline_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import torch.nn.functional as F
from wisp.core import RenderBuffer, Rays
from wisp.utils import PsDebugger, PerfTimer
from wisp.ops.shaders import matcap_shader, pointlight_shadow_shader
from wisp.ops.differential import finitediff_gradient
from wisp.ops.geometric import normalized_grid, normalized_slice
Expand Down Expand Up @@ -174,7 +173,6 @@ def render(self, pipeline, rays, lod_idx=None):
(wisp.core.RenderBuffer): The renderer image.
"""
# Differentiable Renderer
timer = PerfTimer(activate=self.perf)
if self.perf:
_time = time.time()

Expand Down
6 changes: 4 additions & 2 deletions wisp/tracers/base_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from abc import abstractmethod, ABC
from typing import Dict, Any
import inspect
import torch
import torch.nn as nn
from wisp.core import Rays
from wisp.core import WispModule


class BaseTracer(WispModule, ABC):
"""Base class for all tracers within Wisp.
Tracers drive the mapping process which takes an input "Neural Field", and outputs a RenderBuffer of pixels.
Expand Down Expand Up @@ -150,7 +150,9 @@ def forward(self, nef, rays: Rays, channels=None, **kwargs):
default_arg = getattr(self, _arg, None)
if default_arg is not None:
input_args[_arg] = default_arg
return self.trace(nef, rays, requested_channels, requested_extra_channels, **input_args)
with torch.cuda.nvtx.range("Tracer.trace"):
rb = self.trace(nef, rays, requested_channels, requested_extra_channels, **input_args)
return rb

def public_properties(self) -> Dict[str, Any]:
""" Wisp modules expose their public properties in a dictionary.
Expand Down
6 changes: 0 additions & 6 deletions wisp/tracers/packed_sdf_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.nn as nn
import kaolin.render.spc as spc_render
from wisp.core import RenderBuffer
from wisp.utils import PsDebugger, PerfTimer
from wisp.ops.differential import finitediff_gradient
from wisp.ops.geometric import find_depth_bound
from wisp.tracers import BaseTracer
Expand Down Expand Up @@ -77,7 +76,6 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None, num_steps=64,
if lod_idx is None:
lod_idx = nef.grid.num_lods - 1

timer = PerfTimer(activate=False)
invres = 1.0

# Trace SPC
Expand All @@ -104,15 +102,13 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None, num_steps=64,

curr_pidx = pidx[first_hit].long()

timer.check("initial")
# Doing things with where is not super efficient, but we have to make do with what we have...
with torch.no_grad():

# Calculate SDF for current set of query points
dist[mask] = nef(coords=x[mask], lod_idx=lod_idx, pidx=curr_pidx[mask], channels="sdf") * invres * step_size
dist[~mask] = 20
dist_prev = dist.clone()
timer.check("first")

for i in range(num_steps):
# Two-stage Ray Marching
Expand Down Expand Up @@ -142,7 +138,6 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None, num_steps=64,
if not mask.any():
break
dist[mask] = nef(coords=x[mask], lod_idx=lod_idx, pidx=curr_pidx[mask], channels="sdf") * invres * step_size
timer.check("step done")

x_buffer = torch.zeros_like(rays.origins)
depth_buffer = torch.zeros_like(rays.origins[...,0:1])
Expand All @@ -169,6 +164,5 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None, num_steps=64,
rgb_buffer[..., :3] = (normal_buffer + 1.0) / 2.0

alpha_buffer[hit_buffer] = 1.0
timer.check("populate buffers")
return RenderBuffer(xyz=x_buffer, depth=depth_buffer, hit=hit_buffer, normal=normal_buffer,
rgb=rgb_buffer, alpha=alpha_buffer, **extra_outputs)
6 changes: 0 additions & 6 deletions wisp/tracers/packed_spc_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import kaolin.render.spc as spc_render
from wisp.utils import PerfTimer
from wisp.tracers import BaseTracer
from wisp.core import RenderBuffer

Expand Down Expand Up @@ -48,7 +47,6 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None):
Returns:
(wisp.RenderBuffer): A dataclass which holds the output buffers from the render.
"""
timer = PerfTimer(activate=False, show_memory=False)
N = rays.origins.shape[0]

# By default, SPCRFTracer will use the highest level of detail for the ray sampling.
Expand All @@ -60,7 +58,6 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None):
pidx = raytrace_results.pidx
depths = raytrace_results.depth

timer.check("Raytrace")

# Get the indices of the ray tensor which correspond to hits
first_hits_mask = spc_render.mark_pack_boundaries(ridx)
Expand All @@ -71,7 +68,6 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None):
# Get the color for each ray
color = nef(ridx_hit=first_hits_point.long(), channels="rgb")

timer.check("RGBA")
del ridx, pidx, rays

# Fetch colors and depth for closest hits
Expand All @@ -91,6 +87,4 @@ def trace(self, nef, rays, channels, extra_channels, lod_idx=None):
rgb[first_hits_ray.long(), :3] = color
out_alpha[first_hits_ray.long()] = alpha

timer.check("Composit")

return RenderBuffer(depth=depth, hit=hit, rgb=rgb, alpha=out_alpha)
7 changes: 4 additions & 3 deletions wisp/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,10 @@ def train(self):
"""
Override this if some very specific training procedure is needed.
"""
self.is_optimization_running = True
while self.is_optimization_running:
self.iterate()
with torch.autograd.profiler.emit_nvtx(enabled=self.extra_args["profile"]):
self.is_optimization_running = True
while self.is_optimization_running:
self.iterate()

#######################
# Training Events
Expand Down
8 changes: 5 additions & 3 deletions wisp/trainers/multiview_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def init_log_dict(self):
super().init_log_dict()
self.log_dict['rgb_loss'] = 0.0

@torch.cuda.nvtx.range("MultiviewTrainer.step")
def step(self, data):
"""Implement the optimization over image-space loss.
"""
Expand Down Expand Up @@ -74,9 +75,10 @@ def step(self, data):

self.log_dict['total_loss'] += loss.item()

self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
with torch.cuda.nvtx.range("MultiviewTrainer.backward"):
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()

def log_cli(self):
log_text = 'EPOCH {}/{}'.format(self.epoch, self.max_epochs)
Expand Down
7 changes: 4 additions & 3 deletions wisp/trainers/sdf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from wisp.trainers import BaseTrainer, log_metric_to_wandb, log_images_to_wandb
from torch.utils.data import DataLoader
from wisp.utils import PerfTimer
from wisp.datasets import SDFDataset
from wisp.ops.sdf import compute_sdf_iou
from wisp.ops.image import hwc_to_chw
Expand All @@ -35,6 +34,7 @@ def init_log_dict(self):
self.log_dict['rgb_loss'] = 0
self.log_dict['l2_loss'] = 0

@torch.cuda.nvtx.range("SDFTrainer.step")
def step(self, data):
"""Implement training from ground truth TSDF.
"""
Expand Down Expand Up @@ -88,8 +88,9 @@ def step(self, data):
self.log_dict['total_loss'] += loss.item()

# Backpropagate
loss.backward()
self.optimizer.step()
with torch.cuda.nvtx.range("SDFTrainer.backward"):
loss.backward()
self.optimizer.step()

def log_cli(self):
"""Override logging.
Expand Down

0 comments on commit 170b84e

Please sign in to comment.