Skip to content

Commit

Permalink
place batcehd_vectorize utils in own file
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici committed Sep 9, 2024
1 parent 51e3e40 commit da359d5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 358 deletions.
2 changes: 1 addition & 1 deletion desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from desc.utils import (
PRINT_WIDTH,
Timer,
batched_vectorize,
errorif,
flatten_list,
is_broadcastable,
setdefault,
unique_list,
warnif,
)
from desc.utils_batched_vectorize import batched_vectorize


class ObjectiveFunction(IOAble):
Expand Down
361 changes: 4 additions & 357 deletions desc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,13 @@

import operator
import warnings
from functools import partial
from itertools import combinations_with_replacement, permutations
from typing import Callable, Optional

import numpy as np
from scipy.special import factorial
from termcolor import colored

from desc.backend import flatnonzero, fori_loop, functools, jax, jit, jnp, take

if jax.__version_info__ >= (0, 4, 16):
from jax.extend import linear_util as lu
else:
from jax import linear_util as lu

from jax._src.numpy.vectorize import (
_apply_excluded,
_check_output_dims,
_parse_gufunc_signature,
_parse_input_dimensions,
)
from desc.backend import flatnonzero, fori_loop, functools, jit, jnp, take


class Timer:
Expand Down Expand Up @@ -705,348 +691,9 @@ def broadcast_tree(tree_in, tree_out, dtype=int):
raise ValueError("trees must be nested lists of dicts")


# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_chunk_utils.py
#
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");


def _treeify(f):
def _f(x, *args, **kwargs):
return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x)

return _f


@_treeify
def _unchunk(x):
return x.reshape((-1,) + x.shape[2:])


@_treeify
def _chunk(x, jac_chunk_size=None):
# jac_chunk_size=None -> add just a dummy chunk dimension,
# same as np.expand_dims(x, 0)
if x.ndim == 0:
raise ValueError("x cannot be chunked as it has 0 dimensions.")
n = x.shape[0]
if jac_chunk_size is None:
jac_chunk_size = n

n_chunks, residual = divmod(n, jac_chunk_size)
if residual != 0:
raise ValueError(
"The first dimension of x must be divisible by jac_chunk_size."
+ f"\n Got x.shape={x.shape} but jac_chunk_size={jac_chunk_size}."
)
return x.reshape((n_chunks, jac_chunk_size) + x.shape[1:])


def _jac_chunk_size(x):
b = set(map(lambda x: x.shape[:2], jax.tree_util.tree_leaves(x)))
if len(b) != 1:
raise ValueError(
"The arrays in x have inconsistent jac_chunk_size or number of chunks"
)
return b.pop()[1]


def unchunk(x_chunked):
"""Merge the first two axes of an array (or a pytree of arrays).
Parameters
----------
x_chunked: an array (or pytree of arrays) of at least 2 dimensions
Returns
-------
(x, chunk_fn) : tuple
where x is x_chunked reshaped to (-1,)+x.shape[2:]
and chunk_fn is a function which restores x given x_chunked
"""
return _unchunk(x_chunked), functools.partial(
_chunk, jac_chunk_size=_jac_chunk_size(x_chunked)
)


def chunk(x, jac_chunk_size=None):
"""Split an array (or a pytree of arrays) into chunks along the first axis.
Parameters
----------
x: an array (or pytree of arrays)
jac_chunk_size: an integer or None (default)
The first axis in x must be a multiple of jac_chunk_size
Returns
-------
(x_chunked, unchunk_fn): tuple
- x_chunked is x reshaped to (-1, jac_chunk_size)+x.shape[1:]
if jac_chunk_size is None then it defaults to x.shape[0], i.e. just one chunk
- unchunk_fn is a function which restores x given x_chunked
"""
return _chunk(x, jac_chunk_size), _unchunk


####

# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_scanmap.py
#
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");


# TODO in_axes a la vmap?
def _scanmap(fun, scan_fun, argnums=0):
"""A helper function to wrap f with a scan_fun."""

def f_(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = jax.api_util.argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)

return f_


# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_vmap_chunked.py
#
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");


def _eval_fun_in_chunks(vmapped_fun, jac_chunk_size, argnums, *args, **kwargs):
n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0]
n_chunks, n_rest = divmod(n_elements, jac_chunk_size)

if n_chunks == 0 or jac_chunk_size >= n_elements:
y = vmapped_fun(*args, **kwargs)
else:
# split inputs
def _get_chunks(x):
x_chunks = jax.tree_util.tree_map(
lambda x_: x_[: n_elements - n_rest, ...], x
)
x_chunks = _chunk(x_chunks, jac_chunk_size)
return x_chunks

def _get_rest(x):
x_rest = jax.tree_util.tree_map(
lambda x_: x_[n_elements - n_rest :, ...], x
)
return x_rest

args_chunks = [
_get_chunks(a) if i in argnums else a for i, a in enumerate(args)
]
args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)]

y_chunks = _unchunk(
_scanmap(vmapped_fun, jax.lax.scan, argnums)(*args_chunks, **kwargs)
)

if n_rest == 0:
y = y_chunks
else:
y_rest = vmapped_fun(*args_rest, **kwargs)
y = jax.tree_util.tree_map(
lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest
)
return y


def _chunk_vmapped_function(
vmapped_fun: Callable,
jac_chunk_size: Optional[int],
argnums=0,
) -> Callable:
"""Takes a vmapped function and computes it in chunks."""
if jac_chunk_size is None:
return vmapped_fun

if isinstance(argnums, int):
argnums = (argnums,)
return functools.partial(_eval_fun_in_chunks, vmapped_fun, jac_chunk_size, argnums)


def _parse_in_axes(in_axes):
if isinstance(in_axes, int):
in_axes = (in_axes,)

if not set(in_axes).issubset((0, None)):
raise NotImplementedError("Only in_axes 0/None are currently supported")

argnums = tuple(
map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes)))
)
return in_axes, argnums


def vmap_chunked(
f: Callable,
in_axes=0,
*,
jac_chunk_size: Optional[int],
) -> Callable:
"""Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.
Parameters
----------
f: The function to be vectorised.
in_axes: The axes that should be scanned along. Only supports `0` or `None`
jac_chunk_size: The maximum size of the chunks to be used. If it is `None`,
chunking is disabled
Returns
-------
f: A vectorised and chunked function
"""
in_axes, argnums = _parse_in_axes(in_axes)
vmapped_fun = jax.vmap(f, in_axes=in_axes)
return _chunk_vmapped_function(vmapped_fun, jac_chunk_size, argnums)


def batched_vectorize(
pyfunc, *, excluded=frozenset(), signature=None, jac_chunk_size=None
):
"""Define a vectorized function with broadcasting and batching.
below is taken from JAX
FIXME: change restof docstring
:func:`vectorize` is a convenience wrapper for defining vectorized
functions with broadcasting, in the style of NumPy's
`generalized universal functions
<https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html>`_.
It allows for defining functions that are automatically repeated across
any leading dimensions, without the implementation of the function needing to
be concerned about how to handle higher dimensional inputs.
:func:`jax.numpy.vectorize` has the same interface as
:class:`numpy.vectorize`, but it is syntactic sugar for an auto-batching
transformation (:func:`vmap`) rather than a Python loop. This should be
considerably more efficient, but the implementation must be written in terms
of functions that act on JAX arrays.
Parameters
----------
pyfunc: function to vectorize.
excluded: optional set of integers representing positional arguments for
which the function will not be vectorized. These will be passed directly
to ``pyfunc`` unmodified.
signature: optional generalized universal function signature, e.g.,
``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If
provided, ``pyfunc`` will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.
jac_chunk_size: the size of the batches to pass to vmap. if 1, will only
Returns
-------
Vectorized version of the given function.
"""
if any(not isinstance(exclude, (str, int)) for exclude in excluded):
raise TypeError(
"jax.numpy.vectorize can only exclude integer or string arguments, "
"but excluded={!r}".format(excluded)
)
if any(isinstance(e, int) and e < 0 for e in excluded):
raise ValueError(f"excluded={excluded!r} contains negative numbers")

@functools.wraps(pyfunc)
def wrapped(*args, **kwargs):
error_context = (
"on vectorized function with excluded={!r} and "
"signature={!r}".format(excluded, signature)
)
excluded_func, args, kwargs = _apply_excluded(pyfunc, excluded, args, kwargs)

if signature is not None:
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
else:
input_core_dims = [()] * len(args)
output_core_dims = None

none_args = {i for i, arg in enumerate(args) if arg is None}
if any(none_args):
if any(input_core_dims[i] != () for i in none_args):
raise ValueError(
f"Cannot pass None at locations {none_args} with {signature=}"
)
excluded_func, args, _ = _apply_excluded(excluded_func, none_args, args, {})
input_core_dims = [
dim for i, dim in enumerate(input_core_dims) if i not in none_args
]

args = tuple(map(jnp.asarray, args))

broadcast_shape, dim_sizes = _parse_input_dimensions(
args, input_core_dims, error_context
)

checked_func = _check_output_dims(
excluded_func, dim_sizes, output_core_dims, error_context
)

# Rather than broadcasting all arguments to full broadcast shapes, prefer
# expanding dimensions using vmap. By pushing broadcasting
# into vmap, we can make use of more efficient batching rules for
# primitives where only some arguments are batched (e.g., for
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
# arrays.

squeezed_args = []
rev_filled_shapes = []

for arg, core_dims in zip(args, input_core_dims):
noncore_shape = arg.shape[: arg.ndim - len(core_dims)]

pad_ndim = len(broadcast_shape) - len(noncore_shape)
filled_shape = pad_ndim * (1,) + noncore_shape
rev_filled_shapes.append(filled_shape[::-1])

squeeze_indices = tuple(
i for i, size in enumerate(noncore_shape) if size == 1
)
squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices)
squeezed_args.append(squeezed_arg)

vectorized_func = checked_func
dims_to_expand = []
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
if all(axis is None for axis in in_axes):
dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
else:
# change the vmap here to chunked_vmap
vectorized_func = vmap_chunked(
vectorized_func, in_axes, jac_chunk_size=jac_chunk_size
)
result = vectorized_func(*squeezed_args)

if not dims_to_expand:
return result
elif isinstance(result, tuple):
return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result)
else:
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped


@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"})
@functools.partial(
jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}
)
def take_mask(a, mask, /, *, size=None, fill_value=None):
"""JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``.
Expand Down

0 comments on commit da359d5

Please sign in to comment.