-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from Joshuaalbert/incremental-api
Incremental api
- Loading branch information
Showing
10 changed files
with
1,054 additions
and
191 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
317 changes: 317 additions & 0 deletions
317
docs/examples/excitable_damped_harmonic_oscillator.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Tuple, Callable, TypeVar | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax import lax | ||
|
||
PT = TypeVar('PT') | ||
|
||
|
||
def pytree_unravel(example_tree: PT) -> Tuple[Callable[[PT], jax.Array], Callable[[jax.Array], PT]]: | ||
""" | ||
Returns functions to ravel and unravel a pytree. | ||
Args: | ||
example_tree: a pytree to be unravelled | ||
Returns: | ||
ravel_fun: a function to ravel the pytree | ||
unravel_fun: a function to unravel | ||
""" | ||
leaf_list, tree_def = jax.tree.flatten(example_tree) | ||
|
||
sizes = [np.size(leaf) for leaf in leaf_list] | ||
shapes = [np.shape(leaf) for leaf in leaf_list] | ||
dtypes = [leaf.dtype for leaf in leaf_list] | ||
|
||
def ravel_fun(pytree: PT) -> jax.Array: | ||
leaf_list, tree_def = jax.tree.flatten(pytree) | ||
# promote types to common one | ||
common_dtype = jnp.result_type(*dtypes) | ||
leaf_list = [leaf.astype(common_dtype) for leaf in leaf_list] | ||
return jnp.concatenate([lax.reshape(leaf, (size,)) for leaf, size in zip(leaf_list, sizes)]) | ||
|
||
def unravel_fun(flat_array: jax.Array) -> PT: | ||
leaf_list = [] | ||
start = 0 | ||
for size, shape, dtype in zip(sizes, shapes, dtypes): | ||
leaf_list.append(lax.reshape(flat_array[start:start + size], shape).astype(dtype)) | ||
start += size | ||
return jax.tree.unflatten(tree_def, leaf_list) | ||
|
||
return ravel_fun, unravel_fun |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Tuple, NamedTuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
|
||
class SparseRepresentation(NamedTuple): | ||
shape: Tuple[int, ...] | ||
rows: jax.Array | ||
cols: jax.Array | ||
vals: jax.Array | ||
|
||
|
||
def create_sparse_rep(m: np.ndarray) -> SparseRepresentation: | ||
""" | ||
Creates a sparse rep from matrix m. Use in linear models with materialise_jacobian=False for 2x speed up. | ||
Args: | ||
m: [N,M] matrix | ||
Returns: | ||
sparse rep | ||
""" | ||
rows, cols = np.where(m) | ||
sort_indices = np.lexsort((cols, rows)) | ||
rows = rows[sort_indices] | ||
cols = cols[sort_indices] | ||
return SparseRepresentation( | ||
shape=np.shape(m), | ||
rows=jnp.asarray(rows), | ||
cols=jnp.asarray(cols), | ||
vals=jnp.asarray(m[rows, cols]) | ||
) | ||
|
||
|
||
def to_dense(m: SparseRepresentation, out: jax.Array | None = None) -> jax.Array: | ||
""" | ||
Form dense matrix. | ||
Args: | ||
m: sparse rep | ||
out: output buffer | ||
Returns: | ||
out + M | ||
""" | ||
if out is None: | ||
out = jnp.zeros(m.shape, m.vals.dtype) | ||
|
||
return out.at[m.rows, m.cols].add(m.vals, unique_indices=True, indices_are_sorted=True) | ||
|
||
|
||
def matvec_sparse(m: SparseRepresentation, v: jax.Array, out: jax.Array | None = None) -> jax.Array: | ||
""" | ||
Compute matmul for sparse rep. Speeds up large sparse linear models by about 2x. | ||
Args: | ||
m: sparse rep | ||
v: vec | ||
out: output buffer to add to. | ||
Returns: | ||
out + M @ v | ||
""" | ||
if out is None: | ||
out = jnp.zeros(m.shape[0]) | ||
return out.at[m.rows].add(m.vals * v[m.cols], unique_indices=True, indices_are_sorted=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.