Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor updates functionality and Scan inner-graph compilation #848

Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Mar 8, 2022

This PR consists of independent FunctionGraph, Scan, and Function construction changes that have been split off from #824.

One of the main results of these changes is that Scan's inner-graph is handled almost exclusively as a single coherent FunctionGraph, and the inner-graph compilation and Scan.perform logic is much more concise.

For instance, when pre-allocation is applied, Scan currently removes/reorders some of its inner-graph outputs just before compiling. This is necessary because the Function compilation pipeline is somewhat rigid with respect to the way it handles updates (e.g. it assumes that all update steps are at the end of the outputs list of a FunctionGraph). Those issues have been fixed in this PR and, now, a Scan's inner-graph is consistent for the lifetime of the Scan and requires no manipulation before compilation.

This PR also adds some long-overdue refactoring of the VM interface and implementations. Previously, the CVM was the only VM that would perform updates, and Function would perform the updates for all other VMs. This meant that the updating logic needed to be reproduced in Scan.perform (see below)—mostly because Scan.perform was (perhaps unnecessarily) designed to use a Function's VM and not the Function itself. Now, every VM implementation performs updates, so there's no need for that logic/scope leak in Scan (or Function for that matter, but I've left it in for now).

  • Look into removing the manual update logic from Scan.perform (and its Cython implementation).
    This would be another big step toward simplifying the Op.perform logic in Scan.

@brandonwillard brandonwillard added enhancement New feature or request refactor This issue involves refactoring Scan Involves the `Scan` `Op` labels Mar 8, 2022
@brandonwillard brandonwillard self-assigned this Mar 8, 2022
@brandonwillard brandonwillard force-pushed the more-inner-graph-updates branch 5 times, most recently from 028787a to 18c1b10 Compare March 13, 2022 02:23
@brandonwillard brandonwillard force-pushed the more-inner-graph-updates branch 14 times, most recently from bda1cd2 to a412af3 Compare April 5, 2022 23:50
@codecov
Copy link

codecov bot commented Apr 6, 2022

Codecov Report

Merging #848 (36e29a4) into main (344ff6b) will increase coverage by 0.02%.
The diff coverage is 86.71%.

❗ Current head 36e29a4 differs from pull request most recent head fba4184. Consider uploading reports for the commit fba4184 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #848      +/-   ##
==========================================
+ Coverage   78.99%   79.02%   +0.02%     
==========================================
  Files         152      152              
  Lines       47809    47947     +138     
  Branches    10893    10900       +7     
==========================================
+ Hits        37768    37889     +121     
- Misses       7541     7547       +6     
- Partials     2500     2511      +11     
Impacted Files Coverage Δ
aesara/compile/io.py 83.01% <0.00%> (ø)
aesara/configdefaults.py 66.48% <ø> (+0.09%) ⬆️
aesara/link/jax/dispatch.py 80.34% <ø> (-0.52%) ⬇️
aesara/misc/check_blas.py 0.00% <ø> (ø)
aesara/scan/basic.py 85.24% <ø> (+0.15%) ⬆️
aesara/compile/debugmode.py 60.45% <50.00%> (-0.16%) ⬇️
aesara/printing.py 48.58% <50.00%> (-3.08%) ⬇️
aesara/tensor/basic_opt.py 85.10% <54.54%> (+0.01%) ⬆️
aesara/gradient.py 77.23% <71.42%> (ø)
aesara/link/c/basic.py 87.36% <75.00%> (ø)
... and 52 more

@brandonwillard brandonwillard marked this pull request as ready for review April 6, 2022 00:59
@brandonwillard brandonwillard changed the title More inner-graph updates Refactor updates functionality and Scan inner-graph compilation Apr 6, 2022
@brandonwillard brandonwillard force-pushed the more-inner-graph-updates branch from a412af3 to c5f4093 Compare April 7, 2022 00:50
These changes allow one to pass an `fgraph` argument to all key functions in the
`Function` compilation pipeline.  The given `FunctionGraph` will be directly
compiled without cloning.  Unlike the previous `FunctionMaker.__init__`, this
one's `fgraph` argument *will* be optimized according to the given `mode` unless
the keyword argument `no_fgraph_prep` is `True`.
A few core `Feature` implementations were using `assert`s and generic
`Exception` instead of `AlreadyThere`, which unnecessarily produces errors
when already-attached `Feature` requirements are added by `Rewriter`s.
This also prevents `SeqOptimizer.apply` from calling `Rewriter.add_requirements`
on each iteration through its list of sub-`Rewriter`s.
This changes the `update_storage` parameter from a list containing
the input indices that are to be updated with the last N-many outputs
to a tuple of tuples specifying input/output indices.

Now, arbitrary output-to-input update pairings are possible, instead of
forcing graphs and code to compensate for this unnecessary restriction.
In order to use shared updates to pre-allocate the storage for mit-mot input and
output loops, `Scan` would need to remove the corresponding mit-mot outputs from
its inner-`FunctionGraph` before compilation and it would expect the `Function`
compilation pipeline to add them back at the end of the remaining outputs.

Now, `Scan`'s inner-`FunctionGraph`s maintain the same form at every point, and
no special logic is needed to compensate for post-compilation changes in the
order/location of inputs and outputs.
All VM implementations now perform variable updates themselves.  This leaves
some now redundant update code in `Function`, but it also removes some from
`Scan.perform`.
This new method provides an interface for stateful `Feature`s to be easily
cloned and attached to other `FunctionGraph`s.
This method adds the ability to clone a `HasInnerGraph` `Op` and its
inner-`FunctionGraph`.
Since it is no longer necessary to clone `Constant`s, and they add extra work
for the merge rewrites, `clone_get_equiv`'s `copy_orphans` option has been
prevented from cloning `Constant`s.  The option is mostly applicable to
constants, but it has been retained for other non-`Constant` cases (just in
case) and backward compatibility.
This adds the missing `update_mapping`s step during cloning, as well as a new
`Feature` cloning step that prevents issues when features are copied to their
clones.
@brandonwillard brandonwillard force-pushed the more-inner-graph-updates branch from 36e29a4 to fba4184 Compare May 9, 2022 22:03
@brandonwillard brandonwillard merged commit d11f303 into aesara-devs:main May 9, 2022
@brandonwillard brandonwillard deleted the more-inner-graph-updates branch May 9, 2022 23:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important refactor This issue involves refactoring Scan Involves the `Scan` `Op`
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants