Skip to content

Commit

Permalink
fix eig test, make lazy functions autowrap args as LazyArray
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 13, 2021
1 parent 10ab57e commit aa73c36
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
36 changes: 29 additions & 7 deletions autoray/lazy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,12 @@ def __repr__(self):
)


def ensure_lazy(array):
if not isinstance(array, LazyArray):
return LazyArray.from_data(array)
return array


def find_lazy(x):
"""Recursively search for ``LazyArray`` instances in pytrees.
"""
Expand Down Expand Up @@ -850,10 +856,12 @@ def array(x):

@lazy_cache("transpose")
def transpose(a, axes=None):
a = ensure_lazy(a)
fn_transpose = get_lib_fn(a.backend, "transpose")

if axes is None:
axes = range(a.ndim)[::-1]
newshape = tuple(a.shape[i] for i in axes)
fn_transpose = get_lib_fn(a.backend, "transpose")

# check for chaining transpositions
if a._fn is fn_transpose:
Expand All @@ -868,6 +876,7 @@ def transpose(a, axes=None):

@lazy_cache("reshape")
def _reshape_tuple(a, newshape):
a = ensure_lazy(a)
fn_reshape = get_lib_fn(a.backend, "reshape")

# check for redundant reshapes
Expand Down Expand Up @@ -910,6 +919,7 @@ def getitem_hasher(_, a, key):

@lazy_cache("getitem", hasher=getitem_hasher)
def getitem(a, key):
a = ensure_lazy(a)

deps = (a,)

Expand Down Expand Up @@ -1001,6 +1011,7 @@ def einsum(*operands):

@lazy_cache("trace")
def trace(a):
a = ensure_lazy(a)
return a.to(fn=get_lib_fn(a.backend, "trace"), args=(a,), shape=(),)


Expand All @@ -1022,23 +1033,27 @@ def matmul(x1, x2):

@lazy_cache("clip")
def clip(a, a_min, a_max):
a = ensure_lazy(a)
fn_clip = get_lib_fn(a.backend, "clip")
return a.to(fn_clip, (a, a_min, a_max))


@lazy_cache("flip")
def flip(a, axis=None):
a = ensure_lazy(a)
fn_flip = get_lib_fn(a.backend, "flip")
return a.to(fn_flip, (a, axis))


@lazy_cache("sort")
def sort(a, axis=-1):
a = ensure_lazy(a)
return a.to(get_lib_fn(a.backend, "sort"), (a, axis))


@lazy_cache("argsort")
def argsort(a, axis=-1):
a = ensure_lazy(a)
return a.to(
fn=get_lib_fn(a.backend, "argsort"), args=(a, axis), dtype="int",
)
Expand Down Expand Up @@ -1093,14 +1108,18 @@ def binary_func(x1, x2):


def make_unary_func(name, to_real=False):
@lazy_cache(name)
def unary_func(x):

if to_real:
newdtype = dtype_real_equiv(x.dtype)
else:
newdtype = None
if to_real:
def get_newdtype(x):
return dtype_real_equiv(x.dtype)
else:
def get_newdtype(x):
return None

@lazy_cache(name)
def unary_func(x):
x = ensure_lazy(x)
newdtype = get_newdtype(x)
return x.to(fn=get_lib_fn(x.backend, name), args=(x,), dtype=newdtype,)

return unary_func
Expand Down Expand Up @@ -1131,8 +1150,10 @@ def unary_func(x):


def make_reduction_func(name):

@lazy_cache(name)
def reduction_func(a, axis=None):
a = ensure_lazy(a)
nd = a.ndim
if axis is None:
axis = tuple(range(nd))
Expand Down Expand Up @@ -1170,6 +1191,7 @@ def lazy_get_dtype_name(x):

@lazy_cache("astype")
def lazy_astype(x, dtype_name):
x = ensure_lazy(x)
return x.to(fn=astype, args=(x, dtype_name), dtype=dtype_name,)


Expand Down
10 changes: 10 additions & 0 deletions autoray/lazy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..autoray import get_lib_fn

from .core import (
ensure_lazy,
lazy_cache,
dtype_real_equiv,
dtype_complex_equiv,
Expand All @@ -16,6 +17,7 @@

@lazy_cache("linalg.svd")
def svd(a):
a = ensure_lazy(a)
fn_svd = get_lib_fn(a.backend, "linalg.svd")
lsvd = a.to(fn_svd, (a,), shape=(3,))
m, n = a.shape
Expand All @@ -29,6 +31,7 @@ def svd(a):

@lazy_cache("linalg.qr")
def qr(a):
a = ensure_lazy(a)
lQR = a.to(get_lib_fn(a.backend, "linalg.qr"), (a,), shape=(2,))
m, n = a.shape
k = min(m, n)
Expand All @@ -39,6 +42,7 @@ def qr(a):

@lazy_cache("linalg.eig")
def eig(a):
a = ensure_lazy(a)
fn_eig = get_lib_fn(a.backend, "linalg.eig")
leig = a.to(fn_eig, (a,), shape=(2,))
m = a.shape[0]
Expand All @@ -50,6 +54,7 @@ def eig(a):

@lazy_cache("linalg.eigh")
def eigh(a):
a = ensure_lazy(a)
fn_eigh = get_lib_fn(a.backend, "linalg.eigh")
leigh = a.to(fn_eigh, (a,), shape=(2,))
m = a.shape[0]
Expand All @@ -61,18 +66,22 @@ def eigh(a):

@lazy_cache("linalg.inv")
def inv(a):
a = ensure_lazy(a)
fn_inv = get_lib_fn(a.backend, "linalg.inv")
return a.to(fn_inv, (a,))


@lazy_cache("linalg.cholesky")
def cholesky(a):
a = ensure_lazy(a)
fn_inv = get_lib_fn(a.backend, "linalg.cholesky")
return a.to(fn_inv, (a,))


@lazy_cache("linalg.solve")
def solve(a, b):
a = ensure_lazy(a)
b = ensure_lazy(b)
backend = find_common_backend(a, b)
fn_solve = get_lib_fn(backend, "linalg.solve")
dtype = find_common_dtype(a, b)
Expand All @@ -83,6 +92,7 @@ def solve(a, b):

@lazy_cache("linalg.norm")
def norm(x, order=None):
x = ensure_lazy(x)
fn_inv = get_lib_fn(x.backend, "linalg.norm")
newshape = ()
newdtype = dtype_real_equiv(x.dtype)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ def test_eig_inv(backend, dtype):
pytest.xfail(f"{backend} doesn't support 'linalg.eig' yet...")

# N.B. the prob that a real gaussian matrix has all real eigenvalues is
# ``2**(-d * (d - 1) / 4)`` - see Edelman 1997, need d >> 5

x = lazy.array(gen_rand((20, 20), backend, dtype))
# ``2**(-d * (d - 1) / 4)`` - see Edelman 1997 - so need ``d >> 5``
d = 20
x = lazy.array(gen_rand((d, d), backend, dtype))
el, ev = do("linalg.eig", x)
assert el.shape == (5,)
assert ev.shape == (5, 5)
assert el.shape == (d,)
assert ev.shape == (d, d)
ly = ev @ (do("reshape", el, (-1, 1)) * do("linalg.inv", ev))
make_strict(ly)
assert_allclose(
Expand Down

0 comments on commit aa73c36

Please sign in to comment.