diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md index b0d35ac79f2fa..26385896aa49b 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/README.md +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -90,13 +90,30 @@ python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --dtype fp16 --use_g ``` ## Benchmark -To prepare an environment for benchmark, follow [Setup Environment](#setup-environment) and [Download Checkpoints](#download-checkpoints). -Run the benchmark like the following: +We can create a conda environment then run GPU benchmark like the following: ```bash -sh benchmark_sam2.sh +conda create -n sam2_gpu python=3.11 -y +conda activate sam2_gpu +bash benchmark_sam2.sh $HOME gpu ``` -The result is in sam2.csv, which can be loaded into Excel. + +or create a new conda environment for CPU benchmark: +```bash +conda create -n sam2_cpu python=3.11 -y +conda activate sam2_cpu +bash benchmark_sam2.sh $HOME cpu +``` + +The first parameter is a directory to clone git repositories or install CUDA/cuDNN for benchmark. +The second parameter can be either "gpu" or "cpu", which indicates the device to run benchmark. + +The script will automatically install required packages in current conda environment, download checkpoints, export onnx, +and run demo, benchmark and profiling. + +* The performance test result is in sam2_gpu.csv or sam2_cpu.csv, which can be loaded into Excel. +* The demo output is sam2_demo_fp16_gpu.png or sam2_demo_fp32_cpu.png. +* The profiling results are in *.nsys-rep or *.json files in current directory. ## Limitations - The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py index fefb7d308af7e..7e108b1638546 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -12,7 +12,7 @@ import statistics import time from datetime import datetime -from typing import List, Mapping +from typing import List, Mapping, Optional import torch from image_decoder import SAM2ImageDecoder @@ -45,6 +45,8 @@ def __init__( dtype=torch.float32, prefer_nhwc: bool = False, warm_up: int = 5, + enable_nvtx_profile: bool = False, + enable_torch_profile: bool = False, repeats: int = 1000, verbose: bool = False, ): @@ -70,7 +72,9 @@ def __init__( self.enable_cuda_graph = enable_cuda_graph self.dtype = dtype self.prefer_nhwc = prefer_nhwc - self.warm_up = 5 + self.warm_up = warm_up + self.enable_nvtx_profile = enable_nvtx_profile + self.enable_torch_profile = enable_torch_profile self.repeats = repeats self.verbose = verbose @@ -168,7 +172,7 @@ def run_torch(config: TestConfig): with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast): sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device) if config.component == "image_encoder": - if is_cuda: + if is_cuda and config.torch_compile_mode != "none": sam2_model.image_encoder.forward = torch.compile( sam2_model.image_encoder.forward, mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run. @@ -180,11 +184,36 @@ def run_torch(config: TestConfig): img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype) sam2_encoder = SAM2ImageEncoder(sam2_model) - if is_cuda: + if is_cuda and config.torch_compile_mode != "none": print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.") + for _ in range(config.warm_up): _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img) + if is_cuda and config.enable_nvtx_profile: + import nvtx + from cuda import cudart + + cudart.cudaProfilerStart() + print("Start nvtx profiling on encoder ...") + with nvtx.annotate("one_run"): + sam2_encoder(img, enable_nvtx_profile=True) + cudart.cudaProfilerStop() + + if is_cuda and config.enable_torch_profile: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + print("Start torch profiling on encoder ...") + with torch.profiler.record_function("encoder"): + sam2_encoder(img) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + prof.export_chrome_trace("torch_image_encoder.json") + + if config.repeats == 0: + return + print(f"Start {config.repeats} runs of performance tests...") start = time.time() for _ in range(config.repeats): @@ -203,25 +232,61 @@ def run_torch(config: TestConfig): ort_inputs["original_image_size"], ) - sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=config.multi_mask_output) + sam2_decoder = SAM2ImageDecoder( + sam2_model, + multimask_output=config.multi_mask_output, + ) + + if is_cuda and config.torch_compile_mode != "none": + sam2_decoder.forward = torch.compile( + sam2_decoder.forward, + mode=config.torch_compile_mode, + fullgraph=True, + dynamic=False, + ) # warm up - for _ in range(3): + for _ in range(config.warm_up): _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs) + if is_cuda and config.enable_nvtx_profile: + import nvtx + from cuda import cudart + + cudart.cudaProfilerStart() + print("Start nvtx profiling on decoder...") + with nvtx.annotate("one_run"): + sam2_decoder(*torch_inputs, enable_nvtx_profile=True) + cudart.cudaProfilerStop() + + if is_cuda and config.enable_torch_profile: + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + record_shapes=True, + ) as prof: + print("Start torch profiling on decoder ...") + with torch.profiler.record_function("decoder"): + sam2_decoder(*torch_inputs) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + prof.export_chrome_trace("torch_image_decoder.json") + + if config.repeats == 0: + return + print(f"Start {config.repeats} runs of performance tests...") start = time.time() for _ in range(config.repeats): _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs) if is_cuda: torch.cuda.synchronize() + end = time.time() return (end - start) / config.repeats def run_test( - csv_writer: csv.DictWriter, args: argparse.Namespace, + csv_writer: Optional[csv.DictWriter] = None, ): use_gpu: bool = args.use_gpu enable_cuda_graph: bool = args.use_cuda_graph @@ -254,12 +319,19 @@ def run_test( prefer_nhwc=args.prefer_nhwc, repeats=args.repeats, warm_up=args.warm_up, + enable_nvtx_profile=args.enable_nvtx_profile, + enable_torch_profile=args.enable_torch_profile, + torch_compile_mode=args.torch_compile_mode, verbose=False, ) if args.engine == "ort": sess_options = SessionOptions() sess_options.intra_op_num_threads = args.intra_op_num_threads + if config.enable_nvtx_profile: + sess_options.enable_profiling = True + sess_options.log_severity_level = 4 + sess_options.log_verbosity_level = 0 session = create_session(config, sess_options) input_dict = config.random_inputs() @@ -272,6 +344,19 @@ def run_test( print(f"Failed to run {config=}. Exception: {e}") return + if config.enable_nvtx_profile: + import nvtx + from cuda import cudart + + cudart.cudaProfilerStart() + with nvtx.annotate("one_run"): + _ = session.infer(input_dict) + cudart.cudaProfilerStop() + session.ort_session.end_profiling() + + if repeats == 0: + return + latency_list = [] for _ in range(repeats): latency = measure_latency(session, input_dict) @@ -287,6 +372,9 @@ def run_test( print(f"Failed to run {config=}. Exception: {e}") return + if repeats == 0: + return + engine = args.engine + ":" + ("cuda" if use_gpu else "cpu") row = { "model_type": args.model_type, @@ -304,18 +392,22 @@ def run_test( "num_points": config.num_points, "num_masks": config.num_masks, "intra_op_num_threads": args.intra_op_num_threads, - "engine": engine, "warm_up": config.warm_up, "repeats": repeats, + "enable_nvtx_profile": args.enable_nvtx_profile, + "torch_compile_mode": args.torch_compile_mode, + "engine": engine, "average_latency": average_latency, } - csv_writer.writerow(row) + + if csv_writer is not None: + csv_writer.writerow(row) print(f"{vars(config)}") print(f"{row}") -def run_tests(args): +def run_perf_test(args): features = "gpu" if args.use_gpu else "cpu" csv_filename = "benchmark_sam_{}_{}_{}.csv".format( features, @@ -339,15 +431,17 @@ def run_tests(args): "num_points", "num_masks", "intra_op_num_threads", - "engine", "warm_up", "repeats", + "enable_nvtx_profile", + "torch_compile_mode", + "engine", "average_latency", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) csv_writer.writeheader() - run_test(csv_writer, args) + run_test(args, csv_writer) def _parse_arguments(): @@ -455,6 +549,22 @@ def _parse_arguments(): help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider", ) + parser.add_argument( + "--enable_nvtx_profile", + required=False, + default=False, + action="store_true", + help="Enable nvtx profiling. It will add an extra run for profiling before performance test.", + ) + + parser.add_argument( + "--enable_torch_profile", + required=False, + default=False, + action="store_true", + help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.", + ) + parser.add_argument( "--model_type", required=False, @@ -480,6 +590,15 @@ def _parse_arguments(): help="path of onnx model", ) + parser.add_argument( + "--torch_compile_mode", + required=False, + type=str, + default=None, + choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"], + help="torch compile mode. none will disable torch compile.", + ) + args = parser.parse_args() return args @@ -489,9 +608,21 @@ def _parse_arguments(): args = _parse_arguments() print(f"arguments:{args}") + if args.torch_compile_mode is None: + # image decoder will fail with compile modes other than "none". + args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none" + if args.use_gpu: assert torch.cuda.is_available() if args.engine == "ort": assert "CUDAExecutionProvider" in get_available_providers() + args.enable_torch_profile = False + else: + # Only support cuda profiling for now. + assert not args.enable_nvtx_profile + assert not args.enable_torch_profile - run_tests(args) + if args.enable_nvtx_profile or args.enable_torch_profile: + run_test(args) + else: + run_perf_test(args) diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh index 94e57ecb89fc1..f8c5abdb75311 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -1,83 +1,306 @@ -#!/bin/sh +#!/bin/bash # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +# Here assumes that we are using conda (Anaconda/Miniconda/Miniforge) environment. +# For example, you can create a new conda environment like the following before running this script: +# conda create -n sam2_gpu python=3.11 -y +# conda activate sam2_gpu +# bash benchmark_sam2.sh $HOME gpu +# Or create a new conda environment for CPU benchmark: +# conda create -n sam2_cpu python=3.11 -y +# conda activate sam2_cpu +# bash benchmark_sam2.sh $HOME cpu + +python=$CONDA_PREFIX/bin/python3 + # Directory of the script dir="$( cd "$( dirname "$0" )" && pwd )" # Directory of the onnx models onnx_dir=$dir/sam2_onnx_models +# Directory to install CUDA, cuDNN, and git clone sam2 or onnxruntime source code. +install_dir=$HOME +if [ $# -ge 1 ]; then + install_dir=$1 +fi + +if ! [ -d $install_dir ]; then + echo "install_dir: $install_dir does not exist." + exit 1 +fi + # Directory of the sam2 code by "git clone https://github.com/facebookresearch/segment-anything-2" -# It reads from the sam2_dir environment variable, or defaults to ~/segment-anything-2. -sam2_dir=${sam2_dir:-~/segment-anything-2} +sam2_dir=$install_dir/segment-anything-2 # model name to benchmark model=sam2_hiera_large +# Default to use GPU if available. +cpu_or_gpu="gpu" +if [ $# -ge 2 ] && ([ "$2" = "gpu" ] || [ "$2" = "cpu" ]); then + cpu_or_gpu=$2 +fi + +echo "install_dir: $install_dir" +echo "cpu_or_gpu: $cpu_or_gpu" + +install_cuda_12() +{ + pushd $install_dir + wget https://developer.download.nvidia.com/compute/cuda/12.5.1/local_installers/cuda_12.5.1_555.42.06_linux.run + sh cuda_12.5.1_555.42.06_linux.run --toolkit --toolkitpath=$install_dir/cuda12.5 --silent --override --no-man-page + + export PATH="$install_dir/cuda12.5/bin:$PATH" + export LD_LIBRARY_PATH="$install_dir/cuda12.5/lib64:$LD_LIBRARY_PATH" + popd +} + +install_cudnn_9() +{ + pushd $install_dir + wget https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-9.4.0.58_cuda12-archive.tar.xz + mkdir $install_dir/cudnn9.4 + tar -Jxvf cudnn-linux-x86_64-9.4.0.58_cuda12-archive.tar.xz -C $install_dir/cudnn9.4 --strip=1 --no-overwrite-dir + + export LD_LIBRARY_PATH="$install_dir/cudnn9.4/lib:$LD_LIBRARY_PATH" + popd +} + +install_gpu() +{ + if ! [ -d $install_dir/cuda12.5 ]; then + install_cuda_12 + fi + + if ! [ -d $install_dir/cudnn9.4 ]; then + install_cudnn_9 + fi + + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 + pip install onnxruntime-gpu onnx opencv-python matplotlib +} + +install_cpu() +{ + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip install onnxruntime onnx opencv-python matplotlib +} + +install_sam2() +{ + pushd $install_dir + + if ! [ -d $install_dir/segment-anything-2 ]; then + git clone https://github.com/facebookresearch/segment-anything-2.git + fi + + cd segment-anything-2 + + if pip show SAM-2 > /dev/null 2>&1; then + echo "SAM-2 is already installed." + else + pip install -e . + fi + + if ! [ -f checkpoints/sam2_hiera_large.pt ]; then + echo "Downloading checkpoints..." + cd checkpoints + sh ./download_ckpts.sh + fi + + popd +} + +download_test_image() +{ + if ! [ -f truck.jpg ]; then + curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg + fi +} + run_cpu() { repeats=$1 - python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --demo + $python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --demo echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 --component image_decoder - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 --component image_decoder + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 --component image_decoder + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 --component image_decoder echo "Benchmarking SAM2 model $model image encoder for ORT ..." - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder.onnx --dtype fp32 - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_cpu.onnx --dtype fp32 + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder.onnx --dtype fp32 + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_cpu.onnx --dtype fp32 echo "Benchmarking SAM2 model $model image decoder for ORT ..." - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder.onnx --component image_decoder - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_cpu.onnx --component image_decoder + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder.onnx --component image_decoder + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_cpu.onnx --component image_decoder } run_gpu() { repeats=$1 - python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp32 - python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp16 --demo + $python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp32 + $python convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp16 --demo echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 + + # Test different torch compile modes on image encoder (none will disable compile and use eager mode). + for torch_compile_mode in none max-autotune reduce-overhead max-autotune-no-cudagraphs + do + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_encoder --torch_compile_mode $torch_compile_mode + done echo "Benchmarking SAM2 model $model image decoder for PyTorch ..." - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_decoder - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 --component image_decoder - python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 --component image_decoder + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 --component image_decoder + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 --component image_decoder + $python benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_decoder echo "Benchmarking SAM2 model $model image encoder for ORT ..." - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp16_gpu.onnx --use_gpu --dtype fp16 - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_gpu.onnx --use_gpu --dtype fp32 + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp16_gpu.onnx --use_gpu --dtype fp16 + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_gpu.onnx --use_gpu --dtype fp32 echo "Benchmarking SAM2 model $model image decoder for ORT ..." - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp16_gpu.onnx --component image_decoder --use_gpu --dtype fp16 - python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_gpu.onnx --component image_decoder --use_gpu + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp16_gpu.onnx --component image_decoder --use_gpu --dtype fp16 + $python benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_gpu.onnx --component image_decoder --use_gpu +} + +# Build onnxruntime-gpu from source for profiling. +build_onnxruntime_gpu_for_profiling() +{ + pushd $install_dir + if ! [ -d onnxruntime ]; then + git clone https://github.com/microsoft/onnxruntime + fi + cd onnxruntime + + # Get the CUDA compute capability of the GPU. + CUDA_ARCH=$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}{cc[1]}')") + + if [ -n "$CUDA_ARCH" ]; then + pip install --upgrade pip cmake psutil setuptools wheel packaging ninja numpy==1.26.4 + sh build.sh --config Release --build_dir build/cuda12 --build_shared_lib --parallel \ + --use_cuda --cuda_version 12.5 --cuda_home $install_dir/cuda12.5 \ + --cudnn_home $install_dir/cudnn9.4 \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --compile_no_warning_as_error \ + --enable_cuda_nhwc_ops \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=$CUDA_ARCH \ + --cmake_extra_defines onnxruntime_ENABLE_NVTX_PROFILE=ON \ + --enable_cuda_line_info + + pip install build/cuda12/Release/dist/onnxruntime_gpu-*-linux_x86_64.whl numpy==1.26.4 + else + echo "PyTorch is not installed or No CUDA device found." + exit 1 + fi + + popd } -if ! [ -f truck.jpg ]; then - curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg +# Run profiling with NVTX. +run_nvtx_profile() +{ + pip install nvtx cuda-python==12.5.0 + + # Only trace one device to avoid huge output file size. + device_id=0 + + # Environment variables + envs="CUDA_VISIBLE_DEVICES=$device_id,ORT_ENABLE_CUDNN_FLASH_ATTENTION=1,LD_LIBRARY_PATH=$LD_LIBRARY_PATH" + + # For cuda graphs, node activities will be collected and CUDA graphs will not be traced as a whole. + # This may cause significant runtime overhead. But it is useful to understand the performance of individual nodes. + cuda_graph_trace=node + + for engine in ort torch + do + for component in image_encoder image_decoder + do + sudo $install_dir/cuda12.5/bin/nsys profile --capture-range=nvtx --nvtx-capture='one_run' \ + --gpu-metrics-device $device_id --force-overwrite true \ + --sample process-tree --backtrace fp --stats true \ + -t cuda,cudnn,cublas,osrt,nvtx --cuda-memory-usage true --cudabacktrace all \ + --cuda-graph-trace $cuda_graph_trace \ + -e $envs,NSYS_NVTX_PROFILER_REGISTER_ONLY=0 \ + -o sam2_fp16_profile_${component}_${engine}_${cpu_or_gpu} \ + $python benchmark_sam2.py --model_type $model --engine $engine \ + --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ + --onnx_path ${onnx_dir}/${model}_${component}_fp16_gpu.onnx \ + --component $component \ + --use_gpu --dtype fp16 --enable_nvtx_profile + done + done +} + +# Run profiling with PyTorch +run_torch_profile() +{ + for component in image_encoder image_decoder + do + $python benchmark_sam2.py --model_type $model --engine torch \ + --sam2_dir $sam2_dir --warm_up 1 --repeats 0 \ + --component $component \ + --use_gpu --dtype fp16 --enable_torch_profile + done +} + +if ! [ -v CONDA_PREFIX ]; then + echo "Please activate conda environment before running this script." + exit 1 fi -if python3 -c "import torch; assert torch.cuda.is_available()" 2>/dev/null; then - run_gpu 1000 +# Check whether nvidia-smi is available to determine whether to install GPU or CPU version. +if [ "$cpu_or_gpu" = "gpu" ]; then + install_gpu else - run_cpu 100 + install_cpu fi -cat benchmark*.csv > combined_csv -awk '!x[$0]++' combined_csv > sam2.csv -rm combined_csv +install_sam2 + +download_test_image -echo "Benchmarking SAM2 model $model done. Results are saved in sam2.csv" +if ! [ -f sam2_${cpu_or_gpu}.csv ]; then + if [ "$cpu_or_gpu" = "gpu" ]; then + echo "Running GPU benchmark..." + run_gpu 1000 + else + echo "Running CPU benchmark..." + run_cpu 100 + fi + + cat benchmark*.csv > combined_csv + awk '!x[$0]++' combined_csv > sam2_${cpu_or_gpu}.csv + rm combined_csv + + echo "Benchmarking SAM2 model $model results are saved in sam2_${cpu_or_gpu}.csv" +else + echo "sam2_${cpu_or_gpu}.csv already exists, skipping benchmarking..." +fi + +if [ "$cpu_or_gpu" = "gpu" ]; then + echo "Running GPU profiling..." + if ! [ -f sam2_fp16_profile_image_decoder_ort_${cpu_or_gpu}.nsys-rep ]; then + rm -f *.nsys-rep + rm -f *.sqlite + build_onnxruntime_gpu_for_profiling + run_nvtx_profile + else + echo "sam2_fp16_profile_image_decoder_ort_${cpu_or_gpu}.nsys-rep already exists, skipping GPU profiling..." + fi + + run_torch_profile +fi diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py index 40c408e851638..cacad717faf9c 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -251,6 +251,7 @@ def main(): print("demo output files for PyTorch:", torch_image_files) show_all_images(ort_image_files, torch_image_files, suffix) + print(f"Combined demo output: sam2_demo{suffix}.png") if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py index 0f7a1099461bc..5eafb29713126 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py @@ -43,6 +43,7 @@ def forward( input_masks: torch.Tensor, has_input_masks: torch.Tensor, original_image_size: torch.Tensor, + enable_nvtx_profile: bool = False, ): """ Decode masks from image features and prompts. Batched images are not supported. H=W=1024. @@ -60,19 +61,38 @@ def forward( Typically coming from a previous iteration. has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise. original_image_size(torch.Tensor): [2]. original image size H_o, W_o. + enable_nvtx_profile (bool): enable NVTX profiling. Returns: masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size. iou_predictions (torch.Tensor): [1, M]. scores for M masks. low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks. """ + nvtx_helper = None + if enable_nvtx_profile: + from nvtx_helper import NvtxHelper + + nvtx_helper = NvtxHelper(["prompt_encoder", "mask_decoder", "post_process"]) + + if nvtx_helper is not None: + nvtx_helper.start_profile("prompt_encoder", color="blue") + sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder( point_coords, point_labels, input_masks, has_input_masks ) + + if nvtx_helper is not None: + nvtx_helper.stop_profile("prompt_encoder") + nvtx_helper.start_profile("mask_decoder", color="red") + low_res_masks, iou_predictions = self.mask_decoder( image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings ) + if nvtx_helper is not None: + nvtx_helper.stop_profile("mask_decoder") + nvtx_helper.start_profile("post_process", color="green") + # Interpolate the low resolution masks back to the original image size. masks = F.interpolate( low_res_masks, @@ -85,6 +105,10 @@ def forward( if not self.return_logits: masks = masks > self.mask_threshold + if nvtx_helper is not None: + nvtx_helper.stop_profile("post_process") + nvtx_helper.print_latency() + return masks, iou_predictions, low_res_masks diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py index ec05e5f5b0f6c..b9f30d0371dbe 100644 --- a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py +++ b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py @@ -22,7 +22,11 @@ def __init__(self, sam_model: SAM2Base) -> None: self.image_encoder = sam_model.image_encoder self.no_mem_embed = sam_model.no_mem_embed - def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, + image: torch.Tensor, + enable_nvtx_profile: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Encodes images into features. @@ -31,14 +35,28 @@ def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torc Args: image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width. + enable_nvtx_profile (bool): enable NVTX profiling. Returns: image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0 image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1 image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride """ + nvtx_helper = None + if enable_nvtx_profile: + from nvtx_helper import NvtxHelper + + nvtx_helper = NvtxHelper(["image_encoder", "post_process"]) + + if nvtx_helper is not None: + nvtx_helper.start_profile("image_encoder") + backbone_out = self.image_encoder(image) + if nvtx_helper is not None: + nvtx_helper.stop_profile("image_encoder") + nvtx_helper.start_profile("post_process") + # precompute projected level 0 and level 1 features in SAM decoder # to avoid running it again on every SAM click backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) @@ -60,6 +78,10 @@ def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torc for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) ][::-1] + if nvtx_helper is not None: + nvtx_helper.stop_profile("post_process") + nvtx_helper.print_latency() + return feats[0], feats[1], feats[2] diff --git a/onnxruntime/python/tools/transformers/models/sam2/nvtx_helper.py b/onnxruntime/python/tools/transformers/models/sam2/nvtx_helper.py new file mode 100644 index 0000000000000..63b936978e6f7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/nvtx_helper.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import nvtx +from cuda import cudart + + +class NvtxHelper: + def __init__(self, stages): + self.stages = stages + self.events = {} + for stage in stages: + for marker in ["start", "stop"]: + self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] + self.markers = {} + + def start_profile(self, stage, color="blue"): + self.markers[stage] = nvtx.start_range(message=stage, color=color) + event_name = stage + "-start" + if event_name in self.events: + cudart.cudaEventRecord(self.events[event_name], 0) + + def stop_profile(self, stage): + event_name = stage + "-stop" + if event_name in self.events: + cudart.cudaEventRecord(self.events[event_name], 0) + nvtx.end_range(self.markers[stage]) + + def print_latency(self): + for stage in self.stages: + latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1] + print(f"{stage}: {latency:.2f} ms")