Skip to content

Commit

Permalink
Add simple util for CUDA timings
Browse files Browse the repository at this point in the history
  • Loading branch information
yang committed Jan 13, 2024
1 parent 90f70ff commit 24199ec
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions megatron/devutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch.cuda


class Metric:
"""
Dumb utility to collect and report average wall-time metrics.
"""

def __init__(self, label):
self.label = label
self.measurements = []

def collect(self, measurement):
self.measurements.append(measurement)

def get_measurements(self):
return self.measurements[:]

def report(self):
print(
self.label,
torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0),
)


def monitor_method_cuda_wall_times(metric, obj, methodname):
"""
Measure timings for a method on an object or class.
For instance:
>>> metric = Metric('!LNORM')
>>> monitor_method_wall_times(metric, LayerNorm, 'forward')
"""
oldmeth = getattr(obj, methodname)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

def newmeth(*args, **kw):
start_event.record()
try:
return oldmeth(*args, **kw)
finally:
end_event.record()
torch.cuda.synchronize()
elapsed = start_event.elapsed_time(end_event)
metric.collect(elapsed)
metric.report()

setattr(obj, methodname, newmeth)

0 comments on commit 24199ec

Please sign in to comment.