Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Paszke <[email protected]>
  • Loading branch information
mattjj and apaszke committed Mar 4, 2022
1 parent bdc28b9 commit 627a67a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/dex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _from_ptr(cls, ptr):
return self

def __del__(self):
if api.nofree: return
if api is None or api.nofree: return
api.destroyContext(self)

def __getattr__(self, name):
Expand Down
12 changes: 11 additions & 1 deletion python/dex/interop/jax2dex.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def pprint(self) -> str:
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, cache)

from .. import eval

map = safe_map
zip = safe_zip

Expand Down Expand Up @@ -174,6 +176,7 @@ def dex_fun(*args, **kwargs):
return tree_unflatten(out_tree, out_flat)
return dex_fun

# TODO try to delete this, rely on existing jax functions instead
@cache()
def make_jaxpr(fun: Callable, in_tree: PyTreeDef,
in_avals: tuple[core.AbstractValue], # with DBIdx in them
Expand Down Expand Up @@ -238,7 +241,10 @@ def write(v: core.Var, val: str) -> None:
block = Block([], expr)
print(jaxpr, end='\n\n')
print(pprint(expr), end='\n\n')
return dex.eval(pprint(expr)).compile()
return eval(pprint(expr)).compile()

def pprint(e):
return e.pprint()

ExprMaker = Callable[[Any, ...], Expr]
expr_makers: Dict[core.Primitive, ExprMaker] = {}
Expand Down Expand Up @@ -294,3 +300,7 @@ def aval_to_type(aval: core.AbstractValue) -> Type:
else:
raise NotImplementedError(aval)


###


11 changes: 11 additions & 0 deletions python/tests/jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import dex
from dex.interop.jax import primitive
from dex.interop.jax2dex import dexjit


@unittest.skip
Expand Down Expand Up @@ -206,6 +207,16 @@ def grad_jax(x, y):
jax.jit(grad_dex)(x, y),
jax.jit(grad_jax)(x, y))

class JAX2DexTest(unittest.TestCase):

def test_basic(self):
@dexjit
def f(x, y):
assert x.ndim == y.ndim == 0
return x + y

f(1., 2.)


if __name__ == "__main__":
unittest.main()

0 comments on commit 627a67a

Please sign in to comment.