Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[not4land] repro dynamo performance accuracy problem #2519

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,11 @@ def apply_torchdynamo_args(
),
)

print("{args.quantization=}")
if args.quantization:
import torchao
from torchao.quantization import (
change_linear_weights_to_int4_woqtensors,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight
)

torch._dynamo.config.automatic_dynamic_shapes = False
Expand All @@ -196,12 +195,52 @@ def apply_torchdynamo_args(
module, example_inputs = model.get_module()
if args.quantization == "int8dynamic":
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(module)
quantize_(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
elif args.quantization == "int8weightonly":
torch._inductor.config.use_mixed_mm = True
change_linear_weights_to_int8_woqtensors(module)
quantize_(module, int8_weight_only(), set_inductor_config=False)
elif args.quantization == "int4weightonly":
change_linear_weights_to_int4_woqtensors(module)
quantize_(module, int4_weight_only(), set_inductor_config=False)
if args.quantization == "autoquant":
print("module:", type(module))

torchao.autoquant(module, example_input=example_inputs, manual=True, error_on_unseen=False, set_inductor_config=False)
# torchao.autoquant(module, error_on_unseen=False, set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)

module.finalize_autoquant()

# for n, m in model.named_modules():
# if isinstance(m, torch.nn.Linear):
# print(f"name {n}, weight type:, {type(m.weight.data)}")

from torchao.quantization.autoquant import AUTOQUANT_CACHE
assert len(AUTOQUANT_CACHE)>0, f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization"

# print("autoquant profile")
# from torchao.utils import benchmark_model, profiler_runner
# model = torch.compile(module, mode="max-autotune")
# inputs = example_inputs
# benchmark_model(model, 20, inputs)
# print("elapsed_time: ", benchmark_model(model, 100, inputs), " milliseconds")
# profiler_runner("quant.json.gz", benchmark_model, model, 5, inputs)

else:
unwrap_tensor_subclass(module)
# else:
# module, example_inputs = model.get_module()
# # noquant profile
# print("noquant profile")
# from torchao.utils import benchmark_model, profiler_runner
# model = torch.compile(module, mode="max-autotune")
# inputs = example_inputs
# benchmark_model(model, 20, inputs)
# print("elapsed_time: ", benchmark_model(model, 100, inputs), " milliseconds")
# profiler_runner("noquant.json.gz", benchmark_model, model, 5, inputs)
>>>>>>> Stashed changes

if args.freeze_prepack_weights:
torch._inductor.config.freezing = True
Expand Down
8 changes: 7 additions & 1 deletion torchbenchmark/util/experiment/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchbenchmark.util.experiment.instantiator import TorchBenchModelConfig
from torchbenchmark.util.model import BenchmarkModel

WARMUP_ROUNDS = 10
WARMUP_ROUNDS = 20
BENCHMARK_ITERS = 15
MEMPROF_ITER = 2
NANOSECONDS_PER_MILLISECONDS = 1_000_000.0
Expand Down Expand Up @@ -53,6 +53,12 @@ def get_latencies(
func()
t1 = time.time_ns()
result_summary.append((t1 - t0) / NANOSECONDS_PER_MILLISECONDS)

# from torchao.utils import benchmark_model, profiler_runner
# print("device:", device)
# print("elpased:", benchmark_model(func, 100, (), device_type="cuda"))
# profiler_runner("quant.json.gz", benchmark_model, func, 5, (), device_type="cuda")

return result_summary


Expand Down
3 changes: 2 additions & 1 deletion userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3018,7 +3018,8 @@ def warmup(fn, model, example_inputs, mode, niters=10):
total = psutil.virtual_memory().total
percentage = psutil.Process(os.getpid()).memory_percent()
peak_mem = percentage * total / 10**9
except Exception:
except Exception as e:
print("exception:", e)
log.exception("Backend %s failed in warmup()", mode)
write_csv_when_exception(
self.args, current_name, "warmup_failed", current_device
Expand Down
71 changes: 69 additions & 2 deletions userbenchmark/dynamo/dynamobench/torchao_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,103 @@ def torchao_optimize_ctx(quantization: str):
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
import torchao

def inner(model_iter_fn: Callable):
def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
if getattr(module, "_quantized", None) is None:
if quantization == "noquant":
if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}

print("noquant run")
from torchao.utils import benchmark_model, profiler_runner
model = torch.compile(module, mode="max-autotune")
benchmark_model(model, 20, args, kwargs)
print("elapsed_time: ", benchmark_model(model, 100, args, kwargs), " milliseconds")
# profiler_runner("noquant.json.gz", benchmark_model, model, 5, inputs)

if quantization == "int8dynamic":
quantize_(
module,
int8_dynamic_activation_int8_weight(),
set_inductor_config=False,
)

print("int8dynamic run")
from torchao.utils import benchmark_model, profiler_runner
torchao.quantization.utils.recommended_inductor_config_setter()
model = torch.compile(module, mode="max-autotune")
if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}
benchmark_model(model, 20, args, kwargs)
print("elapsed_time: ", benchmark_model(model, 100, args, kwargs), " milliseconds")

elif quantization == "int8weightonly":
quantize_(module, int8_weight_only(), set_inductor_config=False)

print("int8weightonly run")
from torchao.utils import benchmark_model, profiler_runner
torchao.quantization.utils.recommended_inductor_config_setter()
model = torch.compile(module, mode="max-autotune")
if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}
benchmark_model(model, 20, args, kwargs)
print("elapsed_time: ", benchmark_model(model, 100, args, kwargs), " milliseconds")

elif quantization == "int4weightonly":
quantize_(module, int4_weight_only(), set_inductor_config=False)
if quantization == "autoquant":
autoquant(module, error_on_unseen=False, set_inductor_config=False)
from torchao.quantization import autoquant_v2

# autoquant(module, example_input=example_inputs, manual=True, error_on_unseen=False, set_inductor_config=True)
print("calling autoquant v2")
autoquant_v2(module, example_input=example_inputs, manual=True, error_on_unseen=False, set_inductor_config=True)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE
module.finalize_autoquant()

# from torchao.quantization.autoquant import AUTOQUANT_CACHE
from torchao.quantization.autoquant_v2 import AUTOQUANT_CACHE

if len(AUTOQUANT_CACHE) == 0:
raise Exception( # noqa: TRY002`
"NotAutoquantizable"
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
)

print("autoquant run")
from torchao.utils import benchmark_model, profiler_runner
torchao.quantization.utils.recommended_inductor_config_setter()
model = torch.compile(module, mode="max-autotune")
if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}
benchmark_model(model, 20, args, kwargs)
print("elapsed_time: ", benchmark_model(model, 100, args, kwargs), " milliseconds")
# profiler_runner("quant.json.gz", benchmark_model, model, 5, inputs)
else:
unwrap_tensor_subclass(module)
setattr(module, "_quantized", True) # noqa: B010


model_iter_fn(module, example_inputs)

return _torchao_apply
Expand Down
13 changes: 5 additions & 8 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
model: "*"
extended_models:
- huggingface
- timm
model: "resnet18"
test: eval
device: cuda
extra_args: --precision bf16 --torchdynamo inductor --inductor-compile-mode max-autotune
Expand All @@ -10,7 +7,7 @@ metrics:
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
- extra_args: --quantization autoquant
- extra_args: --quantization noquant


6 changes: 4 additions & 2 deletions userbenchmark/torchao/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def _get_ci_args(


def _get_full_ci_args(modelset: str) -> List[List[str]]:
backends = ["autoquant", "int8dynamic", "int8weightonly", "noquant"]
backends = ["autoquant", "noquant"]
modelset = [modelset]
dtype = ["bfloat16"]
mode = ["inference"]
device = ["cuda"]
experiment = ["performance", "accuracy"]
experiment = ["performance"]
cfgs = itertools.product(*[backends, modelset, dtype, mode, device, experiment])
return [_get_ci_args(*cfg) for cfg in cfgs]

Expand Down Expand Up @@ -92,6 +92,8 @@ def run(args: List[str]):
else:
benchmark_args = [pt2_args]

print("benchmark args:", benchmark_args)

output_files = [_run_pt2_args(args) for args in benchmark_args]
# Post-processing
if args.dashboard:
Expand Down
Loading