Skip to content

Commit

Permalink
Efficient probability check
Browse files Browse the repository at this point in the history
  • Loading branch information
Zinoex committed Apr 24, 2024
1 parent 3788292 commit 84a97b3
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/interval_probabilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ end

# Constructor from lower and gap with sanity assertions
function IntervalProbabilities(lower::MR, gap::MR) where {R, MR <: AbstractMatrix{R}}
@assert all(lower .>= 0) "The lower bound transition probabilities must be non-negative."
@assert all(gap .>= 0) "The gap transition probabilities must be non-negative."
@assert all(gap .<= 1) "The gap transition probabilities must be less than or equal to 1."
@assert all(lower .+ gap .<= 1) "The sum of lower and gap transition probabilities must be less than or equal to 1."
checkprobabilities!(lower, gap)

sum_lower = vec(sum(lower; dims = 1))

Expand All @@ -65,6 +62,20 @@ function IntervalProbabilities(lower::MR, gap::MR) where {R, MR <: AbstractMatri
return IntervalProbabilities(lower, gap, sum_lower)
end

function checkprobabilities!(lower::AbstractMatrix, gap::AbstractMatrix)
@assert all(lower .>= 0) "The lower bound transition probabilities must be non-negative."
@assert all(gap .>= 0) "The gap transition probabilities must be non-negative."
@assert all(gap .<= 1) "The gap transition probabilities must be less than or equal to 1."
@assert all(lower .+ gap .<= 1) "The sum of lower and gap transition probabilities must be less than or equal to 1."
end

function checkprobabilities!(lower::AbstractSparseMatrix, gap::AbstractSparseMatrix)
@assert all(nonzeros(lower) .>= 0) "The lower bound transition probabilities must be non-negative."
@assert all(nonzeros(gap) .>= 0) "The gap transition probabilities must be non-negative."
@assert all(nonzeros(gap) .<= 1) "The gap transition probabilities must be less than or equal to 1."
@assert all(nonzeros(lower) .+ nonzeros(gap) .<= 1) "The sum of lower and gap transition probabilities must be less than or equal to 1."
end

# Keyword constructor from lower and upper
function IntervalProbabilities(; lower::MR, upper::MR) where {MR <: AbstractMatrix}
lower, gap = compute_gap(lower, upper)
Expand Down

0 comments on commit 84a97b3

Please sign in to comment.