Skip to content

Commit

Permalink
Make bellman! for SparseOrthogonalWorkspace type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
Zinoex committed Sep 26, 2024
1 parent e10b8ef commit 18712b3
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/bellman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ function bellman_sparse_orthogonal!(
gap_nzvals_per_prob = [nonzeros(@view(gap(p)[:, jₐ])) for p in prob]
sum_lower_per_prob = [sum_lower(p)[jₐ] for p in prob]

nnz_per_prob = Tuple(nnz(@view(gap(p)[:, jₐ])) for p in prob)
nnz_per_prob = ntuple(i -> nnz(@view(gap(prob[i])[:, jₐ])), ndims(prob))
Vₑ = [
@view(cache[1:nnz]) for
(cache, nnz) in zip(workspace.expectation_cache, nnz_per_prob[2:end])
Expand All @@ -591,9 +591,7 @@ function bellman_sparse_orthogonal!(
else
# For each higher-level state in the product space
for I in CartesianIndices(nnz_per_prob[2:end])
Isparse = CartesianIndex(Tuple(map(enumerate(Tuple(I))) do (d, i)
nzinds_per_prob[d][i]
end))
Isparse = CartesianIndex(ntuple(d -> nzinds_per_prob[d][I[d]], ndims(prob) - 1))

# For the first dimension, we need to copy the values from V
v = orthogonal_sparse_inner_bellman!(
Expand Down

0 comments on commit 18712b3

Please sign in to comment.