Skip to content

Commit

Permalink
Add fusion script for segment anything v2 (#22167)
Browse files Browse the repository at this point in the history
### Description
* Add MultiHeadAttention fusion for SAM2.
* Add LayerNormalization fusion for NCHW format by inserting Transpose
from NCHW to NHWC before layer normalization, and add another Transpose
after layer norm to convert NHWC back to NCHW. Hopefully, those extra
Transpose nodes will be removed when prefer_nhwc is enabled later.
* Add a condition that the input shall be 3D when fuse SkipLayerNorm.
* Update convert_to_onnx.py to add `--optimize` and `--use_gpu` options
to output optimized onnx model for CPU/CUDA eps.
* Add an option `--dtype fp16|fp32` in convert_to_onnx.py to support
converting optimized model to float16.
* Update the demo to use the optimized onnx models.

### Motivation and Context
To support optimization of SAM2 for CPU/CUDA eps that is exported in
#22119
  • Loading branch information
tianleiwu authored Sep 21, 2024
1 parent fe8a10c commit 1431215
Show file tree
Hide file tree
Showing 10 changed files with 1,085 additions and 106 deletions.
534 changes: 534 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_attention_sam2.py

Large diffs are not rendered by default.

160 changes: 158 additions & 2 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict
from typing import Dict, List

from fusion_base import Fusion
from onnx import helper
from onnx import TensorProto, helper
from onnx_model import OnnxModel

logger = getLogger(__name__)
Expand Down Expand Up @@ -143,6 +143,162 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name


class FusionLayerNormalizationNCHW(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "LayerNormalization", "ReduceMean")

def get_weight_or_bias(self, output_name, description):
value = self.model.get_constant_value(output_name)
if value is None:
logger.debug(f"{description} {output_name} is not initializer.")
return None

if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1:
logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}")
return None

return value.reshape([value.shape[0]])

def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
"""Append a Transpose node after an input"""
node_name = self.model.create_node_name("Transpose")

if output_name is None:
output_name = node_name + "_out" + "-" + input_name

transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
transpose_node.attribute.extend([helper.make_attribute("perm", perm)])

return transpose_node

def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
"""
Fuse Layer Normalization subgraph into one node LayerNormalization:
+----------------------+
| NxCxHxW |
| v (Cx1x1) (Cx1x1)
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add -->
(axes=1) | (Y=2) (axes=1) (E-6) ^
| |
+-----------------------------------------------+
Fused subgraph:
(0,2,3,1) (0,3,1,2)
[Root] --> Transpose --> LayerNormalization --> Transpose -->
"""
axes = OnnxModel.get_node_attribute(node, "axes")
if (not isinstance(axes, list)) or axes != [1]:
return

subgraph_nodes = []
children = self.model.get_children(node, input_name_to_nodes)
if len(children) != 1:
return

root_input = node.input[0]

if children[0].op_type != "Sub" or children[0].input[0] != root_input:
return
sub = children[0]

div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False)
if div_node is None:
return

parent_nodes = self.model.match_parent_path(
div_node,
["Sqrt", "Add", "ReduceMean", "Pow", "Sub"],
[1, 0, 0, 0, 0],
output_name_to_node,
)
if parent_nodes is None:
return

_sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes
if sub != sub_node:
return

i, add_weight = self.model.get_constant_input(second_add_node)
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}")
return

axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes")
assert isinstance(axes, list)
if axes != [1]:
return

if self.model.find_constant_input(pow_node, 2.0) != 1:
return

temp_node = input_name_to_nodes[div_node.output[0]][0]
mul_node = temp_node
if mul_node.op_type != "Mul":
return

last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return

subgraph_nodes.append(node)
subgraph_nodes.extend(parent_nodes)
subgraph_nodes.extend([last_add_node, mul_node, div_node])

if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
last_add_node.output,
input_name_to_nodes,
output_name_to_node,
):
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
return

node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
weight = self.get_weight_or_bias(weight_input, "layernorm weight")
if weight is None:
return

bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
bias = self.get_weight_or_bias(bias_input, "layernorm bias")
if bias is None:
return

weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)

bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
self.model.add_initializer(weight_nhwc, self.this_graph_name)
self.model.add_initializer(bias_nhwc, self.this_graph_name)

self.nodes_to_remove.extend(subgraph_nodes)

transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1])

layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm")

transpose_output = self.create_transpose_node(
layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0]
)

normalize_node = helper.make_node(
"LayerNormalization",
inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"],
outputs=[layernorm_node_name + "_out_nhwc"],
name=layernorm_node_name,
)
normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])

self.nodes_to_add.append(transpose_input)
self.nodes_to_add.append(normalize_node)
self.nodes_to_add.append(transpose_output)
self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name

counter_name = "LayerNormalization(NHWC)"
self.increase_counter(counter_name)


class FusionLayerNormalizationTF(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "LayerNormalization", "Add", "TF")
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_skiplayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):

if hasattr(self, "shape_infer_helper"):
if self.shape_infer_helper is not None:
if (
self.shape_infer_helper.get_edge_shape(add.input[0])
and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
):
logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
return

# TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
logger.debug(
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/python/tools/transformers/models/sam2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,39 @@ To see all parameters, run the following command:
python3 convert_to_onnx.py -h
```

## Optimize ONNX

To optimize the onnx models for CPU with float32 data type:
```bash
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp32
```

To optimize the onnx models for GPU with float16 data type:
```bash
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu
```

Another option is to use optimizer.py like the following:
```
cd ../..
python optimizer.py --input models/sam2/sam2_onnx_models/sam2_hiera_large_image_encoder.onnx \
--output models/sam2/sam2_onnx_models/sam2_hiera_large_image_encoder_fp16_gpu.onnx \
--use_gpu --model_type sam2 --float16
```
The optimizer.py could be helpful when you have SAM2 onnx models that is exported by other tools.

## Run Demo

The exported ONNX models can run on a CPU. The demo will output sam2_demo.png.
```bash
curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo
```

It is able to run demo on optimized model as well. For example,
```bash
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --optimize --dtype fp16 --use_gpu --demo
```

## Limitations
- The exported image_decoder model does not support batch mode for now.
Loading

0 comments on commit 1431215

Please sign in to comment.