-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch apply_lse_kernel
for online=True
#23
Conversation
There was a bug which caused the coupling/marginals not to match |
Thanks a lot Michal! this is fantastic. |
Batch `apply_lse_kernel` for `online=True`
As discussed in #20 , this PR fixes
online
by batchingapply_lse_kernel
and running the fully-vectorized computation on a batch of shapen * batch_size
orm * batch_size
(depends of the axis) instead ofn * m
. Minor inefficiency comes from that this approach computed the kernel application for extra{n,m} % batch_size
points, given that result of each iteration ofjax.lax.scan
must have the same shape.@LaetitiaPapaxanthos there are 2 points I am unsure how you want to handle:
shall the backward compatibility be retained and allowUPDATE:backward=True
? In that case, we'd need to use some default value, that can either be fixed or depend on the number of points (from our benchmarks, value of1024
seems to work well)online=True
is the same asonline=1024
currently, the batch size is the same when usingUPDATE: kept the same batch size for both axesaxis=0
oraxis=1
, but this could be done in axis-specific mannerTODOs:
online=True
)closes #20