Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant graph_inputs usage in OpFromGraph #1306

Merged
merged 3 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 84 additions & 54 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import OrderedDict
from copy import copy
from functools import partial
from typing import List, Optional, Sequence, cast
from typing import Dict, List, Optional, Sequence, Tuple, cast

import aesara.tensor as at
from aesara import function
Expand All @@ -19,7 +19,6 @@
clone_replace,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
Expand Down Expand Up @@ -82,6 +81,81 @@ def local_traverse(out):
return ret


def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> Tuple[
FunctionGraph,
Sequence[Variable],
Dict[Variable, Variable],
Dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)

dummy_inputs.append(inp.type())

dummy_shared_inputs = []
shared_inputs = []
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, new_shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not new_shared_inputs

fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
nominal_local_inputs = tuple(
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
)

fgraph.replace_all(zip(local_inputs, nominal_local_inputs))

for i, inp in enumerate(fgraph.inputs):
nom_inp = nominal_local_inputs[i]
fgraph.inputs[i] = nom_inp
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, shared_inputs, update_d, update_expr


class OpFromGraph(Op, HasInnerGraph):
r"""
This creates an `Op` from inputs and outputs lists of variables.
Expand Down Expand Up @@ -333,66 +407,21 @@ def __init__(
if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists")

for i in inputs + outputs:
if not isinstance(i, Variable):
for out in outputs:
if not isinstance(out, Variable):
raise TypeError(
f"Inputs and outputs must be Variable instances; got {i}"
f"Inputs and outputs must be Variable instances; got {out}"
)
if i in inputs:
if isinstance(i, Constant):
raise TypeError(f"Constants not allowed as inputs; {i}")
if isinstance(i, SharedVariable):
raise TypeError(f"SharedVariables not allowed as inputs; {i}")

for var in graph_inputs(outputs, inputs):
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

if "updates" in kwargs or "givens" in kwargs:
raise NotImplementedError("Updates and givens are not allowed here")
raise NotImplementedError("Updates and givens are not supported")

self.is_inline = inline

# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = []
for var in graph_inputs(outputs):
if isinstance(var, SharedVariable):
self.shared_inputs.append(var)

inputs, outputs = replace_nominals_with_dummies(inputs, outputs)

# The inputs should be `NominalVariable`s, so that graphs can be merged
replacements = {}
for n, v in enumerate(inputs):
replacements[v] = NominalVariable(n, v.type)

shared_vars = [
NominalVariable(n, var.type)
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
]

replacements.update(dict(zip(self.shared_inputs, shared_vars)))

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars,
replace=replacements,
copy_inputs_over=False,
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert not shared_inputs

self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)

self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand All @@ -415,6 +444,7 @@ def __init__(
else:
self.set_lop_overrides("default")
self._lop_type = "lop"

self.set_rop_overrides(rop_overrides)

self._connection_pattern = connection_pattern
Expand Down
41 changes: 10 additions & 31 deletions aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@

import aesara
from aesara import tensor as at
from aesara.compile import SharedVariable
from aesara.compile.builders import infer_shape
from aesara.compile.builders import construct_nominal_fgraph, infer_shape
from aesara.compile.function.pfunc import pfunc
from aesara.compile.io import In, Out
from aesara.compile.mode import Mode, get_default_mode, get_mode
Expand All @@ -65,17 +64,13 @@
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from aesara.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
equal_computations,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import InconsistencyError, MissingInputError
from aesara.link.c.basic import CLinker
Expand Down Expand Up @@ -755,22 +750,12 @@ def __init__(
If ``True``, all the shared variables used in the inner-graph must be provided.

"""
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)

input_replacements = []
for n, v in enumerate(inputs):
if not isinstance(v, (SharedVariable, Constant)):
input_replacements.append((v, NominalVariable(n, v.type)))

assert not isinstance(v, NominalVariable)

outputs = clone_replace(outputs, replace=input_replacements)

if input_replacements:
_, inputs_ = zip(*input_replacements)
inputs = list(inputs_)
else:
inputs = []
# The shared variables should have been removed, so, if there are
# any, it's because the user didn't specify an input.
if shared_inputs:
raise MissingInputError(f"Scan is missing inputs: {shared_inputs}")

self.info = info
self.truncate_gradient = truncate_gradient
Expand All @@ -782,7 +767,7 @@ def __init__(
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
message = self.name + " sub profile"
message = f"{self.name} sub profile"
else:
message = "Scan sub profile"

Expand All @@ -805,7 +790,7 @@ def tensorConstructor(shape, dtype):
while idx < info.n_mit_mot_outs:
# Not that for mit_mot there are several output slices per
# output sequence
o = outputs[idx]
o = self.fgraph.outputs[idx]
self.output_types.append(
# TODO: What can we actually say about the shape of this
# added dimension?
Expand All @@ -818,15 +803,15 @@ def tensorConstructor(shape, dtype):
# mit_sot / sit_sot / nit_sot
end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot

for o in outputs[idx:end]:
for o in self.fgraph.outputs[idx:end]:
self.output_types.append(
# TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
)

# shared outputs + possibly the ending condition
for o in outputs[end:]:
for o in self.fgraph.outputs[end:]:
self.output_types.append(o.type)

if info.as_while:
Expand Down Expand Up @@ -862,19 +847,13 @@ def tensorConstructor(shape, dtype):
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs

self.fgraph = FunctionGraph(inputs, outputs, clone=False)

_ = self.prepare_fgraph(self.fgraph)

if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
"Inner-graphs must not contain in-place operations."
)

# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant):
raise MissingInputError(f"ScanOp is missing an input: {repr(var)}")
self._cmodule_key = CLinker().cmodule_key_variables(
self.inner_inputs, self.inner_outputs, []
)
Expand Down
4 changes: 0 additions & 4 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,6 @@ def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W):
assert np.allclose(aesara_values, v_out)

def test_oinp_iinp_iout_oout_mappings(self):
"""
Test the mapping produces by
ScanOp.get_oinp_iinp_iout_oout_mappings()
"""

rng = RandomStream(123)

Expand Down