diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 2606910d3906..8b06b8721994 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -351,8 +351,6 @@ struct ReverseAD : ExprMutator { Expr VisitExpr_(const CallNode* op) final { if (const OpNode* op_node = op->op.as()) { Op op_ref = GetRef(op_node); - CHECK(rev_map.count(op_ref)) - << op_node->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector args; for (const auto& arg : op->args) { @@ -408,6 +406,34 @@ Expr BPEmpty() { return RefCreateNode::make(unitF); } +bool MissingGrad(const Expr& e) { + struct MGVisitor : ExprVisitor { + const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + std::unordered_set op_names; + + void VisitExpr_(const OpNode* op) final { + Op op_ref = GetRef(op); + if (!rev_map.count(op_ref)) { + op_names.insert(op_ref->name); + } + ExprVisitor::VisitExpr_(op); + } + }; + + MGVisitor mg; + mg.VisitExpr(e); + + if (mg.op_names.size() > 0) { + LOG(WARNING) << "found operators with missing gradients:"; + for (const auto& op : mg.op_names) { + LOG(WARNING) << " " << op; + } + return true; + } + + return false; +} + Expr Gradient(const Expr& re, const Module& mod) { auto e = DeGlobal(mod, re); auto f = e.as(); @@ -416,6 +442,7 @@ Expr Gradient(const Expr& re, const Module& mod) { for (const auto& p : f->params) { CHECK(p->checked_type().as()) << "input parameters need to be tensor"; } + CHECK(!MissingGrad(e)) << "input has operators with missing gradients"; Expr body = LetList::With([&](LetList* ll) { Var bp = ll->Push(BPEmpty()); Expr rev = ReverseAD(bp)(e);