diff --git a/.vscode/settings.json b/.vscode/settings.json index 94a3a17c2caa2..fd28e2d7b335c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -34,5 +34,10 @@ "python.linting.pydocstyleArgs": [ "--convention=google" ], - "python.linting.banditEnabled": true + "python.linting.banditEnabled": true, + "cpplint.lineLength": 120, + "cpplint.filters": [ + "-build/include_subdir", + "-runtime/references" + ] } diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index f48b884520ddc..bafabb410b5e5 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -10,11 +10,14 @@ #include #include +#include +#include +#include namespace torch_ort { namespace eager { -//#pragma region Helpers +// #pragma region Helpers using NodeAttributes = onnxruntime::NodeAttributes; namespace { inline bool is_device_supported(at::DeviceType type) { @@ -34,7 +37,7 @@ namespace { throw std::runtime_error("ORT copy: device not supported"); } } -} +} // namespace at::Tensor aten_tensor_from_ort( OrtValue&& ot, @@ -59,7 +62,7 @@ const std::vector aten_tensor_from_ort( onnxruntime::MLDataType ort_scalar_type_from_aten( at::ScalarType dtype) { - switch (dtype){ + switch (dtype) { case at::kFloat: return onnxruntime::DataTypeImpl::GetType(); case at::kDouble: @@ -107,7 +110,7 @@ OrtValue create_ort_value( break; } default: - // TODO: support more types + // TODO(unknown): support more types // For most at::ScalarType, it should be safe to just call value.to<> // on it, but for now we want to explicitly know when we've encountered // a new scalar type while bringing up ORT eager mode. @@ -131,13 +134,17 @@ OrtValue create_ort_value( auto element_type = ort_scalar_type_from_aten(tensor.scalar_type()); OrtValue ort_tensor; - onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape(tensor.sizes().vec()), tensor.data_ptr(), - *mem_info, ort_tensor, 0L /* offset = 0 - because tensor.data_ptr() includes the underyling offset */, - tensor.strides().vec()); + onnxruntime::Tensor::InitOrtValue( + element_type, + onnxruntime::TensorShape(tensor.sizes().vec()), + tensor.data_ptr(), + *mem_info, ort_tensor, + 0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset + tensor.strides().vec()); return ort_tensor; } -OrtValue create_ort_value(const at::Tensor& tensor){ +OrtValue create_ort_value(const at::Tensor& tensor) { auto& invoker = GetORTInvoker(tensor.device()); return create_ort_value(invoker, tensor); } @@ -146,7 +153,7 @@ std::vector create_ort_value( onnxruntime::ORTInvoker& invoker, at::TensorList values) { auto output = std::vector{}; - for (auto element: values){ + for (auto element : values) { output.push_back(create_ort_value(element)); } return output; @@ -157,7 +164,7 @@ onnx::AttributeProto create_ort_attribute( at::Scalar value, const bool isTensor, at::ScalarType type) { - if (isTensor){ + if (isTensor) { onnx::AttributeProto attr; attr.set_name(name); attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR); @@ -190,8 +197,7 @@ onnx::AttributeProto create_ort_attribute( ORT_THROW("Unsupported: at::ScalarType::", value.type()); } return attr; - } - else{ + } else { return create_ort_attribute(name, value, value.type()); } } @@ -254,33 +260,33 @@ onnx::AttributeProto create_ort_attribute( return attr; } -bool IsSupportedType(at::Scalar scalar, const std::vector& valid_types){ +bool IsSupportedType(at::Scalar scalar, const std::vector& valid_types) { return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end(); } -bool IsSupportedType(at::Tensor tensor, const std::vector& valid_types){ +bool IsSupportedType(at::Tensor tensor, const std::vector& valid_types) { return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end(); } -bool IsSupportedType(at::IntArrayRef arrary, const std::vector& valid_types){ +bool IsSupportedType(at::IntArrayRef arrary, const std::vector& valid_types) { return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() || std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end(); } -bool IsSupportedType(int64_t val, const std::vector& valid_types){ +bool IsSupportedType(int64_t val, const std::vector& valid_types) { return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end(); } -bool IsSupportedType(c10::optional val, const std::vector& valid_types){ +bool IsSupportedType(c10::optional val, const std::vector& valid_types) { return IsSupportedType(val.value(), valid_types); } -bool IsSupportedType(at::TensorList tensors, const std::vector& valid_types){ +bool IsSupportedType(at::TensorList tensors, const std::vector& valid_types) { return IsSupportedType(tensors[0], valid_types); } -ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){ - switch (dtype){ +ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype) { + switch (dtype) { case at::kFloat: return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; case at::kDouble: @@ -349,7 +355,7 @@ c10::optional PromoteScalarTypesWithCategory( return typeFromTensor; } -OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){ +OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type) { std::vector output(1); NodeAttributes attrs(2); attrs["to"] = create_ort_attribute( @@ -425,7 +431,7 @@ void resize_output( resize_impl_ort_(invoker, output, shape); } -//#pragma endregion +// #pragma endregion /* * Resize backing store of a TensorImpl. @@ -530,52 +536,44 @@ void resize_impl_ort_( return; } -//#pragma region Hand-Implemented ATen Ops +// #pragma region Hand-Implemented ATen Ops namespace aten { -at::Tensor empty_memory_format( +at::Tensor empty_strided( at::IntArrayRef size, - // *, + at::IntArrayRef stride, c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory, - c10::optional memory_format) { - ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format); - - assert(dtype_opt.has_value()); - assert(device_opt.has_value()); + c10::optional layout_opt, // Ignored because there's no ONNX support. + c10::optional device_opt, // Will be ORT by the time this is dispatched. + c10::optional pin_memory_opt) { // Ignored because there's no ONNX support. + ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); - // TODO: validate options and memory format - // TODO: figure out how to get the correct element type. OrtValue ot; + assert(device_opt.has_value()); + at::ScalarType dtype = c10::dtype_or_default(dtype_opt); auto& invoker = GetORTInvoker(*device_opt); - onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(*dtype_opt), onnxruntime::TensorShape(size.vec()), - invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot); + onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()), + invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot, + stride.vec()); return aten_tensor_from_ort( std::move(ot), at::TensorOptions() .device(*device_opt) - .dtype(*dtype_opt)); + .dtype(dtype)); } -at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional dtype_opt, - c10::optional layout_opt, c10::optional device_opt, - c10::optional pin_memory_opt) { - ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +at::Tensor empty_memory_format( + at::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory, + c10::optional memory_format) { // Ignored because there's no ONNX support. + ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format); - // TODO: how to handle type conversion - OrtValue ot; - assert(device_opt.has_value()); - // TODO: how to support layout - // assert(!layout_opt.has_value()); - at::ScalarType dtype = c10::dtype_or_default(dtype_opt); - auto& invoker = GetORTInvoker(*device_opt); - onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()), - invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot, - stride.vec()); - return aten_tensor_from_ort(std::move(ot), at::TensorOptions().device(*device_opt).dtype(dtype)); + // Use the strided impl with default (no strides specified). + return empty_strided(size, at::IntArrayRef({}), dtype_opt, layout_opt, device_opt, pin_memory); } // aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) @@ -602,9 +600,9 @@ at::Tensor as_strided( at::Tensor _reshape_alias( const at::Tensor& self, at::IntArrayRef size, - at::IntArrayRef stride){ + at::IntArrayRef stride) { ORT_LOG_FN(self, size, stride); - // TODO: support stride + // TODO(unknown): support stride auto& invoker = GetORTInvoker(self.device()); auto ort_input = create_ort_value(invoker, self); return aten_tensor_from_ort( @@ -645,7 +643,7 @@ at::Tensor& copy_( : src.device()); const auto ort_src = create_ort_value(invoker, src); auto ort_self = create_ort_value(invoker, self); - if (self.scalar_type() != src.scalar_type()){ + if (self.scalar_type() != src.scalar_type()) { // invoke cast first std::vector ort_cast_output(1); onnxruntime::NodeAttributes attrs(1); @@ -661,8 +659,7 @@ at::Tensor& copy_( "ORT return failure status:" + status.ErrorMessage()); copy(invoker, ort_cast_output[0], ort_self); - } - else{ + } else { copy(invoker, ort_src, ort_self); } @@ -671,7 +668,7 @@ at::Tensor& copy_( at::Tensor _copy_from_and_resize( const at::Tensor& self, - const at::Tensor& dst){ + const at::Tensor& dst) { ORT_LOG_FN(self, dst); assert_tensor_supported(self); @@ -688,11 +685,11 @@ at::Tensor _copy_from_and_resize( return self; } -at::Tensor& zero_(at::Tensor& self){ +at::Tensor& zero_(at::Tensor& self) { auto& invoker = GetORTInvoker(self.device()); auto ort_in_self = create_ort_value(invoker, self); OrtValue flag_val; - //construct a constant tensor + // construct a constant tensor auto element_type = onnxruntime::DataTypeImpl::GetType(); onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape({}), invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), flag_val); @@ -715,7 +712,7 @@ at::Tensor& zero_(at::Tensor& self){ return self; } -// TODO: enhance opgen.py to support inplace binary operations. +// TODO(unknown): enhance opgen.py to support inplace binary operations. // aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) at::Tensor& add__Tensor( at::Tensor& self, @@ -723,10 +720,11 @@ at::Tensor& add__Tensor( const at::Scalar& alpha) { ORT_LOG_FN(self, other, alpha); + auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; if ( - !IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) || - !IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) || - !IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) { + !IsSupportedType(alpha, st) || + !IsSupportedType(other, st) || + !IsSupportedType(self, st)) { return at::native::call_fallback_fn< &at::native::cpu_fallback, ATEN_OP(add__Tensor)>::call(self, other, alpha); @@ -827,8 +825,9 @@ bool keepdim, at::Tensor& out) { ORT_LOG_FN(self, dim, keepdim, out); + auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}; if ( - !IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) { + !IsSupportedType(self, st)) { return at::native::call_fallback_fn< &at::native::cpu_fallback, ATEN_OP(argmax_out)>::call(self, dim, keepdim, out); @@ -1034,7 +1033,7 @@ at::Tensor& _log_softmax_out( ORT_LOG_FN(self, dim, half_to_float, out); if ( - !IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) { + !IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) { return at::native::call_fallback_fn< &at::native::cpu_fallback, ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out); @@ -1096,7 +1095,7 @@ at::Tensor& _log_softmax_out( ort_outputs_2_Transpose[0] = ort_input_out; NodeAttributes attrs_2(1); - attrs_2["perm"] = create_ort_attribute("perm", axes);; + attrs_2["perm"] = create_ort_attribute("perm", axes); status = invoker.Invoke("Transpose", { std::move(ort_outputs_1_LogSoftmax[0]), @@ -1165,9 +1164,9 @@ at::Tensor& mm_out( } -} // namespace aten +} // namespace aten -//#pragma endregion +// #pragma endregion -} // namespace eager -} // namespace torch_ort +} // namespace eager +} // namespace torch_ort diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index a93c364335330..756e519f43366 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -224,6 +224,13 @@ def test_zero_stride(self): cpu_tensor_copied = ort_tensor.cpu() assert cpu_tensor_copied.stride() == (0, 0, 0) + def test_empty(self): + device = self.get_device() + cpu_tensor = torch.empty(size=(3, 4)) + ort_tensor = torch.empty(size=(3, 4), device=device) + assert ort_tensor.is_ort + assert ort_tensor.size() == cpu_tensor.size() + def test_softmax(self): device = self.get_device() cpu_tensor = torch.rand(3, 5)