Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fusing Taichi with JAX #6367

Open
chaoming0625 opened this issue Oct 18, 2022 · 9 comments
Open

Fusing Taichi with JAX #6367

chaoming0625 opened this issue Oct 18, 2022 · 9 comments
Assignees
Labels
feature request Suggest an idea on this project

Comments

@chaoming0625
Copy link
Contributor

We have already seen some examples which can use Taichi as a part of the PyTorch program. For example,

However, is it possible to integrate Taichi into JAX?

Taichi is able to generate highly optimized operators, and it is very suitable to implement operators involving sparse computations. If Taichi kernels can be used in a JAX program, it will be interesting for broad programmers.

I think the key to the integration is the address of the compiled kernel in Taichi. There are examples that launch a GPU kernel compiled by Triton in JAX. Maybe it is straightforward for Taichi too.

@chaoming0625 chaoming0625 added the feature request Suggest an idea on this project label Oct 18, 2022
@taichi-gardener taichi-gardener moved this to Untriaged in Taichi Lang Oct 18, 2022
@feisuzhu feisuzhu moved this from Untriaged to Todo in Taichi Lang Oct 21, 2022
@ailzhang
Copy link
Contributor

@chaoming0625 I'm by no means a JAX expert so my guess could be wrong. IIUC JAX device arrays don't give a raw ptr to storage in memory as PyTorch does, which making a torch-like integration (zero-copy) with Taichi for JAX kinda hard. Then if you have to copy the device array from JAX, copying it to numpy arrays or torch tensors so that Taichi can operate on those pretty efficiently, this could be a possible way to workaround?

Note taichi's sparse computation requires a specific datalayout (depending on your snode structure) in a root buffer managed by Taichi, dense numpy arrays/torch tensors are still the recommended way to interact with other librarys for those sparse fields.

@chaoming0625
Copy link
Contributor Author

Dear @ailzhang , one way to interoperate JAX data with Taichi is using dlpack:

import jax.dlpack
import torch

def j2t(x_jax):
  x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
  return x_torch

def t2j(x_torch):
  x_torch = x_torch.contiguous()
  x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
  return x_jax

This could make a zero-copy from JAX data to PyTorch tensor. This PyTorch tensor can then be used in Taichi kenerls. Finally, the tensors returned from the Taichi kernel can also be zero-copied to JAX.

I think this may be one possible solution.

@chaoming0625
Copy link
Contributor Author

We are just wondering where can get the address of the Taichi compiled kernels. Thanks.

@ailzhang
Copy link
Contributor

ailzhang commented Oct 31, 2022

@chaoming0625 Sounds good! Taichi ndarrays are just contiguous memory so it should be pretty straightforward to support dl_pack format (although it doesn't yet).
Taichi compiled kernels are https://github.com/taichi-dev/taichi/blob/master/python/taichi/lang/kernel_impl.py#L574.

@chaoming0625
Copy link
Contributor Author

Dear @ailzhang , that's wonderful. Thanks very much!

@salykova
Copy link

salykova commented May 9, 2023

Hi @ailzhang

I just wanted to ask if there is any update on this issue or an alternative solution to @chaoming0625's. Do you plan to implement support for jax arrays via taichi.ndarray as it was done for pytorch?

@maedoc
Copy link

maedoc commented Oct 11, 2023

Also curious about this, since I'd like to use some packages written for Jax (numpyro specifically) and try out the taichi ad system.

@jarmitage
Copy link

As further motivation, I would love to be able to tap into these JAX projects with Taichi:

@chaoming0625
Copy link
Contributor Author

See examples in brainpy/BrainPy#553

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Suggest an idea on this project
Projects
Status: Todo
Development

No branches or pull requests

5 participants