Skip to content

Commit

Permalink
Merge pull request #30 from anh-tong/top-level
Browse files Browse the repository at this point in the history
Move main functionality to top-level of module + add example notebooks to CI (#30)
  • Loading branch information
anh-tong authored Jun 1, 2023
2 parents 96cf22c + 8877b4d commit f177dab
Show file tree
Hide file tree
Showing 17 changed files with 206 additions and 241 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ jobs:
- name: Upload coverage report
uses: codecov/[email protected]

- name: Run examples
run: |
cd examples
python -m pip install -r requirements.txt
python -m pip install jupyter nbclient
jupyter execute estimate_hurst.ipynb
jupyter execute generative_model.ipynb
jupyter execute inversion.ipynb
dist:
needs: [pre-commit]
name: Distribution build
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,7 @@ Thumbs.db
# Common editor files
*~
*.swp

# notebooks
Untitiled.ipynb
t.ipynb
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ Basic usage
```python
import jax
import jax.random as jrandom
import signax

from signax.signature import signature

key = jrandom.PRNGKey(0)
depth = 3
Expand All @@ -41,17 +41,17 @@ depth = 3
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signature(path, depth)
output = signax.signature(path, depth)
# output is a list of array representing tensor algebra

# compute signature for batches (multiple) of paths
# this is done via `jax.vmap`
batch_size = 20
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
output = jax.vmap(lambda x: signature(x, depth))(path)
output = jax.vmap(lambda x: signax.signature(x, depth))(path)
```

Integrate with [equinox](https://github.com/patrick-kidger/equinox) library
Integrate with the [equinox](https://github.com/patrick-kidger/equinox) library

```python
import equinox as eqx
Expand Down
5 changes: 1 addition & 4 deletions examples/compare.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
"import numpy as np\n",
"import signatory\n",
"import torch\n",
"from signax.signature import signature\n",
"\n",
"# set white background\n",
"plt.rcParams[\"figure.facecolor\"] = \"white\""
"from signax import signature"
]
},
{
Expand Down
114 changes: 60 additions & 54 deletions examples/estimate_hurst.ipynb

Large diffs are not rendered by default.

110 changes: 25 additions & 85 deletions examples/generative_model.ipynb

Large diffs are not rendered by default.

62 changes: 36 additions & 26 deletions examples/inversion.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions examples/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import jax.numpy as jnp
import jax.random as jrandom

from signax import signature, signature_combine
from signax.module import SignatureTransform
from signax.signature import signature, signature_combine
from signax.utils import flatten


Expand Down Expand Up @@ -70,6 +70,8 @@ def __init__(
def __call__(
self,
x: jnp.ndarray,
*,
key=None,
):
"""x size (length, dim)"""
length, _ = x.shape
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(self, length, adjusted_length, signature_depth=2) -> None:
self.adjusted_length = adjusted_length
self.signature_depth = signature_depth

def __call__(self, x):
def __call__(self, x, *, key=None):
"""
Transform input `x` into a smaller window.
Each window starts at index 0 with increasing size according
Expand Down
2 changes: 2 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
matplotlib
optax
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,5 @@ isort.required-imports = ["from __future__ import annotations"]
[tool.ruff.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]
"src/signax/module.py" = ["ARG002"] # unused argument key in __call__ of Module (equinox)
"examples/nets.py" = ["ARG002"] # unused argument key in __call__ of Module (equinox)
23 changes: 22 additions & 1 deletion src/signax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,25 @@

__version__ = "0.1.2"

__all__ = ("__version__",)
__all__ = (
"__version__",
"module",
"utils",
"tensor_ops",
"signature",
"logsignature",
"signature_combine",
"signature_to_logsignature",
"multi_signature_combine",
"signature_batch",
)

from signax import module, tensor_ops, utils
from signax.signatures import (
logsignature,
multi_signature_combine,
signature,
signature_batch,
signature_combine,
signature_to_logsignature,
)
18 changes: 15 additions & 3 deletions src/signax/module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

from typing import Any

__all__ = (
"SignatureTransform",
"SignatureCombine",
)

import equinox as eqx
import jax

from signax.signature_flattened import signature, signature_combine
from signax.signatures import signature, signature_combine
from signax.utils import flatten, unravel_signature


class SignatureTransform(eqx.Module):
Expand All @@ -15,8 +23,10 @@ def __init__(self, depth: int):
def __call__(
self,
path: jax.Array,
*,
key: Any | None = None,
) -> jax.Array:
return signature(path, self.depth)
return flatten(signature(path, self.depth))


class SignatureCombine(eqx.Module):
Expand All @@ -28,4 +38,6 @@ def __init__(self, dim: int, depth: int):
self.depth = depth

def __call__(self, signature1: jax.Array, signature2: jax.Array):
return signature_combine(signature1, signature2, self.dim, self.depth)
sig1 = unravel_signature(signature1, self.dim, self.depth)
sig2 = unravel_signature(signature2, self.dim, self.depth)
return flatten(signature_combine(sig1, sig2))
58 changes: 0 additions & 58 deletions src/signax/signature_flattened.py

This file was deleted.

13 changes: 11 additions & 2 deletions src/signax/signature.py → src/signax/signatures.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from __future__ import annotations

__all__ = (
"signature",
"logsignature",
"signature_batch",
"signature_combine",
"signature_to_logsignature",
"multi_signature_combine",
)

from functools import partial

import jax
import jax.numpy as jnp

from .tensor_ops import log, mult, mult_fused_restricted_exp, restricted_exp
from .utils import compress, lyndon_words
from signax.tensor_ops import log, mult, mult_fused_restricted_exp, restricted_exp
from signax.utils import compress, lyndon_words


@partial(jax.jit, static_argnames="depth")
Expand Down
9 changes: 9 additions & 0 deletions src/signax/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations

__all__ = (
"index_select",
"lyndon_words",
"compress",
"unravel_signature",
"flatten",
"term_at",
)

from collections import defaultdict
from functools import partial
from typing import cast
Expand Down
2 changes: 1 addition & 1 deletion tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from numpy.random import default_rng

from signax.signature import (
from signax import (
multi_signature_combine,
signature,
signature_batch,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from numpy.random import default_rng

from signax.signature import signature, signature_to_logsignature
from signax import signature, signature_to_logsignature
from signax.tensor_ops import (
addcmul,
mult,
Expand Down

0 comments on commit f177dab

Please sign in to comment.