Skip to content

Commit

Permalink
Add a fusion rewrite for CAReduces with Elemwise inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 21, 2022
1 parent 6bb2833 commit e9839a1
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4048,7 +4048,7 @@ def __init__(self, inputs, outputs):

@property
def fn(self):
return self._fn
return None

@property
def inner_inputs(self):
Expand Down
89 changes: 87 additions & 2 deletions aesara/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,21 @@
import aesara
import aesara.scalar.basic as aes
from aesara import compile
from aesara.compile.mode import get_target_language
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, io_toposort
from aesara.graph.features import ReplaceValidate
from aesara.graph.op import compute_test_value, get_test_value
from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter
from aesara.graph.rewriting.basic import (
GraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from aesara.graph.rewriting.db import SequenceDB
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize
from aesara.tensor.shape import shape_padleft
Expand Down Expand Up @@ -944,3 +950,82 @@ def local_useless_composite(fgraph, node):
c = aes.Composite(inputs=comp.inputs, outputs=new_outputs)
e = Elemwise(scalar_op=c)(*node.inputs, return_list=True)
return dict(zip([node.outputs[i] for i in idx], e))


@node_rewriter([CAReduce])
def local_careduce_fusion(fgraph, node):
"""Fuse a `CAReduce` applied to an `Elemwise`."""

(car_input,) = node.inputs
elm_node = car_input.owner

if elm_node is None or not isinstance(elm_node.op, Elemwise):
return False

elm_inputs = elm_node.inputs
elm_outputs = elm_node.outputs

if len(elm_inputs) > 1 or len(elm_outputs) > 1:
# TODO: Implement the multiple inputs case
return False

if len(fgraph.clients[elm_outputs[0]]) > 1:
return False

# Don't form the fusion when the target language is Python
elm_scalar_op = elm_node.op.scalar_op
car_scalar_op = node.op.scalar_op

if get_target_language() == ("py",):
return False

try:
elm_scalar_op.c_code(
elm_node,
"test_presence_of_c_code",
["x" for x in elm_inputs],
["z" for z in elm_outputs],
{"fail": "%(fail)s"},
)

car_scalar_op.c_code(
node,
"test_presence_of_c_code",
["x" for x in node.inputs],
["z" for z in node.outputs],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
return False

car_axis = node.op.axis

scalar_elm_inputs = [
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
]
elm_output = elm_scalar_op(*scalar_elm_inputs)
# This input represents the previous value in the `CAReduce` binary reduction
carried_car_input = elm_output.type()
scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)]

fused_scalar_op = aes.Composite(
inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs
)

# The fused `Op` needs to look and behave like a `BinaryScalarOp`
# TODO: Generate a new `type` and make this relationship official?
fused_scalar_op.identity = car_scalar_op.identity
fused_scalar_op.nin = 2
fused_scalar_op.nout = 1

new_car_op = CAReduce(fused_scalar_op, car_axis)

return [new_car_op(*elm_inputs)]


compile.optdb.register( # type: ignore
"local_careduce_fusion",
in2out(local_careduce_fusion),
"fusion",
position=49,
)
80 changes: 80 additions & 0 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,86 @@ def test_test_values(self, test_value):
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)

@pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
def test_CAReduce_single_input(self, linker, axis):
"""Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""

mode = Mode(linker=linker)
mode._optimizer = mode._optimizer.including(
"local_careduce_fusion",
"canonicalize",
"inplace",
)

x = tensor("floatX", shape=(None, None, None), name="x")
out = exp(x).sum(axis=axis)

out_fn = function([x], out, mode=mode)

if linker != "py":
(out_node,) = out_fn.maker.fgraph.toposort()
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)

rng = np.random.default_rng(2320)
x_val = rng.random((4, 3, 2), dtype=config.floatX)

exp_res = np.exp(x_val).sum(axis=axis)

out_val = out_fn(x_val)
assert out_val.shape == exp_res.shape
assert np.allclose(out_val, exp_res)
else:
out_nodes = out_fn.maker.fgraph.toposort()
assert not any(
isinstance(out_node.op.scalar_op, aes.basic.Composite)
for out_node in out_nodes
if hasattr(out_node.op, "scalar_op")
)

# `Elemwise`s with more than one client shouldn't be rewritten
x = tensor("floatX", shape=(None, None, None), name="x")
exp_x = exp(x)
out = exp_x.sum(axis=axis) + exp(x)

out_fn = function([x], out, mode=mode)
out_nodes = out_fn.maker.fgraph.toposort()
assert not any(
isinstance(out_node.op.scalar_op, aes.basic.Composite)
for out_node in out_nodes
if hasattr(out_node.op, "scalar_op")
)

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
def test_CAReduce_multiple_inputs(self, linker, axis):
"""Make sure that `CAReduce` and `Elemwise` fusions work with multiple inputs."""

mode = Mode(linker=linker)
mode._optimizer = mode._optimizer.including(
"local_careduce_fusion",
"canonicalize",
"inplace",
)

x = tensor("floatX", shape=(None, None, None), name="x")
y = tensor("floatX", shape=(None, None, None), name="y")
out = (x + y).sum(axis=axis)

out_fn = function([x, y], out, mode=mode)
(out_node,) = out_fn.maker.fgraph.toposort()

assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)

rng = np.random.default_rng(2320)
x_val = rng.random((4, 3, 2), dtype=config.floatX)
y_val = rng.random((4, 3, 2), dtype=config.floatX)
exp_res = (x_val + y_val).sum(axis=axis)
out_val = out_fn(x_val, y_val)
assert out_val.shape == exp_res.shape
assert np.allclose(out_val, exp_res)


class TimesN(aes.basic.UnaryScalarOp):
"""
Expand Down

0 comments on commit e9839a1

Please sign in to comment.