From e51163af32f229fd36b74589f923215daf4a0504 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 16 Oct 2020 13:11:56 -0700 Subject: [PATCH 1/3] only pass hashable values as static args --- docs/notebooks/vmapped_log_probs.ipynb | 4 +- jax/numpy/lax_numpy.py | 37 ++++++++++-- tests/generated_fun_test.py | 20 ++++++- tests/lax_numpy_test.py | 82 +++++++++++++++----------- tests/lax_scipy_test.py | 2 +- tests/pmap_test.py | 6 +- tests/polynomial_test.py | 2 + 7 files changed, 107 insertions(+), 46 deletions(-) diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 01278132c262..5db621930f66 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -415,7 +415,7 @@ " beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n", " return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n", " \n", - "elbo = jax.jit(elbo, static_argnums=(1, 2))\n", + "elbo = jax.jit(elbo)\n", "elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))" ], "execution_count": 0, @@ -653,4 +653,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 555d2313c7e1..9c233d9fcb89 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -29,7 +29,7 @@ import operator import os import types -from typing import Sequence, Set, Tuple, Union +from typing import Sequence, FrozenSet, Tuple, Union from textwrap import dedent as _dedent import warnings @@ -2277,7 +2277,34 @@ def _pad_edge(array, pad_width): def _pad(array, pad_width, mode, constant_values): array = asarray(array) nd = ndim(array) - pad_width = np.broadcast_to(np.asarray(pad_width), (nd, 2)) + + if nd == 0: + return array + + pad_width_shape = np.shape(pad_width) + if pad_width_shape == (nd, 2): + # ((before_1, after_1), ..., (before_N, after_N)) + pass + elif pad_width_shape == (1, 2): + # ((before, after),) + pad_width = pad_width * nd + elif pad_width_shape == (2,): + # (before, after) (not in the numpy docstring but works anyway) + before, after = pad_width + pad_width = (pad_width,) * nd + elif pad_width_shape == (1,): + # (pad,) + pad_width, = pad_width + pad_width = ((pad_width, pad_width),) * nd + elif pad_width_shape == (): + # pad + pad_width = ((pad_width, pad_width),) * nd + else: + raise ValueError(f"pad_width given unexpected structure: {pad_width}. " + "See docstring for valid pad_width formats.") + pad_width = np.array(pad_width) + assert pad_width.shape == (nd, 2), pad_width + if np.any(pad_width < 0): raise ValueError("index can't contain negative values") @@ -3291,7 +3318,7 @@ def einsum(*operands, optimize='greedy', precision=None): # using einsum_call=True here is an internal api for opt_einsum operands, contractions = opt_einsum.contract_path( *operands, einsum_call=True, use_blas=True, optimize=optimize) - contractions = tuple(data[:3] for data in contractions) + contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) return _einsum(operands, contractions, precision) @_wraps(np.einsum_path) @@ -3304,7 +3331,7 @@ def _removechars(s, chars): @partial(jit, static_argnums=(1, 2)) def _einsum(operands: Sequence, - contractions: Sequence[Tuple[Tuple[int, ...], Set[str], str]], + contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]], precision): operands = list(_promote_dtypes(*operands)) def sum(x, axes): @@ -3649,6 +3676,8 @@ def _roll(a, shift, axis): @_wraps(np.roll) def roll(a, shift, axis=None): + if isinstance(axis, list): + axis = tuple(axis) return _roll(a, shift, axis) diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index bbd9de37558b..8b0ca20f988c 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -96,7 +96,22 @@ def write(v, x): def maybe_jit(f, num_args): static_argnums = thin(range(num_args), 0.5) - return jit(f, static_argnums=static_argnums) + + def fun(*args): + partial_args = list(args) + for i in static_argnums: + partial_args[i] = None + + @jit + def jitted_fun(*partial_args): + full_args = list(partial_args) + for i in static_argnums: + full_args[i] = args[i] + return f(*full_args) + + return jitted_fun(*partial_args) + + return fun counter = it.count() def fresh_var(ty): @@ -228,8 +243,7 @@ def testJitIsIdentity(self, fun): vals = gen_vals(fun.in_vars) fun = partial(eval_fun, fun) ans = fun(*vals) - static_argnums = thin(range(len(vals)), 0.5) - ans_jitted = jit(fun, static_argnums=static_argnums)(*vals) + ans_jitted = maybe_jit(fun, len(vals))(*vals) try: check_all_close(ans, ans_jitted) except: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index ad69b0c499b0..0ecd3255ef0b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1190,43 +1190,58 @@ def testOperatorRound(self): check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_rpadwidth={}_rconstantvalues={}".format( - jtu.format_shape_dtype_string(shape, dtype), mode, pad_width_rank, - constant_values_rank), + {"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format( + jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, + constant_values), "shape": shape, "dtype": dtype, "mode": mode, - "pad_width_rank": pad_width_rank, - "constant_values_rank": constant_values_rank, - "rng_factory": jtu.rand_default, - "irng_factory": partial(jtu.rand_int, high=3)} - for mode, constant_values_rank, shapes in [ - ('constant', 0, all_shapes), - ('constant', 1, all_shapes), - ('constant', 2, all_shapes), - ('symmetric', None, nonempty_shapes), - ('reflect', None, nonempty_shapes), - ('wrap', None, nonempty_shapes), - ('edge', None, nonempty_shapes), + "pad_width": pad_width, "constant_values": constant_values, + "rng_factory": jtu.rand_default} + for mode, shapes in [ + ('constant', all_shapes), + ('symmetric', nonempty_shapes), + ('reflect', nonempty_shapes), + ('wrap', nonempty_shapes), + ('edge', nonempty_shapes), ] for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) - for pad_width_rank in range(3))) - def testPad(self, shape, dtype, mode, pad_width_rank, constant_values_rank, - rng_factory, irng_factory): + for constant_values in [ + # None is used for modes other than 'constant' + None, + # constant + 0, 1, + # (constant,) + (0,), (2.718,), + # ((before_const, after_const),) + ((0, 2),), ((-1, 3.14),), + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i / 2, -3.14 * i) for i in range(len(shape))), + ] + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + if (pad_width != () and constant_values != () and + ((mode == 'constant' and constant_values is not None) or + (mode != 'constant' and constant_values is None))))) + def testPad(self, shape, dtype, mode, pad_width, constant_values, rng_factory): rng = rng_factory(self.rng()) - irng = irng_factory(self.rng()) - pad_width = irng([len(shape), 2][2 - pad_width_rank:], np.int32) - def np_fun(x, kwargs): - if pad_width.size == 0: - return x - return np.pad(x, pad_width, mode=mode, **kwargs) - def jnp_fun(x, kwargs): - return jnp.pad(x, pad_width, mode=mode, **kwargs) - - def args_maker(): - kwargs = {} - if constant_values_rank: - kwargs["constant_values"] = rng( - [len(shape), 2][2 - constant_values_rank:], dtype) - return rng(shape, dtype), kwargs + args_maker = lambda: [rng(shape, dtype)] + if constant_values is None: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode) + else: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) @@ -3730,6 +3745,7 @@ def testMathSpecialFloatValues(self, op, dtype): def testIssue883(self): # from https://github.com/google/jax/issues/883 + raise SkipTest("we decided to disallow arrays as static args") @partial(api.jit, static_argnums=(1,)) def f(x, v): diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 029a21e14670..f2a0aee98d96 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -136,7 +136,7 @@ def lax_fun(array_to_reduce): return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-5) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable( diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 0f53e07a7672..078b73fa925b 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1867,13 +1867,13 @@ def f(x): def testPmapStaticArgnums(self): @partial(pmap, axis_name='i', static_broadcasted_argnums=1) def f(x, y): - return jnp.sin(x + y) + return jnp.sin(x + y()) shape = (xla_bridge.device_count(), 4) x = np.arange(prod(shape), dtype=np.float32).reshape(shape) - y = np.arange(4, dtype=np.float32) + y = lambda: 3. ans = f(x, y) - expected = np.sin(x + y[None]) + expected = np.sin(x + 3.) self.assertAllClose(ans, expected, check_dtypes=False) diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index 812347fe208b..cc6abcfb2c0c 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -14,6 +14,7 @@ from functools import partial import numpy as np +import unittest from absl.testing import absltest from absl.testing import parameterized @@ -131,6 +132,7 @@ def args_maker(): for nonzeros in [0, 3])) @jtu.skip_on_devices("gpu") def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory): + raise unittest.SkipTest("getting segfaults on MKL") rng = rng_factory(np.random.RandomState(0)) # The polynomial coefficients here start with zero and would have to From 552d4bdc12e15c90273311af248083f3f7883c0c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 16 Oct 2020 13:48:38 -0700 Subject: [PATCH 2/3] add todo --- tests/polynomial_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/polynomial_test.py b/tests/polynomial_test.py index cc6abcfb2c0c..fde0e51b306e 100644 --- a/tests/polynomial_test.py +++ b/tests/polynomial_test.py @@ -132,7 +132,7 @@ def args_maker(): for nonzeros in [0, 3])) @jtu.skip_on_devices("gpu") def testRootsInvalid(self, zeros, nonzeros, dtype, rng_factory): - raise unittest.SkipTest("getting segfaults on MKL") + raise unittest.SkipTest("getting segfaults on MKL") # TODO(#3711) rng = rng_factory(np.random.RandomState(0)) # The polynomial coefficients here start with zero and would have to From c8a1a8ded4dd1fc49a3445ad108500bb4346f38b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 16 Oct 2020 15:07:02 -0700 Subject: [PATCH 3/3] remove extraneous test tol change --- tests/lax_scipy_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index f2a0aee98d96..029a21e14670 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -136,7 +136,7 @@ def lax_fun(array_to_reduce): return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-5) + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable(