Skip to content

Commit

Permalink
refactor, fix checkpoint compute for tuple and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh committed Oct 24, 2019
1 parent 7bc21ca commit 19b68f3
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 23 deletions.
6 changes: 5 additions & 1 deletion src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ Mark a checkpoint for checkpointing memory optimization.
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<Tensor>& inputs,
const Type& out_dtype, const Target& target) -> Array<Tensor> {
return {topi::identity(inputs[0])};
Array<Tensor> outputs;
for (size_t i = 0; i < inputs.size(); ++i) {
outputs.push_back(topi::identity(inputs[i]));
}
return outputs;
});

} // namespace relay
Expand Down
52 changes: 30 additions & 22 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,33 +400,41 @@ struct ReverseAD : ExprMutator {
throw;
}

Expr VisitCheckpoint(const CallNode *call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));

TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return ret;
});
}

Expr VisitExpr_(const CallNode* call) final {
if (const OpNode* op_node = call->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);

if (op_ref->name == "annotation.checkpoint") {
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));

TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {});
}),
TupleTypeNode::make({}),
{});
ll->Push(RefWriteNode::make(bp, nbp));
return ret;
});
return VisitCheckpoint(call);
}

CHECK(rev_map.count(op_ref))
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relay/test_op_grad_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def test_checkpoint():
relay.add(inputs[2], inputs[3]))
check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))

out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
relay.multiply(inputs[2], inputs[3])])
out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
relay.TupleGetItem(out_tuple, 1))
check_grad(relay.Function(inputs, out_single))


if __name__ == "__main__":
test_cross_entropy_grad()
47 changes: 47 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,53 @@ def test_checkpoint_alpha_equal():

relay.analysis.assert_alpha_equal(df, df_parsed)

def test_checkpoint_alpha_equal_tuple():
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
f = relay.Function(xs, relay.annotation.checkpoint(
relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
))
df = transform.gradient(run_infer_type(f))

# run PE and DCE
with transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(relay.Module.from_expr(df))
df = mod["main"]

df_parsed = relay.parser.fromtext(
"""
v0.0.4
fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
%z: Tensor[(1), float32], %w: Tensor[(1), float32])
-> ((Tensor[(1), float32], Tensor[(1), float32]),
(Tensor[(1), float32], Tensor[(1), float32],
Tensor[(1), float32], Tensor[(1), float32])) {
let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */;
let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */;
let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */;
let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */;
%0 = (%x1, %x2);
%1 = zeros_like(%x) /* ty=Tensor[(1), float32] */;
%2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */;
%3 = add(%1, %2) /* ty=Tensor[(1), float32] */;
%4 = zeros_like(%y) /* ty=Tensor[(1), float32] */;
%5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */;
%6 = add(%4, %5) /* ty=Tensor[(1), float32] */;
%7 = zeros_like(%z) /* ty=Tensor[(1), float32] */;
%8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */;
%9 = add(%7, %8) /* ty=Tensor[(1), float32] */;
%10 = zeros_like(%w) /* ty=Tensor[(1), float32] */;
%11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */;
%12 = add(%10, %11) /* ty=Tensor[(1), float32] */;
%13 = (%3, %6, %9, %12);
(%0, %13)
}
"""
)

relay.analysis.assert_alpha_equal(df, df_parsed)

def test_collapse_sum_like():
shape = (3, 4, 5, 6)
shape_like = (4, 5, 6)
Expand Down

0 comments on commit 19b68f3

Please sign in to comment.