From a80132e05700d1b7b1fef639d1b1030888aa1859 Mon Sep 17 00:00:00 2001 From: anh tong Date: Tue, 27 Sep 2022 20:21:34 +0900 Subject: [PATCH 1/5] use jax associative_scan --- signax/signature.py | 33 ++++++++++++++++++++------------- test/test_signature.py | 24 +++++++++++++++++++----- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/signax/signature.py b/signax/signature.py index ff955e0..c3a5739 100644 --- a/signax/signature.py +++ b/signax/signature.py @@ -39,6 +39,18 @@ def _body(i, val): def signature_batch(path: jnp.ndarray, depth: int, n_chunks: int): + """Compute signature for a long path + + The path will be divided into chunks. The numbers of chunks + is set manually. + + Args: + path: size (length, dim) + depth: signature depth + n_chunks: + Returns: + signature in a form of [(n,), (n,n), ...] + """ length, dim = path.shape chunk_length = int((length - 1) / n_chunks) remainder = (length - 1) % n_chunks @@ -49,7 +61,7 @@ def signature_batch(path: jnp.ndarray, depth: int, n_chunks: int): basepoints = jnp.roll(path_bulk[:, -1], shift=1, axis=0) basepoints = basepoints.at[0].set(path[0]) path_bulk = jnp.concatenate([basepoints[:, None, :], path_bulk], axis=1) - path_remainder = path[bulk_length:] + path_remainder = path[bulk_length - 1 :] # noqa def _signature(path): return signature(path, depth) @@ -121,16 +133,11 @@ def multi_signature_combine(signatures: List[jnp.ndarray]): Returns: size [(n,), (n, n), (n, n, n), ...] """ - batch_size = signatures[0].shape[0] - - init_val = [x[0] for x in signatures] - - def _body_fn(i, val): - current = [x[i] for x in signatures] - ret = mult(val, current) - return ret - - combined = jax.lax.fori_loop( - lower=1, upper=batch_size, body_fun=_body_fn, init_val=init_val + result = jax.lax.associative_scan( + fn=jax.vmap(signature_combine), + elems=signatures, ) - return combined + # return the last index after associative scan + result = jax.tree_map(lambda x: x[-1], result) + + return result diff --git a/test/test_signature.py b/test/test_signature.py index 07e16ab..16878b1 100644 --- a/test/test_signature.py +++ b/test/test_signature.py @@ -24,7 +24,7 @@ def test_signature_1d_path(): def test_multi_signature_combine(): - batch_size = 10 + batch_size = 5 dim = 5 signatures = [ np.random.randn(batch_size, dim), @@ -52,19 +52,33 @@ def test_multi_signature_combine(): def test_signature_batch(): - # TODO: not complete yet - # no remainder case depth = 3 + # no remainder case length = 1001 dim = 100 n_chunks = 10 + path = np.random.randn(length, dim) + jax_signature = signature_batch(path, depth, n_chunks) + jax_sum = sum(jnp.sum(x) for x in jax_signature) - signature_batch(path, depth, n_chunks) + torch_path = torch.tensor(path) + torch_signature = signatory.signature(torch_path[None, ...], depth=depth) + torch_sum = torch_signature.sum().item() + + # TODO: this has a low precision error + assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1) # has remainder case length = 1005 path = np.random.randn(length, dim) - signature_batch(path, depth, n_chunks) + jax_signature = signature_batch(path, depth, n_chunks) + jax_sum = sum(jnp.sum(x) for x in jax_signature) + + torch_path = torch.tensor(path) + torch_signature = signatory.signature(torch_path[None, ...], depth=depth) + torch_sum = torch_signature.sum().item() + + assert jnp.allclose(jax_sum, torch_sum, rtol=1e-3, atol=1e-5) From 09f68ea984d940695de60e3d41ba7fcfc1223597 Mon Sep 17 00:00:00 2001 From: anh tong Date: Tue, 27 Sep 2022 20:22:29 +0900 Subject: [PATCH 2/5] bump version --- signax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/signax/__init__.py b/signax/__init__.py index 3dc1f76..485f44a 100644 --- a/signax/__init__.py +++ b/signax/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.1.1" From 40740ed134007a0bcf2efc2cbf23eec9c8624e68 Mon Sep 17 00:00:00 2001 From: anh tong Date: Tue, 27 Sep 2022 20:59:23 +0900 Subject: [PATCH 3/5] add jit --- signax/signature.py | 1 + 1 file changed, 1 insertion(+) diff --git a/signax/signature.py b/signax/signature.py index c3a5739..effaba4 100644 --- a/signax/signature.py +++ b/signax/signature.py @@ -38,6 +38,7 @@ def _body(i, val): return exp_term +@partial(jax.jit, static_argnames=["depth", "n_chunks"]) def signature_batch(path: jnp.ndarray, depth: int, n_chunks: int): """Compute signature for a long path From f015c4b3adb45190fcbe2d42ebf8d8f5b3732997 Mon Sep 17 00:00:00 2001 From: anh tong Date: Tue, 27 Sep 2022 21:13:32 +0900 Subject: [PATCH 4/5] lower error tolerance --- test/test_signature.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_signature.py b/test/test_signature.py index 16878b1..e85957d 100644 --- a/test/test_signature.py +++ b/test/test_signature.py @@ -81,4 +81,5 @@ def test_signature_batch(): torch_signature = signatory.signature(torch_path[None, ...], depth=depth) torch_sum = torch_signature.sum().item() - assert jnp.allclose(jax_sum, torch_sum, rtol=1e-3, atol=1e-5) + # TODO: this has a low precision error + assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1) From 87ec6569fbe80feb69b926181ea46edd18cd6c0b Mon Sep 17 00:00:00 2001 From: anh tong Date: Tue, 27 Sep 2022 21:19:57 +0900 Subject: [PATCH 5/5] set tolerance again --- test/test_signature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_signature.py b/test/test_signature.py index e85957d..2669a78 100644 --- a/test/test_signature.py +++ b/test/test_signature.py @@ -48,7 +48,7 @@ def test_multi_signature_combine(): torch_signatures, input_channels=dim, depth=len(signatures) ) torch_sum = torch_output.sum().item() - assert jnp.allclose(jax_sum, torch_sum) + assert jnp.allclose(jax_sum, torch_sum, rtol=1e-2, atol=1e-1) def test_signature_batch():