Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make broadcast recursion in flatten structural #29816

Merged
merged 2 commits into from
Oct 28, 2018
Merged

Conversation

Keno
Copy link
Member

@Keno Keno commented Oct 26, 2018

The inference enhancements in #29294 work quite well to prevent limiting
on many kinds of code. However, targetting TPUs, one code pattern it
struggeled with was a fairly large broadcast fusion in Flux:

 λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))

The reason #29294 doesn't trigger is that the make_makeargs function used by the
implementation of Broadcast.flatten (which the TPU backend uses) had
a non-decreasing first argument (passing the return value of a previous
invocation of make_makeargs back in as the first argument). However,
that's not a fundamental limitation of the operation, but rather an
implementation choice. This PR switches that function's recursion pattern
to be purely structural, allowing inference to infer through it (with
the changes in #29294). As a result, ResNet50 infers properly.

The inference enhancements in #29294 work quite well to prevent limiting
on many kinds of code. However, targetting TPUs, one code pattern it
struggeled with was a fairly large broadcast fusion in Flux:

     λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))

The reason #29294 is because the make_makeargs function used by the
implementation of Broadcast.flatten (which the TPU backend uses) had
a non-decreasing first argument (passing the return value of a previous
invocation of make_makeargs back in as the first argument). However,
that's not a fundamental limitation of the operation, but rather an
implementation choice. This PR switches that function's recursion pattern
to be purely structural, allowing inference to infer through it (with
the changes in #29294). As a result, ResNet50 infers properly.
@Keno Keno requested a review from mbauman October 26, 2018 18:24
@Keno
Copy link
Member Author

Keno commented Oct 26, 2018

@nanosoldier runbenchmarks(ALL, vs = ":master")

base/broadcast.jl Outdated Show resolved Hide resolved
@mbauman mbauman self-assigned this Oct 26, 2018
Co-Authored-By: mbauman <[email protected]>
Copy link
Member

@mbauman mbauman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for these comments!

@nanosoldier
Copy link
Collaborator

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @ararslan

@Keno Keno merged commit b84fe52 into master Oct 28, 2018
@vtjnash vtjnash deleted the kf/broadcastflatten branch October 28, 2018 19:26
@andreasnoack
Copy link
Member

This PR seems to have broken Flux. cc @MikeInnes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants