Skip to content

Commit

Permalink
Merge pull request #59 from Zinoex/fm/more_orthogonal_impls
Browse files Browse the repository at this point in the history
More workspace types for OrthogonalIntervalProbabilities
  • Loading branch information
Zinoex authored Sep 11, 2024
2 parents 24f0043 + 8e460d3 commit f1ab0bf
Show file tree
Hide file tree
Showing 9 changed files with 712 additions and 47 deletions.
13 changes: 9 additions & 4 deletions ext/cuda/workspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ struct CuDenseWorkspace <: AbstractCuWorkspace
max_actions::Int32
end

IntervalMDP.construct_workspace(::AbstractGPUMatrix, max_actions) =
CuDenseWorkspace(max_actions)
IntervalMDP.construct_workspace(
prob::IntervalProbabilities{R, VR, MR},
max_actions = 1,
) where {R, VR, MR <: AbstractGPUMatrix{R}} = CuDenseWorkspace(max_actions)

####################
# Sparse workspace #
Expand All @@ -23,5 +25,8 @@ function CuSparseWorkspace(p::AbstractCuSparseMatrix, max_actions)
return CuSparseWorkspace(max_nonzeros, max_actions)
end

IntervalMDP.construct_workspace(p::AbstractCuSparseMatrix, max_actions) =
CuSparseWorkspace(p, max_actions)
IntervalMDP.construct_workspace(
prob::IntervalProbabilities{R, VR, MR},
max_actions = 1,
) where {R, VR, MR <: AbstractCuSparseMatrix{R}} =
CuSparseWorkspace(max_actions, max_actions)
270 changes: 260 additions & 10 deletions src/bellman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Vcur = bellman(Vprev, prob; upper_bound = false)
"""
function bellman(V, prob; upper_bound = false)
Vres = similar(V, num_source(prob))
Vres = similar(V, source_shape(prob))
return bellman!(Vres, V, prob; upper_bound = upper_bound)
end

Expand Down Expand Up @@ -293,7 +293,7 @@ function gap_value(Vp, sum_lower)
return res
end

# Dense
# Dense orthogonal
function bellman!(
workspace::DenseOrthogonalWorkspace,
strategy_cache::AbstractStrategyCache,
Expand All @@ -309,15 +309,101 @@ function bellman!(

# For each higher-level state in the product space
for I in CartesianIndices(product_nstates[2:end])
perm = @view workspace.permutation[axes(V, 1)]
sortperm!(perm, @view(V[:, I]); rev = upper_bound, scratch = workspace.scratch)

copyto!(@view(workspace.first_level_perm[:, I]), perm)
sort_dense_orthogonal(workspace, workspace.first_level_perm, V, I, upper_bound)
end

# For each source state
@inbounds for (jₛ_cart, jₛ_linear) in
zip(CartesianIndices(axes(V)), LinearIndices(axes(V)))
bellman_dense_orthogonal!(
workspace,
workspace.first_level_perm,
strategy_cache,
Vres,
V,
prob,
stateptr,
product_nstates,
jₛ_cart,
jₛ_linear;
upper_bound = upper_bound,
maximize = maximize,
)
end

return Vres
end

function bellman!(
workspace::ThreadedDenseOrthogonalWorkspace,
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
stateptr;
upper_bound = false,
maximize = true,
)
# Since sorting for the first level is shared among all higher levels, we can precompute it
product_nstates = num_target(prob)

# For each higher-level state in the product space
@threadstid tid for I in CartesianIndices(product_nstates[2:end])
ws = workspace.thread_workspaces[tid]
sort_dense_orthogonal(ws, workspace.first_level_perm, V, I, upper_bound)
end

# For each source state
I_linear = LinearIndices(axes(V))
@threadstid tid for jₛ_cart in CartesianIndices(axes(V))
# We can't use @threadstid over a zip, so we need to manually index
jₛ_linear = I_linear[jₛ_cart]

ws = workspace.thread_workspaces[tid]

bellman_dense_orthogonal!(
ws,
workspace.first_level_perm,
strategy_cache,
Vres,
V,
prob,
stateptr,
product_nstates,
jₛ_cart,
jₛ_linear;
upper_bound = upper_bound,
maximize = maximize,
)
end

return Vres
end

function sort_dense_orthogonal(workspace, first_level_perm, V, I, upper_bound)
@inbounds begin
perm = @view workspace.permutation[axes(V, 1)]
sortperm!(perm, @view(V[:, I]); rev = upper_bound, scratch = workspace.scratch)

copyto!(@view(first_level_perm[:, I]), perm)
end
end

function bellman_dense_orthogonal!(
workspace,
first_level_perm,
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
stateptr,
product_nstates,
jₛ_cart,
jₛ_linear;
upper_bound = false,
maximize = true,
)
@inbounds begin
s₁, s₂ = stateptr[jₛ_linear], stateptr[jₛ_linear + 1]
actions = @view workspace.actions[1:(s₂ - s₁)]
for (i, jₐ) in enumerate(s₁:(s₂ - 1))
Expand All @@ -328,7 +414,8 @@ function bellman!(

# For the first dimension, we need to copy the values from V
v = orthogonal_inner_sorted_bellman!(
@view(workspace.first_level_perm[:, I]),
# Use shared first level permutation across threads
@view(first_level_perm[:, I]),
@view(V[:, I]),
prob[1],
jₐ,
Expand Down Expand Up @@ -359,12 +446,10 @@ function bellman!(

Vres[jₛ_cart] = extract_strategy!(strategy_cache, actions, V, jₛ_cart, s₁, maximize)
end

return Vres
end

Base.@propagate_inbounds function orthogonal_inner_bellman!(
workspace::DenseOrthogonalWorkspace,
workspace::Union{DenseOrthogonalWorkspace, ThreadDenseOrthogonalWorkspace},
V,
prob,
jₐ,
Expand All @@ -390,3 +475,168 @@ Base.@propagate_inbounds function orthogonal_inner_sorted_bellman!(

return dot(V, lowerⱼ) + gap_value(V, gapⱼ, used, perm)
end

# Sparse orthogonal
function bellman!(
workspace::SparseOrthogonalWorkspace,
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
stateptr;
upper_bound = false,
maximize = true,
)
# For each source state
@inbounds for (jₛ_cart, jₛ_linear) in
zip(CartesianIndices(axes(V)), LinearIndices(axes(V)))
bellman_sparse_orthogonal!(
workspace,
strategy_cache,
Vres,
V,
prob,
stateptr,
jₛ_cart,
jₛ_linear;
upper_bound = upper_bound,
maximize = maximize,
)
end

return Vres
end
function bellman!(
workspace::ThreadedSparseOrthogonalWorkspace,
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
stateptr;
upper_bound = false,
maximize = true,
)
# For each source state
I_linear = LinearIndices(axes(V))
@threadstid tid for jₛ_cart in CartesianIndices(axes(V))
# We can't use @threadstid over a zip, so we need to manually index
jₛ_linear = I_linear[jₛ_cart]

ws = workspace.thread_workspaces[tid]

bellman_sparse_orthogonal!(
ws,
strategy_cache,
Vres,
V,
prob,
stateptr,
jₛ_cart,
jₛ_linear;
upper_bound = upper_bound,
maximize = maximize,
)
end

return Vres
end

function bellman_sparse_orthogonal!(
workspace,
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
stateptr,
jₛ_cart,
jₛ_linear;
upper_bound = false,
maximize = true,
)
@inbounds begin
s₁, s₂ = stateptr[jₛ_linear], stateptr[jₛ_linear + 1]
actions = @view workspace.actions[1:(s₂ - s₁)]
for (i, jₐ) in enumerate(s₁:(s₂ - 1))
nzinds_first = SparseArrays.nonzeroinds(@view(gap(prob[1])[:, jₐ]))
nzinds_per_prob =
[SparseArrays.nonzeroinds(@view(gap(p)[:, jₐ])) for p in prob[2:end]]

lower_nzvals_per_prob = [nonzeros(@view(lower(p)[:, jₐ])) for p in prob]
gap_nzvals_per_prob = [nonzeros(@view(gap(p)[:, jₐ])) for p in prob]
sum_lower_per_prob = [sum_lower(p)[jₐ] for p in prob]

nnz_per_prob = Tuple(nnz(@view(gap(p)[:, jₐ])) for p in prob)
Vₑ = [
@view(cache[1:nnz]) for
(cache, nnz) in zip(workspace.expectation_cache, nnz_per_prob[2:end])
]

# For each higher-level state in the product space
for I in CartesianIndices(nnz_per_prob[2:end])
Isparse = CartesianIndex(Tuple(map(enumerate(Tuple(I))) do (d, i)
nzinds_per_prob[d][i]
end))

# For the first dimension, we need to copy the values from V
v = orthogonal_sparse_inner_bellman!(
workspace,
@view(V[nzinds_first, Isparse]),
lower_nzvals_per_prob[1],
gap_nzvals_per_prob[1],
sum_lower_per_prob[1],
upper_bound,
)
Vₑ[1][I[1]] = v

# For the remaining dimensions, if "full", compute expectation and store in the next level
for d in 2:(ndims(prob) - 1)
if I[d - 1] == nnz_per_prob[d]
v = orthogonal_sparse_inner_bellman!(
workspace,
Vₑ[d - 1],
lower_nzvals_per_prob[d],
gap_nzvals_per_prob[d],
sum_lower_per_prob[d],
upper_bound,
)
Vₑ[d][I[d]] = v
else
break
end
end
end

# Last dimension
v = orthogonal_sparse_inner_bellman!(
workspace,
Vₑ[end],
lower_nzvals_per_prob[end],
gap_nzvals_per_prob[end],
sum_lower_per_prob[end],
upper_bound,
)
actions[i] = v
end

Vres[jₛ_cart] = extract_strategy!(strategy_cache, actions, V, jₛ_cart, s₁, maximize)
end
end

Base.@propagate_inbounds function orthogonal_sparse_inner_bellman!(
workspace::SparseOrthogonalWorkspace,
V,
lower,
gap,
sum_lower,
upper_bound::Bool,
)
Vp_workspace = @view workspace.values_gaps[1:length(gap)]
for (i, (v, p)) in enumerate(zip(V, gap))
Vp_workspace[i] = (v, p)
end

# rev=true for upper bound
sort!(Vp_workspace; rev = upper_bound, scratch = workspace.scratch)

return dot(V, lower) + gap_value(Vp_workspace, sum_lower)
end
8 changes: 6 additions & 2 deletions src/interval_probabilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ sum_lower(p::IntervalProbabilities) = p.sum_lower
Return the number of source states or source/action pairs.
"""
num_source(p::IntervalProbabilities) = size(gap(p), 2)
source_shape(p::IntervalProbabilities) = (num_source(p),)

"""
axes_source(p::IntervalProbabilities)
Expand Down Expand Up @@ -220,7 +221,7 @@ target states along each axis.
### Fields
- `probs::NTuple{N, P}`: A tuple of `IntervalProbabilities` transition probabilities along each axis.
- `dims::NTuple{N, Int32}`: The dimensions of the orthogonal probabilities.
- `source_dims::NTuple{N, Int32}`: The dimensions of the orthogonal probabilities for the source axis. This is flattened to a single dimension for indexing.
### Examples
# TODO: Update example
Expand All @@ -229,7 +230,7 @@ target states along each axis.
struct OrthogonalIntervalProbabilities{N, P <: IntervalProbabilities} <:
AbstractIntervalProbabilities
probs::NTuple{N, P}
dims::NTuple{N, Int32}
source_dims::NTuple{N, Int32}
end

"""
Expand Down Expand Up @@ -272,6 +273,7 @@ sum_lower(p::OrthogonalIntervalProbabilities, i) = p.probs[i].sum_lower
Return the number of source states or source/action pairs.
"""
num_source(p::OrthogonalIntervalProbabilities) = num_source(first(p.probs))
source_shape(p::OrthogonalIntervalProbabilities) = p.source_dims

"""
axes_source(p::OrthogonalIntervalProbabilities)
Expand All @@ -288,3 +290,5 @@ Base.getindex(p::OrthogonalIntervalProbabilities, i) = p.probs[i]
Base.lastindex(p::OrthogonalIntervalProbabilities) = ndims(p)
Base.firstindex(p::OrthogonalIntervalProbabilities) = 1
Base.length(p::OrthogonalIntervalProbabilities) = ndims(p)
Base.iterate(p::OrthogonalIntervalProbabilities) = (p[1], 2)
Base.iterate(p::OrthogonalIntervalProbabilities, i) = i > ndims(p) ? nothing : (p[i], i + 1)
5 changes: 4 additions & 1 deletion src/strategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ construct_strategy_cache(mp::IntervalMarkovProcess, config) =
# Strategy cache for not storing policies - useful for dispatching
struct NoStrategyCache <: AbstractStrategyCache end

function construct_strategy_cache(::IntervalProbabilities, ::NoStrategyConfig)
function construct_strategy_cache(
::Union{IntervalProbabilities, OrthogonalIntervalProbabilities},
::NoStrategyConfig,
)
return NoStrategyCache()
end

Expand Down
Loading

0 comments on commit f1ab0bf

Please sign in to comment.