Skip to content

Commit

Permalink
[Relay][Training] Add checkpoint annotation for checkpointing memory …
Browse files Browse the repository at this point in the history
…optimization (#4146)

* add checkpoint annotation for checkpointing memory optimization

* add alpha-equivalence checkpoint test and fix gradient type issue

* fix build issues

* ignore checkpoint annotation when checking missing gradients

* refactor, fix checkpoint compute for tuple and add tests
  • Loading branch information
altanh authored and jroesch committed Oct 27, 2019
1 parent 7732873 commit 93d610a
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 36 deletions.
19 changes: 18 additions & 1 deletion python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
"""Annotation operations."""
from __future__ import absolute_import as _abs
from . import _make
from ..op import register_schedule, schedule_injective
from .... import nd as _nd
from .... import TVMContext as _TVMContext


def on_device(data, device):
"""Annotate an expression with a certain device type.
Expand Down Expand Up @@ -61,3 +61,20 @@ def stop_fusion(data):
The annotated expression.
"""
return _make.stop_fusion(data)

def checkpoint(data):
"""Annotate an expression to be a checkpoint for the checkpointing memory optimization.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.checkpoint(data)

register_schedule("annotation.checkpoint", schedule_injective)
27 changes: 27 additions & 0 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,32 @@ Mark the end of bitpacking.
return {topi::identity(inputs[0])};
});

TVM_REGISTER_API("relay.op.annotation._make.checkpoint")
.set_body_typed<Expr(Expr)>([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint");
return CallNode::make(op, {data}, Attrs{}, {});
});

RELAY_REGISTER_OP("annotation.checkpoint")
.describe(R"code(
Mark a checkpoint for checkpointing memory optimization.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
Array<Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i]));
}
return outputs;
});

} // namespace relay
} // namespace tvm
4 changes: 3 additions & 1 deletion src/relay/pass/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ Expr DeDup(const Expr& e) {
}

Expr VisitExpr(const Expr& e) final {
return ExprMutator::VisitExpr(e);
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
return ret;
}

Expr VisitExpr_(const VarNode* op) final {
Expand Down
162 changes: 128 additions & 34 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,50 +273,93 @@ Type ReverseType(const Type& t) {
* by doing a structure preserving map.
*/
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
const Type& t,
const std::function<Type(const Type&)>& tf,
const Type& forward_type,
const Expr& e,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (t.as<TensorTypeNode>()) {
if (forward_type.as<TensorTypeNode>()) {
auto ret = f(e);
ret->checked_type_ = t;
ret->checked_type_ = tf(forward_type);
return ret;
} else if (auto* tt = t.as<TupleTypeNode>()) {
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
tvm::Array<Type> types;
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(LiftTensor(f,
tt->fields[i],
ll->Push(GetField(e, i)),
ll));
auto field = LiftTensor(f,
tf,
tt->fields[i],
ll->Push(GetField(e, i)),
ll);
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = TupleNode::make(fields);
ret->checked_type_ = t;
ret->checked_type_ = TupleTypeNode::make(types);
return std::move(ret);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
throw;
}
}

/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
* by stitching the references in the AD values.
*/
void TransferGrads(const Type& forward_type,
const Expr& from,
const Expr& to,
LetList* ll) {
CHECK(IsAtomic(from)) << from;
CHECK(IsAtomic(to)) << to;
if (forward_type.as<TensorTypeNode>()) {
auto from_ref = TupleGetItemNode::make(from, 1);
auto to_ref = TupleGetItemNode::make(to, 1);
ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
TransferGrads(tt->fields[i],
ll->Push(TupleGetItemNode::make(from, i)),
ll->Push(TupleGetItemNode::make(to, i)),
ll);
}
} else {
LOG(FATAL) << "Unsupported input/output type: " << forward_type;
throw;
}
}

/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) {
return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
};
return LiftTensor(rev, t, e, ll);
auto rev_type = [&](const Type& forward_type) {
return ReverseType(forward_type);
};
return LiftTensor(rev, rev_type, forward_type, e, ll);
}

/*! \brief ReverseType(t) -> t. Get the original value. */
Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
auto val = [&](const Expr& e) {
return GetField(e, 0);
};
auto val_type = [&](const Type& forward_type) {
return forward_type;
};
return LiftTensor(val, val_type, forward_type, e, ll);
}

/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) {
return ll->Push(RefReadNode::make(GetField(e, 1)));
};
return LiftTensor(grad, t, e, ll);
auto grad_type = [&](const Type& forward_type) {
return forward_type;
};
return LiftTensor(grad, grad_type, forward_type, e, ll);
}

void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
Expand All @@ -337,42 +380,87 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
}
}

Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}

struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;

Var bp;
std::shared_ptr<ADVarMap> ad_vars;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");

explicit ReverseAD(const Var& bp) : bp(bp) { }
explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
: bp(bp), ad_vars(ad_vars) { }

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}

Expr VisitExpr_(const CallNode* op) final {
if (const OpNode* op_node = op->op.as<OpNode>()) {
Expr VisitCheckpoint(const CallNode *call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));

TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return ret;
});
}

Expr VisitExpr_(const CallNode* call) final {
if (const OpNode* op_node = call->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);

if (op_ref->name == "annotation.checkpoint") {
return VisitCheckpoint(call);
}

CHECK(rev_map.count(op_ref))
<< op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
for (const auto& arg : op->args) {
for (const auto& arg : call->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (size_t i = 0; i < args.size(); i++) {
orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
}
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
orig->checked_type_ = op->checked_type();
Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
orig->checked_type_ = call->checked_type();
Var orig_var = ll->Push(orig);
orig_var->checked_type_ = op->checked_type();
auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return CallNode::make(bpv, {});
}),
Expand All @@ -382,7 +470,7 @@ struct ReverseAD : ExprMutator {
return ret;
});
}
return ExprMutator::VisitExpr_(op);
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expand All @@ -396,24 +484,30 @@ struct ReverseAD : ExprMutator {
VisitExpr(op->false_branch));
}

Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (!ad_vars->count(var_ref)) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}

return ad_vars->at(var_ref);
}

Type VisitType(const Type& t) final {
return t.defined() ? ReverseType(t) : t;
}
};

Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
return RefCreateNode::make(unitF);
}

bool MissingGrad(const Expr& e) {
struct MGVisitor : ExprVisitor {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
std::unordered_set<std::string> op_names;

void VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
if (!rev_map.count(op_ref)) {
if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
op_names.insert(op_ref->name);
}
ExprVisitor::VisitExpr_(op);
Expand Down Expand Up @@ -445,7 +539,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp)(e);
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def test_cross_entropy_with_logits_grad():
x = relay.var("x", shape=(2, 5))
y = relay.var("y", shape=(2, 5))
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)

def test_checkpoint():
inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
output = relay.multiply(relay.add(inputs[0], inputs[1]),
relay.add(inputs[2], inputs[3]))
check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))

out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
relay.multiply(inputs[2], inputs[3])])
out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
relay.TupleGetItem(out_tuple, 1))
check_grad(relay.Function(inputs, out_single))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 93d610a

Please sign in to comment.