Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up jax internals in preparation for only allowing hashable values passed as static args #3712

Merged
merged 3 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/notebooks/vmapped_log_probs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
" \n",
"elbo = jax.jit(elbo, static_argnums=(1, 2))\n",
"elbo = jax.jit(elbo)\n",
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
],
"execution_count": 0,
Expand Down Expand Up @@ -653,4 +653,4 @@
"outputs": []
}
]
}
}
37 changes: 33 additions & 4 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import operator
import os
import types
from typing import Sequence, Set, Tuple, Union
from typing import Sequence, FrozenSet, Tuple, Union
from textwrap import dedent as _dedent
import warnings

Expand Down Expand Up @@ -2277,7 +2277,34 @@ def _pad_edge(array, pad_width):
def _pad(array, pad_width, mode, constant_values):
array = asarray(array)
nd = ndim(array)
pad_width = np.broadcast_to(np.asarray(pad_width), (nd, 2))
mattjj marked this conversation as resolved.
Show resolved Hide resolved

if nd == 0:
return array

pad_width_shape = np.shape(pad_width)
if pad_width_shape == (nd, 2):
# ((before_1, after_1), ..., (before_N, after_N))
pass
elif pad_width_shape == (1, 2):
# ((before, after),)
pad_width = pad_width * nd
elif pad_width_shape == (2,):
# (before, after) (not in the numpy docstring but works anyway)
before, after = pad_width
pad_width = (pad_width,) * nd
elif pad_width_shape == (1,):
# (pad,)
pad_width, = pad_width
pad_width = ((pad_width, pad_width),) * nd
elif pad_width_shape == ():
# pad
pad_width = ((pad_width, pad_width),) * nd
else:
raise ValueError(f"pad_width given unexpected structure: {pad_width}. "
"See docstring for valid pad_width formats.")
pad_width = np.array(pad_width)
assert pad_width.shape == (nd, 2), pad_width

if np.any(pad_width < 0):
raise ValueError("index can't contain negative values")

Expand Down Expand Up @@ -3291,7 +3318,7 @@ def einsum(*operands, optimize='greedy', precision=None):
# using einsum_call=True here is an internal api for opt_einsum
operands, contractions = opt_einsum.contract_path(
*operands, einsum_call=True, use_blas=True, optimize=optimize)
contractions = tuple(data[:3] for data in contractions)
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
return _einsum(operands, contractions, precision)

@_wraps(np.einsum_path)
Expand All @@ -3304,7 +3331,7 @@ def _removechars(s, chars):

@partial(jit, static_argnums=(1, 2))
def _einsum(operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], Set[str], str]],
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
precision):
operands = list(_promote_dtypes(*operands))
def sum(x, axes):
Expand Down Expand Up @@ -3649,6 +3676,8 @@ def _roll(a, shift, axis):

@_wraps(np.roll)
def roll(a, shift, axis=None):
if isinstance(axis, list):
axis = tuple(axis)
return _roll(a, shift, axis)


Expand Down
20 changes: 17 additions & 3 deletions tests/generated_fun_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,22 @@ def write(v, x):

def maybe_jit(f, num_args):
static_argnums = thin(range(num_args), 0.5)
return jit(f, static_argnums=static_argnums)

def fun(*args):
partial_args = list(args)
for i in static_argnums:
partial_args[i] = None

@jit
def jitted_fun(*partial_args):
full_args = list(partial_args)
for i in static_argnums:
full_args[i] = args[i]
return f(*full_args)

return jitted_fun(*partial_args)

return fun

counter = it.count()
def fresh_var(ty):
Expand Down Expand Up @@ -228,8 +243,7 @@ def testJitIsIdentity(self, fun):
vals = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
ans = fun(*vals)
static_argnums = thin(range(len(vals)), 0.5)
ans_jitted = jit(fun, static_argnums=static_argnums)(*vals)
ans_jitted = maybe_jit(fun, len(vals))(*vals)
try:
check_all_close(ans, ans_jitted)
except:
Expand Down
82 changes: 49 additions & 33 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,43 +1190,58 @@ def testOperatorRound(self):
check_dtypes=False)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width_rank,
constant_values_rank),
{"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format(
jtu.format_shape_dtype_string(shape, dtype), mode, pad_width,
constant_values),
"shape": shape, "dtype": dtype, "mode": mode,
"pad_width_rank": pad_width_rank,
"constant_values_rank": constant_values_rank,
"rng_factory": jtu.rand_default,
"irng_factory": partial(jtu.rand_int, high=3)}
for mode, constant_values_rank, shapes in [
('constant', 0, all_shapes),
('constant', 1, all_shapes),
('constant', 2, all_shapes),
('symmetric', None, nonempty_shapes),
('reflect', None, nonempty_shapes),
('wrap', None, nonempty_shapes),
('edge', None, nonempty_shapes),
"pad_width": pad_width, "constant_values": constant_values,
"rng_factory": jtu.rand_default}
for mode, shapes in [
('constant', all_shapes),
('symmetric', nonempty_shapes),
('reflect', nonempty_shapes),
('wrap', nonempty_shapes),
('edge', nonempty_shapes),
]
for shape, dtype in _shape_and_dtypes(shapes, all_dtypes)
for pad_width_rank in range(3)))
def testPad(self, shape, dtype, mode, pad_width_rank, constant_values_rank,
rng_factory, irng_factory):
for constant_values in [
# None is used for modes other than 'constant'
None,
# constant
0, 1,
# (constant,)
(0,), (2.718,),
# ((before_const, after_const),)
((0, 2),), ((-1, 3.14),),
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i / 2, -3.14 * i) for i in range(len(shape))),
]
for pad_width in [
# ((before_1, after_1), ..., (before_N, after_N))
tuple((i % 3, (i + 1) % 3) for i in range(len(shape))),
# ((before, after),)
((1, 2),), ((2, 0),),
# (before, after) (not in the docstring but works in numpy)
(2, 0), (0, 0),
# (pad,)
(1,), (2,),
# pad
0, 1,
]
if (pad_width != () and constant_values != () and
((mode == 'constant' and constant_values is not None) or
(mode != 'constant' and constant_values is None)))))
def testPad(self, shape, dtype, mode, pad_width, constant_values, rng_factory):
rng = rng_factory(self.rng())
irng = irng_factory(self.rng())
pad_width = irng([len(shape), 2][2 - pad_width_rank:], np.int32)
def np_fun(x, kwargs):
if pad_width.size == 0:
return x
return np.pad(x, pad_width, mode=mode, **kwargs)
def jnp_fun(x, kwargs):
return jnp.pad(x, pad_width, mode=mode, **kwargs)

def args_maker():
kwargs = {}
if constant_values_rank:
kwargs["constant_values"] = rng(
[len(shape), 2][2 - constant_values_rank:], dtype)
return rng(shape, dtype), kwargs
args_maker = lambda: [rng(shape, dtype)]
if constant_values is None:
np_fun = partial(np.pad, pad_width=pad_width, mode=mode)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode)
else:
np_fun = partial(np.pad, pad_width=pad_width, mode=mode,
constant_values=constant_values)
jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode,
constant_values=constant_values)

self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE)
Expand Down Expand Up @@ -3730,6 +3745,7 @@ def testMathSpecialFloatValues(self, op, dtype):

def testIssue883(self):
# from https://github.com/google/jax/issues/883
raise SkipTest("we decided to disallow arrays as static args")

@partial(api.jit, static_argnums=(1,))
def f(x, v):
Expand Down
6 changes: 3 additions & 3 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,13 +1867,13 @@ def f(x):
def testPmapStaticArgnums(self):
@partial(pmap, axis_name='i', static_broadcasted_argnums=1)
def f(x, y):
return jnp.sin(x + y)
return jnp.sin(x + y())
shape = (xla_bridge.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = np.arange(4, dtype=np.float32)
y = lambda: 3.

ans = f(x, y)
expected = np.sin(x + y[None])
expected = np.sin(x + 3.)
self.assertAllClose(ans, expected, check_dtypes=False)


Expand Down
2 changes: 2 additions & 0 deletions tests/polynomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from functools import partial
import numpy as np
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -131,6 +132,7 @@ def args_maker():
for nonzeros in [0, 3]))
@jtu.skip_on_devices("gpu")
def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory):
raise unittest.SkipTest("getting segfaults on MKL") # TODO(#3711)
rng = rng_factory(np.random.RandomState(0))

# The polynomial coefficients here start with zero and would have to
Expand Down