Skip to content

Commit

Permalink
ort 1.15.1 with vision patch
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathik Rao committed Aug 16, 2023
1 parent baeece4 commit 060bb44
Showing 1 changed file with 52 additions and 20 deletions.
72 changes: 52 additions & 20 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,26 @@ static bool IsSupportedDataType(const Node& node, int first_n_inputs = -1) {
return true;
}

static bool IsFP16OutputDataType(const Node& node, int first_n_outputs = -1) {
int output_index = 0;
for (const auto& output_arg : node.OutputDefs()) {
if (first_n_outputs != -1 && output_index >= first_n_outputs) {
return true;
}
if (*(output_arg->Type()) != "tensor(float16)") {
return false;
}
++output_index;
}
return true;
}

/**
Layer Normalization will fuse LayerNormalization into one node :
+---------------------+
| |
| v
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Optional[Add] --> Sqrt --> Div --> Mul --> Add
| ^
| |
+-----------------------------------------------+
Expand All @@ -47,7 +61,7 @@ It also handles cases of duplicated sub nodes exported from older version of PyT
| +-------> Sub ---------------------------------------------+
| | |
| | v
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Optional[Add] --> Sqrt --> Div --> Mul --> Add
| ^
| |
+---------------------+
Expand All @@ -57,14 +71,14 @@ due to restriction in older opsets. Therefore, Layer Normalization will also han
+---------------------+
| |
| v
X --> ReduceMean --> Sub --> Cast --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
X --> ReduceMean --> Sub --> Cast --> Pow --> ReduceMean --> Optional[Add] --> Sqrt --> Div --> Mul --> Add
| ^
| |
+------------------------------------------------+
+---------------------+ Cast
| | |
| v v
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Optional[Add] --> Sqrt --> Div --> Mul --> Add
| ^
| |
+------------------------------------------------+
Expand All @@ -73,7 +87,7 @@ When using Apex O2, a Cast node may be inserted between Div and Mul, Layer Norma
+---------------------+
| |
| v
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Cast --> Mul --> Add
X --> ReduceMean --> Sub --> Pow --> ReduceMean --> Optional[Add] --> Sqrt --> Div --> Cast --> Mul --> Add
| ^
| |
+-----------------------------------------------+
Expand All @@ -83,14 +97,18 @@ OR
+---------------------+
| |
| v
X --> Cast --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Cast --> Mul --> Add
X --> Cast --> ReduceMean --> Sub --> Pow --> ReduceMean --> Optional[Add] --> Sqrt --> Div --> Cast --> Mul --> Add
| ^
| |
+-----------------------------------------------+
Logically since LayerNormalization supports input and scale/bias in different data types, and during the kernel execution,
data are casted to float/double to calculate for precision, so if there is any Cast Ops in the sub-graph, we can remove it.
Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes.
Optional[Add]: In mixed precision training setting, if epsilon is less than 1e-5 (default value) then downcasting from fp32 to fp16
sets the epsilon to 0. Consequently, the add node will be removed during round 1 of optimizations (since adding 0 is a no-op).
Thus, we account for this removal by making the add node check optional if performing mixed precision training.
*/
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
Expand Down Expand Up @@ -254,17 +272,26 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,

// Traceback the sqrt node to find add --> sqrt
Node& add2_node = *graph.GetNode(sqrt_node.InputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13, 14}) ||
add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add2_node, 1) ||
!IsSupportedDataType(add2_node)) {
const Node* p_reduce_mean2 = nullptr;
if (add2_node.OpType() == "Add") {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add2_node, "Add", {7, 13, 14}) ||
add2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add2_node, 1) ||
!IsSupportedDataType(add2_node)) {
continue;
}
nodes_to_remove.push_back(add2_node);
// Traceback the add node to find reduceMean --> add

p_reduce_mean2 = graph_utils::FirstParentByType(add2_node, "ReduceMean");
} else if (add2_node.OpType() == "ReduceMean" && IsFP16OutputDataType(add2_node)) {
// add2_node was removed by CommonSubexpressionElimination duing mixed precision training
// allow fusion to proceed without add2_node
p_reduce_mean2 = &add2_node;
} else {
continue;
}
nodes_to_remove.push_back(add2_node);
// Traceback the add node to find reduceMean --> add
const Node* p_reduce_mean2 = nullptr;

p_reduce_mean2 = graph_utils::FirstParentByType(add2_node, "ReduceMean");
if (p_reduce_mean2 == nullptr) {
continue;
}
Expand Down Expand Up @@ -414,13 +441,18 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
{}, {}, kOnnxDomain);

// Get constant "epsilon" from "Add2" node if available. Else, default value will be used.
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add2_node.MutableInputDefs()[1]->Name());
if (tensor_proto != nullptr &&
tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
Initializer initializer{*tensor_proto, graph.ModelPath()};
layer_norm_node.AddAttribute("epsilon", initializer.data<float>()[0]);
} else {
// If add2_node was removed, use default value of epsilon
if (add2_node.OpType() == "ReduceMean" && IsFP16OutputDataType(add2_node)) {
layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON);
} else {
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add2_node.MutableInputDefs()[1]->Name());
if (tensor_proto != nullptr &&
tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
Initializer initializer{*tensor_proto, graph.ModelPath()};
layer_norm_node.AddAttribute("epsilon", initializer.data<float>()[0]);
} else {
layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON);
}
}

// Set stash_type to double if any input is double, default value if float.
Expand Down

0 comments on commit 060bb44

Please sign in to comment.