Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move main functionality to top-level of module + add example notebooks to CI #30

Merged
merged 11 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ jobs:
- name: Upload coverage report
uses: codecov/[email protected]

- name: Run examples
run: |
cd examples
python -m pip install -r requirements.txt
python -m pip install jupyter nbclient
jupyter execute estimate_hurst.ipynb
jupyter execute generative_model.ipynb
jupyter execute inversion.ipynb

Comment on lines +73 to +81
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know much about running notebooks in CI. I wonder if this is a common practice. Anyway, I think it's okay for these notebooks since they may not take much time.

dist:
needs: [pre-commit]
name: Distribution build
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,7 @@ Thumbs.db
# Common editor files
*~
*.swp

# notebooks
Untitiled.ipynb
t.ipynb
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ Basic usage
```python
import jax
import jax.random as jrandom
import signax

from signax.signature import signature

key = jrandom.PRNGKey(0)
depth = 3
Expand All @@ -41,17 +41,17 @@ depth = 3
length = 100
dim = 20
path = jrandom.normal(shape=(length, dim), key=key)
output = signature(path, depth)
output = signax.signature(path, depth)
# output is a list of array representing tensor algebra

# compute signature for batches (multiple) of paths
# this is done via `jax.vmap`
batch_size = 20
path = jrandom.normal(shape=(batch_size, length, dim), key=key)
output = jax.vmap(lambda x: signature(x, depth))(path)
output = jax.vmap(lambda x: signax.signature(x, depth))(path)
```

Integrate with [equinox](https://github.com/patrick-kidger/equinox) library
Integrate with the [equinox](https://github.com/patrick-kidger/equinox) library

```python
import equinox as eqx
Expand Down
5 changes: 1 addition & 4 deletions examples/compare.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
"import numpy as np\n",
"import signatory\n",
"import torch\n",
"from signax.signature import signature\n",
"\n",
"# set white background\n",
"plt.rcParams[\"figure.facecolor\"] = \"white\""
"from signax import signature"
]
},
{
Expand Down
114 changes: 60 additions & 54 deletions examples/estimate_hurst.ipynb

Large diffs are not rendered by default.

110 changes: 25 additions & 85 deletions examples/generative_model.ipynb

Large diffs are not rendered by default.

62 changes: 36 additions & 26 deletions examples/inversion.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions examples/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import jax.numpy as jnp
import jax.random as jrandom

from signax import signature, signature_combine
from signax.module import SignatureTransform
from signax.signature import signature, signature_combine
from signax.utils import flatten


Expand Down Expand Up @@ -70,6 +70,8 @@ def __init__(
def __call__(
self,
x: jnp.ndarray,
*,
key=None,
):
"""x size (length, dim)"""
length, _ = x.shape
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(self, length, adjusted_length, signature_depth=2) -> None:
self.adjusted_length = adjusted_length
self.signature_depth = signature_depth

def __call__(self, x):
def __call__(self, x, *, key=None):
"""
Transform input `x` into a smaller window.
Each window starts at index 0 with increasing size according
Expand Down
2 changes: 2 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
matplotlib
optax
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,5 @@ isort.required-imports = ["from __future__ import annotations"]
[tool.ruff.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]
"src/signax/module.py" = ["ARG002"] # unused argument key in __call__ of Module (equinox)
"examples/nets.py" = ["ARG002"] # unused argument key in __call__ of Module (equinox)
23 changes: 22 additions & 1 deletion src/signax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,25 @@

__version__ = "0.1.2"

__all__ = ("__version__",)
__all__ = (
"__version__",
"module",
"utils",
"tensor_ops",
"signature",
"logsignature",
"signature_combine",
"signature_to_logsignature",
"multi_signature_combine",
"signature_batch",
)

from signax import module, tensor_ops, utils
from signax.signatures import (
logsignature,
multi_signature_combine,
signature,
signature_batch,
signature_combine,
signature_to_logsignature,
)
18 changes: 15 additions & 3 deletions src/signax/module.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

from typing import Any

__all__ = (
"SignatureTransform",
"SignatureCombine",
)

import equinox as eqx
import jax

from signax.signature_flattened import signature, signature_combine
from signax.signatures import signature, signature_combine
from signax.utils import flatten, unravel_signature


class SignatureTransform(eqx.Module):
Expand All @@ -15,8 +23,10 @@ def __init__(self, depth: int):
def __call__(
self,
path: jax.Array,
*,
key: Any | None = None,
) -> jax.Array:
return signature(path, self.depth)
return flatten(signature(path, self.depth))


class SignatureCombine(eqx.Module):
Expand All @@ -28,4 +38,6 @@ def __init__(self, dim: int, depth: int):
self.depth = depth

def __call__(self, signature1: jax.Array, signature2: jax.Array):
return signature_combine(signature1, signature2, self.dim, self.depth)
sig1 = unravel_signature(signature1, self.dim, self.depth)
sig2 = unravel_signature(signature2, self.dim, self.depth)
return flatten(signature_combine(sig1, sig2))
58 changes: 0 additions & 58 deletions src/signax/signature_flattened.py

This file was deleted.

13 changes: 11 additions & 2 deletions src/signax/signature.py → src/signax/signatures.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from __future__ import annotations

__all__ = (
"signature",
"logsignature",
"signature_batch",
"signature_combine",
"signature_to_logsignature",
"multi_signature_combine",
)

from functools import partial

import jax
import jax.numpy as jnp

from .tensor_ops import log, mult, mult_fused_restricted_exp, restricted_exp
from .utils import compress, lyndon_words
from signax.tensor_ops import log, mult, mult_fused_restricted_exp, restricted_exp
from signax.utils import compress, lyndon_words


@partial(jax.jit, static_argnames="depth")
Expand Down
9 changes: 9 additions & 0 deletions src/signax/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations

__all__ = (
"index_select",
"lyndon_words",
"compress",
"unravel_signature",
"flatten",
"term_at",
)

from collections import defaultdict
from functools import partial
from typing import cast
Expand Down
2 changes: 1 addition & 1 deletion tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from numpy.random import default_rng

from signax.signature import (
from signax import (
multi_signature_combine,
signature,
signature_batch,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from numpy.random import default_rng

from signax.signature import signature, signature_to_logsignature
from signax import signature, signature_to_logsignature
from signax.tensor_ops import (
addcmul,
mult,
Expand Down