Skip to content

Commit

Permalink
torch compile modes
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Sep 30, 2024
1 parent fffc87e commit 65fad4c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def run_torch(config: TestConfig):
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
if config.component == "image_encoder":
if is_cuda:
if is_cuda and config.torch_compile_mode != "none":
sam2_model.image_encoder.forward = torch.compile(
sam2_model.image_encoder.forward,
mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
Expand All @@ -184,8 +184,9 @@ def run_torch(config: TestConfig):
img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
sam2_encoder = SAM2ImageEncoder(sam2_model)

if is_cuda:
if is_cuda and config.torch_compile_mode != "none":
print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")

for _ in range(config.warm_up):
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)

Expand All @@ -208,6 +209,10 @@ def run_torch(config: TestConfig):
with torch.profiler.record_function("encoder"):
sam2_encoder(img)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_encoder.json")

if config.repeats == 0:
return

print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
Expand All @@ -232,6 +237,14 @@ def run_torch(config: TestConfig):
multimask_output=config.multi_mask_output,
)

if is_cuda and config.torch_compile_mode != "none":
sam2_decoder.forward = torch.compile(
sam2_decoder.forward,
mode=config.torch_compile_mode,
fullgraph=True,
dynamic=False,
)

# warm up
for _ in range(config.warm_up):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
Expand All @@ -255,6 +268,10 @@ def run_torch(config: TestConfig):
with torch.profiler.record_function("decoder"):
sam2_decoder(*torch_inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_decoder.json")

if config.repeats == 0:
return

print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
Expand Down Expand Up @@ -304,6 +321,7 @@ def run_test(
warm_up=args.warm_up,
enable_nvtx_profile=args.enable_nvtx_profile,
enable_torch_profile=args.enable_torch_profile,
torch_compile_mode=args.torch_compile_mode,
verbose=False,
)

Expand Down Expand Up @@ -336,6 +354,9 @@ def run_test(
cudart.cudaProfilerStop()
session.ort_session.end_profiling()

if repeats == 0:
return

latency_list = []
for _ in range(repeats):
latency = measure_latency(session, input_dict)
Expand All @@ -351,6 +372,9 @@ def run_test(
print(f"Failed to run {config=}. Exception: {e}")
return

if repeats == 0:
return

engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
row = {
"model_type": args.model_type,
Expand All @@ -371,6 +395,7 @@ def run_test(
"warm_up": config.warm_up,
"repeats": repeats,
"enable_nvtx_profile": args.enable_nvtx_profile,
"torch_compile_mode": args.torch_compile_mode,
"engine": engine,
"average_latency": average_latency,
}
Expand Down Expand Up @@ -409,6 +434,7 @@ def run_perf_test(args):
"warm_up",
"repeats",
"enable_nvtx_profile",
"torch_compile_mode",
"engine",
"average_latency",
]
Expand Down Expand Up @@ -564,6 +590,15 @@ def _parse_arguments():
help="path of onnx model",
)

parser.add_argument(
"--torch_compile_mode",
required=False,
type=str,
default=None,
choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
help="torch compile mode. none will disable torch compile.",
)

args = parser.parse_args()

return args
Expand All @@ -573,6 +608,10 @@ def _parse_arguments():
args = _parse_arguments()
print(f"arguments:{args}")

if args.torch_compile_mode is None:
# image decoder will fail with compile modes other than "none".
args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"

if args.use_gpu:
assert torch.cuda.is_available()
if args.engine == "ort":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ install_sam2()

cd segment-anything-2

if ! [ pip show SAM-2 2>/dev/null ]; then
if pip show SAM-2 > /dev/null 2>&1; then
echo "SAM-2 is already installed."
else
pip install -e .
fi

Expand Down Expand Up @@ -151,14 +153,19 @@ run_gpu()
$python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp16 --demo

echo "Benchmarking SAM2 model $model image encoder for PyTorch ..."
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32

# Test different torch compile modes on image encoder (none will disable compile and use eager mode).
for torch_compile_mode in none max-autotune reduce-overhead max-autotune-no-cudagraphs
do
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode
done

echo "Benchmarking SAM2 model $model image decoder for PyTorch ..."
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_decoder
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 --component image_decoder
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 --component image_decoder
$python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_decoder

echo "Benchmarking SAM2 model $model image encoder for ORT ..."
$python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp16_gpu.onnx --use_gpu --dtype fp16
Expand All @@ -173,7 +180,7 @@ run_gpu()
build_onnxruntime_gpu_for_profiling()
{
pushd $install_dir
if ! [ -d $install_dir/onnxruntime ]; then
if ! [ -d onnxruntime ]; then
git clone https://github.com/microsoft/onnxruntime
fi
cd onnxruntime
Expand Down Expand Up @@ -204,7 +211,7 @@ build_onnxruntime_gpu_for_profiling()
}

# Run profiling with NVTX.
run_gpu_profile()
run_nvtx_profile()
{
pip install nvtx cuda-python==12.5.0

Expand Down Expand Up @@ -238,6 +245,18 @@ run_gpu_profile()
done
}

# Run profiling with PyTorch
run_torch_profile()
{
for component in image_encoder image_decoder
do
$python benchmark_sam2.py --model_type $model --engine torch \
--sam2_dir $sam2_dir --warm_up 1 --repeats 0 \
--component $component \
--use_gpu --dtype fp16 --enable_torch_profile
done
}

if ! [ -v CONDA_PREFIX ]; then
echo "Please activate conda environment before running this script."
exit 1
Expand Down Expand Up @@ -278,8 +297,10 @@ if [ "$cpu_or_gpu" = "gpu" ]; then
rm -f *.nsys-rep
rm -f *.sqlite
build_onnxruntime_gpu_for_profiling
run_gpu_profile
run_nvtx_profile
else
echo "sam2_fp16_profile_ort.nsys-rep already exists, skipping GPU profiling..."
fi

run_torch_profile
fi

0 comments on commit 65fad4c

Please sign in to comment.