Skip to content

Commit

Permalink
【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (#…
Browse files Browse the repository at this point in the history
…63936)

* Fix

* Fix

* Fix

* Fix
  • Loading branch information
co63oc authored May 7, 2024
1 parent cb6aa58 commit b37b71f
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 345 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/eager/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
return metas;
}

std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
const paddle::optional<std::vector<paddle::Tensor>>& targets) {
std::vector<AutogradMeta*> metas;
if (targets.get_ptr() != nullptr) {
metas.reserve(targets.get_ptr()->size());
for (const paddle::Tensor& t : (*(targets.get_ptr()))) {
metas.emplace_back(nullable_autograd_meta(t));
}
}
return metas;
}

std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
const std::vector<paddle::Tensor*>& targets) {
std::vector<AutogradMeta*> metas;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/eager/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class TEST_API EagerUtils {
const paddle::optional<paddle::Tensor>& target);
static std::vector<AutogradMeta*> nullable_autograd_meta(
const std::vector<paddle::Tensor>& targets);
static std::vector<AutogradMeta*> nullable_autograd_meta(
const paddle::optional<std::vector<paddle::Tensor>>& targets);
static std::vector<AutogradMeta*> nullable_autograd_meta(
const std::vector<paddle::Tensor*>& targets);
static AutogradMeta* unsafe_autograd_meta(const paddle::Tensor& target);
Expand Down
285 changes: 0 additions & 285 deletions paddle/fluid/operators/cudnn_lstm_op.cc

This file was deleted.

59 changes: 0 additions & 59 deletions paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,12 @@ def GenBuildOutputsPart2(
}}
"""

# In cudnn_lstm operator, the output weight_list_grad requires the use of optional input weight_list,
# so "pir::VectorType {name}" outside the "if" block.
CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<paddle::dialect::IrTensor> vec_ir_tensor_{name};
pir::VectorType {name};
if ({name}_.impl() != nullptr) {{
pir::VectorType {name} = {name}_.type().dyn_cast<pir::VectorType>();
{name} = {name}_.type().dyn_cast<pir::VectorType>();
for (size_t i=0; i < static_cast<size_t>({name}.size()); i++) {{
if({name}[i].isa<paddle::dialect::DenseTensorType>()) {{
auto {name}_type = {name}[i].dyn_cast<paddle::dialect::DenseTensorType>();
Expand Down
Loading

0 comments on commit b37b71f

Please sign in to comment.