Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] [Refactor] Support return in serial kernels #909

Open
yuanming-hu opened this issue May 2, 2020 · 3 comments
Open

[IR] [Refactor] Support return in serial kernels #909

yuanming-hu opened this issue May 2, 2020 · 3 comments
Assignees
Labels
Milestone

Comments

@yuanming-hu
Copy link
Member

yuanming-hu commented May 2, 2020

Remove ArgStoreStmt

class ArgStoreStmt : public Stmt {

Currently the only use of ArgStoreStmt is for returning values in a kernel. I personally feel like this is a confusing design (#839 (comment)). I suggest removing this statement and replace it with KernelReturnStmt.

Add KernelReturnStmt

KernelReturnStmt can only work in serial offloaded tasks.

  • Set return values;
  • Terminate the kernel. Actually, we should enforce that KernelReturnStmt can only be the final statement of a serial kernel. If we add this restriction, then no need to worry about terminating the kernel.

Benefits

  1. SNode readers can finally use return
  2. We can allow return values of arbitrary kernels, e.g.
@ti.kernel
def compute_sum():
  s = 0.0  
  for i in x:
    s += x[i]
  return s

s = compute_sum()

Discussions

  • How about returning a vector/matrix?
@archibate
Copy link
Collaborator

Cool! This is so cool because we never use ArgStoreStmt in python-defined kernels. But removing ArgStoreStmt, what if we want multiple return values like the pythonicreturn a, b?

@archibate archibate added this to the v0.7.0 milestone May 2, 2020
@yuanming-hu
Copy link
Member Author

Good point. KernelReturnStmt should have a member std::vector<Stmt *> return_values to support multiple return values.

@archibate
Copy link
Collaborator

Does return make sense in diff-kernels? e.g.:

@ti.kernel
def func(x: ti.i32):
  return x ** 2

@ti.kernel
def func_grad(x: ti.i32):
  return 2 * x  # am I understanding correct?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants