From c6cd869250746fb04105b5c0c7945d9bd25ad477 Mon Sep 17 00:00:00 2001 From: Xuanzhao Gao <45324209+ArrogantGao@users.noreply.github.com> Date: Tue, 30 Jan 2024 23:29:57 +0800 Subject: [PATCH] add load tn from code (#90) --- src/Core.jl | 15 +++++++++++++++ test/map.jl | 12 ++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/Core.jl b/src/Core.jl index 092b499..3651228 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -175,6 +175,21 @@ function TensorNetworkModel( TensorNetworkModel(collect(LT, vars), code, tensors, evidence, mars) end +""" +$(TYPEDSIGNATURES) +""" +function TensorNetworkModel( + model::UAIModel{T}, code; + evidence = Dict{Int,Int}(), + mars = [[i] for i=1:model.nvars], + vars = [1:model.nvars...] +)::TensorNetworkModel where{T} + @debug "constructing tensor network model from code" + tensors = Array{T}[[ones(T, [model.cards[i] for i in mar]...) for mar in mars]..., [t.vals for t in model.factors]...] + + return TensorNetworkModel(vars, code, tensors, evidence, mars) +end + """ $(TYPEDSIGNATURES) diff --git a/test/map.jl b/test/map.jl index 1195254..abc0c98 100644 --- a/test/map.jl +++ b/test/map.jl @@ -2,6 +2,18 @@ using Test using OMEinsum using TensorInference +@testset "load from code" begin + model = problem_from_artifact("uai2014", "MAR", "Promedus", 14) + + tn1 = TensorNetworkModel(read_model(model); + evidence=read_evidence(model), + optimizer = TreeSA(ntrials = 3, niters = 2, βs = 1:0.1:80)) + + tn2 = TensorNetworkModel(read_model(model), tn1.code, evidence=read_evidence(model)) + + @test tn1.code == tn2.code +end + @testset "gradient-based tensor network solvers" begin model = problem_from_artifact("uai2014", "MAR", "Promedus", 14)