-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cpplint & Eager mode: refactor and add comments to empty_* functions, general lint cleanup in ort_aten #12238
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
ee009cf
empty* comments and code reuse
msftlincoln 7d9261a
lint
msftlincoln 4a8948c
more cpplint
msftlincoln da9738d
add cpplint settings
msftlincoln 11452a9
test empty
msftlincoln e493cda
Merge remote-tracking branch 'origin/master' into users/msftlincoln/o…
msftlincoln File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,14 @@ | |
#include <c10/util/irange.h> | ||
#include <ATen/WrapDimUtils.h> | ||
|
||
#include <algorithm> | ||
#include <vector> | ||
#include <utility> | ||
|
||
namespace torch_ort { | ||
namespace eager { | ||
|
||
//#pragma region Helpers | ||
// #pragma region Helpers | ||
using NodeAttributes = onnxruntime::NodeAttributes; | ||
namespace { | ||
inline bool is_device_supported(at::DeviceType type) { | ||
|
@@ -34,7 +37,7 @@ namespace { | |
throw std::runtime_error("ORT copy: device not supported"); | ||
} | ||
} | ||
} | ||
} // namespace | ||
|
||
at::Tensor aten_tensor_from_ort( | ||
OrtValue&& ot, | ||
|
@@ -59,7 +62,7 @@ const std::vector<at::Tensor> aten_tensor_from_ort( | |
|
||
onnxruntime::MLDataType ort_scalar_type_from_aten( | ||
at::ScalarType dtype) { | ||
switch (dtype){ | ||
switch (dtype) { | ||
case at::kFloat: | ||
return onnxruntime::DataTypeImpl::GetType<float>(); | ||
case at::kDouble: | ||
|
@@ -107,7 +110,7 @@ OrtValue create_ort_value( | |
break; | ||
} | ||
default: | ||
// TODO: support more types | ||
// TODO(unknown): support more types | ||
// For most at::ScalarType, it should be safe to just call value.to<> | ||
// on it, but for now we want to explicitly know when we've encountered | ||
// a new scalar type while bringing up ORT eager mode. | ||
|
@@ -131,13 +134,17 @@ OrtValue create_ort_value( | |
auto element_type = ort_scalar_type_from_aten(tensor.scalar_type()); | ||
|
||
OrtValue ort_tensor; | ||
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape(tensor.sizes().vec()), tensor.data_ptr(), | ||
*mem_info, ort_tensor, 0L /* offset = 0 - because tensor.data_ptr() includes the underyling offset */, | ||
tensor.strides().vec()); | ||
onnxruntime::Tensor::InitOrtValue( | ||
element_type, | ||
onnxruntime::TensorShape(tensor.sizes().vec()), | ||
tensor.data_ptr(), | ||
*mem_info, ort_tensor, | ||
0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should we fix up spelling here while we are tidying this file? |
||
tensor.strides().vec()); | ||
return ort_tensor; | ||
} | ||
|
||
OrtValue create_ort_value(const at::Tensor& tensor){ | ||
OrtValue create_ort_value(const at::Tensor& tensor) { | ||
auto& invoker = GetORTInvoker(tensor.device()); | ||
return create_ort_value(invoker, tensor); | ||
} | ||
|
@@ -146,7 +153,7 @@ std::vector<OrtValue> create_ort_value( | |
onnxruntime::ORTInvoker& invoker, | ||
at::TensorList values) { | ||
auto output = std::vector<OrtValue>{}; | ||
for (auto element: values){ | ||
for (auto element : values) { | ||
output.push_back(create_ort_value(element)); | ||
} | ||
return output; | ||
|
@@ -157,7 +164,7 @@ onnx::AttributeProto create_ort_attribute( | |
at::Scalar value, | ||
const bool isTensor, | ||
at::ScalarType type) { | ||
if (isTensor){ | ||
if (isTensor) { | ||
onnx::AttributeProto attr; | ||
attr.set_name(name); | ||
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR); | ||
|
@@ -190,8 +197,7 @@ onnx::AttributeProto create_ort_attribute( | |
ORT_THROW("Unsupported: at::ScalarType::", value.type()); | ||
} | ||
return attr; | ||
} | ||
else{ | ||
} else { | ||
return create_ort_attribute(name, value, value.type()); | ||
} | ||
} | ||
|
@@ -254,33 +260,33 @@ onnx::AttributeProto create_ort_attribute( | |
return attr; | ||
} | ||
|
||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){ | ||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types) { | ||
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end(); | ||
} | ||
|
||
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types){ | ||
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types) { | ||
return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end(); | ||
} | ||
|
||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types){ | ||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types) { | ||
return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() || | ||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end(); | ||
} | ||
|
||
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types){ | ||
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types) { | ||
return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end(); | ||
} | ||
|
||
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types){ | ||
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types) { | ||
return IsSupportedType(val.value(), valid_types); | ||
} | ||
|
||
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types){ | ||
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types) { | ||
return IsSupportedType(tensors[0], valid_types); | ||
} | ||
|
||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){ | ||
switch (dtype){ | ||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype) { | ||
switch (dtype) { | ||
case at::kFloat: | ||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; | ||
case at::kDouble: | ||
|
@@ -349,7 +355,7 @@ c10::optional<at::ScalarType> PromoteScalarTypesWithCategory( | |
return typeFromTensor; | ||
} | ||
|
||
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){ | ||
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type) { | ||
std::vector<OrtValue> output(1); | ||
NodeAttributes attrs(2); | ||
attrs["to"] = create_ort_attribute( | ||
|
@@ -425,7 +431,7 @@ void resize_output( | |
resize_impl_ort_(invoker, output, shape); | ||
} | ||
|
||
//#pragma endregion | ||
// #pragma endregion | ||
|
||
/* | ||
* Resize backing store of a TensorImpl. | ||
|
@@ -530,52 +536,44 @@ void resize_impl_ort_( | |
return; | ||
} | ||
|
||
//#pragma region Hand-Implemented ATen Ops | ||
// #pragma region Hand-Implemented ATen Ops | ||
|
||
namespace aten { | ||
|
||
at::Tensor empty_memory_format( | ||
at::Tensor empty_strided( | ||
at::IntArrayRef size, | ||
// *, | ||
at::IntArrayRef stride, | ||
c10::optional<at::ScalarType> dtype_opt, | ||
c10::optional<at::Layout> layout_opt, | ||
c10::optional<at::Device> device_opt, | ||
c10::optional<bool> pin_memory, | ||
c10::optional<at::MemoryFormat> memory_format) { | ||
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format); | ||
|
||
assert(dtype_opt.has_value()); | ||
assert(device_opt.has_value()); | ||
c10::optional<at::Layout> layout_opt, // Ignored because there's no ONNX support. | ||
msftlincoln marked this conversation as resolved.
Show resolved
Hide resolved
|
||
c10::optional<at::Device> device_opt, // Will be ORT by the time this is dispatched. | ||
c10::optional<bool> pin_memory_opt) { // Ignored because there's no ONNX support. | ||
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); | ||
|
||
// TODO: validate options and memory format | ||
// TODO: figure out how to get the correct element type. | ||
OrtValue ot; | ||
assert(device_opt.has_value()); | ||
at::ScalarType dtype = c10::dtype_or_default(dtype_opt); | ||
msftlincoln marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto& invoker = GetORTInvoker(*device_opt); | ||
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(*dtype_opt), onnxruntime::TensorShape(size.vec()), | ||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot); | ||
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()), | ||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot, | ||
stride.vec()); | ||
return aten_tensor_from_ort( | ||
std::move(ot), | ||
at::TensorOptions() | ||
.device(*device_opt) | ||
.dtype(*dtype_opt)); | ||
.dtype(dtype)); | ||
} | ||
|
||
at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt, | ||
c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt, | ||
c10::optional<bool> pin_memory_opt) { | ||
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); | ||
at::Tensor empty_memory_format( | ||
WilBrady marked this conversation as resolved.
Show resolved
Hide resolved
|
||
at::IntArrayRef size, | ||
c10::optional<at::ScalarType> dtype_opt, | ||
c10::optional<at::Layout> layout_opt, | ||
c10::optional<at::Device> device_opt, | ||
c10::optional<bool> pin_memory, | ||
c10::optional<at::MemoryFormat> memory_format) { // Ignored because there's no ONNX support. | ||
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format); | ||
|
||
// TODO: how to handle type conversion | ||
OrtValue ot; | ||
assert(device_opt.has_value()); | ||
// TODO: how to support layout | ||
// assert(!layout_opt.has_value()); | ||
at::ScalarType dtype = c10::dtype_or_default(dtype_opt); | ||
auto& invoker = GetORTInvoker(*device_opt); | ||
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()), | ||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot, | ||
stride.vec()); | ||
return aten_tensor_from_ort(std::move(ot), at::TensorOptions().device(*device_opt).dtype(dtype)); | ||
// Use the strided impl with default (no strides specified). | ||
return empty_strided(size, at::IntArrayRef({}), dtype_opt, layout_opt, device_opt, pin_memory); | ||
} | ||
|
||
// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) | ||
|
@@ -602,9 +600,9 @@ at::Tensor as_strided( | |
at::Tensor _reshape_alias( | ||
const at::Tensor& self, | ||
at::IntArrayRef size, | ||
at::IntArrayRef stride){ | ||
at::IntArrayRef stride) { | ||
ORT_LOG_FN(self, size, stride); | ||
// TODO: support stride | ||
// TODO(unknown): support stride | ||
auto& invoker = GetORTInvoker(self.device()); | ||
auto ort_input = create_ort_value(invoker, self); | ||
return aten_tensor_from_ort( | ||
|
@@ -645,7 +643,7 @@ at::Tensor& copy_( | |
: src.device()); | ||
const auto ort_src = create_ort_value(invoker, src); | ||
auto ort_self = create_ort_value(invoker, self); | ||
if (self.scalar_type() != src.scalar_type()){ | ||
if (self.scalar_type() != src.scalar_type()) { | ||
// invoke cast first | ||
std::vector<OrtValue> ort_cast_output(1); | ||
onnxruntime::NodeAttributes attrs(1); | ||
|
@@ -661,8 +659,7 @@ at::Tensor& copy_( | |
"ORT return failure status:" + status.ErrorMessage()); | ||
|
||
copy(invoker, ort_cast_output[0], ort_self); | ||
} | ||
else{ | ||
} else { | ||
copy(invoker, ort_src, ort_self); | ||
} | ||
|
||
|
@@ -671,7 +668,7 @@ at::Tensor& copy_( | |
|
||
at::Tensor _copy_from_and_resize( | ||
const at::Tensor& self, | ||
const at::Tensor& dst){ | ||
const at::Tensor& dst) { | ||
ORT_LOG_FN(self, dst); | ||
|
||
assert_tensor_supported(self); | ||
|
@@ -688,11 +685,11 @@ at::Tensor _copy_from_and_resize( | |
return self; | ||
} | ||
|
||
at::Tensor& zero_(at::Tensor& self){ | ||
at::Tensor& zero_(at::Tensor& self) { | ||
auto& invoker = GetORTInvoker(self.device()); | ||
auto ort_in_self = create_ort_value(invoker, self); | ||
OrtValue flag_val; | ||
//construct a constant tensor | ||
// construct a constant tensor | ||
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>(); | ||
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape({}), | ||
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), flag_val); | ||
|
@@ -715,18 +712,19 @@ at::Tensor& zero_(at::Tensor& self){ | |
return self; | ||
} | ||
|
||
// TODO: enhance opgen.py to support inplace binary operations. | ||
// TODO(unknown): enhance opgen.py to support inplace binary operations. | ||
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) | ||
at::Tensor& add__Tensor( | ||
at::Tensor& self, | ||
const at::Tensor& other, | ||
const at::Scalar& alpha) { | ||
ORT_LOG_FN(self, other, alpha); | ||
|
||
auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; | ||
if ( | ||
!IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) || | ||
!IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) || | ||
!IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) { | ||
!IsSupportedType(alpha, st) || | ||
!IsSupportedType(other, st) || | ||
!IsSupportedType(self, st)) { | ||
return at::native::call_fallback_fn< | ||
&at::native::cpu_fallback, | ||
ATEN_OP(add__Tensor)>::call(self, other, alpha); | ||
|
@@ -827,8 +825,9 @@ bool keepdim, | |
at::Tensor& out) { | ||
ORT_LOG_FN(self, dim, keepdim, out); | ||
|
||
auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; | ||
if ( | ||
!IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) { | ||
!IsSupportedType(self, st)) { | ||
return at::native::call_fallback_fn< | ||
&at::native::cpu_fallback, | ||
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out); | ||
|
@@ -1034,7 +1033,7 @@ at::Tensor& _log_softmax_out( | |
ORT_LOG_FN(self, dim, half_to_float, out); | ||
|
||
if ( | ||
!IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) { | ||
!IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) { | ||
return at::native::call_fallback_fn< | ||
&at::native::cpu_fallback, | ||
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); | ||
|
@@ -1096,7 +1095,7 @@ at::Tensor& _log_softmax_out( | |
ort_outputs_2_Transpose[0] = ort_input_out; | ||
|
||
NodeAttributes attrs_2(1); | ||
attrs_2["perm"] = create_ort_attribute("perm", axes);; | ||
attrs_2["perm"] = create_ort_attribute("perm", axes); | ||
|
||
status = invoker.Invoke("Transpose", { | ||
std::move(ort_outputs_1_LogSoftmax[0]), | ||
|
@@ -1165,9 +1164,9 @@ at::Tensor& mm_out( | |
} | ||
|
||
|
||
} // namespace aten | ||
} // namespace aten | ||
|
||
//#pragma endregion | ||
// #pragma endregion | ||
|
||
} // namespace eager | ||
} // namespace torch_ort | ||
} // namespace eager | ||
} // namespace torch_ort |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there documentation for the
cpplint
settings? Does this apply to a particular extension?