diff --git a/core/conversion/evaluators/BUILD b/core/conversion/evaluators/BUILD index e4ada604d1..38bb7bb0d3 100644 --- a/core/conversion/evaluators/BUILD +++ b/core/conversion/evaluators/BUILD @@ -15,6 +15,7 @@ cc_library( srcs = [ "NodeEvaluatorRegistry.cpp", "prim.cpp", + "aten.cpp" ], deps = [ "//core/util:prelude", diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp new file mode 100644 index 0000000000..9d5844347c --- /dev/null +++ b/core/conversion/evaluators/aten.cpp @@ -0,0 +1,36 @@ +#include "torch/csrc/jit/ir/ir.h" +#include "torch/csrc/jit/ir/constants.h" +#include "ATen/core/functional.h" +#include "ATen/core/ivalue.h" +#include "ATen/core/List.h" +#include "ATen/core/stack.h" +#include "c10/util/intrusive_ptr.h" +#include "torch/torch.h" + +#include "core/conversion/evaluators/evaluators.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace evaluators { +namespace { + +auto aten_registrations = RegisterNodeEvaluators() + .evaluator({ + c10::Symbol::fromQualString("aten::zeros"), + // aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor) + [](const torch::jit::Node* n, kwargs& args) -> c10::optional { + auto options = torch::TensorOptions() + .dtype(c10::ScalarType(args.at(&(n->output()[1])).unwrapToInt())) + .layout(torch::kStrided) + .device(torch::kCUDA); + + auto out_tensor = torch::zeros(args.at(&(n->output()[0])).unwrapToIntList().vec(), options); + return out_tensor; + } + }); +} +} // namespace evaluators +} // namespace conversion +} // namespace core +} // namespace trtorch