Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add associative scan for computing signature of long paths #21

Merged
merged 5 commits into from
Sep 27, 2022

Conversation

anh-tong
Copy link
Owner

Improve the signature_batch for computing signature of long paths

  • Previous: sequentially combine between chunks.
  • This update: use jax.lax.associate_scan as we can combine signature between chunks without any order, such a combination is done parallelly (see doc)

Thanks @patrick-kidger for the suggestions.

@anh-tong anh-tong merged commit d817573 into main Sep 27, 2022
@anh-tong anh-tong deleted the assoc-scan branch September 27, 2022 12:33
@patrick-kidger
Copy link

FYI I think this means you can remove signature_batch altogether. Just do an associate scan down the whole length.

@anh-tong
Copy link
Owner Author

@patrick-kidger I just realize we can do that with the exponent operator. We will update this later. Thanks!

@patrick-kidger
Copy link

More specifically you can do it with the operator A, b -> A \otimes exp(b). (Which itself can be computed more efficiently than just composing exp and \otimes; see the Signatory paper if you haven't seen it already.)

@anh-tong
Copy link
Owner Author

Can you expand on this point more?

For now, I'm still thinking how to implement f: A, b -> A \otimes exp(b) efficiently by scanning down the whole length using jax.lax.associative_scan.

First, it's not directly applicable to use f right away with current jax.lax.associative_scan because A and b have different structure (A is in a tensor algebra while b is a vector).

However, I think that this can be done by using a new version

def fn(Ab1, Ab2):
   A1,b1,flag1 = Ab1
   A2,b2,flag2 = Ab2
   # flag1, flag2 to check if A1, A2 are computed
   # based on the flag we can choose
   # to use A1 \otimes exp(b2)
   # or to use A1 \otimes A2
   A_combined = ...
   return (A_combined, b2, 1)

The computation complexity of A \otimes \exp(b) is O(d^N) while the computation complexity within tensor algebra A1 \otimes A2 is O(N d^N). (d: number of channels, N: signature depth)

The current JAX associative scan follows the computation graph (+ sign is when the computation happens - taken from wiki )
300px-Prefix_sum_16 svg

The first and the last row of + signs use A \otimes \exp(b) which takes O(L d^N) in parallel. The middle rows of + signs require to compute A1 \otimes A2 which takes roughly O(L N d^N) in parallel and is the main computation here. Here L is the length of paths.

Using jax.lax.associative_scan or a prefix sum does not take all advantage of A \otimes exp(b) as the length of a chunk is 2.
What I think is that A \otimes exp(b) is beneficial when the length of a chunk is "reasonable". AndA1 \otimes A2 is still inevitable.

I do think your suggestion is a way to go to compute signatures for expanding intervals (stream in Signatory). This implementation in JAX is missing such a feature.

@patrick-kidger
Copy link

Oh, right. Hmm. I'd definitely need to think about this some more / don't really have the time to. Good luck though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants