Skip to content

Commit

Permalink
Update SAM2 benchmark for testing torch compile modes and profiling (m…
Browse files Browse the repository at this point in the history
…icrosoft#22279)

This pull request introduces several enhancements to the benchmarking
process for the SAM2 model, including:
(1) Add profiling capabilities.
(2) test torch compile modes (none will disable compile and fallback to
eager mode)
(3) Update README for setting up the environment.

### Documentation Updates:
* README.md: Updated instructions to create separate conda environments
for GPU and CPU benchmarking, and detailed the parameters and outputs of
the benchmark script.

### Benchmark Script Enhancements:
* benchmark_sam2.py: Added optional parameters for enabling NVTX and
PyTorch profiling, and adjusted the initialization and execution flow to
incorporate these profiling options.

These changes enhance the flexibility and functionality of the
benchmarking process, making it easier to profile and benchmark the SAM2
model on different hardware configurations.
  • Loading branch information
tianleiwu authored and Ishwar Raut committed Nov 19, 2024
1 parent 7b45d8d commit c972eb4
Show file tree
Hide file tree
Showing 7 changed files with 502 additions and 51 deletions.
25 changes: 21 additions & 4 deletions onnxruntime/python/tools/transformers/models/sam2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,30 @@ python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --dtype fp16 --use_g
```

## Benchmark
To prepare an environment for benchmark, follow [Setup Environment](#setup-environment) and [Download Checkpoints](#download-checkpoints).

Run the benchmark like the following:
We can create a conda environment then run GPU benchmark like the following:
```bash
sh benchmark_sam2.sh
conda create -n sam2_gpu python=3.11 -y
conda activate sam2_gpu
bash benchmark_sam2.sh $HOME gpu
```
The result is in sam2.csv, which can be loaded into Excel.

or create a new conda environment for CPU benchmark:
```bash
conda create -n sam2_cpu python=3.11 -y
conda activate sam2_cpu
bash benchmark_sam2.sh $HOME cpu
```

The first parameter is a directory to clone git repositories or install CUDA/cuDNN for benchmark.
The second parameter can be either "gpu" or "cpu", which indicates the device to run benchmark.

The script will automatically install required packages in current conda environment, download checkpoints, export onnx,
and run demo, benchmark and profiling.

* The performance test result is in sam2_gpu.csv or sam2_cpu.csv, which can be loaded into Excel.
* The demo output is sam2_demo_fp16_gpu.png or sam2_demo_fp32_cpu.png.
* The profiling results are in *.nsys-rep or *.json files in current directory.

## Limitations
- The exported image_decoder model does not support batch mode for now.
157 changes: 144 additions & 13 deletions onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import statistics
import time
from datetime import datetime
from typing import List, Mapping
from typing import List, Mapping, Optional

import torch
from image_decoder import SAM2ImageDecoder
Expand Down Expand Up @@ -45,6 +45,8 @@ def __init__(
dtype=torch.float32,
prefer_nhwc: bool = False,
warm_up: int = 5,
enable_nvtx_profile: bool = False,
enable_torch_profile: bool = False,
repeats: int = 1000,
verbose: bool = False,
):
Expand All @@ -70,7 +72,9 @@ def __init__(
self.enable_cuda_graph = enable_cuda_graph
self.dtype = dtype
self.prefer_nhwc = prefer_nhwc
self.warm_up = 5
self.warm_up = warm_up
self.enable_nvtx_profile = enable_nvtx_profile
self.enable_torch_profile = enable_torch_profile
self.repeats = repeats
self.verbose = verbose

Expand Down Expand Up @@ -168,7 +172,7 @@ def run_torch(config: TestConfig):
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
if config.component == "image_encoder":
if is_cuda:
if is_cuda and config.torch_compile_mode != "none":
sam2_model.image_encoder.forward = torch.compile(
sam2_model.image_encoder.forward,
mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
Expand All @@ -180,11 +184,36 @@ def run_torch(config: TestConfig):
img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
sam2_encoder = SAM2ImageEncoder(sam2_model)

if is_cuda:
if is_cuda and config.torch_compile_mode != "none":
print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")

for _ in range(config.warm_up):
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)

if is_cuda and config.enable_nvtx_profile:
import nvtx
from cuda import cudart

cudart.cudaProfilerStart()
print("Start nvtx profiling on encoder ...")
with nvtx.annotate("one_run"):
sam2_encoder(img, enable_nvtx_profile=True)
cudart.cudaProfilerStop()

if is_cuda and config.enable_torch_profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
print("Start torch profiling on encoder ...")
with torch.profiler.record_function("encoder"):
sam2_encoder(img)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_encoder.json")

if config.repeats == 0:
return

print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
for _ in range(config.repeats):
Expand All @@ -203,25 +232,61 @@ def run_torch(config: TestConfig):
ort_inputs["original_image_size"],
)

sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=config.multi_mask_output)
sam2_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=config.multi_mask_output,
)

if is_cuda and config.torch_compile_mode != "none":
sam2_decoder.forward = torch.compile(
sam2_decoder.forward,
mode=config.torch_compile_mode,
fullgraph=True,
dynamic=False,
)

# warm up
for _ in range(3):
for _ in range(config.warm_up):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)

if is_cuda and config.enable_nvtx_profile:
import nvtx
from cuda import cudart

cudart.cudaProfilerStart()
print("Start nvtx profiling on decoder...")
with nvtx.annotate("one_run"):
sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
cudart.cudaProfilerStop()

if is_cuda and config.enable_torch_profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
print("Start torch profiling on decoder ...")
with torch.profiler.record_function("decoder"):
sam2_decoder(*torch_inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_decoder.json")

if config.repeats == 0:
return

print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
for _ in range(config.repeats):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
if is_cuda:
torch.cuda.synchronize()

end = time.time()
return (end - start) / config.repeats


def run_test(
csv_writer: csv.DictWriter,
args: argparse.Namespace,
csv_writer: Optional[csv.DictWriter] = None,
):
use_gpu: bool = args.use_gpu
enable_cuda_graph: bool = args.use_cuda_graph
Expand Down Expand Up @@ -254,12 +319,19 @@ def run_test(
prefer_nhwc=args.prefer_nhwc,
repeats=args.repeats,
warm_up=args.warm_up,
enable_nvtx_profile=args.enable_nvtx_profile,
enable_torch_profile=args.enable_torch_profile,
torch_compile_mode=args.torch_compile_mode,
verbose=False,
)

if args.engine == "ort":
sess_options = SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads
if config.enable_nvtx_profile:
sess_options.enable_profiling = True
sess_options.log_severity_level = 4
sess_options.log_verbosity_level = 0

session = create_session(config, sess_options)
input_dict = config.random_inputs()
Expand All @@ -272,6 +344,19 @@ def run_test(
print(f"Failed to run {config=}. Exception: {e}")
return

if config.enable_nvtx_profile:
import nvtx
from cuda import cudart

cudart.cudaProfilerStart()
with nvtx.annotate("one_run"):
_ = session.infer(input_dict)
cudart.cudaProfilerStop()
session.ort_session.end_profiling()

if repeats == 0:
return

latency_list = []
for _ in range(repeats):
latency = measure_latency(session, input_dict)
Expand All @@ -287,6 +372,9 @@ def run_test(
print(f"Failed to run {config=}. Exception: {e}")
return

if repeats == 0:
return

engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
row = {
"model_type": args.model_type,
Expand All @@ -304,18 +392,22 @@ def run_test(
"num_points": config.num_points,
"num_masks": config.num_masks,
"intra_op_num_threads": args.intra_op_num_threads,
"engine": engine,
"warm_up": config.warm_up,
"repeats": repeats,
"enable_nvtx_profile": args.enable_nvtx_profile,
"torch_compile_mode": args.torch_compile_mode,
"engine": engine,
"average_latency": average_latency,
}
csv_writer.writerow(row)

if csv_writer is not None:
csv_writer.writerow(row)

print(f"{vars(config)}")
print(f"{row}")


def run_tests(args):
def run_perf_test(args):
features = "gpu" if args.use_gpu else "cpu"
csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
features,
Expand All @@ -339,15 +431,17 @@ def run_tests(args):
"num_points",
"num_masks",
"intra_op_num_threads",
"engine",
"warm_up",
"repeats",
"enable_nvtx_profile",
"torch_compile_mode",
"engine",
"average_latency",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()

run_test(csv_writer, args)
run_test(args, csv_writer)


def _parse_arguments():
Expand Down Expand Up @@ -455,6 +549,22 @@ def _parse_arguments():
help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
)

parser.add_argument(
"--enable_nvtx_profile",
required=False,
default=False,
action="store_true",
help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
)

parser.add_argument(
"--enable_torch_profile",
required=False,
default=False,
action="store_true",
help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
)

parser.add_argument(
"--model_type",
required=False,
Expand All @@ -480,6 +590,15 @@ def _parse_arguments():
help="path of onnx model",
)

parser.add_argument(
"--torch_compile_mode",
required=False,
type=str,
default=None,
choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
help="torch compile mode. none will disable torch compile.",
)

args = parser.parse_args()

return args
Expand All @@ -489,9 +608,21 @@ def _parse_arguments():
args = _parse_arguments()
print(f"arguments:{args}")

if args.torch_compile_mode is None:
# image decoder will fail with compile modes other than "none".
args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"

if args.use_gpu:
assert torch.cuda.is_available()
if args.engine == "ort":
assert "CUDAExecutionProvider" in get_available_providers()
args.enable_torch_profile = False
else:
# Only support cuda profiling for now.
assert not args.enable_nvtx_profile
assert not args.enable_torch_profile

run_tests(args)
if args.enable_nvtx_profile or args.enable_torch_profile:
run_test(args)
else:
run_perf_test(args)
Loading

0 comments on commit c972eb4

Please sign in to comment.