Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC/WIP: Use macro to auto switch between for loop and broadcast #639

Closed
wants to merge 4 commits into from

Conversation

YingboMa
Copy link
Member

@YingboMa YingboMa commented Feb 3, 2019

@loop u = uprev+dt*(a31*k1 + a32*k2) is functionally equivalent with @. u = uprev+dt*(a31*k1 + a32*k2).

@loop u = uprev+dt*(a31*k1 + a32*k2) is functionally equivalent with

k = similar(u)
@. k = uprev+dt*(a31*k1 + a32*k2)

.

julia> @macroexpand @loop u = uprev+dt*(a31*k1 + a32*k2)
quote
    #= REPL[115]:8 =#
    if u isa Array
        #= REPL[115]:9 =#
        for ##II#1344 = Base.eachindex(u)
            #= REPL[115]:10 =#
            u[##II#1344] = (muladd)(_getindex(dt, ##II#1344), (muladd)(_getindex(a32, ##II#1344), _getindex(k2, ##II#1344), _getindex(a31, ##II#1344) * _getindex(k1, ##II#1344)), _getindex(uprev, ##II#1344))
        end
    else
        #= REPL[115]:13 =#
        u .= (muladd).(dt, (muladd).(a32, k2, (*).(a31, k1)), uprev)
    end
end

julia> @macroexpand @loop u uprev+dt*(a31*k1 + a32*k2)
quote
    #= REPL[114]:4 =#
    if u isa Array
        #= REPL[114]:5 =#
        ##1345 = Base.similar(u)
        #= REPL[114]:6 =#
        begin
            #= REPL[115]:8 =#
            if ##1345 isa Array
                #= REPL[115]:9 =#
                for ##II#1346 = Base.eachindex(##1345)
                    #= REPL[115]:10 =#
                    ##1345[##II#1346] = (muladd)(_getindex(dt, ##II#1346), (muladd)(_getindex(a32, ##II#1346), _getindex(k2, ##II#1346), _getindex(a31, ##II#1346) * _getindex(k1, ##II#1346)), _getindex(uprev, ##II#1346))
                end
            else
                #= REPL[115]:13 =#
                ##1345 .= (muladd).(dt, (muladd).(a32, k2, (*).(a31, k1)), uprev)
            end
        end
    else
        #= REPL[114]:8 =#
        (+).(uprev, (*).(dt, (+).((*).(a31, k1), (*).(a32, k2))))
    end
end

@ChrisRackauckas
Copy link
Member

lgtm. We should do a quick performance test though before throwing it on the whole package. But this has a good change of solving all of our issues with broadcast.

@ChrisRackauckas
Copy link
Member

This should probably be a FastAtDot.jl or something like that, since we'll want to use it in other packages as well.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I agree, that seems like a good way to circumvent most broadcasting related issues at the moment 👍

@YingboMa
Copy link
Member Author

YingboMa commented Feb 3, 2019

This should probably be a FastAtDot.jl or something like that, since we'll want to use it in other packages as well.

It assumes that the LHS has the same axes with any non-scalar RHS. It is a very strong assumption. Also, I need to add more checks, @inbounds, @simd, and a way to disable all the checks to make it more mature.

src/misc_utils.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member

It assumes that the LHS has the same axes with any non-scalar RHS. It is a very strong assumption. Also, I need to add more checks, @inbounds, @simd, and a way to disable all the checks to make it more mature.

Okay yeah, so very DiffEq. Let's make it a DiffEqBase util then.

@YingboMa
Copy link
Member Author

YingboMa commented Feb 4, 2019

master

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), ones(10), (0,100.))), Tsit5());
  10.497 μs (148 allocations: 17.11 KiB)

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), cu(ones(10)), (0,100.))), Tsit5());
  24.335 ms (16613 allocations: 737.70 KiB)

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), @MVector(ones(10)), (0,100.))), Tsit5());
  8.646 μs (159 allocations: 12.69 KiB)

julia> @btime solve($(ODEProblem((u,p,t)->@.(0.01*u), @SVector(ones(10)), (0,100.))), Tsit5());
  8.040 μs (104 allocations: 17.25 KiB)

PR : Array

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), ones(10), (0,100.))), Tsit5());
  11.094 μs (148 allocations: 17.11 KiB)

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), cu(ones(10)), (0,100.))), Tsit5());
  1.503 ms (6829 allocations: 304.88 KiB)

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), @MVector(ones(10)), (0,100.))), Tsit5());
  668.722 μs (16899 allocations: 307.06 KiB)

julia> @btime solve($(ODEProblem((u,p,t)->@.(0.01*u), @SVector(ones(10)), (0,100.))), Tsit5());
  8.332 μs (101 allocations: 16.97 KiB)

PR : Array, MArray

julia> @btime solve($(ODEProblem((du,u,p,t)->@.(du=0.01*u), @MVector(ones(10)), (0,100.))), Tsit5());
  8.770 μs (159 allocations: 12.69 KiB)

@devmotion
Copy link
Member

Is the plan to not add broadcasts to the non-mutating algorithms since they should be used only with numbers and static arrays? Otherwise the benchmarks for SVector should be repeated with a broadcasted implementation of Tsit5ConstantCache, I guess.

@YingboMa
Copy link
Member Author

YingboMa commented Feb 4, 2019

julia> @btime solve($(ODEProblem((u,p,t)->@.(0.01*u), @SVector(ones(10)), (0,100.))), Tsit5()); #PR
  7.680 μs (125 allocations: 18.28 KiB)

@ChrisRackauckas
Copy link
Member

Looks like it's done but just needs to move to DiffEqBase

@ChrisRackauckas
Copy link
Member

Superseded by #716

@YingboMa YingboMa deleted the myb/bc branch April 14, 2019 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants