Skip to content

Commit

Permalink
Merge pull request #69 from Zinoex/fm/mixture_probabilities
Browse files Browse the repository at this point in the history
Mixture model
  • Loading branch information
Zinoex authored Oct 20, 2024
2 parents f4e8241 + eef5719 commit 6ead194
Show file tree
Hide file tree
Showing 21 changed files with 1,152 additions and 252 deletions.
6 changes: 6 additions & 0 deletions docs/src/reference/systems.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@ stateptr(mdp::IntervalMarkovDecisionProcess)
OrthogonalIntervalMarkovChain
OrthogonalIntervalMarkovDecisionProcess
stateptr(mdp::OrthogonalIntervalMarkovDecisionProcess)
MixtureIntervalMarkovChain
MixtureIntervalMarkovDecisionProcess
stateptr(mdp::MixtureIntervalMarkovDecisionProcess)
```

## Probability representation
```@docs
IntervalProbabilities
OrthogonalIntervalProbabilities
MixtureIntervalProbabilities
lower
upper
gap
sum_lower
num_source
num_target
axes_source
mixture_probs
weighting_probs
```
7 changes: 0 additions & 7 deletions ext/cuda/strategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@ function IntervalMDP.construct_action_cache(
return CUDA.zeros(Int32, dims)
end

function IntervalMDP.construct_action_cache(
::OrthogonalIntervalProbabilities{N, <:IntervalProbabilities{R, VR}},
dims,
) where {N, R <: Real, VR <: AbstractGPUVector{R}}
return CUDA.zeros(Int32, dims)
end

abstract type ActiveCache end

struct NoStrategyActiveCache <: ActiveCache end
Expand Down
17 changes: 3 additions & 14 deletions src/IntervalMDP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,14 @@ include("utils.jl")
include("errors.jl")
export InvalidStateError, StateDimensionMismatch

include("interval_probabilities.jl")
export IntervalProbabilities, OrthogonalIntervalProbabilities
export lower, upper, gap, sum_lower
export num_source, axes_source, num_target, axes_target

include("models/IntervalMarkovProcess.jl")
include("models/IntervalMarkovDecisionProcess.jl")
include("models/OrthogonalIntervalMarkovDecisionProcess.jl")
export IntervalMarkovProcess
export AllStates
export IntervalMarkovDecisionProcess, IntervalMarkovChain
export OrthogonalIntervalMarkovDecisionProcess, OrthogonalIntervalMarkovChain
export transition_prob, num_states, initial_states, stateptr, tomarkovchain, time_length
include("probabilities/probabilities.jl")
include("models/models.jl")

include("strategy.jl")
export GivenStrategyConfig,
NoStrategyConfig, TimeVaryingStrategyConfig, StationaryStrategyConfig
export StationaryStrategy, TimeVaryingStrategy
export construct_strategy_cache
export construct_strategy_cache, time_length

include("specification.jl")
export Property, LTLFormula, LTLfFormula, PCTLFormula
Expand Down
115 changes: 83 additions & 32 deletions src/bellman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ bellman_precomputation!(
) = nothing

function state_bellman!(
workspace,
workspace::Union{DenseWorkspace, SparseWorkspace},
strategy_cache::OptimizingStrategyCache,
Vres,
V,
Expand All @@ -198,7 +198,7 @@ function state_bellman!(
end

function state_bellman!(
workspace,
workspace::Union{DenseWorkspace, SparseWorkspace},
strategy_cache::NonOptimizingStrategyCache,
Vres,
V,
Expand Down Expand Up @@ -226,11 +226,8 @@ Base.@propagate_inbounds function state_action_bellman(
end

Base.@propagate_inbounds function dense_sorted_state_action_bellman(V, prob, jₐ, perm)
lowerⱼ = @view lower(prob)[:, jₐ]
gapⱼ = @view gap(prob)[:, jₐ]
used = sum_lower(prob)[jₐ]

return dot(V, lowerⱼ) + gap_value(V, gapⱼ, used, perm)
return dot(V, lower(prob, :, jₐ)) +
gap_value(V, gap(prob, :, jₐ), sum_lower(prob, jₐ), perm)
end

Base.@propagate_inbounds function gap_value(
Expand Down Expand Up @@ -262,8 +259,8 @@ Base.@propagate_inbounds function state_action_bellman(
jₐ,
upper_bound,
)
lowerⱼ = @view lower(prob)[:, jₐ]
gapⱼ = @view gap(prob)[:, jₐ]
lowerⱼ = lower(prob, :, jₐ)
gapⱼ = gap(prob, :, jₐ)
used = sum_lower(prob)[jₐ]

Vp_workspace = @view workspace.values_gaps[1:nnz(gapⱼ)]
Expand Down Expand Up @@ -298,14 +295,12 @@ end
################################################################
# Bellman operator for OrthogonalIntervalMarkovDecisionProcess #
################################################################

# Dense orthogonal
function bellman!(
workspace::Union{DenseOrthogonalWorkspace, SparseOrthogonalWorkspace},
workspace::Union{DenseOrthogonalWorkspace, SparseOrthogonalWorkspace, MixtureWorkspace},
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
prob,
stateptr;
upper_bound = false,
maximize = true,
Expand Down Expand Up @@ -335,11 +330,15 @@ function bellman!(
end

function bellman!(
workspace::Union{ThreadedDenseOrthogonalWorkspace, ThreadedSparseOrthogonalWorkspace},
workspace::Union{
ThreadedDenseOrthogonalWorkspace,
ThreadedSparseOrthogonalWorkspace,
ThreadedMixtureWorkspace,
},
strategy_cache::AbstractStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
prob,
stateptr;
upper_bound = false,
maximize = true,
Expand Down Expand Up @@ -414,11 +413,11 @@ function sort_dense_orthogonal(workspace, V, I, upper_bound)
end

function state_bellman!(
workspace,
workspace::Union{DenseOrthogonalWorkspace, SparseOrthogonalWorkspace, MixtureWorkspace},
strategy_cache::OptimizingStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
prob,
stateptr,
jₛ_cart,
jₛ_linear;
Expand All @@ -427,22 +426,22 @@ function state_bellman!(
)
@inbounds begin
s₁, s₂ = stateptr[jₛ_linear], stateptr[jₛ_linear + 1]
actions = @view workspace.actions[1:(s₂ - s₁)]
act_vals = @view actions(workspace)[1:(s₂ - s₁)]

for (i, jₐ) in enumerate(s₁:(s₂ - 1))
actions[i] = state_action_bellman(workspace, V, prob, jₐ, upper_bound)
act_vals[i] = state_action_bellman(workspace, V, prob, jₐ, upper_bound)
end

Vres[jₛ_cart] = extract_strategy!(strategy_cache, actions, V, jₛ_cart, maximize)
Vres[jₛ_cart] = extract_strategy!(strategy_cache, act_vals, V, jₛ_cart, maximize)
end
end

function state_bellman!(
workspace,
workspace::Union{DenseOrthogonalWorkspace, SparseOrthogonalWorkspace, MixtureWorkspace},
strategy_cache::NonOptimizingStrategyCache,
Vres,
V,
prob::OrthogonalIntervalProbabilities,
prob,
stateptr,
jₛ_cart,
jₛ_linear;
Expand Down Expand Up @@ -513,7 +512,7 @@ Base.@propagate_inbounds function state_action_bellman(
end

Base.@propagate_inbounds function orthogonal_inner_bellman!(
workspace::DenseOrthogonalWorkspace,
workspace,
V,
prob,
jₐ,
Expand All @@ -535,17 +534,15 @@ Base.@propagate_inbounds function state_action_bellman(
upper_bound,
)
# This function uses ntuple excessively to avoid allocations (list comprehension requires allocation, while ntuple does not)
nzinds_first = SparseArrays.nonzeroinds(@view(gap(prob[1])[:, jₐ]))
nzinds_per_prob = ntuple(
i -> SparseArrays.nonzeroinds(@view(gap(prob[i + 1])[:, jₐ])),
ndims(prob) - 1,
)
nzinds_first = SparseArrays.nonzeroinds(gap(prob, 1, :, jₐ))
nzinds_per_prob =
ntuple(i -> SparseArrays.nonzeroinds(gap(prob, i + 1, :, jₐ)), ndims(prob) - 1)

lower_nzvals_per_prob = ntuple(i -> nonzeros(@view(lower(prob[i])[:, jₐ])), ndims(prob))
gap_nzvals_per_prob = ntuple(i -> nonzeros(@view(gap(prob[i])[:, jₐ])), ndims(prob))
sum_lower_per_prob = ntuple(i -> sum_lower(prob[i])[jₐ], ndims(prob))
lower_nzvals_per_prob = ntuple(i -> nonzeros(lower(prob, i, :, jₐ)), ndims(prob))
gap_nzvals_per_prob = ntuple(i -> nonzeros(gap(prob, i, :, jₐ)), ndims(prob))
sum_lower_per_prob = ntuple(i -> sum_lower(prob, i, jₐ), ndims(prob))

nnz_per_prob = ntuple(i -> nnz(@view(gap(prob[i])[:, jₐ])), ndims(prob))
nnz_per_prob = ntuple(i -> nnz(gap(prob, i, :, jₐ)), ndims(prob))
Vₑ = ntuple(
i -> @view(workspace.expectation_cache[i][1:nnz_per_prob[i + 1]]),
ndims(prob) - 1,
Expand Down Expand Up @@ -627,3 +624,57 @@ Base.@propagate_inbounds function orthogonal_sparse_inner_bellman!(

return dot(V, lower) + gap_value(Vp_workspace, sum_lower)
end

################################################################
# Bellman operator for MixturelIntervalMarkovDecisionProcess #
################################################################
bellman_precomputation!(workspace::MixtureWorkspace, V, prob, upper_bound) =
bellman_precomputation!(workspace.orthogonal_workspace, V, prob, upper_bound)

function bellman_precomputation!(
workspace::ThreadedMixtureWorkspace{<:DenseOrthogonalWorkspace},
V,
prob,
upper_bound,
)
# Since sorting for the first level is shared among all higher levels, we can precompute it
product_nstates = num_target(prob)

# For each higher-level state in the product space
@threadstid tid for I in CartesianIndices(product_nstates[2:end])
ws = workspace[tid]
sort_dense_orthogonal(ws.orthogonal_workspace, V, I, upper_bound)
end
end

bellman_precomputation!(
workspace::ThreadedMixtureWorkspace{<:SparseOrthogonalWorkspace},
V,
prob,
upper_bound,
) = nothing

Base.@propagate_inbounds function state_action_bellman(
workspace::MixtureWorkspace,
V,
prob,
jₐ,
upper_bound,
)
# Value iteration for each model in the mixture (for source-action pair jₐ)
for (k, p) in enumerate(prob)
v = state_action_bellman(workspace.orthogonal_workspace, V, p, jₐ, upper_bound)
workspace.mixture_cache[k] = v
end

# Combine mixture with weighting probabilities
v = orthogonal_inner_bellman!(
workspace,
workspace.mixture_cache,
weighting_probs(prob),
jₐ,
upper_bound,
)

return v
end
Loading

0 comments on commit 6ead194

Please sign in to comment.