diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 87b390db92..b089172d31 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -12,7 +12,6 @@ from desc.utils import ( PRINT_WIDTH, Timer, - batched_vectorize, errorif, flatten_list, is_broadcastable, @@ -20,6 +19,7 @@ unique_list, warnif, ) +from desc.utils_batched_vectorize import batched_vectorize class ObjectiveFunction(IOAble): diff --git a/desc/utils.py b/desc/utils.py index df6ad3488d..d1483aa039 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -2,27 +2,13 @@ import operator import warnings -from functools import partial from itertools import combinations_with_replacement, permutations -from typing import Callable, Optional import numpy as np from scipy.special import factorial from termcolor import colored -from desc.backend import flatnonzero, fori_loop, functools, jax, jit, jnp, take - -if jax.__version_info__ >= (0, 4, 16): - from jax.extend import linear_util as lu -else: - from jax import linear_util as lu - -from jax._src.numpy.vectorize import ( - _apply_excluded, - _check_output_dims, - _parse_gufunc_signature, - _parse_input_dimensions, -) +from desc.backend import flatnonzero, fori_loop, functools, jit, jnp, take class Timer: @@ -705,348 +691,9 @@ def broadcast_tree(tree_in, tree_out, dtype=int): raise ValueError("trees must be nested lists of dicts") -# The following section of this code is derived from the NetKet project -# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ -# netket/jax/_chunk_utils.py -# -# The original copyright notice is as follows -# Copyright 2021 The NetKet Authors - All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); - - -def _treeify(f): - def _f(x, *args, **kwargs): - return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x) - - return _f - - -@_treeify -def _unchunk(x): - return x.reshape((-1,) + x.shape[2:]) - - -@_treeify -def _chunk(x, jac_chunk_size=None): - # jac_chunk_size=None -> add just a dummy chunk dimension, - # same as np.expand_dims(x, 0) - if x.ndim == 0: - raise ValueError("x cannot be chunked as it has 0 dimensions.") - n = x.shape[0] - if jac_chunk_size is None: - jac_chunk_size = n - - n_chunks, residual = divmod(n, jac_chunk_size) - if residual != 0: - raise ValueError( - "The first dimension of x must be divisible by jac_chunk_size." - + f"\n Got x.shape={x.shape} but jac_chunk_size={jac_chunk_size}." - ) - return x.reshape((n_chunks, jac_chunk_size) + x.shape[1:]) - - -def _jac_chunk_size(x): - b = set(map(lambda x: x.shape[:2], jax.tree_util.tree_leaves(x))) - if len(b) != 1: - raise ValueError( - "The arrays in x have inconsistent jac_chunk_size or number of chunks" - ) - return b.pop()[1] - - -def unchunk(x_chunked): - """Merge the first two axes of an array (or a pytree of arrays). - - Parameters - ---------- - x_chunked: an array (or pytree of arrays) of at least 2 dimensions - - Returns - ------- - (x, chunk_fn) : tuple - where x is x_chunked reshaped to (-1,)+x.shape[2:] - and chunk_fn is a function which restores x given x_chunked - - """ - return _unchunk(x_chunked), functools.partial( - _chunk, jac_chunk_size=_jac_chunk_size(x_chunked) - ) - - -def chunk(x, jac_chunk_size=None): - """Split an array (or a pytree of arrays) into chunks along the first axis. - - Parameters - ---------- - x: an array (or pytree of arrays) - jac_chunk_size: an integer or None (default) - The first axis in x must be a multiple of jac_chunk_size - - Returns - ------- - (x_chunked, unchunk_fn): tuple - - x_chunked is x reshaped to (-1, jac_chunk_size)+x.shape[1:] - if jac_chunk_size is None then it defaults to x.shape[0], i.e. just one chunk - - unchunk_fn is a function which restores x given x_chunked - """ - return _chunk(x, jac_chunk_size), _unchunk - - -#### - -# The following section of this code is derived from the NetKet project -# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ -# netket/jax/_scanmap.py -# -# The original copyright notice is as follows -# Copyright 2021 The NetKet Authors - All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); - - -# TODO in_axes a la vmap? -def _scanmap(fun, scan_fun, argnums=0): - """A helper function to wrap f with a scan_fun.""" - - def f_(*args, **kwargs): - f = lu.wrap_init(fun, kwargs) - f_partial, dyn_args = jax.api_util.argnums_partial( - f, argnums, args, require_static_args_hashable=False - ) - return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args) - - return f_ - - -# The following section of this code is derived from the NetKet project -# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ -# netket/jax/_vmap_chunked.py -# -# The original copyright notice is as follows -# Copyright 2021 The NetKet Authors - All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); - - -def _eval_fun_in_chunks(vmapped_fun, jac_chunk_size, argnums, *args, **kwargs): - n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0] - n_chunks, n_rest = divmod(n_elements, jac_chunk_size) - - if n_chunks == 0 or jac_chunk_size >= n_elements: - y = vmapped_fun(*args, **kwargs) - else: - # split inputs - def _get_chunks(x): - x_chunks = jax.tree_util.tree_map( - lambda x_: x_[: n_elements - n_rest, ...], x - ) - x_chunks = _chunk(x_chunks, jac_chunk_size) - return x_chunks - - def _get_rest(x): - x_rest = jax.tree_util.tree_map( - lambda x_: x_[n_elements - n_rest :, ...], x - ) - return x_rest - - args_chunks = [ - _get_chunks(a) if i in argnums else a for i, a in enumerate(args) - ] - args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)] - - y_chunks = _unchunk( - _scanmap(vmapped_fun, jax.lax.scan, argnums)(*args_chunks, **kwargs) - ) - - if n_rest == 0: - y = y_chunks - else: - y_rest = vmapped_fun(*args_rest, **kwargs) - y = jax.tree_util.tree_map( - lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest - ) - return y - - -def _chunk_vmapped_function( - vmapped_fun: Callable, - jac_chunk_size: Optional[int], - argnums=0, -) -> Callable: - """Takes a vmapped function and computes it in chunks.""" - if jac_chunk_size is None: - return vmapped_fun - - if isinstance(argnums, int): - argnums = (argnums,) - return functools.partial(_eval_fun_in_chunks, vmapped_fun, jac_chunk_size, argnums) - - -def _parse_in_axes(in_axes): - if isinstance(in_axes, int): - in_axes = (in_axes,) - - if not set(in_axes).issubset((0, None)): - raise NotImplementedError("Only in_axes 0/None are currently supported") - - argnums = tuple( - map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes))) - ) - return in_axes, argnums - - -def vmap_chunked( - f: Callable, - in_axes=0, - *, - jac_chunk_size: Optional[int], -) -> Callable: - """Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks. - - Parameters - ---------- - f: The function to be vectorised. - in_axes: The axes that should be scanned along. Only supports `0` or `None` - jac_chunk_size: The maximum size of the chunks to be used. If it is `None`, - chunking is disabled - - - Returns - ------- - f: A vectorised and chunked function - """ - in_axes, argnums = _parse_in_axes(in_axes) - vmapped_fun = jax.vmap(f, in_axes=in_axes) - return _chunk_vmapped_function(vmapped_fun, jac_chunk_size, argnums) - - -def batched_vectorize( - pyfunc, *, excluded=frozenset(), signature=None, jac_chunk_size=None -): - """Define a vectorized function with broadcasting and batching. - - below is taken from JAX - FIXME: change restof docstring - :func:`vectorize` is a convenience wrapper for defining vectorized - functions with broadcasting, in the style of NumPy's - `generalized universal functions - `_. - It allows for defining functions that are automatically repeated across - any leading dimensions, without the implementation of the function needing to - be concerned about how to handle higher dimensional inputs. - - :func:`jax.numpy.vectorize` has the same interface as - :class:`numpy.vectorize`, but it is syntactic sugar for an auto-batching - transformation (:func:`vmap`) rather than a Python loop. This should be - considerably more efficient, but the implementation must be written in terms - of functions that act on JAX arrays. - - Parameters - ---------- - pyfunc: function to vectorize. - excluded: optional set of integers representing positional arguments for - which the function will not be vectorized. These will be passed directly - to ``pyfunc`` unmodified. - signature: optional generalized universal function signature, e.g., - ``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If - provided, ``pyfunc`` will be called with (and expected to return) arrays - with shapes given by the size of corresponding core dimensions. By - default, pyfunc is assumed to take scalars arrays as input and output. - jac_chunk_size: the size of the batches to pass to vmap. if 1, will only - - Returns - ------- - Vectorized version of the given function. - - """ - if any(not isinstance(exclude, (str, int)) for exclude in excluded): - raise TypeError( - "jax.numpy.vectorize can only exclude integer or string arguments, " - "but excluded={!r}".format(excluded) - ) - if any(isinstance(e, int) and e < 0 for e in excluded): - raise ValueError(f"excluded={excluded!r} contains negative numbers") - - @functools.wraps(pyfunc) - def wrapped(*args, **kwargs): - error_context = ( - "on vectorized function with excluded={!r} and " - "signature={!r}".format(excluded, signature) - ) - excluded_func, args, kwargs = _apply_excluded(pyfunc, excluded, args, kwargs) - - if signature is not None: - input_core_dims, output_core_dims = _parse_gufunc_signature(signature) - else: - input_core_dims = [()] * len(args) - output_core_dims = None - - none_args = {i for i, arg in enumerate(args) if arg is None} - if any(none_args): - if any(input_core_dims[i] != () for i in none_args): - raise ValueError( - f"Cannot pass None at locations {none_args} with {signature=}" - ) - excluded_func, args, _ = _apply_excluded(excluded_func, none_args, args, {}) - input_core_dims = [ - dim for i, dim in enumerate(input_core_dims) if i not in none_args - ] - - args = tuple(map(jnp.asarray, args)) - - broadcast_shape, dim_sizes = _parse_input_dimensions( - args, input_core_dims, error_context - ) - - checked_func = _check_output_dims( - excluded_func, dim_sizes, output_core_dims, error_context - ) - - # Rather than broadcasting all arguments to full broadcast shapes, prefer - # expanding dimensions using vmap. By pushing broadcasting - # into vmap, we can make use of more efficient batching rules for - # primitives where only some arguments are batched (e.g., for - # lax_linalg.triangular_solve), and avoid instantiating large broadcasted - # arrays. - - squeezed_args = [] - rev_filled_shapes = [] - - for arg, core_dims in zip(args, input_core_dims): - noncore_shape = arg.shape[: arg.ndim - len(core_dims)] - - pad_ndim = len(broadcast_shape) - len(noncore_shape) - filled_shape = pad_ndim * (1,) + noncore_shape - rev_filled_shapes.append(filled_shape[::-1]) - - squeeze_indices = tuple( - i for i, size in enumerate(noncore_shape) if size == 1 - ) - squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices) - squeezed_args.append(squeezed_arg) - - vectorized_func = checked_func - dims_to_expand = [] - for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)): - in_axes = tuple(None if size == 1 else 0 for size in axis_sizes) - if all(axis is None for axis in in_axes): - dims_to_expand.append(len(broadcast_shape) - 1 - negdim) - else: - # change the vmap here to chunked_vmap - vectorized_func = vmap_chunked( - vectorized_func, in_axes, jac_chunk_size=jac_chunk_size - ) - result = vectorized_func(*squeezed_args) - - if not dims_to_expand: - return result - elif isinstance(result, tuple): - return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result) - else: - return jnp.expand_dims(result, axis=dims_to_expand) - - return wrapped - - -@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}) +@functools.partial( + jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"} +) def take_mask(a, mask, /, *, size=None, fill_value=None): """JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``.