Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

fp8 benchmark testing #148

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Optional

import fire
import functools

import torch
from float8_experimental.float8_linear_utils import (
Expand All @@ -19,6 +20,7 @@
sync_float8_amax_and_scale_history,
)
from torch.profiler import profile, ProfilerActivity, record_function
from torch._inductor.utils import do_bench_using_profiling


@dataclass
Expand Down Expand Up @@ -73,6 +75,10 @@ def profile_function(
if config.file_path is None:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

full_func = functools.partial(func, *args, **kwargs)
latency = do_bench_using_profiling(full_func)
print(f"{func=}, {latency=}")

return prof


Expand Down Expand Up @@ -134,24 +140,24 @@ def main(profile_path: Path, compile: bool, linear_type: str):

def ref_forw_backward(x):
if params.layer_norm:
with record_function("layer_norm"):
x = ln(x)
with record_function("forward"):
out = linear_ref(x)
with record_function("backward"):
out.sum().backward()
#with record_function("layer_norm"):
x = ln(x)
#with record_function("forward"):
out = linear_ref(x)
#with record_function("backward"):
out.sum().backward()

def float8_forw_backward(x):
if linear_requires_sync(linear_type):
with record_function("scale_amax_and_scales"):
sync_float8_amax_and_scale_history(linear_float8)
# with record_function("scale_amax_and_scales"):
sync_float8_amax_and_scale_history(linear_float8)
if params.layer_norm:
with record_function("layer_norm"):
x = ln(x)
with record_function("forward"):
out = linear_float8(x)
with record_function("backward"):
out.sum().backward()
# with record_function("layer_norm"):
x = ln(x)
# with record_function("forward"):
out = linear_float8(x)
# with record_function("backward"):
out.sum().backward()

if transformer_engine_installed:
# Create an FP8 recipe. Note: All input args are optional.
Expand All @@ -170,15 +176,20 @@ def te_forw_backward(x):
out.sum().backward()

if params.torch_compile:
ref_forw_backward = torch.compile(ref_forw_backward)
#ref_forw_backward = torch.compile(ref_forw_backward)
float8_forw_backward = torch.compile(float8_forw_backward)
# Compiling TE_linear fails but they are already compiling under the hood
# if transformer_engine_installed:
# te_forw_backward = torch.compile(te_forw_backward)

def wrapper_float8(x):
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(linear_float8)
float8_forw_backward(x)

for _ in range(5):
ref_forw_backward(input_tensor)
float8_forw_backward(input_tensor)
wrapper_float8(input_tensor)
if transformer_engine_installed:
te_forw_backward(input_tensor)

Expand All @@ -189,7 +200,7 @@ def te_forw_backward(x):
)
profile_function(profile_config, ref_forw_backward, input_tensor)

# # Profile Float8 Linear
# Profile Float8 Linear
float8_string = f"linear_float8_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}_{linear_type}.json"
profile_config = ProfileConfig(
str(profile_path / float8_string),
Expand All @@ -198,7 +209,7 @@ def te_forw_backward(x):
warmup_iters=5,
sync=True,
)
profile_function(profile_config, float8_forw_backward, input_tensor)
profile_function(profile_config, wrapper_float8, input_tensor)

te_string = f"linear_transformer_engine_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}.json"
if transformer_engine_installed:
Expand Down
Loading