Skip to content

Commit

Permalink
extended jump broadcast and jump parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 9, 2017
1 parent 0174d2a commit e211ba5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
29 changes: 29 additions & 0 deletions src/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,35 @@ add_idxs2{T<:ExtendedJumpArray}(::Type{T},expr) = :($(expr).jump_u)
res
end

Base.Broadcast.promote_containertype(::Type{T}, ::Type{T}) where {T<:ExtendedJumpArray} = T
Base.Broadcast.promote_containertype(::Type{T}, ::Type{S}) where {T<:ExtendedJumpArray, S<:AbstractArray} = T
Base.Broadcast.promote_containertype(::Type{S}, ::Type{T}) where {T<:ExtendedJumpArray, S<:AbstractArray} = T
Base.Broadcast.promote_containertype(::Type{T}, ::Type{<:Any}) where {T<:ExtendedJumpArray} = T
Base.Broadcast.promote_containertype(::Type{<:Any}, ::Type{T}) where {T<:ExtendedJumpArray} = T
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{T}) where {T<:ExtendedJumpArray} = T
Base.Broadcast.promote_containertype(::Type{T}, ::Type{Array}) where {T<:ExtendedJumpArray} = T
Base.Broadcast._containertype(::Type{T}) where {T<:ExtendedJumpArray} = T
Base.Broadcast.broadcast_indices(::Type{<:ExtendedJumpArray}, A) = indices(A)

@inline function Base.Broadcast.broadcast_c(f, ::Type{S}, A, Bs...) where S<:ExtendedJumpArray
T = Base.Broadcast._broadcast_eltype(f, A, Bs...)
shape = Base.Broadcast.broadcast_indices(A, Bs...)
broadcast!(f, similar(A), A, Bs...)
end

@inline function Base.Broadcast.broadcast_c(f, ::Type{S}, A::ExtendedJumpArray, Bs::Union{ExtendedJumpArray,Number}...) where S<:ExtendedJumpArray
new_A = similar(A)
broadcast!(f,new_A,A,Bs...)
new_A
end

@inline function Base.Broadcast.broadcast_c(f, ::Type{S}, A::ExtendedJumpArray) where S<:ExtendedJumpArray
new_A = similar(A)
broadcast!(f,new_A,A)
new_A
end

@inline Base.broadcast!(::typeof(identity), u::DiffEqJump.ExtendedJumpArray, x::Number) = fill!(u,x)
#=
Base.Broadcast._containertype(::Type{<:ExtendedJumpArray}) = ExtendedJumpArray
Base.Broadcast.promote_containertype(::Type{ExtendedJumpArray}, _) = ExtendedJumpArray
Expand Down
12 changes: 10 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,27 @@ function JumpProblem(prob,aggregator::Direct,jumps::JumpSet;
end

function extend_problem(prob::AbstractODEProblem,jumps)
jump_f = function (t,u,du)
function jump_f(t,u,du)
prob.f(t,u.u,@view du[1:length(u.u)])
update_jumps!(du,t,u,length(u.u),jumps.variable_jumps...)
end
function jump_f(t,u,p,du)
prob.f(t,u.u,p,@view du[1:length(u.u)])
update_jumps!(du,t,u,length(u.u),jumps.variable_jumps...)
end
u0 = ExtendedJumpArray(prob.u0,[-randexp() for i in 1:length(jumps.variable_jumps)])
ODEProblem(jump_f,u0,prob.tspan)
end

function extend_problem(prob::AbstractSDEProblem,jumps)
jump_f = function (t,u,du)
function jump_f(t,u,du)
prob.f(t,u.u,@view du[1:length(u.u)])
update_jumps!(du,t,u,length(u.u),jumps.variable_jumps...)
end
function jump_f(t,u,p,du)
prob.f(t,u.u,p,@view du[1:length(u.u)])
update_jumps!(du,t,u,length(u.u),jumps.variable_jumps...)
end
u0 = ExtendedJumpArray(prob.u0,[-randexp() for i in 1:length(jumps.variable_jumps)])
SDEProblem(jump_f,prob.g,u0,prob.tspan)
end
Expand Down

0 comments on commit e211ba5

Please sign in to comment.