From d5b990784723ab5df537f613ce8655c5ab4d9ecb Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Sep 2024 22:46:23 +0000 Subject: [PATCH] Add benchmark script for sam2 --- .../models/sam2/benchmark_sam2.py | 497 ++++++++++++++++++ .../models/sam2/benchmark_sam2.sh | 72 +++ 2 files changed, 569 insertions(+) create mode 100644 onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py create mode 100644 onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py new file mode 100644 index 0000000000000..fefb7d308af7e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.py @@ -0,0 +1,497 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage. +""" + +import argparse +import csv +import statistics +import time +from datetime import datetime +from typing import List, Mapping + +import torch +from image_decoder import SAM2ImageDecoder +from image_encoder import SAM2ImageEncoder +from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model + +from onnxruntime import InferenceSession, SessionOptions, get_available_providers +from onnxruntime.transformers.io_binding_helper import CudaSession + + +class TestConfig: + def __init__( + self, + model_type: str, + onnx_path: str, + sam2_dir: str, + device: torch.device, + component: str = "image_encoder", + provider="CPUExecutionProvider", + torch_compile_mode="max-autotune", + batch_size: int = 1, + height: int = 1024, + width: int = 1024, + num_labels: int = 1, + num_points: int = 1, + num_masks: int = 1, + multi_mask_output: bool = False, + use_tf32: bool = True, + enable_cuda_graph: bool = False, + dtype=torch.float32, + prefer_nhwc: bool = False, + warm_up: int = 5, + repeats: int = 1000, + verbose: bool = False, + ): + assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"] + assert height >= 160 and height <= 4096 + assert width >= 160 and width <= 4096 + + self.model_type = model_type + self.onnx_path = onnx_path + self.sam2_dir = sam2_dir + self.component = component + self.provider = provider + self.torch_compile_mode = torch_compile_mode + self.batch_size = batch_size + self.height = height + self.width = width + self.num_labels = num_labels + self.num_points = num_points + self.num_masks = num_masks + self.multi_mask_output = multi_mask_output + self.device = device + self.use_tf32 = use_tf32 + self.enable_cuda_graph = enable_cuda_graph + self.dtype = dtype + self.prefer_nhwc = prefer_nhwc + self.warm_up = 5 + self.repeats = repeats + self.verbose = verbose + + if self.component == "image_encoder": + assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder." + + def __repr__(self): + return f"{vars(self)}" + + def shape_dict(self) -> Mapping[str, List[int]]: + if self.component == "image_encoder": + return encoder_shape_dict(self.batch_size, self.height, self.width) + else: + return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks) + + def random_inputs(self): + if self.component == "image_encoder": + return { + "image": torch.randn( + self.batch_size, 3, self.height, self.width, dtype=torch.float32, device=self.device + ) + } + else: + return { + "image_features_0": torch.rand(1, 32, 256, 256, dtype=torch.float32, device=self.device), + "image_features_1": torch.rand(1, 64, 128, 128, dtype=torch.float32, device=self.device), + "image_embeddings": torch.rand(1, 256, 64, 64, dtype=torch.float32, device=self.device), + "point_coords": torch.randint( + 0, 1024, (self.num_labels, self.num_points, 2), dtype=torch.float32, device=self.device + ), + "point_labels": torch.randint( + 0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device + ), + "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=torch.float32, device=self.device), + "has_input_masks": torch.ones(self.num_labels, dtype=torch.float32, device=self.device), + "original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device), + } + + +def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession: + if config.verbose: + print(f"create session for {vars(config)}") + + if config.provider == "CUDAExecutionProvider": + device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index + provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) + provider_options["use_tf32"] = int(config.use_tf32) + if config.prefer_nhwc: + provider_options["prefer_nhwc"] = 1 + providers = [(config.provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + ort_session = InferenceSession(config.onnx_path, session_options, providers=providers) + return ort_session + + +def create_session(config: TestConfig, session_options=None) -> CudaSession: + ort_session = create_ort_session(config, session_options) + cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) + cuda_session.allocate_buffers(config.shape_dict()) + return cuda_session + + +class OrtTestSession: + """A wrapper of ORT session to test relevance and performance.""" + + def __init__(self, config: TestConfig, session_options=None): + self.ort_session = create_session(config, session_options) + self.feed_dict = config.random_inputs() + + def infer(self): + return self.ort_session.infer(self.feed_dict) + + +def measure_latency(cuda_session: CudaSession, input_dict): + start = time.time() + _ = cuda_session.infer(input_dict) + end = time.time() + return end - start + + +def run_torch(config: TestConfig): + device_type = config.device.type + is_cuda = device_type == "cuda" + + # Turn on TF32 for Ampere GPUs which could help when data type is float32. + if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + enabled_auto_cast = is_cuda and config.dtype != torch.float32 + ort_inputs = config.random_inputs() + + with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast): + sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device) + if config.component == "image_encoder": + if is_cuda: + sam2_model.image_encoder.forward = torch.compile( + sam2_model.image_encoder.forward, + mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run. + fullgraph=True, + dynamic=False, + ) + + image_shape = config.shape_dict()["image"] + img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype) + sam2_encoder = SAM2ImageEncoder(sam2_model) + + if is_cuda: + print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.") + for _ in range(config.warm_up): + _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img) + + print(f"Start {config.repeats} runs of performance tests...") + start = time.time() + for _ in range(config.repeats): + _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img) + if is_cuda: + torch.cuda.synchronize() + else: + torch_inputs = ( + ort_inputs["image_features_0"], + ort_inputs["image_features_1"], + ort_inputs["image_embeddings"], + ort_inputs["point_coords"], + ort_inputs["point_labels"], + ort_inputs["input_masks"], + ort_inputs["has_input_masks"], + ort_inputs["original_image_size"], + ) + + sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=config.multi_mask_output) + + # warm up + for _ in range(3): + _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs) + + print(f"Start {config.repeats} runs of performance tests...") + start = time.time() + for _ in range(config.repeats): + _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs) + if is_cuda: + torch.cuda.synchronize() + end = time.time() + return (end - start) / config.repeats + + +def run_test( + csv_writer: csv.DictWriter, + args: argparse.Namespace, +): + use_gpu: bool = args.use_gpu + enable_cuda_graph: bool = args.use_cuda_graph + repeats: int = args.repeats + + if use_gpu: + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + provider = "CUDAExecutionProvider" + else: + device_id = 0 + device = torch.device("cpu") + enable_cuda_graph = False + provider = "CPUExecutionProvider" + + dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + config = TestConfig( + model_type=args.model_type, + onnx_path=args.onnx_path, + sam2_dir=args.sam2_dir, + component=args.component, + provider=provider, + batch_size=args.batch_size, + height=args.height, + width=args.width, + device=device, + use_tf32=True, + enable_cuda_graph=False, + dtype=dtypes[args.dtype], + prefer_nhwc=args.prefer_nhwc, + repeats=args.repeats, + warm_up=args.warm_up, + verbose=False, + ) + + if args.engine == "ort": + sess_options = SessionOptions() + sess_options.intra_op_num_threads = args.intra_op_num_threads + + session = create_session(config, sess_options) + input_dict = config.random_inputs() + + # warm up session + try: + for _ in range(config.warm_up): + _ = measure_latency(session, input_dict) + except Exception as e: + print(f"Failed to run {config=}. Exception: {e}") + return + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + else: # torch + with torch.no_grad(): + try: + average_latency = run_torch(config) + except Exception as e: + print(f"Failed to run {config=}. Exception: {e}") + return + + engine = args.engine + ":" + ("cuda" if use_gpu else "cpu") + row = { + "model_type": args.model_type, + "component": args.component, + "dtype": args.dtype, + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "prefer_nhwc": config.prefer_nhwc, + "use_tf32": config.use_tf32, + "batch_size": args.batch_size, + "height": args.height, + "width": args.width, + "multi_mask_output": args.multimask_output, + "num_labels": config.num_labels, + "num_points": config.num_points, + "num_masks": config.num_masks, + "intra_op_num_threads": args.intra_op_num_threads, + "engine": engine, + "warm_up": config.warm_up, + "repeats": repeats, + "average_latency": average_latency, + } + csv_writer.writerow(row) + + print(f"{vars(config)}") + print(f"{row}") + + +def run_tests(args): + features = "gpu" if args.use_gpu else "cpu" + csv_filename = "benchmark_sam_{}_{}_{}.csv".format( + features, + args.engine, + datetime.now().strftime("%Y%m%d-%H%M%S"), + ) + with open(csv_filename, mode="a", newline="") as csv_file: + column_names = [ + "model_type", + "component", + "dtype", + "use_gpu", + "enable_cuda_graph", + "prefer_nhwc", + "use_tf32", + "batch_size", + "height", + "width", + "multi_mask_output", + "num_labels", + "num_points", + "num_masks", + "intra_op_num_threads", + "engine", + "warm_up", + "repeats", + "average_latency", + ] + csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) + csv_writer.writeheader() + + run_test(csv_writer, args) + + +def _parse_arguments(): + parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.") + + parser.add_argument( + "--component", + required=False, + choices=["image_encoder", "image_decoder"], + default="image_encoder", + help="component to benchmark. Choices are image_encoder and image_decoder.", + ) + + parser.add_argument( + "--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference." + ) + + parser.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for inference.", + ) + parser.set_defaults(use_gpu=False) + + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use cuda graph in onnxruntime.", + ) + parser.set_defaults(use_cuda_graph=False) + + parser.add_argument( + "--intra_op_num_threads", + required=False, + type=int, + choices=[0, 1, 2, 4, 8, 16], + default=0, + help="intra_op_num_threads for onnxruntime. ", + ) + + parser.add_argument( + "--batch_size", + required=False, + type=int, + default=1, + help="batch size", + ) + + parser.add_argument( + "--height", + required=False, + type=int, + default=1024, + help="image height", + ) + + parser.add_argument( + "--width", + required=False, + type=int, + default=1024, + help="image width", + ) + + parser.add_argument( + "--repeats", + required=False, + type=int, + default=1000, + help="number of repeats for performance test. Default is 1000.", + ) + + parser.add_argument( + "--warm_up", + required=False, + type=int, + default=5, + help="number of runs for warm up. Default is 5.", + ) + + parser.add_argument( + "--engine", + required=False, + type=str, + default="ort", + choices=["ort", "torch"], + help="engine for inference", + ) + + parser.add_argument( + "--multimask_output", + required=False, + default=False, + action="store_true", + help="Export mask_decoder or image_decoder with multimask_output", + ) + + parser.add_argument( + "--prefer_nhwc", + required=False, + default=False, + action="store_true", + help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider", + ) + + parser.add_argument( + "--model_type", + required=False, + type=str, + default="sam2_hiera_large", + choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"], + help="sam2 model name", + ) + + parser.add_argument( + "--sam2_dir", + required=False, + type=str, + default="./segment-anything-2", + help="The directory of segment-anything-2 git root directory", + ) + + parser.add_argument( + "--onnx_path", + required=False, + type=str, + default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx", + help="path of onnx model", + ) + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = _parse_arguments() + print(f"arguments:{args}") + + if args.use_gpu: + assert torch.cuda.is_available() + if args.engine == "ort": + assert "CUDAExecutionProvider" in get_available_providers() + + run_tests(args) diff --git a/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh new file mode 100644 index 0000000000000..74048f90424cd --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/benchmark_sam2.sh @@ -0,0 +1,72 @@ +#!/bin/sh +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +# Directory of the script +dir="$( cd "$( dirname "$0" )" && pwd )" + +# Directory of the onnx models +onnx_dir=$dir/sam2_onnx_models + +# Directory of the sam2 code by "git clone https://github.com/facebookresearch/segment-anything-2" +sam2_dir=~/segment-anything-2 + +# model name to benchmark +model=sam2_hiera_large + +run_cpu() +{ + repeats=$1 + + python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --demo + + echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 + + echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp32 --component image_decoder + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --dtype fp16 --component image_decoder + + echo "Benchmarking SAM2 model $model image encoder for ORT ..." + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder.onnx --dtype fp32 + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_cpu.onnx --dtype fp32 + + echo "Benchmarking SAM2 model $model image decoder for ORT ..." + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder.onnx --component image_decoder + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_cpu.onnx --component image_decoder +} + +run_gpu() +{ + repeats=$1 + + python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp32 + python3 convert_to_onnx.py --sam2_dir $sam2_dir --optimize --use_gpu --dtype fp16 --demo + + echo "Benchmarking SAM2 model $model image encoder for PyTorch ..." + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 + + echo "Benchmarking SAM2 model $model image decoder for PyTorch ..." + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp16 --component image_decoder + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype bf16 --component image_decoder + python3 benchmark_sam2.py --model_type $model --engine torch --sam2_dir $sam2_dir --repeats $repeats --use_gpu --dtype fp32 --component image_decoder + + echo "Benchmarking SAM2 model $model image encoder for ORT ..." + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp16_gpu.onnx --use_gpu --dtype fp16 + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_encoder_fp32_gpu.onnx --use_gpu --dtype fp32 + + echo "Benchmarking SAM2 model $model image decoder for ORT ..." + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp16_gpu.onnx --component image_decoder --use_gpu --dtype fp16 + python3 benchmark_sam2.py --model_type $model --engine ort --sam2_dir $sam2_dir --repeats $repeats --onnx_path ${onnx_dir}/${model}_image_decoder_fp32_gpu.onnx --component image_decoder --use_gpu +} + +if python3 -c "import torch; assert torch.cuda.is_available()" 2>/dev/null; then + run_gpu 1000 +else + run_cpu 100 +fi