Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] implement LazyArray + autocompile #9

Merged
merged 9 commits into from
Jun 4, 2021
Merged

[WIP] implement LazyArray + autocompile #9

merged 9 commits into from
Jun 4, 2021

Conversation

jcmgray
Copy link
Owner

@jcmgray jcmgray commented May 8, 2021

This adds two nice and fairly natural features to autoray, any feedback on interface welcome!

Lazy Computation (LazyArray)

from autoray import lazy
... 
lx = lazy.array(x)
lf = fn(lx)
lf.compute()

If you write a function / algorithm with do calls, then you can trace through the entire thing lazily and:

  • e.g. check max array size encountered, number of calls to specific functions
  • plot the computational graph
  • identify and cache shared intermediates only (with lazy.shared_intermediates(): ...)
  • perform all constant parts of a computational graph

This is all implemented in a very lightweight manner making it about 100x faster than e.g. dask (which of course offers many other features) on some examples, and suitable for computational graphs with >100,000s nodes.

spiral

TODO:

  • implement a few remaining functions
  • eventually might be possible to estimate FLOPs etc
  • eventually might be useful to execute the computational graph with e.g. ray

Auto compilation / unified JIT interface (@autocompile)

@autocompile
def fn(x, y):
    ...

fn(x, y, backend='torch')

The aim here is to have a single decorator for marking an autoray function to be JIT compiled, which makes it v easy to switch backends ('jax', 'tensorflow', 'torch') or just dispatches to the correct one automatically. cc @RikVoorhaar and #5.

  • currently the signature must be fn(*arrays) -> array
  • there is a 'python' backend that 'simply' uses the LazyArray computational graph to compile and exec an unravelled version of the function, with shared intermediates, folded constants + surrounding logic stripped away

TODO:

  • CompilePython probably needs to dispatch based on the input array shapes (overhead is very minimal)

@jcmgray jcmgray added the enhancement New feature or request label May 8, 2021
@pep8speaks
Copy link

pep8speaks commented May 8, 2021

Hello @jcmgray! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

Line 397:62: E231 missing whitespace after ','
Line 412:62: E231 missing whitespace after ','

Comment last updated at 2021-05-13 20:46:28 UTC

@codecov
Copy link

codecov bot commented May 8, 2021

Codecov Report

Merging #9 (aa73c36) into master (b62baa9) will decrease coverage by 1.96%.
The diff coverage is 94.58%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master       #9      +/-   ##
==========================================
- Coverage   97.67%   95.71%   -1.97%     
==========================================
  Files           2        6       +4     
  Lines         560     1471     +911     
==========================================
+ Hits          547     1408     +861     
- Misses         13       63      +50     
Impacted Files Coverage Δ
autoray/lazy/core.py 93.84% <93.84%> (ø)
autoray/compiler.py 94.57% <94.57%> (ø)
autoray/autoray.py 97.24% <94.66%> (-0.41%) ⬇️
autoray/__init__.py 100.00% <100.00%> (ø)
autoray/lazy/__init__.py 100.00% <100.00%> (ø)
autoray/lazy/linalg.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b62baa9...aa73c36. Read the comment docs.

@jcmgray jcmgray merged commit 24bf989 into master Jun 4, 2021
@jcmgray jcmgray deleted the lazy-compile branch June 4, 2021 20:20
@FHof
Copy link

FHof commented May 19, 2022

When I added support for compilation in torchquad, I noticed some limitations of autojit and decided to use backend-specific functions instead.
https://github.com/esa/torchquad/blob/bbbb3782cda4ff56f0e8093102dadaf87517b2a0/torchquad/integration/monte_carlo.py#L111
https://github.com/esa/torchquad/blob/bbbb3782cda4ff56f0e8093102dadaf87517b2a0/torchquad/integration/newton_cotes.py#L89

Inconsistent random number generation

In comparison to PyTorch and TensorFlow, with JAX a compiled non-deterministic function needs the PRNGKey as input and output; otherwise JAX would always generate the same numbers. PyTorch and TensorFlow use their default global PRNG state but for JAX autoray maintains the global variable _JAX_RANDOM_KEY.

from autoray import numpy as anp
from autoray import autojit

for _ in range(2):
    for backend in ("torch", "tensorflow", "jax"):
        func = lambda x: x + anp.random.uniform(like=backend)
        func_compiled = autojit(func)
        one = anp.array([1.0], like=backend)
        random_numbers = [func_compiled(one) for _ in range(4)]
        print(f"Random numbers for {backend}: {random_numbers}")

This example code produces errors and the print output is as follows:

Random numbers for torch: [tensor([1.3615]), tensor([1.4931]), tensor([1.9064]), tensor([1.7361])]
Random numbers for tensorflow: [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.586278], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.9022352], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.8580912], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.0506557], dtype=float32)>]
Random numbers for jax: [DeviceArray([1.9223499], dtype=float32), DeviceArray([1.9223499], dtype=float32), DeviceArray([1.9223499], dtype=float32), DeviceArray([1.9223499], dtype=float32)]
Random numbers for torch: [tensor([1.1129]), tensor([1.4516]), tensor([1.8102]), tensor([1.0571])]
Random numbers for tensorflow: [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.3308468], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.8029245], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.8157065], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.2073383], dtype=float32)>]

With JAX it first produces always the same numbers and then it crashes because of using a global variable during tracing.

PyTorch redundant torch.jit.script call

I think torch.jit.script and torch.jit.trace are alternative ways to compile code: torch.jit.script compiles (a subset of) Python3 code, whereas torch.jit.trace generates and compiles a computational graph from an execution of the Python3 code. The documentation about mixing both shows situations where both are used together but it doesn't mention that directly passing a traced function to torch.jit.script does something useful.

autojit with PyTorch by default executes self.torch.jit.script(self._jit_fn) and I think calling torch.jit.script is redundant here.

PyTorch .grad attribute access warning

torch.jit.trace currently (PyTorch 1.9.1.post3) shows an invalid warning when passing a tensor with gradient information:

python3.9/site-packages/torch/jit/_trace.py:154: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information.
  if a.grad is not None:
tensor([2., 3.], grad_fn=<AddBackward0>)

Example code:

import autoray, torch

func = lambda x: x
x = torch.tensor([1.0, 2.0])
x.requires_grad = True
print(autoray.autojit(func)(x + 1))
# torch.jit.trace(func, example_inputs=(x + 1,))

If I replace x + 1 by x, it doesn't appear because x is a leaf.
To avoid this warning in autoray, it is possible to make the example input a leaf tensor, e.g.:

z = x
if z.requires_grad:
	z = z.detach()
	z.requires_grad = True

TensorFlow deprecated experimental_compile

The documentation of tf.function currently (TensorFlow 2.7.0) shows that the experimental_compile argument should be replaced by jit_compile:

Warning: SOME ARGUMENTS ARE DEPRECATED: (experimental_compile). They will be removed in a future version.
Instructions for updating:
experimental_compile is deprecated, use jit_compile instead

autojit with TensorFlow always sets experimental_compile=False as default argument, so I assume that it may not work with a future TensorFlow version. I think TensorFlow currently disables XLA (jit_compile) by default, so removing the experimental_compile default argument from CompileTensorFlow could be safe.

static_argnames support

In my opinion the static_argnames argument of JAX is helpful to reduce the number of code lines because the functions which I compile often have both tensors and configuration options as arguments. Without this argument (e.g. with PyTorch), I need to explicitly wrap my function into a new one which has only tensors as input.

It would be convenient to have an analogous argument in the autojit function.

@jcmgray
Copy link
Owner Author

jcmgray commented May 20, 2022

Hi @FHof, thanks for the feedback on this admittedly quite experimental feature! I'm open to making all these changes.

Inconsistent random number generation

Yes this is a bit harder to offer a universal interface for with global state. Mostly the stuff I implemented was just for making generating test cases easier. I did come across the jax problem, one solution to which is something like the following:

#autoray.py

@contextlib.contextmanager 
def jittable_jax_random_seed(seed): 
    global _JAX_RANDOM_KEY 
    old = _JAX_RANDOM_KEY 
    try: 
        jax_random_seed(seed) 
        yield 
    finally: 
        _JAX_RANDOM_KEY = old

but I wonder if simply only supporting a RNG generator interface is the sustainable way to go. (Which is what it looks like you have in torchquad). Even then it does seem one has to hide a jax key somewhere in order to get an impure function.

PyTorch redundant torch.jit.script call

Yes - I'm not totally familiar with these. Apparently script can handle some dynamic control behaviour but that is beyond what autojit should provide.

PyTorch .grad attribute access warning

Happy to change this. It's just a warning when supplying the example tensors it seems.

TensorFlow deprecated experimental_compile

Thanks for pointing this out - I don't use tensorflow much..

static_argnames support

This seems nice to have and falls under the general 'improve caching behavior'.


Would you be interested in PR covering any of these things?

@FHof
Copy link

FHof commented May 21, 2022

For the random number generation in torchquad with JAX, I have added the jax_get_key and jax_set_key methods to the RNG class and use them to pass the current PRNGKey values. A PRNGKey is passed to the compiled function, which uses it to generate random numbers and in the end the compiled function returns a new PRNGKey, so the compiled function is deterministic and pure. With TensorFlow on the other hand, the tf.random.Generator has to be initialised outside of the compiled function and it generates different random numbers in each invocation; I assume it saves the state in a tf.Variable.

I'm currently not interested in implementing the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants