From 24199eceaf282896fb5dd9af6236f689e2f91d0c Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Sun, 24 Dec 2023 07:37:07 +0000 Subject: [PATCH] Add simple util for CUDA timings --- megatron/devutil.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 megatron/devutil.py diff --git a/megatron/devutil.py b/megatron/devutil.py new file mode 100644 index 000000000..7563d7dcf --- /dev/null +++ b/megatron/devutil.py @@ -0,0 +1,51 @@ +import torch.cuda + + +class Metric: + """ + Dumb utility to collect and report average wall-time metrics. + """ + + def __init__(self, label): + self.label = label + self.measurements = [] + + def collect(self, measurement): + self.measurements.append(measurement) + + def get_measurements(self): + return self.measurements[:] + + def report(self): + print( + self.label, + torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0), + ) + + +def monitor_method_cuda_wall_times(metric, obj, methodname): + """ + Measure timings for a method on an object or class. + + For instance: + + >>> metric = Metric('!LNORM') + >>> monitor_method_wall_times(metric, LayerNorm, 'forward') + """ + oldmeth = getattr(obj, methodname) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + def newmeth(*args, **kw): + start_event.record() + try: + return oldmeth(*args, **kw) + finally: + end_event.record() + torch.cuda.synchronize() + elapsed = start_event.elapsed_time(end_event) + metric.collect(elapsed) + metric.report() + + setattr(obj, methodname, newmeth)