Skip to content

Commit

Permalink
Eager mode: binary ops more complete behavior and testing. (#12293)
Browse files Browse the repository at this point in the history
* Remove hand written add_.Tensor as it can now be generated.

* Generate .out for tensor version of basic math ops. Add.out testing added too.

* Remove sin tests as they are covered by parameterized tests. Also, moved all parameterized tests to the end in their own section.

* Add binary ops tests for tensors. Scalar tests are calling the aten .out which is for tensor.

* Add support for scalar input to add, div, mul, and sub.
  • Loading branch information
WilBrady authored Jul 26, 2022
1 parent 3e014a5 commit de57daa
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 337 deletions.
28 changes: 14 additions & 14 deletions orttraining/orttraining/eager/opgen/opgen/atenops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,6 @@ def __init__(self, dY, X):
"selu",
]

for binary_op, onnx_op in {
"add": Add("self", Mul("alpha", "other")),
"sub": Sub("self", Mul("alpha", "other")),
"mul": Mul("self", "other"),
"div": Div("self", "other"),
}.items():
for dtype in ["Tensor", "Scalar"]:
for variant in ["", "_"]:
name = f"aten::{binary_op}{variant}.{dtype}"
if name not in ops:
ops[f"aten::{binary_op}{variant}.{dtype}"] = deepcopy(onnx_op)
type_promotion_ops.append(f"aten::{binary_op}{variant}.{dtype}")

for unary_op in unary_ops_with_out:
ops[f"aten::{unary_op}.out"] = onnx_ops[unary_op]("self")

Expand All @@ -115,6 +102,20 @@ def __init__(self, dY, X):
for unary_op in unary_ops:
ops[f"aten::{unary_op}"] = onnx_ops[unary_op]("self")

for binary_op, onnx_op in {
"add": Add("self", Mul("alpha", "other")),
"sub": Sub("self", Mul("alpha", "other")),
"mul": Mul("self", "other"),
"div": Div("self", "other"),
}.items():
# for Tensor, binary_op.out is used by both binary_op and binary_op_, so we only generate .out
# from testing and call stacks, it also apears scalar ops fall back to the (Tensor) binary_op.out,
# so this is all we need.
name = f"aten::{binary_op}.out"
if name not in ops:
ops[f"aten::{binary_op}.out"] = deepcopy(onnx_op)
type_promotion_ops.append(f"aten::{binary_op}.out")

# Notes on Onnx op mapping
#
# Equal - Onnx spec has the return as a bool tensor, but aten will keep the tensor
Expand All @@ -136,7 +137,6 @@ def __init__(self, dY, X):
# manually implement Slice using stride and offset.
"aten::slice.Tensor": SignatureOnly(),
"aten::addmm": Gemm("mat1", "mat2", "self", alpha="alpha", beta="beta"),
"aten::add_.Tensor": SignatureOnly(),
"aten::t": Transpose("self"),
# MatMul("self", "mat2"), fails since it resizes based on self but should be based on result shape of the mult
"aten::mm.out": SignatureOnly(),
Expand Down
139 changes: 83 additions & 56 deletions orttraining/orttraining/eager/opgen/opgen/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

# pylint: disable=missing-docstring, too-many-public-methods, too-many-nested-blocks

import json
import sys
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -264,53 +266,7 @@ def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedO
writer.write(first_param.identifier.value)
writer.writeline(".size()>0);")

# generate the type check
need_type_check = False
if not self._custom_ops:
for onnx_op_index, onnx_op in enumerate(ctx.ops):
for op_input in onnx_op.inputs:
if not isinstance(op_input, Outputs):
need_type_check = True
break
if need_type_check:
writer.write("if (")
i = 0
for onnx_op_index, onnx_op in enumerate(ctx.ops):
for idx, op_input in enumerate(onnx_op.inputs):
if isinstance(op_input, Outputs):
continue
writer.writeline(" || " if i > 0 else "")
if i == 0:
writer.push_indent()
cpp_param = cpp_func.get_parameter(op_input)
supported_types = ",".join(sorted([type for type in onnx_op.input_types[idx]]))
writer.write(f"!IsSupportedType({cpp_param.identifier.value}, {{{supported_types}}})")
i += 1
writer.writeline(") {")
self._write_cpu_fall_back(writer, mapped_func)
writer.pop_indent()
writer.writeline("}")

if (
not isinstance(first_param.parameter_type.desugar(), ast.ConcreteType)
or "Tensor" not in first_param.parameter_type.desugar().identifier_tokens[0].value
):
raise FunctionGenerationError(cpp_func, "First parameter must be an at::Tensor")

writer.write("auto& invoker = GetORTInvoker(")
writer.write(first_param.identifier.value)
if first_param.parameter_type.desugar().identifier_tokens[0].value == "TensorList":
writer.write("[0]")
writer.writeline(".device());")
writer.writeline()

# FIXME: warn if we have not consumed all torch parameters (either as
# an ORT input or ORT attribute).

# Perform kernel fission on the ATen op to yield a chain of ORT Invokes
# e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y))

# whether need type promotion
# check whether need type promotion, if we do we will use this later to confirm out cast is supported.
need_type_promotion = False
if mapped_func.mapped_op_name in self.type_promotion_ops:
types_from_tensor = []
Expand Down Expand Up @@ -338,7 +294,7 @@ def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedO
return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None

# if the torch func has a return ref tensor, out is the last param, and self param is the first input
# then we need to update and return out.
# then we need to update and return out. Record this need in set_out_tensor.
# TODO: make this more general to handle cases where the first param is not self such as
# - cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
# - complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)
Expand All @@ -350,7 +306,6 @@ def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedO
and last_param.identifier.value == "out"
and first_param.identifier.value == "self"
):

# output_alias is how the return tensor is marked, normally (a! -> a)
output_alias = self._get_alias_info(return_info)

Expand All @@ -359,12 +314,66 @@ def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedO
# and the current param (torch_p/last_param) is marked with the output alias
if output_alias and output_alias.is_writable and self._get_alias_info(torch_p) == output_alias:
set_out_tensor = True
writer.writeline("// resize the output and then create output ort value to be updated.")
writer.writeline(
"resize_output(invoker, dynamic_cast<ORTTensorImpl*>(out.unsafeGetTensorImpl()), self.sizes());"
)
writer.writeline("auto ort_input_out = create_ort_value(invoker, out);")
writer.writeline()

# generate the type check
need_type_check = False
cast_op_found = False
if not self._custom_ops:
for onnx_op_index, onnx_op in enumerate(ctx.ops):
for op_input in onnx_op.inputs:
if not isinstance(op_input, Outputs):
need_type_check = True
break
if need_type_check:
writer.write("if (")
i = 0
for onnx_op_index, onnx_op in enumerate(ctx.ops):
# track is the CAST op was explicitly used
if onnx_op.name == "Cast":
cast_op_found = True
for idx, op_input in enumerate(onnx_op.inputs):
if isinstance(op_input, Outputs):
continue
writer.writeline(" || " if i > 0 else "")
if i == 0:
writer.push_indent()
cpp_param = cpp_func.get_parameter(op_input)
supported_types = ",".join(sorted(list(onnx_op.input_types[idx])))
writer.write(f"!IsSupportedType({cpp_param.identifier.value}, {{{supported_types}}})")
i += 1
# if we have type promotion and need to set the out tensor and CAST op not explictily listed,
# then we confirm the promotion type is castable to the out type.
if need_type_promotion and set_out_tensor and not cast_op_found:
writer.writeline(" || ")
writer.write("!c10::canCast(*promoted_type, out.scalar_type())")
writer.writeline(") {")
self._write_cpu_fall_back(writer, mapped_func)
writer.pop_indent()
writer.writeline("}")

if (
not isinstance(first_param.parameter_type.desugar(), ast.ConcreteType)
or "Tensor" not in first_param.parameter_type.desugar().identifier_tokens[0].value
):
raise FunctionGenerationError(cpp_func, "First parameter must be an at::Tensor")

writer.write("auto& invoker = GetORTInvoker(")
writer.write(first_param.identifier.value)
if first_param.parameter_type.desugar().identifier_tokens[0].value == "TensorList":
writer.write("[0]")
writer.writeline(".device());")
writer.writeline()

# FIXME: warn if we have not consumed all torch parameters (either as
# an ORT input or ORT attribute).

if set_out_tensor:
writer.writeline("// resize the output and then create output ort value to be updated.")
writer.writeline(
"resize_output(invoker, dynamic_cast<ORTTensorImpl*>(out.unsafeGetTensorImpl()), self.sizes());"
)
writer.writeline("auto ort_input_out = create_ort_value(invoker, out);")
writer.writeline()

for onnx_op_index, onnx_op in enumerate(ctx.ops):
# Torch -> ORT inputs
Expand Down Expand Up @@ -467,7 +476,16 @@ def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedO
# if no in_place_params found and there is an out input to set
# and this is the last onnx op, we set the out to be written to
if len(in_place_params) == 0 and set_out_tensor and onnx_op_index == (len(ctx.ops) - 1):
writer.writeline(f"{onnx_op.outputs}[0] = ort_input_out;")
# if we have type promotion, need to set the out tensor and CAST op not explictily listed,
# check if we need to do a cast
if need_type_promotion and not cast_op_found:
writer.writeline("if (*promoted_type == out.scalar_type()) {")
writer.push_indent()
writer.writeline(f"{onnx_op.outputs}[0] = ort_input_out;")
writer.pop_indent()
writer.writeline("}")
else:
writer.writeline(f"{onnx_op.outputs}[0] = ort_input_out;")

if len(in_place_params) != 0 and len(in_place_params) != (
len(return_info.elements) if isinstance(return_info, ast.TupleType) else 1
Expand Down Expand Up @@ -511,6 +529,15 @@ def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedO
elif len(in_place_params) == 0:
# tensor options
if set_out_tensor:
if need_type_promotion and not cast_op_found:
writer.writeline("if (*promoted_type != out.scalar_type()) {")
writer.push_indent()
writer.writeline(
f"CastToType_out(invoker, {onnx_op.outputs}[0], ort_input_out, out.scalar_type());"
)
writer.pop_indent()
writer.writeline("}")

writer.writeline(f"return {last_param.identifier.value};")
return

Expand Down
96 changes: 23 additions & 73 deletions orttraining/orttraining/eager/ort_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,20 +367,32 @@ c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(

OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type) {
std::vector<OrtValue> output(1);
NodeAttributes attrs(2);
NodeAttributes attrs(1);
attrs["to"] = create_ort_attribute(
"to", GetONNXTensorProtoDataType(type), at::ScalarType::Long);

auto status = invoker.Invoke("Cast",
{std::move(input)},
output, &attrs);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
CHECK_STATUS(status);
return output[0];
}

void CastToType_out(onnxruntime::ORTInvoker& invoker, const OrtValue& input, OrtValue& output, at::ScalarType type) {
std::vector<OrtValue> output_result(1);
output_result[0] = output;
NodeAttributes attrs(1);
attrs["to"] = create_ort_attribute(
"to", GetONNXTensorProtoDataType(type), at::ScalarType::Long);

auto status = invoker.Invoke("Cast",
{std::move(input)},
output_result, &attrs);

CHECK_STATUS(status);
}

/*
* Utility method to calculate the resulting shape of tensor after a reduction operation.
*
Expand Down Expand Up @@ -662,9 +674,7 @@ at::Tensor& copy_(
{std::move(ort_src)},
ort_cast_output, &attrs);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
CHECK_STATUS(status);

copy(invoker, ort_cast_output[0], ort_self);
} else {
Expand Down Expand Up @@ -709,57 +719,7 @@ at::Tensor& zero_(at::Tensor& self) {
{std::move(ort_in_self), std::move(flag_val)},
ort_out, nullptr, onnxruntime::kMSDomain, 1);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());

return self;
}

// TODO(unknown): enhance opgen.py to support inplace binary operations.
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
at::Tensor& add__Tensor(
at::Tensor& self,
const at::Tensor& other,
const at::Scalar& alpha) {
ORT_LOG_FN(self, other, alpha);

auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
if (
!IsSupportedType(alpha, st) ||
!IsSupportedType(other, st) ||
!IsSupportedType(self, st)) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(add__Tensor)>::call(self, other, alpha);
}
auto& invoker = GetORTInvoker(self.device());

auto ort_input_alpha = create_ort_value(invoker, alpha, other.scalar_type());
auto ort_input_other = create_ort_value(invoker, other);

std::vector<OrtValue> ort_outputs_0_Mul(1);

auto status = invoker.Invoke("Mul",
{std::move(ort_input_alpha), std::move(ort_input_other)},
ort_outputs_0_Mul, nullptr);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());

auto ort_input_self = create_ort_value(invoker, self);

std::vector<OrtValue> ort_outputs_1_Add(1);
ort_outputs_1_Add[0] = ort_input_self;

status = invoker.Invoke("Add",
{std::move(ort_input_self), std::move(ort_outputs_0_Mul[0])},
ort_outputs_1_Add, nullptr);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
CHECK_STATUS(status);

return self;
}
Expand Down Expand Up @@ -864,9 +824,7 @@ at::Tensor& argmax_out(
{std::move(ort_input_self)},
ort_outputs_0_ArgMax, &attrs);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
CHECK_STATUS(status);

return out;
}
Expand Down Expand Up @@ -909,9 +867,7 @@ bool equal(
{std::move(ort_input_self), std::move(ort_input_other)},
ort_outputs_0_Equal, nullptr);

if (!equalStatus.IsOK())
throw std::runtime_error(
"ORT Equal return failure status:" + equalStatus.ErrorMessage());
CHECK_STATUS(equalStatus);

// now reduce the resulting tensor of bool values to its minimum value (any false)
NodeAttributes attrs(1);
Expand All @@ -928,9 +884,7 @@ bool equal(
{std::move(equalAsInt)},
ort_outputs_0_ReduceMin, &attrs);

if (!reduceStatus.IsOK())
throw std::runtime_error(
"ORT ReduceMin return failure reduceStatus:" + reduceStatus.ErrorMessage());
CHECK_STATUS(reduceStatus);

auto* ort_tensor = ort_outputs_0_ReduceMin[0].GetMutable<onnxruntime::Tensor>();
// the first (and only) value of the tensor will be 0 for false else true
Expand Down Expand Up @@ -983,9 +937,7 @@ at::Tensor& fill__Scalar(
{std::move(ort_input_self)},
ort_outputs_0_Shape, nullptr);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
CHECK_STATUS(status);

std::vector<OrtValue> ort_outputs_1_ConstantOfShape(1);
ort_outputs_1_ConstantOfShape[0] = ort_input_self;
Expand All @@ -998,9 +950,7 @@ at::Tensor& fill__Scalar(
{std::move(ort_outputs_0_Shape[0])},
ort_outputs_1_ConstantOfShape, &attrs);

if (!status.IsOK())
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
CHECK_STATUS(status);

return self;
}
Expand Down
Loading

0 comments on commit de57daa

Please sign in to comment.