Skip to content

Commit

Permalink
Merge pull request #21 from anh-tong/assoc-scan
Browse files Browse the repository at this point in the history
add associative scan for computing signature of long paths
  • Loading branch information
anh-tong authored Sep 27, 2022
2 parents b84b573 + 87ec656 commit d817573
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
2 changes: 1 addition & 1 deletion signax/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.1.1"
34 changes: 21 additions & 13 deletions signax/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,20 @@ def _body(i, val):
return exp_term


@partial(jax.jit, static_argnames=["depth", "n_chunks"])
def signature_batch(path: jnp.ndarray, depth: int, n_chunks: int):
"""Compute signature for a long path
The path will be divided into chunks. The numbers of chunks
is set manually.
Args:
path: size (length, dim)
depth: signature depth
n_chunks:
Returns:
signature in a form of [(n,), (n,n), ...]
"""
length, dim = path.shape
chunk_length = int((length - 1) / n_chunks)
remainder = (length - 1) % n_chunks
Expand All @@ -49,7 +62,7 @@ def signature_batch(path: jnp.ndarray, depth: int, n_chunks: int):
basepoints = jnp.roll(path_bulk[:, -1], shift=1, axis=0)
basepoints = basepoints.at[0].set(path[0])
path_bulk = jnp.concatenate([basepoints[:, None, :], path_bulk], axis=1)
path_remainder = path[bulk_length:]
path_remainder = path[bulk_length - 1 :] # noqa

def _signature(path):
return signature(path, depth)
Expand Down Expand Up @@ -121,16 +134,11 @@ def multi_signature_combine(signatures: List[jnp.ndarray]):
Returns:
size [(n,), (n, n), (n, n, n), ...]
"""
batch_size = signatures[0].shape[0]

init_val = [x[0] for x in signatures]

def _body_fn(i, val):
current = [x[i] for x in signatures]
ret = mult(val, current)
return ret

combined = jax.lax.fori_loop(
lower=1, upper=batch_size, body_fun=_body_fn, init_val=init_val
result = jax.lax.associative_scan(
fn=jax.vmap(signature_combine),
elems=signatures,
)
return combined
# return the last index after associative scan
result = jax.tree_map(lambda x: x[-1], result)

return result
27 changes: 21 additions & 6 deletions test/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_signature_1d_path():


def test_multi_signature_combine():
batch_size = 10
batch_size = 5
dim = 5
signatures = [
np.random.randn(batch_size, dim),
Expand All @@ -48,23 +48,38 @@ def test_multi_signature_combine():
torch_signatures, input_channels=dim, depth=len(signatures)
)
torch_sum = torch_output.sum().item()
assert jnp.allclose(jax_sum, torch_sum)
assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1)


def test_signature_batch():
# TODO: not complete yet
# no remainder case
depth = 3

# no remainder case
length = 1001
dim = 100
n_chunks = 10

path = np.random.randn(length, dim)
jax_signature = signature_batch(path, depth, n_chunks)
jax_sum = sum(jnp.sum(x) for x in jax_signature)

signature_batch(path, depth, n_chunks)
torch_path = torch.tensor(path)
torch_signature = signatory.signature(torch_path[None, ...], depth=depth)
torch_sum = torch_signature.sum().item()

# TODO: this has a low precision error
assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1)

# has remainder case
length = 1005
path = np.random.randn(length, dim)

signature_batch(path, depth, n_chunks)
jax_signature = signature_batch(path, depth, n_chunks)
jax_sum = sum(jnp.sum(x) for x in jax_signature)

torch_path = torch.tensor(path)
torch_signature = signatory.signature(torch_path[None, ...], depth=depth)
torch_sum = torch_signature.sum().item()

# TODO: this has a low precision error
assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1)

0 comments on commit d817573

Please sign in to comment.