Skip to content

Commit

Permalink
Add rotating buffer feature to quantize_bench
Browse files Browse the repository at this point in the history
Summary: On AMD, rotating buffer gives benchmarking results closer to E2E runs

Reviewed By: xw285cornell

Differential Revision: D59828276
  • Loading branch information
mxz297 authored and facebook-github-bot committed Jul 16, 2024
1 parent a7ef792 commit baf4045
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
26 changes: 23 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def benchmark(
k: int,
kernels: Optional[List[str]] = None,
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
) -> Dict[str, Any]:
# Create input tensors.
A = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
Expand All @@ -63,10 +64,17 @@ def benchmark(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(A, B, bench_quantize=True)
ms_runtime = quantize_op.benchmark(
A,
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals, bench_quantize=False
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
)

# Print out results for this op.
Expand Down Expand Up @@ -137,7 +145,13 @@ def main(args: Any):
for m, n, k in MNK:
print(f"Benchmarking M={m}, N={n}, K={k}.")
quantize_measurements = benchmark(
quantize_ops, m, n, k, kernels, args.bench_quantize
quantize_ops,
m,
n,
k,
kernels,
args.bench_quantize,
args.use_rotating_buffer_bench,
)
benchmark_results.append(quantize_measurements)
if args.export_csv:
Expand Down Expand Up @@ -189,6 +203,12 @@ def invoke_main() -> None:
parser.add_argument(
"--K", default=None, help="Comma separated list of K values to benchmark."
)
parser.add_argument(
"--use_rotating_buffer_bench",
default=False,
action="store_true",
help="If set, use rotating buffer to benchmark.",
)

args = parser.parse_args()
main(args)
51 changes: 48 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,57 @@ def quantize_and_compute(self, *args):
"""Function which quantizes inputs and performs main compute operation."""
pass

def benchmark(self, *args, bench_quantize: bool = False) -> float:
def bench_with_rotating_buffer(self, fn, args):
import copy
import pickle

# torch.cuda.get_device_properties does not have L2 cache size,
# so hard code an overapproximation of L2 cache size to ensure L2 cache flush
total_buffer_size = 16 * 1024 * 1024

# Use pickle to serialize model input to estimate total sizes of input
input_sizes = len(pickle.dumps(args))

# Make at least one copy of the inputs
copy_cnt = total_buffer_size // input_sizes
if copy_cnt == 0:
copy_cnt = 1

args_list = [args]
for _ in range(copy_cnt):
args_list.append(copy.deepcopy(args))

def rotating_buffer_fn(fn, args_list, copy_cnt):
for i in range(copy_cnt):
fn(*(args_list[i]))

with torch.cuda.stream(torch.cuda.Stream()):
# A rotating_buffer_fn contains multiple runs of the fn to benchmark,
# so divide time accordingly
return triton.testing.do_bench_cudagraph(
lambda: rotating_buffer_fn(self.compute, args_list, copy_cnt + 1),
rep=500,
) / (copy_cnt + 1)

def benchmark(
self,
*args,
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False
) -> float:
"""Benchmark runtime of this operator."""
if bench_quantize:
return triton.testing.do_bench(lambda: self.quantize_and_compute(*args))
with torch.cuda.stream(torch.cuda.Stream()):
t = triton.testing.do_bench_cudagraph(
lambda: self.quantize_and_compute(*args)
)
else:
return triton.testing.do_bench(lambda: self.compute(*args))
if use_rotating_buffer_bench:
t = self.bench_with_rotating_buffer(self.compute, args)
else:
with torch.cuda.stream(torch.cuda.Stream()):
t = triton.testing.do_bench_cudagraph(lambda: self.compute(*args))
return t

@abc.abstractproperty
def name(self) -> str:
Expand Down

0 comments on commit baf4045

Please sign in to comment.