Skip to content

Commit

Permalink
[Cherry-pick] fix set_value with scalar grad (#60930)
Browse files Browse the repository at this point in the history
* Fix set value grad (#59034)

* first fix the UT

* fix set value grad

* polish code

* add static mode backward test

* always has input valuetensor

* add dygraph test

* Fix shape error in combined-indexing setitem (#60447)

* add ut

* fix shape error in combine-indexing

* fix ut

* Set value with scalar (#60452)

* set_value with scalar

* fix ut

* remove test_pir

* remove one test since 2.6 not support uint8-add
  • Loading branch information
zoooo0820 authored Jan 19, 2024
1 parent d788e9b commit 1aa5f4b
Show file tree
Hide file tree
Showing 13 changed files with 530 additions and 118 deletions.
44 changes: 19 additions & 25 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {

protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("set_value_grad");

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
op->SetType("set_value_grad");
op->SetInput("ValueTensor", this->Input("ValueTensor"));
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));

if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
}
};

Expand Down
108 changes: 63 additions & 45 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,

// step3: Dealing with advanced indexing
std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim;
std::vector<int> trans_back_dim, trans_dim;
int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1;

paddle::Tensor transed_tensor = dealWithAdvancedIndex(out,
Expand All @@ -1385,7 +1385,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&transed_index,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim);
&rank_of_new_dim,
&trans_dim);

if (transed_index.size() == 1 &&
transed_index[0].dtype() == phi::DataType::BOOL) {
Expand Down Expand Up @@ -1607,58 +1608,70 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Parse values
PADDLE_ENFORCE(
PyCheckTensor(value_obj),
platform::errors::InvalidArgument("The value must be a Tensor"));

std::vector<phi::Scalar> values;
paddle::Tensor value_tensor =
reinterpret_cast<TensorObject*>(value_obj)->tensor;
dealWithValues(tensor, value_obj, &values, has_advanced_index);

if (!has_advanced_index) {
// use set_value OP if there is no advanced index

// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (value_tensor.initialized()) {
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
}
}
}

// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
}
} else {
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor);
}
self->tensor = set_value__ad_func(self->tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes,
{1},
values);
}
} else {
// step3.2: Case for there are advanced indexing.
Expand All @@ -1679,9 +1692,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim;
std::vector<int> trans_back_dim, trans_dim;

int pos_of_new_dim = 0, rank_of_new_dim = 0;
int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1;

paddle::Tensor transed_sub_tensor =
dealWithAdvancedIndex(sub_tensor,
Expand All @@ -1691,7 +1704,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&transed_index,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim);
&rank_of_new_dim,
&trans_dim);

// Release gil and do tracing
py::gil_scoped_release release;
Expand All @@ -1714,6 +1728,10 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
}
}

if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) {
value_tensor = transpose_ad_func(value_tensor, trans_dim);
}

// TODO(zoooo0820) 1.Using inplace version index_put
// 2.Remove following code after backward bug fixed.
transed_sub_tensor = assign_ad_func(transed_sub_tensor);
Expand Down
Loading

0 comments on commit 1aa5f4b

Please sign in to comment.