Skip to content

Commit

Permalink
Get grads types from cpp for adam to speed up (#47769)
Browse files Browse the repository at this point in the history
Get grads types from cpp for adam to speed up
  • Loading branch information
0x45f authored Nov 10, 2022
1 parent 8d99dd0 commit 5900129
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 21 deletions.
38 changes: 38 additions & 0 deletions paddle/fluid/pybind/eager_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,40 @@ PyObject* eager_api_get_grads_lists(PyObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyObject* eager_api_get_grads_types(PyObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto tensor_list = CastPyArg2VectorOfTensor(PyTuple_GET_ITEM(args, 0), 0);

std::vector<int> ret;

for (auto& tensor : tensor_list) {
VLOG(6) << "Get grad for tensor: " << tensor.name();
auto meta = egr::EagerUtils::nullable_autograd_meta(tensor);
if (!meta || meta->StopGradient()) {
ret.emplace_back(-1);
continue;
}

auto& grad = meta->Grad();
if (meta && grad.initialized()) {
if (grad.is_dense_tensor() &&
(tensor.dtype() == paddle::experimental::DataType::FLOAT32 ||
tensor.dtype() == paddle::experimental::DataType::FLOAT16)) {
ret.emplace_back(
paddle::framework::TransToProtoVarType(tensor.dtype()));
}
} else {
ret.emplace_back(-1);
}
}

return ToPyObject(ret);

EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* eager_api_read_next_tensor_list(PyObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -1067,6 +1101,10 @@ PyMethodDef variable_functions[] = {
(PyCFunction)(void (*)(void))eager_api_get_grads_lists,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"get_grads_types",
(PyCFunction)(void (*)(void))eager_api_get_grads_types,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"read_next_tensor_list",
(PyCFunction)(void (*)(void))eager_api_read_next_tensor_list,
METH_VARARGS | METH_KEYWORDS,
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -721,11 +721,14 @@ PyObject* ToPyObject(const std::vector<paddle::experimental::Tensor>& value,
}

PyObject* ToPyObject(
const std::vector<std::vector<paddle::experimental::Tensor>>& value) {
const std::vector<std::vector<paddle::experimental::Tensor>>& value,
bool return_py_none_if_not_initialize) {
PyObject* result = PyList_New((Py_ssize_t)value.size());

for (size_t i = 0; i < value.size(); i++) {
PyList_SET_ITEM(result, static_cast<Py_ssize_t>(i), ToPyObject(value[i]));
PyList_SET_ITEM(result,
static_cast<Py_ssize_t>(i),
ToPyObject(value[i], return_py_none_if_not_initialize));
}

return result;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ PyObject* ToPyObject(const std::vector<std::vector<size_t>>& value);
PyObject* ToPyObject(const std::vector<paddle::experimental::Tensor>& value,
bool return_py_none_if_not_initialize = false);
PyObject* ToPyObject(
const std::vector<std::vector<paddle::experimental::Tensor>>& value);
const std::vector<std::vector<paddle::experimental::Tensor>>& value,
bool return_py_none_if_not_initialize = false);
PyObject* ToPyObject(const platform::Place& value);
PyObject* ToPyObject(const phi::DenseTensor* value);
PyObject* ToPyObject(const phi::SelectedRows* value);
Expand Down
59 changes: 41 additions & 18 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

__all__ = []

GRAD_TYPES = [int(paddle.float32), int(paddle.float16)]


class Adam(Optimizer):
r"""
Expand Down Expand Up @@ -644,26 +646,47 @@ def _append_optimize_multi_tensor_op(
lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}

if isinstance(parameters_and_grads, list):
for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None:
continue
if param_and_grad[0].stop_gradient is False:
if (
param_and_grad[0].dtype == paddle.float32
and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR
):
grad_dict['FP32_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
params = [pair[0] for pair in parameters_and_grads]
grads_types = core.eager.get_grads_types(params)
for index, tp in enumerate(grads_types):
if tp == GRAD_TYPES[0]:
grad_dict['FP32_LODTensor'].append(
parameters_and_grads[index][1]
)
lr = self._create_param_lr(parameters_and_grads[index])
lr_dict['FP32_LODTensor'].append(lr)
elif (
param_and_grad[0].dtype == paddle.float16
and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR
):
grad_dict['FP16_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
elif tp == GRAD_TYPES[1]:
grad_dict['FP16_LODTensor'].append(
parameters_and_grads[index][1]
)
lr = self._create_param_lr(parameters_and_grads[index])
lr_dict['FP16_LODTensor'].append(lr)
else:
for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None:
continue
if param_and_grad[0].stop_gradient is False:
if (
param_and_grad[0].dtype == paddle.float32
and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR
):
grad_dict['FP32_LODTensor'].append(
param_and_grad[1]
)
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif (
param_and_grad[0].dtype == paddle.float16
and param_and_grad[1].type
== core.VarDesc.VarType.LOD_TENSOR
):
grad_dict['FP16_LODTensor'].append(
param_and_grad[1]
)
lr = self._create_param_lr(param_and_grad)
lr_dict['FP16_LODTensor'].append(lr)
else:
for param_and_grad in parameters_and_grads['params']:
if param_and_grad[1] is None:
Expand Down

0 comments on commit 5900129

Please sign in to comment.