Skip to content

Commit

Permalink
Revert "Fix convergence for dolly+stage3 training (#17685)"
Browse files Browse the repository at this point in the history
This reverts commit 7201def.
  • Loading branch information
yf711 committed Oct 10, 2023
1 parent 28a556f commit f77b21b
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 657 deletions.
85 changes: 32 additions & 53 deletions orttraining/orttraining/core/framework/torch/torch_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "orttraining/core/framework/torch/gil.h"
#include "core/platform/env.h"

namespace onnxruntime::language_interop_ops::torch {
namespace onnxruntime {
namespace language_interop_ops {
namespace torch {

void PythonObjectDeleter(PyObject* ptr) { Py_XDECREF(ptr); };

Expand Down Expand Up @@ -128,18 +130,6 @@ PyObject* CreateRequiresGradFlags(
return flags;
}

PyObject* CreateInplaceMap(
const std::vector<int64_t>& inplace_map) {
PyObject* inplace_map_obj = Ort_PyList_New(inplace_map.size(), "inplace_map");

for (size_t output_index = 0; output_index < inplace_map.size(); ++output_index) {
PyObject* input_index = PyLong_FromLong(inplace_map[output_index]);
Ort_PyList_SetItem_NoIncref(inplace_map_obj, output_index, input_index, std::to_string(__LINE__));
}

return inplace_map_obj;
}

void InvokeRunner(
PyObject* callback_runner,
PyObject* args,
Expand Down Expand Up @@ -207,15 +197,14 @@ PythonObjectPtr CreatePythonCallArguments(
const std::vector<void*>& obj_args,
const std::vector<int64_t>& obj_indices,
const bool is_training_mode,
const std::vector<int64_t>& inplace_map,
const std::string& invoke_id,
const std::string& func_name) {
const bool is_inplace,
const std::string& invoke_id) {
ORT_ENFORCE(PyCallable_Check(callback), "Forward callback is not callable.");
// The number of variables before those of
// autograd.Function.apply and autograd.Function.backward.
// The extra variables are used to configure the launch
// forward and backward runners.
constexpr int64_t num_control_args = 7;
constexpr int64_t num_control_args = 6;

// All arguments created for Python call will be destroyed along with PythonObjectPtr.
PythonObjectPtr args(Ort_PyTuple_New(num_control_args + len, "forward_arguments_tuple"), PythonObjectDeleter);
Expand All @@ -227,16 +216,11 @@ PythonObjectPtr CreatePythonCallArguments(
Ort_PyTuple_SetItem_NoIncref(args.get(), 2, tensor_flags, "tensor_flags");
PyObject* is_training_mode_arg = is_training_mode ? Py_True : Py_False;
Ort_PyTuple_SetItem_Incref(args.get(), 3, is_training_mode_arg, "is_training_mode");

PyObject* inplace_map_arg = CreateInplaceMap(inplace_map);
Ort_PyTuple_SetItem_NoIncref(args.get(), 4, inplace_map_arg, "inplace_map");

PyObject* is_inplace_arg = is_inplace ? Py_True : Py_False;
Ort_PyTuple_SetItem_Incref(args.get(), 4, is_inplace_arg, "is_inplace_mode");
PyObject* kernel_invoke_id_arg = PyBytes_FromStringAndSize(invoke_id.c_str(), invoke_id.size());
Ort_PyTuple_SetItem_NoIncref(args.get(), 5, kernel_invoke_id_arg, "kernel_invoke_id_arg");

PyObject* func_name_arg = PyBytes_FromStringAndSize(func_name.c_str(), func_name.size());
Ort_PyTuple_SetItem_NoIncref(args.get(), 6, func_name_arg, "func_name_arg");

// Tensor inputs to call autograd.Function.apply or autograd.Function.backward.
for (size_t i = 0; i < tensor_args.size(); ++i) {
if (!tensor_args[i].has_value()) {
Expand All @@ -262,19 +246,18 @@ PythonObjectPtr CreatePythonCallArguments(
}

void Invoke(
const std::string& func_name,
PyObject* runner,
PyObject* callback,
const std::vector<int64_t>& requires_grads,
const std::vector<std::optional<OrtValue>>& tensor_args,
const std::vector<int64_t>& tensor_indices,
const std::vector<void*>& obj_args,
const std::vector<int64_t>& obj_indices,
const bool is_training_mode,
const std::vector<int64_t>& inplace_map,
const std::string& invoke_id,
void** diff_ctx,
std::vector<OrtValue>& returned_ortvalues) {
std::vector<OrtValue>& returned_ortvalues,
const bool is_training_mode,
const bool is_inplace,
const std::string& invoke_id) {
const auto len = tensor_args.size() + obj_args.size();
CheckArguments(len, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices);
RefCountTracker::GetInstance().Reset();
Expand All @@ -288,9 +271,8 @@ void Invoke(
obj_args,
obj_indices,
is_training_mode,
inplace_map,
invoke_id,
func_name);
is_inplace,
invoke_id);

RefCountTracker::GetInstance().DumpDetails("Before Invoke Python Call");
InvokeRunner(runner, args.get(), is_training_mode, diff_ctx, returned_ortvalues);
Expand All @@ -300,18 +282,17 @@ void Invoke(
}

void TorchProxy::Forward(
const std::string& func_name,
void* callback,
const std::vector<int64_t>& requires_grads,
const std::vector<std::optional<OrtValue>>& tensor_args,
const std::vector<int64_t>& tensor_indices,
const std::vector<void*>& obj_args,
const std::vector<int64_t>& obj_indices,
const bool is_training_mode,
const std::vector<int64_t>& inplace_map,
const std::string& invoke_id,
void** diff_ctx,
std::vector<OrtValue>& returned_ortvalues) {
std::vector<OrtValue>& returned_ortvalues,
const bool is_training_mode,
const bool is_inplace,
const std::string& invoke_id) {
// Semantically, this lock uniquely takes the ownership of TorchProxy
// so that there will be only one of TorchProxy::Forward TorchProxy::Backward
// can be run at one time.
Expand All @@ -320,31 +301,29 @@ void TorchProxy::Forward(
GilGuard guard;
auto runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner();
Invoke(
func_name,
runner,
reinterpret_cast<PyObject*>(callback),
requires_grads,
tensor_args,
tensor_indices,
obj_args,
obj_indices,
is_training_mode,
inplace_map,
invoke_id,
diff_ctx,
returned_ortvalues);
returned_ortvalues,
is_training_mode,
is_inplace,
invoke_id);
}

void TorchProxy::Backward(
const std::string& func_name,
void* callback,
const std::vector<std::optional<OrtValue>>& tensor_args,
const std::vector<int64_t>& tensor_indices,
const std::vector<void*>& obj_args,
const std::vector<int64_t>& obj_indices,
const std::vector<int64_t>& inplace_map,
const std::string& invoke_id,
std::vector<OrtValue>& returned_ortvalues) {
std::vector<OrtValue>& returned_ortvalues,
const bool is_inplace,
const std::string& invoke_id) {
// Semantically, this lock uniquely takes the ownership of TorchProxy
// so that there will be only one of TorchProxy::Forward TorchProxy::Backward
// can be run at one time.
Expand All @@ -357,19 +336,19 @@ void TorchProxy::Backward(
const auto all_input_count = tensor_args.size() + obj_args.size();
const std::vector<int64_t> requires_grads(all_input_count, 0);
Invoke(
func_name,
runner,
reinterpret_cast<PyObject*>(callback),
requires_grads,
tensor_args,
tensor_indices,
obj_args,
obj_indices,
true /* is_training_mode */,
inplace_map,
invoke_id,
nullptr /* context to store */,
returned_ortvalues);
returned_ortvalues,
true /* is_training_mode */,
is_inplace,
invoke_id);
}

} // namespace onnxruntime::language_interop_ops::torch
} // namespace torch
} // namespace language_interop_ops
} // namespace onnxruntime
16 changes: 7 additions & 9 deletions orttraining/orttraining/core/framework/torch/torch_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,27 @@ class TorchProxy {
};

void Forward(
const std::string& func_name,
void* callback,
const std::vector<int64_t>& requires_grads,
const std::vector<std::optional<OrtValue>>& tensor_args,
const std::vector<int64_t>& tensor_indices,
const std::vector<void*>& obj_args,
const std::vector<int64_t>& obj_indices,
const bool is_training_mode,
const std::vector<int64_t>& inplace_map,
const std::string& invoke_id,
void** diff_ctx,
std::vector<OrtValue>& returned_ortvalues);
std::vector<OrtValue>& returned_ortvalues,
const bool is_training_mode,
const bool is_inplace,
const std::string& invoke_id);

void Backward(
const std::string& func_name,
void* callback,
const std::vector<std::optional<OrtValue>>& tensor_args,
const std::vector<int64_t>& tensor_indices,
const std::vector<void*>& obj_args,
const std::vector<int64_t>& obj_indices,
const std::vector<int64_t>& inplace_map,
const std::string& invoke_id,
std::vector<OrtValue>& return_args);
std::vector<OrtValue>& return_args,
const bool is_inplace,
const std::string& invoke_id);

private:
TorchProxy(){};
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) {
ORT_ENFORCE(utils::HasString(src_attrs.at("func_name")));
attrs.push_back(MakeAttribute("func_name", src_attrs.at("func_name").s()));
attrs.push_back(MakeAttribute("output_convention", src_attrs.at("input_convention").s()));
attrs.push_back(MakeAttribute("inplace", src_attrs.at("inplace").i()));

// input_tensor_types[i] store the type of autograd.Function.apply's ith output.
// Note that PythonOpGrad's 0-th input is the Python context generated by PythonOp.
Expand Down
29 changes: 9 additions & 20 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3908,16 +3908,10 @@ Return true if all elements are true and false otherwise.
AttributeProto::INTS)
// Other attributes.
.Attr(
"tensor_reuse_map",
"A int array indicating whether output at each index is reusing specific input or not."
"If the given index is -1, it means the output is not reusing any input."
"For example, there are 2 tensor inputs and 3 tensor outputs (including ctx), "
"tensor_reuse_map = [-1, 1, 0] means"
"- the output 0 (ctx) don't reuse any input buffer."
"- the output 1 reuses the input 1."
"- the output 2 reuses the input 0.",
AttributeProto::INTS,
false)
"inplace",
"Indicate if the output should reuse input memory.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"training_mode",
"Indicate if the model is exported in training_mode, by default, False.",
Expand Down Expand Up @@ -4039,6 +4033,11 @@ Return true if all elements are true and false otherwise.
"func_name",
"Name of custom class.",
AttributeProto::STRING)
.Attr(
"inplace",
"Indicate if the output should reuse input memory. Todo(pengwa): do we need it?",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"input_tensor_types",
"Input types of autograd.Function.backward (including only tensor inputs)."
Expand Down Expand Up @@ -4070,16 +4069,6 @@ Return true if all elements are true and false otherwise.
"A string inidicating autograd.Function.backward outputs's type."
"value 'c' - non-tensor output; value 'd' - tensor output.",
AttributeProto::STRING)
.Attr(
"tensor_reuse_map",
"A int array indicating whether output at each index is reusing specific input or not."
"If the given index is -1, it means the output is not reusing any input."
"For example, there are 3 inputs (including ctx) and 2 outputs, tensor_reuse_map = [2, 1] means"
"- the output 0 reuses the input 2."
"- the output 1 reuses the input 1."
"Be noted: the input 0 is ctx.",
AttributeProto::INTS,
false)
.Attr(
"comment",
"comment only for debugging purposes.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"wrap exportable sub-nn.Module's as ORTModule."
)

inplace = kwargs["inplace"]
# TODO move to public API once the exporter team exposes that
training_mode = None
if get_runtime_pytorch_version() >= version.parse("1.12"):
Expand Down Expand Up @@ -259,6 +260,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):

attrs = {
"func_name_s": func_full_qual_name,
"inplace_i": inplace,
"input_convention_s": cconv,
"outputs": n.outputsSize(),
"input_tensor_types_i": input_tensor_types,
Expand Down
Loading

0 comments on commit f77b21b

Please sign in to comment.