diff --git a/src/moment_matrix.jl b/src/moment_matrix.jl index e4a1110..018222d 100644 --- a/src/moment_matrix.jl +++ b/src/moment_matrix.jl @@ -1,4 +1,4 @@ -export SymMatrix, MomentMatrix, getmat, moment_matrix +export SymMatrix, MomentMatrix, getmat, moment_matrix, symmetric_setindex! using SemialgebraicSets @@ -29,8 +29,17 @@ end Base.size(Q::SymMatrix) = (Q.n, Q.n) +""" + symmetric_setindex!(Q::SymMatrix, value, i::Integer, j::Integer) + +Set `Q[i, j]` and `Q[j, i]` to the value `value`. +""" +function symmetric_setindex!(Q::SymMatrix, value, i::Integer, j::Integer) + Q.Q[trimap(max(i, j), min(i, j))] = value +end + function Base.getindex(Q::SymMatrix, i::Integer, j::Integer) - Q.Q[trimap(max(i, j), min(i, j))] + return Q.Q[trimap(max(i, j), min(i, j))] end Base.getindex(Q::SymMatrix, I::Tuple) = Q[I...] Base.getindex(Q::SymMatrix, I::CartesianIndex) = Q[I.I] diff --git a/test/moment_matrix.jl b/test/moment_matrix.jl index c6d6be8..8b28805 100644 --- a/test/moment_matrix.jl +++ b/test/moment_matrix.jl @@ -1,3 +1,13 @@ +@testset "SymMatrix" begin + Q = SymMatrix([1, 2, 3], 2) + symmetric_setindex!(Q, 4, 1, 1) + @test Q.Q == [4, 2, 3] + symmetric_setindex!(Q, 5, 1, 2) + @test Q.Q == [4, 5, 3] + symmetric_setindex!(Q, 6, 2, 2) + @test Q.Q == [4, 5, 6] +end + @testset "MomentMatrix" begin Mod.@polyvar x y @test_throws ArgumentError moment_matrix(measure([1], [x]), [y])