-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a fusion rewrite for CAReduce
s with Elemwise
inputs
#1285
Add a fusion rewrite for CAReduce
s with Elemwise
inputs
#1285
Conversation
cbf33e4
to
b681459
Compare
914f7f6
to
c371651
Compare
Should we only fuse when the unreduced output has a single client, and therefore is definitely never needed? |
c371651
to
a9d8ca0
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1285 +/- ##
==========================================
+ Coverage 74.12% 74.15% +0.03%
==========================================
Files 174 174
Lines 48652 48706 +54
Branches 10366 10372 +6
==========================================
+ Hits 36064 36119 +55
- Misses 10299 10301 +2
+ Partials 2289 2286 -3
|
a9d8ca0
to
9f9f2a0
Compare
Yeah, that and a few other things need/needed to be done before this stops being a draft. I just added it now, though—along with another fix. |
Some current results: import numpy as np
import aesara
import aesara.tensor as at
from aesara.compile.mode import get_mode
fusion_mode = get_mode("FAST_RUN").including("local_careduce_fusion")
no_fusion_mode = get_mode("FAST_RUN").excluding("local_careduce_fusion")
x = at.matrix("x")
y = at.exp(x).sum(axis=1)
y_fn = aesara.function([x], y, mode=no_fusion_mode)
aesara.dprint(y_fn)
# Sum{axis=[1], acc_dtype=float64} [id A] 1
# |Elemwise{exp,no_inplace} [id B] 0
# |x [id C]
y_fusion_fn = aesara.function([x], y, mode=fusion_mode)
aesara.dprint(y_fusion_fn)
# CAReduce{Composite{(i0 + exp(i1))}}{axis=[1], acc_dtype=float64} [id A] 0
# |x [id B]
rng = np.random.default_rng(23920)
x_small_val = rng.random((10, 10))
x_large_val = rng.random((5000, 2000))
%timeit y_fn(x_small_val)
# 6.58 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit y_fn(x_large_val)
# 198 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
res = y_fn(x_large_val)
exp_res = np.exp(x_large_val).sum(axis=1)
assert res.shape == exp_res.shape
assert np.allclose(res, exp_res)
%timeit y_fusion_fn(x_small_val)
# 6.25 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit y_fusion_fn(x_large_val)
# 55.3 ms ± 826 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
res = y_fusion_fn(x_large_val)
assert res.shape == exp_res.shape
assert np.allclose(res, exp_res) |
9f9f2a0
to
01b8153
Compare
34ca8c3
to
d977ee4
Compare
d977ee4
to
d3830c4
Compare
d3830c4
to
e9839a1
Compare
- Lazily create and cache `FunctionGraph`s, the `Composite.perform` implementation, C code, and name values - Use `fgraph_to_python` for `Composite.perform` - Use the `HasInnerGraph` interface
e9839a1
to
91f3438
Compare
This PR adds fusion rewrites for
CAReduce
nodes withElemwise
-derived inputs.Composite
Op
s generated by this fusionCAReduceDtype
It's a fairly redundant subclass that probably should be merged with
CAReduce
anyway.E.g. test theaxis
parameter[ ] Support multiple inputs (optional)This will require some refactoring of
CAReduce
or a new subclass and should be split off into its own issue/PR. See FuseCAReduce
s with multi-inputElemwise
s #1307.