Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[DISCUSS] Relax control flow support #93

Open
yongwww opened this issue Mar 14, 2022 · 2 comments
Open

[DISCUSS] Relax control flow support #93

yongwww opened this issue Mar 14, 2022 · 2 comments

Comments

@yongwww
Copy link
Collaborator

yongwww commented Mar 14, 2022

Author: @YuchenJin, @yongwww

Introduction

We observe that more and more dynamic models are created with control flow, especially in the fields of object detection and natural language processing. We plan to enable control flow support on Relax to meet the increasing needs. This design document outlines the plan and key steps to support models with control flow (e.g., LSTM, Mask-RCNN) on Relax.

When importing models like SSD, LSTM from PyTorch or TensorFlow to Relay, there are while_loop constructs and TensorArray in the converted relay IR. The while_loop in Relay is implemented with tail recursion, which needs closure and lambda lifting. TensorArray is expressed via Relay ADT. In Relax, we can flexibly enable opaque Object type and write opaque packed functions to operate on TensorArray, more details about TensorArray support on Relax please take a look at discussion #87. The following is part of the relay IR converted from PyTorch LSTM model.

...
%39 = (
    let %while_loop = fn (%i.1: int32, %outputs.11: List[meta[IncompleteType][0]],
                          %state.11: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]),
                          %input.1: Tensor[(5, 2, 3), float32]) {
      %2 = less(%i.1, 5);
      if (%2) {
        %3 = transpose(%cell.weight_ih, axes=[1, 0]);
        %4 = take(%input.1, %i.1, axis=0, mode="wrap");
        ...
        %35 = add(%i.1, 1);
        %36 = @concat(%outputs.11, %34);
        %37 = %31.1;
        %while_loop(%35, %36, %37, %input.1)
      } else {
        (%i.1, %outputs.11, %state.11, %input.1)
      }
    };
    %while_loop
  );
  %40 = %39(0, %38, %states, %input);
  %41 = %40.1;
  %42 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %41);
  %43 = @tensor_array_stack_float32_2_4(%42);
  %44 = @tensor_get_data_float32_any_2_4(%43);
  %45 = %40.2;
  (%44, %45)

Key goals:

  1. If-then-else (done)
  2. Closure
  3. Lamda Lifting
  4. Loop through recursion

If-then-else

Currently, Relax is able to support if-then-else statement. If the condition is true, the true branch is executed, otherwise the execution jumps into the false branch.

@relax.function
def foo(cond: Tensor[(), "bool"], x: Tensor[(3, 4), "float32"]):
    if cond:
        r = relax.call_packed("test.vm.add", x, x)
    else:
        r = relax.call_packed("test.vm.mul", x, x)
    return r

We introduced two VM instructions: If, and Goto. When we visit IfNode in codegen, the related If and Goto instructions will be emitted with the calculated offsets.

Closure

A closure is a function that remembers its environment (captured variables). For example, we have such a program:

fn f1() {
  y = 1
  z = 2
  g = fn f2(x) {  # g is a closure
    return x + y + z
  }
  return g
}
result = f1()
result(3) # evaluate to 6

After lambda lifting, f2 becomes a global function f2’:

fn f2'(x, env) {
  return x + env[0] + env[1];
}

fn f1'() {
  y = 1
  z = 2
  env = (y, z)
  g = make_closure(f2', env)
  return g
}
result = f1'()
result(3)

In this design, we came up with three options to support closure in the Relax VM:

  • O1: Have a VMClosure object to store the function index and the captured variables(env), like the current VMClosureObj in the Relay VM;
  • O2: PackedFunc is tvm.runtime.Object, so a closure can be stored in a Relax VM register.
    The challenge of O2 is in order for a closure to function like a standalone PackedFunc, the closure will need to capture the VM itself, and VM also references closure because closure is stored in a vm register. This causes a cyclic reference problem as the lifetime of VM and closure depend on each other. As a result, the ref counts of both VM and closure will never go 0, hence they will never be destructed and end up with memory leak.
  • O3: To resolve the cyclic reference issue in O2, an alternative is to have a weak pointer to capture VM. However, weak capture will leave a door of invoking closure when vm get destructed.

In O1, VMClosure does not have a reference to the VM itself and always requires the VM to be available, so there is no cyclic reference issue, so we plan to start with O1, while we can migrate to O3 if there are needs in the future.

Lamda Lifting

The Lambda Lifting will be implemented as a LambdaLift pass, which is similar to Relay. The main difference is that we will introduce a make_closure intrinsic. Whereas in relay, an attribute Closure=1 is used in the function to help tell codegen it is a closure.

Recursive function:

Before lambda lifting:

x1 = 1
f1 = fn(x) {
    y = x - x1
    f1(y)
}

After lambda lifting:

f1'(x, env) {
    y = x - env[0]
    f1'(y, env)
}

x1 = 1
f1 = make_closure(f1', [x1])

Mutual recursive function calls(functions call into each other):

(leave for future support when we have use cases)

Before lambda lifting:

x1 = 1
x2 = 2
f1 = fn (x) {
    f2(x) + x1
}
f2 = fn (x) {
    f1(x) + x2
}

How to lift: compute the maximum closure of functions that calls into each other(in this case f1 and f2), and lift them together. Importantly, the environment of the closure(f1, f2) should be the union of the variables captured by both functions, so after lambda lifting:

After lambda lifting:

f1'(x, env) {
    f2'(x, env) + env[0]
}
f2'(x, env): {
    f1'(x, env) + env[1]
}

f1 = make_closure(f1', [x1, x2])
f2 = make_closure(f2', [x1, x2])

Loop through recursion

We use recursive function to construct loop, loop as an operation is in the future plan. Each time when a model with loop is fed into relax, the frontend converter will construct the recursive function. See the following sample for an example of conversion.

# TensorFlow while_loop
i = tf.constant(0, name="while/constant")
def c(i):
    return tf.less(i, 10)
def b(i):
    return tf.add(i, 1)
r = tf.while_loop(c, b, [i])
# relax
loop = fn(loop_var_0: Tensor[(), "int32"]) {
  c = less(loop_var_0)
  cond = min(c);
  if (cond) {
    v = add(loop_var_0, 1)
    return loop(v)
  } else 
    return loop_var_0
  }
};
loop(0)
@hypercubestart
Copy link
Collaborator

thanks for the great proposal! A few questions:

  • in the future do we plan to move away from recursion and use if/goto for loops instead, as the Relax convertor develops so we don't rely on Relay for model conversion?
  • could you help me understand the differences between Relay/Relax VM? My understanding is that to support primitives such as closures/array/tuples we are taking the same approach as Relay VM with VM intrinsics and TVM objects. Is this right?

@yongwww
Copy link
Collaborator Author

yongwww commented Mar 18, 2022

@hypercubestart sorry for my late response, forgot to check the notification.

We plan to have a loop operator in the future, it is something like PyTorch prim::loop or tensorflow while_loop. And we will still have recursion since it is more friendly to automatic differentiation. AFAIK, we don't have a very clear idea about whether relying on relay for model conversion or not. Personally, for a long term goal, I think relax needs its own converter converting models directly from original inputs instead of the conversion path model->relay->relax.

Relax VM takes similar approaches for primitives as Relay VM. The main differences between relay vm and relax vm are:

  1. we minimize the instruction set of Relay VM, but have Call instruction to call builtin packed functions;
  2. We added the capability to do runtime shape computation. More details please take a look at this doc Relax-VM-Design

@YuchenJin YuchenJin changed the title [DISCUSSION] Relax control flow support [DISCUSS] Relax control flow support Mar 23, 2022
@yongwww yongwww mentioned this issue May 3, 2022
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this issue Jan 13, 2023
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this issue Jan 13, 2023
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this issue Jan 16, 2023
vinx13 pushed a commit to vinx13/relax that referenced this issue Jan 31, 2023
vinx13 pushed a commit to vinx13/relax that referenced this issue Jan 31, 2023
vinx13 pushed a commit to vinx13/relax that referenced this issue Feb 8, 2023
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this issue Feb 12, 2023
vinx13 pushed a commit to vinx13/relax that referenced this issue Feb 13, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants