Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fusion rewrite for CAReduces with Elemwise inputs #1285

Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Nov 4, 2022

This PR adds fusion rewrites for CAReduce nodes with Elemwise-derived inputs.

  • Make the Python backend work for the Composite Ops generated by this fusion
  • Do something about CAReduceDtype
    It's a fairly redundant subclass that probably should be merged with CAReduce anyway.
  • Add more/better tests
    • E.g. test the axis parameter
  • Consider only performing the rewrite when not using the Python backend (for performance reasons)
  • [ ] Support multiple inputs (optional)
    This will require some refactoring of CAReduce or a new subclass and should be split off into its own issue/PR. See Fuse CAReduces with multi-input Elemwises #1307.

@brandonwillard brandonwillard marked this pull request as draft November 4, 2022 22:38
@brandonwillard brandonwillard linked an issue Nov 4, 2022 that may be closed by this pull request
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch 2 times, most recently from cbf33e4 to b681459 Compare November 4, 2022 22:56
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch 5 times, most recently from 914f7f6 to c371651 Compare November 6, 2022 05:18
@ricardoV94
Copy link
Contributor

Should we only fuse when the unreduced output has a single client, and therefore is definitely never needed?

@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch from c371651 to a9d8ca0 Compare November 6, 2022 14:50
@codecov
Copy link

codecov bot commented Nov 6, 2022

Codecov Report

Merging #1285 (91f3438) into main (3ad936f) will increase coverage by 0.03%.
The diff coverage is 94.53%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1285      +/-   ##
==========================================
+ Coverage   74.12%   74.15%   +0.03%     
==========================================
  Files         174      174              
  Lines       48652    48706      +54     
  Branches    10366    10372       +6     
==========================================
+ Hits        36064    36119      +55     
- Misses      10299    10301       +2     
+ Partials     2289     2286       -3     
Impacted Files Coverage Δ
aesara/compile/function/pfunc.py 84.18% <ø> (-0.24%) ⬇️
aesara/compile/function/types.py 79.16% <75.00%> (+0.16%) ⬆️
aesara/tensor/elemwise.py 88.07% <90.54%> (-0.52%) ⬇️
aesara/tensor/rewriting/elemwise.py 86.40% <94.44%> (+0.65%) ⬆️
aesara/scalar/basic.py 79.02% <95.16%> (+0.10%) ⬆️
aesara/compile/mode.py 84.47% <100.00%> (+1.22%) ⬆️
aesara/tensor/math.py 90.40% <100.00%> (+0.37%) ⬆️

@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch from a9d8ca0 to 9f9f2a0 Compare November 6, 2022 18:17
@brandonwillard
Copy link
Member Author

Should we only fuse when the unreduced output has a single client, and therefore is definitely never needed?

Yeah, that and a few other things need/needed to be done before this stops being a draft. I just added it now, though—along with another fix.

@brandonwillard
Copy link
Member Author

brandonwillard commented Nov 6, 2022

Some current results:

import numpy as np

import aesara
import aesara.tensor as at

from aesara.compile.mode import get_mode


fusion_mode = get_mode("FAST_RUN").including("local_careduce_fusion")
no_fusion_mode = get_mode("FAST_RUN").excluding("local_careduce_fusion")


x = at.matrix("x")
y = at.exp(x).sum(axis=1)

y_fn = aesara.function([x], y, mode=no_fusion_mode)

aesara.dprint(y_fn)
# Sum{axis=[1], acc_dtype=float64} [id A] 1
#  |Elemwise{exp,no_inplace} [id B] 0
#    |x [id C]

y_fusion_fn = aesara.function([x], y, mode=fusion_mode)

aesara.dprint(y_fusion_fn)
# CAReduce{Composite{(i0 + exp(i1))}}{axis=[1], acc_dtype=float64} [id A] 0
#  |x [id B]

rng = np.random.default_rng(23920)

x_small_val = rng.random((10, 10))
x_large_val = rng.random((5000, 2000))

%timeit y_fn(x_small_val)
# 6.58 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit y_fn(x_large_val)
# 198 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

res = y_fn(x_large_val)
exp_res = np.exp(x_large_val).sum(axis=1)
assert res.shape == exp_res.shape
assert np.allclose(res, exp_res)

%timeit y_fusion_fn(x_small_val)
# 6.25 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit y_fusion_fn(x_large_val)
# 55.3 ms ± 826 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

res = y_fusion_fn(x_large_val)
assert res.shape == exp_res.shape
assert np.allclose(res, exp_res)

@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch from 9f9f2a0 to 01b8153 Compare November 11, 2022 23:55
@brandonwillard brandonwillard self-assigned this Nov 20, 2022
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch 2 times, most recently from 34ca8c3 to d977ee4 Compare November 21, 2022 01:53
@brandonwillard brandonwillard marked this pull request as ready for review November 21, 2022 01:58
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch from d977ee4 to d3830c4 Compare November 21, 2022 02:02
@brandonwillard brandonwillard requested a review from rlouf November 21, 2022 05:22
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch from d3830c4 to e9839a1 Compare November 21, 2022 18:21
- Lazily create and cache `FunctionGraph`s, the `Composite.perform`
  implementation, C code, and name values
- Use `fgraph_to_python` for `Composite.perform`
- Use the `HasInnerGraph` interface
@brandonwillard brandonwillard force-pushed the fuse-CAReduce-and-Elemwise branch from e9839a1 to 91f3438 Compare November 22, 2022 01:21
@brandonwillard brandonwillard merged commit ae20174 into aesara-devs:main Nov 22, 2022
@brandonwillard brandonwillard deleted the fuse-CAReduce-and-Elemwise branch November 22, 2022 15:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fuse CAReduces and Elemwises
2 participants