-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] implement LazyArray + autocompile #9
Conversation
Hello @jcmgray! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:
Comment last updated at 2021-05-13 20:46:28 UTC |
Codecov Report
@@ Coverage Diff @@
## master #9 +/- ##
==========================================
- Coverage 97.67% 95.71% -1.97%
==========================================
Files 2 6 +4
Lines 560 1471 +911
==========================================
+ Hits 547 1408 +861
- Misses 13 63 +50
Continue to review full report at Codecov.
|
When I added support for compilation in torchquad, I noticed some limitations of Inconsistent random number generationIn comparison to PyTorch and TensorFlow, with JAX a compiled non-deterministic function needs the PRNGKey as input and output; otherwise JAX would always generate the same numbers. PyTorch and TensorFlow use their default global PRNG state but for JAX autoray maintains the global variable from autoray import numpy as anp
from autoray import autojit
for _ in range(2):
for backend in ("torch", "tensorflow", "jax"):
func = lambda x: x + anp.random.uniform(like=backend)
func_compiled = autojit(func)
one = anp.array([1.0], like=backend)
random_numbers = [func_compiled(one) for _ in range(4)]
print(f"Random numbers for {backend}: {random_numbers}") This example code produces errors and the
With JAX it first produces always the same numbers and then it crashes because of using a global variable during tracing. PyTorch redundant torch.jit.script callI think
PyTorch .grad attribute access warning
Example code: import autoray, torch
func = lambda x: x
x = torch.tensor([1.0, 2.0])
x.requires_grad = True
print(autoray.autojit(func)(x + 1))
# torch.jit.trace(func, example_inputs=(x + 1,)) If I replace z = x
if z.requires_grad:
z = z.detach()
z.requires_grad = True TensorFlow deprecated experimental_compileThe documentation of
static_argnames supportIn my opinion the It would be convenient to have an analogous argument in the |
Hi @FHof, thanks for the feedback on this admittedly quite experimental feature! I'm open to making all these changes. Inconsistent random number generationYes this is a bit harder to offer a universal interface for with global state. Mostly the stuff I implemented was just for making generating test cases easier. I did come across the jax problem, one solution to which is something like the following: #autoray.py
@contextlib.contextmanager
def jittable_jax_random_seed(seed):
global _JAX_RANDOM_KEY
old = _JAX_RANDOM_KEY
try:
jax_random_seed(seed)
yield
finally:
_JAX_RANDOM_KEY = old but I wonder if simply only supporting a RNG generator interface is the sustainable way to go. (Which is what it looks like you have in PyTorch redundant torch.jit.script callYes - I'm not totally familiar with these. Apparently PyTorch .grad attribute access warningHappy to change this. It's just a warning when supplying the example tensors it seems. TensorFlow deprecated experimental_compileThanks for pointing this out - I don't use tensorflow much.. static_argnames supportThis seems nice to have and falls under the general 'improve caching behavior'. Would you be interested in PR covering any of these things? |
For the random number generation in torchquad with JAX, I have added the I'm currently not interested in implementing the changes. |
This adds two nice and fairly natural features to
autoray
, any feedback on interface welcome!Lazy Computation (
LazyArray
)If you write a function / algorithm with
do
calls, then you can trace through the entire thing lazily and:with lazy.shared_intermediates(): ...
)This is all implemented in a very lightweight manner making it about 100x faster than e.g. dask (which of course offers many other features) on some examples, and suitable for computational graphs with >100,000s nodes.
TODO:
ray
Auto compilation / unified JIT interface (
@autocompile
)The aim here is to have a single decorator for marking an
autoray
function to be JIT compiled, which makes it v easy to switch backends ('jax', 'tensorflow', 'torch') or just dispatches to the correct one automatically. cc @RikVoorhaar and #5.fn(*arrays) -> array
'python'
backend that 'simply' uses theLazyArray
computational graph tocompile
andexec
an unravelled version of the function, with shared intermediates, folded constants + surrounding logic stripped awayTODO:
CompilePython
probably needs to dispatch based on the input array shapes (overhead is very minimal)