Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpplint & Eager mode: refactor and add comments to empty_* functions, general lint cleanup in ort_aten #12238

Merged
merged 6 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@
"python.linting.pydocstyleArgs": [
"--convention=google"
],
"python.linting.banditEnabled": true
"python.linting.banditEnabled": true,
"cpplint.lineLength": 120,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there documentation for the cpplint settings? Does this apply to a particular extension?

"cpplint.filters": [
"-build/include_subdir",
"-runtime/references"
]
}
143 changes: 71 additions & 72 deletions orttraining/orttraining/eager/ort_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
#include <c10/util/irange.h>
#include <ATen/WrapDimUtils.h>

#include <algorithm>
#include <vector>
#include <utility>

namespace torch_ort {
namespace eager {

//#pragma region Helpers
// #pragma region Helpers
using NodeAttributes = onnxruntime::NodeAttributes;
namespace {
inline bool is_device_supported(at::DeviceType type) {
Expand All @@ -34,7 +37,7 @@ namespace {
throw std::runtime_error("ORT copy: device not supported");
}
}
}
} // namespace

at::Tensor aten_tensor_from_ort(
OrtValue&& ot,
Expand All @@ -59,7 +62,7 @@ const std::vector<at::Tensor> aten_tensor_from_ort(

onnxruntime::MLDataType ort_scalar_type_from_aten(
at::ScalarType dtype) {
switch (dtype){
switch (dtype) {
case at::kFloat:
return onnxruntime::DataTypeImpl::GetType<float>();
case at::kDouble:
Expand Down Expand Up @@ -107,7 +110,7 @@ OrtValue create_ort_value(
break;
}
default:
// TODO: support more types
// TODO(unknown): support more types
// For most at::ScalarType, it should be safe to just call value.to<>
// on it, but for now we want to explicitly know when we've encountered
// a new scalar type while bringing up ORT eager mode.
Expand All @@ -131,13 +134,17 @@ OrtValue create_ort_value(
auto element_type = ort_scalar_type_from_aten(tensor.scalar_type());

OrtValue ort_tensor;
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape(tensor.sizes().vec()), tensor.data_ptr(),
*mem_info, ort_tensor, 0L /* offset = 0 - because tensor.data_ptr() includes the underyling offset */,
tensor.strides().vec());
onnxruntime::Tensor::InitOrtValue(
element_type,
onnxruntime::TensorShape(tensor.sizes().vec()),
tensor.data_ptr(),
*mem_info, ort_tensor,
0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we fix up spelling here while we are tidying this file?

tensor.strides().vec());
return ort_tensor;
}

OrtValue create_ort_value(const at::Tensor& tensor){
OrtValue create_ort_value(const at::Tensor& tensor) {
auto& invoker = GetORTInvoker(tensor.device());
return create_ort_value(invoker, tensor);
}
Expand All @@ -146,7 +153,7 @@ std::vector<OrtValue> create_ort_value(
onnxruntime::ORTInvoker& invoker,
at::TensorList values) {
auto output = std::vector<OrtValue>{};
for (auto element: values){
for (auto element : values) {
output.push_back(create_ort_value(element));
}
return output;
Expand All @@ -157,7 +164,7 @@ onnx::AttributeProto create_ort_attribute(
at::Scalar value,
const bool isTensor,
at::ScalarType type) {
if (isTensor){
if (isTensor) {
onnx::AttributeProto attr;
attr.set_name(name);
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
Expand Down Expand Up @@ -190,8 +197,7 @@ onnx::AttributeProto create_ort_attribute(
ORT_THROW("Unsupported: at::ScalarType::", value.type());
}
return attr;
}
else{
} else {
return create_ort_attribute(name, value, value.type());
}
}
Expand Down Expand Up @@ -254,33 +260,33 @@ onnx::AttributeProto create_ort_attribute(
return attr;
}

bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
}

bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end();
}

bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() ||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}

bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}

bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types) {
return IsSupportedType(val.value(), valid_types);
}

bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types) {
return IsSupportedType(tensors[0], valid_types);
}

ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
switch (dtype){
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype) {
switch (dtype) {
case at::kFloat:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case at::kDouble:
Expand Down Expand Up @@ -349,7 +355,7 @@ c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
return typeFromTensor;
}

OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type) {
std::vector<OrtValue> output(1);
NodeAttributes attrs(2);
attrs["to"] = create_ort_attribute(
Expand Down Expand Up @@ -425,7 +431,7 @@ void resize_output(
resize_impl_ort_(invoker, output, shape);
}

//#pragma endregion
// #pragma endregion

/*
* Resize backing store of a TensorImpl.
Expand Down Expand Up @@ -530,52 +536,44 @@ void resize_impl_ort_(
return;
}

//#pragma region Hand-Implemented ATen Ops
// #pragma region Hand-Implemented ATen Ops

namespace aten {

at::Tensor empty_memory_format(
at::Tensor empty_strided(
at::IntArrayRef size,
// *,
at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt,
c10::optional<at::Device> device_opt,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);

assert(dtype_opt.has_value());
assert(device_opt.has_value());
c10::optional<at::Layout> layout_opt, // Ignored because there's no ONNX support.
msftlincoln marked this conversation as resolved.
Show resolved Hide resolved
c10::optional<at::Device> device_opt, // Will be ORT by the time this is dispatched.
c10::optional<bool> pin_memory_opt) { // Ignored because there's no ONNX support.
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);

// TODO: validate options and memory format
// TODO: figure out how to get the correct element type.
OrtValue ot;
assert(device_opt.has_value());
at::ScalarType dtype = c10::dtype_or_default(dtype_opt);
msftlincoln marked this conversation as resolved.
Show resolved Hide resolved
auto& invoker = GetORTInvoker(*device_opt);
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(*dtype_opt), onnxruntime::TensorShape(size.vec()),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot);
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot,
stride.vec());
return aten_tensor_from_ort(
std::move(ot),
at::TensorOptions()
.device(*device_opt)
.dtype(*dtype_opt));
.dtype(dtype));
}

at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt,
c10::optional<bool> pin_memory_opt) {
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
at::Tensor empty_memory_format(
WilBrady marked this conversation as resolved.
Show resolved Hide resolved
at::IntArrayRef size,
c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt,
c10::optional<at::Device> device_opt,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) { // Ignored because there's no ONNX support.
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);

// TODO: how to handle type conversion
OrtValue ot;
assert(device_opt.has_value());
// TODO: how to support layout
// assert(!layout_opt.has_value());
at::ScalarType dtype = c10::dtype_or_default(dtype_opt);
auto& invoker = GetORTInvoker(*device_opt);
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot,
stride.vec());
return aten_tensor_from_ort(std::move(ot), at::TensorOptions().device(*device_opt).dtype(dtype));
// Use the strided impl with default (no strides specified).
return empty_strided(size, at::IntArrayRef({}), dtype_opt, layout_opt, device_opt, pin_memory);
}

// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
Expand All @@ -602,9 +600,9 @@ at::Tensor as_strided(
at::Tensor _reshape_alias(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride){
at::IntArrayRef stride) {
ORT_LOG_FN(self, size, stride);
// TODO: support stride
// TODO(unknown): support stride
auto& invoker = GetORTInvoker(self.device());
auto ort_input = create_ort_value(invoker, self);
return aten_tensor_from_ort(
Expand Down Expand Up @@ -645,7 +643,7 @@ at::Tensor& copy_(
: src.device());
const auto ort_src = create_ort_value(invoker, src);
auto ort_self = create_ort_value(invoker, self);
if (self.scalar_type() != src.scalar_type()){
if (self.scalar_type() != src.scalar_type()) {
// invoke cast first
std::vector<OrtValue> ort_cast_output(1);
onnxruntime::NodeAttributes attrs(1);
Expand All @@ -661,8 +659,7 @@ at::Tensor& copy_(
"ORT return failure status:" + status.ErrorMessage());

copy(invoker, ort_cast_output[0], ort_self);
}
else{
} else {
copy(invoker, ort_src, ort_self);
}

Expand All @@ -671,7 +668,7 @@ at::Tensor& copy_(

at::Tensor _copy_from_and_resize(
const at::Tensor& self,
const at::Tensor& dst){
const at::Tensor& dst) {
ORT_LOG_FN(self, dst);

assert_tensor_supported(self);
Expand All @@ -688,11 +685,11 @@ at::Tensor _copy_from_and_resize(
return self;
}

at::Tensor& zero_(at::Tensor& self){
at::Tensor& zero_(at::Tensor& self) {
auto& invoker = GetORTInvoker(self.device());
auto ort_in_self = create_ort_value(invoker, self);
OrtValue flag_val;
//construct a constant tensor
// construct a constant tensor
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape({}),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), flag_val);
Expand All @@ -715,18 +712,19 @@ at::Tensor& zero_(at::Tensor& self){
return self;
}

// TODO: enhance opgen.py to support inplace binary operations.
// TODO(unknown): enhance opgen.py to support inplace binary operations.
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
at::Tensor& add__Tensor(
at::Tensor& self,
const at::Tensor& other,
const at::Scalar& alpha) {
ORT_LOG_FN(self, other, alpha);

auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
if (
!IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
!IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
!IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) {
!IsSupportedType(alpha, st) ||
!IsSupportedType(other, st) ||
!IsSupportedType(self, st)) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(add__Tensor)>::call(self, other, alpha);
Expand Down Expand Up @@ -827,8 +825,9 @@ bool keepdim,
at::Tensor& out) {
ORT_LOG_FN(self, dim, keepdim, out);

auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
if (
!IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) {
!IsSupportedType(self, st)) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
Expand Down Expand Up @@ -1034,7 +1033,7 @@ at::Tensor& _log_softmax_out(
ORT_LOG_FN(self, dim, half_to_float, out);

if (
!IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) {
!IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
Expand Down Expand Up @@ -1096,7 +1095,7 @@ at::Tensor& _log_softmax_out(
ort_outputs_2_Transpose[0] = ort_input_out;

NodeAttributes attrs_2(1);
attrs_2["perm"] = create_ort_attribute("perm", axes);;
attrs_2["perm"] = create_ort_attribute("perm", axes);

status = invoker.Invoke("Transpose", {
std::move(ort_outputs_1_LogSoftmax[0]),
Expand Down Expand Up @@ -1165,9 +1164,9 @@ at::Tensor& mm_out(
}


} // namespace aten
} // namespace aten

//#pragma endregion
// #pragma endregion

} // namespace eager
} // namespace torch_ort
} // namespace eager
} // namespace torch_ort
7 changes: 7 additions & 0 deletions orttraining/orttraining/eager/test/ort_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ def test_zero_stride(self):
cpu_tensor_copied = ort_tensor.cpu()
assert cpu_tensor_copied.stride() == (0, 0, 0)

def test_empty(self):
device = self.get_device()
cpu_tensor = torch.empty(size=(3, 4))
ort_tensor = torch.empty(size=(3, 4), device=device)
assert ort_tensor.is_ort
assert ort_tensor.size() == cpu_tensor.size()

def test_softmax(self):
device = self.get_device()
cpu_tensor = torch.rand(3, 5)
Expand Down