Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class methods vs. quaxified functions #33

Open
vadmbertr opened this issue Oct 15, 2024 · 2 comments
Open

Class methods vs. quaxified functions #33

vadmbertr opened this issue Oct 15, 2024 · 2 comments
Labels
question User queries

Comments

@vadmbertr
Copy link
Contributor

Hi,

I am really happy to see some multiple dispatching mechanism brought to JAX, thanks for that!

I ran a highly toy use case to compare some timings between using quaxified functions and class methods and I was quite surprised by the results.
From the snippet below, is there something I have missed? Maybe quax targets more large, already-implemented models/functions?

from __future__ import annotations

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike
import quax


# using equinox first
class ArrayIsh(eqx.Module):
    array: ArrayLike

    @staticmethod
    def __get_other_array(other: ArrayIsh | ArrayLike) -> ArrayLike:
        if isinstance(other, ArrayIsh):
            return other.array
        else:
            return other

    def __sub__(self, other: ArrayIsh | ArrayLike):
        return self.array + self.__get_other_array(other)

    @classmethod
    def from_array(cls, array: ArrayLike):
        return cls(array)


array_ish1 = ArrayIsh.from_array(jnp.full((10, 2), 5.))
array_ish2 = ArrayIsh.from_array(jnp.full((10, 2), 3.))
array_like = jnp.full((10, 2), 2.)

# substract array attributes directly
%timeit (array_ish1.array - array_ish2.array).block_until_ready()
3.6 μs ± 48.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh objects using __sub__ method
%timeit (array_ish1 - array_ish2).block_until_ready()
5.1 μs ± 246 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh object and ArrayLike using __sub__ method
%timeit (array_ish1 - array_like).block_until_ready()
5.26 μs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# now using quax
class ArrayIsh(quax.ArrayValue):
    array: ArrayLike

    def aval(self):
        shape = jnp.shape(self.array)
        dtype = jnp.result_type(self.array)
        return jax.core.ShapedArray(shape, dtype)

    def materialise(self):
        raise ValueError("Refusing to materialise ArrayIsh array.")

    @classmethod
    def from_array(cls, array: ArrayLike):
        return cls(array)


@quax.register(jax.lax.sub_p)
def _(x: ArrayIsh, y: ArrayIsh):
    return x.array - y.array

@quax.register(jax.lax.sub_p)
def _(x: ArrayIsh, y: ArrayLike):
    return x.array - y

@quax.quaxify
def sub(x, y):
    return x - y


array_ish1 = ArrayIsh.from_array(jnp.full((10, 2), 5.))
array_ish2 = ArrayIsh.from_array(jnp.full((10, 2), 3.))

# substract array attributes directly
%timeit (array_ish1.array - array_ish2.array).block_until_ready()
3.64 μs ± 44.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# substract ArrayIsh objects using quaxify registered function sub
%timeit sub(array_ish1, array_ish2).block_until_ready()
440 μs ± 4.89 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# substract ArrayIsh object and ArrayLike using quaxify registered function sub
%timeit sub(array_ish1, array_like).block_until_ready()
425 μs ± 3.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Again, thanks for this library and the feedback.

Vadim

@patrick-kidger
Copy link
Owner

Make sure to wrap your programs in jax.jit or equinox.filter_jit. This is an important thing to do for all JAX programs if you want good performnace :)

(Once this has happened then Quax itself will have zero overhead, as it only affects the compilation process -- the result will be a compuation graph like any other jit-compiled JAX computation graph.)

@patrick-kidger patrick-kidger added the question User queries label Oct 15, 2024
@vadmbertr
Copy link
Contributor Author

Yes of course... I did not as the operation is very cheap.
But indeed the jit-compiled versions have very similar run times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants