Skip to content

Commit

Permalink
Generalize belief propagation to tensor network message tensors (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Apr 25, 2023
1 parent 04377fc commit ee90168
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 150 deletions.
83 changes: 62 additions & 21 deletions examples/belief_propagation/bpexample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ using Random
using SplitApplyCombine

using ITensorNetworks:
compute_message_tensors, calculate_contraction, contract_inner, nested_graph_leaf_vertices
belief_propagation,
approx_network_region,
contract_inner,
message_tensors,
nested_graph_leaf_vertices

function main()
n = 4
dims = (n, n)
g = named_grid(dims)
g_dims = (n, n)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
chi = 2

Expand All @@ -26,43 +30,80 @@ function main()
v = (1, 1)

#Now do Simple Belief Propagation to Measure Sz on Site v
nsites = 1
mts = message_tensors(
ψψ; subgraph_vertices=collect(values(group(v -> v[1], vertices(ψψ))))
)

vertex_groups = nested_graph_leaf_vertices(
partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites)
mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"))
numerator_network = approx_network_region(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
mts = compute_message_tensors(ψψ; vertex_groups=vertex_groups)
sz_bp =
calculate_contraction(
ψψ, mts, [(v, 1)]; verts_tensors=ITensor[apply(op("Sz", s[v]), ψ[v])]
)[] / calculate_contraction(ψψ, mts, [(v, 1)])[]
denominator_network = approx_network_region(ψψ, mts, [(v, 1)])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"Simple Belief Propagation Gives Sz on Site " * string(v) * " as " * string(sz_bp)
)

#Now do General Belief Propagation to Measure Sz on Site v
nsites = 4
vertex_groups = nested_graph_leaf_vertices(
partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites)
Zp = partition(
partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites
)
Zpp = partition(ψψ; subgraph_vertices=nested_graph_leaf_vertices(Zp))
mts = message_tensors(Zpp)
mts = belief_propagation(ψψ, mts; contract_kwargs=(; alg="exact"))
numerator_network = approx_network_region(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
mts = compute_message_tensors(ψψ; vertex_groups=vertex_groups)
sz_bp =
calculate_contraction(
ψψ, mts, [(v, 1)]; verts_tensors=ITensor[apply(op("Sz", s[v]), ψ[v])]
)[] / calculate_contraction(ψψ, mts, [(v, 1)])[]
denominator_network = approx_network_region(ψψ, mts, [(v, 1)])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"General Belief Propagation (2-site subgraphs) Gives Sz on Site " *
"General Belief Propagation (4-site subgraphs) Gives Sz on Site " *
string(v) *
" as " *
string(sz_bp),
)

#Now do it exactly
#Now do General Belief Propagation with Matrix Product State Message Tensors Measure Sz on Site v
ψψ = flatten_networks(ψ, dag(ψ); combine_linkinds=false, map_bra_linkinds=prime)
= copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
sz_exact = contract_inner(Oψ, ψ) / contract_inner(ψ, ψ)
ψOψ = flatten_networks(ψ, dag(Oψ); combine_linkinds=false, map_bra_linkinds=prime)

combiners = linkinds_combiners(ψψ)
ψψ = combine_linkinds(ψψ, combiners)
ψOψ = combine_linkinds(ψOψ, combiners)

Z = partition(ψψ, group(v -> v[1], vertices(ψψ)))
maxdim = 8
mts = message_tensors(Z)

mts = belief_propagation(
ψψ,
mts;
contract_kwargs=(;
alg="density_matrix",
output_structure=path_graph_structure,
maxdim,
contraction_sequence_alg="optimal",
),
)

numerator_network = approx_network_region(ψψ, mts, [v]; verts_tn=ITensorNetwork(ψOψ[v]))
denominator_network = approx_network_region(ψψ, mts, [v])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"General Belief Propagation with Column Partitioning and MPS Message Tensors (Max dim 8) Gives Sz on Site " *
string(v) *
" as " *
string(sz_bp),
)

#Now do it exactly
sz_exact = contract(ψOψ)[] / contract(ψψ)[]

return println("The exact value of Sz on Site " * string(v) * " is " * string(sz_exact))
end
Expand Down
18 changes: 18 additions & 0 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ function linkinds(tn::AbstractITensorNetwork, edge)
return commoninds(tn, edge)
end

function internalinds(tn::AbstractITensorNetwork)
return unique(flatten([commoninds(tn, e) for e in edges(tn)]))
end

function externalinds(tn::AbstractITensorNetwork)
return unique(flatten([uniqueinds(tn, e) for e in edges(tn)]))
end

# Priming and tagging (changing Index identifiers)
function replaceinds(tn::AbstractITensorNetwork, is_is′::Pair{<:IndsNetwork,<:IndsNetwork})
tn = copy(tn)
Expand Down Expand Up @@ -819,6 +827,16 @@ function insert_missing_internal_inds(
return insert_internal_inds(tn, edges(tn); internal_inds_space)
end

function ITensors.commoninds(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
inds = Index[]
for v1 in vertices(tn1)
for v2 in vertices(tn2)
append!(inds, commoninds(tn1[v1], tn2[v2]))
end
end
return inds
end

## # TODO: should this make sure that internal indices
## # don't clash?
## function hvncat(
Expand Down
1 change: 1 addition & 0 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function ITensors.apply(
)
end

envs = Vector{ITensor}(envs)
if !isempty(envs)
extended_envs = vcat(envs, Qᵥ₁, prime(dag(Qᵥ₁)), Qᵥ₂, prime(dag(Qᵥ₂)))
Rᵥ₁, Rᵥ₂ = optimise_p_q(
Expand Down
150 changes: 71 additions & 79 deletions src/beliefpropagation.jl
Original file line number Diff line number Diff line change
@@ -1,99 +1,110 @@
function construct_initial_mts(
tn::ITensorNetwork, nvertices_per_partition::Integer; partition_kwargs=(;), kwargs...
function message_tensors(
tn::ITensorNetwork;
nvertices_per_partition=nothing,
npartitions=nothing,
subgraph_vertices=nothing,
kwargs...,
)
return construct_initial_mts(
tn, partition(tn; nvertices_per_partition, partition_kwargs...); kwargs...
return message_tensors(
partition(tn; nvertices_per_partition, npartitions, subgraph_vertices); kwargs...
)
end

function construct_initial_mts(
tn::ITensorNetwork, subgraphs::DataGraph; init=(I...) -> @compat allequal(I) ? 1 : 0
function message_tensors(
subgraphs::DataGraph; itensor_constructor=inds_e -> dense(delta(inds_e))
)
# TODO: This is dropping the vertex data for some reason.
# mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensor}(subgraphs)
mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensor}(
mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensorNetwork}(
directed_graph(underlying_graph(subgraphs))
)
for v in vertices(mts)
mts[v] = subgraphs[v]
end
for subgraph in vertices(subgraphs)
tns_to_contract = ITensor[]
for subgraph_neighbor in neighbors(subgraphs, subgraph)
edge_inds = Index[]
for vertex in vertices(subgraphs[subgraph])
psiv = tn[vertex]
for e in [edgetype(tn)(vertex => neighbor) for neighbor in neighbors(tn, vertex)]
if (find_subgraph(dst(e), subgraphs) == subgraph_neighbor)
append!(edge_inds, commoninds(tn, e))
end
end
end
mt = normalize!(
itensor(
[init(Tuple(I)...) for I in CartesianIndices(tuple(dim.(edge_inds)...))],
edge_inds,
),
)
mts[subgraph => subgraph_neighbor] = mt
end
for e in edges(subgraphs)
inds_e = commoninds(subgraphs[src(e)], subgraphs[dst(e)])
mts[e] = ITensorNetwork(map(itensor_constructor, inds_e))
mts[reverse(e)] = dag(mts[e])
end
return mts
end

"""
DO a single update of a message tensor using the current subgraph and the incoming mts
"""
function update_mt(
function update_message_tensor(
tn::ITensorNetwork,
subgraph_vertices::Vector,
mts::Vector{ITensor};
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
mts::Vector{ITensorNetwork};
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
)
contract_list = [mts; [tn[v] for v in subgraph_vertices]]
contract_list = ITensorNetwork[mts; ITensorNetwork([tn[v] for v in subgraph_vertices])]

new_mt = if isone(length(contract_list))
tn = if isone(length(contract_list))
copy(only(contract_list))
else
contract(contract_list; sequence=contraction_sequence(contract_list))
reduce(, contract_list)
end
return normalize!(new_mt)

contract_output = contract(tn; contract_kwargs...)
itn = if typeof(contract_output) == ITensor
ITensorNetwork(contract_output)
else
first(contract_output)
end
normalize!.(vertex_data(itn))

return itn
end

function update_mt(
tn::ITensorNetwork, subgraph::ITensorNetwork, mts::Vector{ITensor}; kwargs...
function update_message_tensor(
tn::ITensorNetwork, subgraph::ITensorNetwork, mts::Vector{ITensorNetwork}; kwargs...
)
return update_mt(tn, vertices(subgraph), mts; kwargs...)
return update_message_tensor(tn, vertices(subgraph), mts; kwargs...)
end

"""
Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
"""
function update_all_mts(
function belief_propagation_iteration(
tn::ITensorNetwork,
mts::DataGraph;
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
)
update_mts = copy(mts)
new_mts = copy(mts)
for e in edges(mts)
environment_tensors = ITensor[
environment_tensornetworks = ITensorNetwork[
mts[e_in] for e_in in setdiff(boundary_edges(mts, src(e); dir=:in), [reverse(e)])
]
update_mts[src(e) => dst(e)] = update_mt(
tn, mts[src(e)], environment_tensors; contraction_sequence

new_mts[src(e) => dst(e)] = update_message_tensor(
tn, mts[src(e)], environment_tensornetworks; contract_kwargs
)
end
return update_mts
return new_mts
end

function update_all_mts(
function belief_propagation(
tn::ITensorNetwork,
mts::DataGraph,
niters::Int;
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
mts::DataGraph;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
niters=20,
)
for i in 1:niters
mts = update_all_mts(tn, mts; contraction_sequence)
mts = belief_propagation_iteration(tn, mts; contract_kwargs)
end
return mts
end

function belief_propagation(
tn::ITensorNetwork;
contract_kwargs=(; alg="density_matrix", output_structure=path_graph_structure, maxdim=1),
nvertices_per_partition=nothing,
npartitions=nothing,
subgraph_vertices=nothing,
niters=20,
)
mts = message_tensors(tn; nvertices_per_partition, npartitions, subgraph_vertices)
for i in 1:niters
mts = belief_propagation_iteration(tn, mts; contract_kwargs)
end
return mts
end
Expand All @@ -109,43 +120,24 @@ function get_environment(tn::ITensorNetwork, mts::DataGraph, verts::Vector; dir=
return get_environment(tn, mts, setdiff(vertices(tn), verts))
end

env_tensors = ITensor[mts[e] for e in boundary_edges(mts, subgraphs; dir=:in)]
return vcat(
env_tensors,
ITensor[tn[v] for v in setdiff(flatten([vertices(mts[s]) for s in subgraphs]), verts)],
)
env_tns = ITensorNetwork[mts[e] for e in boundary_edges(mts, subgraphs; dir=:in)]
central_tn = ITensorNetwork([
tn[v] for v in setdiff(flatten([vertices(mts[s]) for s in subgraphs]), verts)
])
return ITensorNetwork(vcat(env_tns, ITensorNetwork[central_tn]))
end

"""
Calculate the contraction of a tensor network centred on the vertices verts. Using message tensors.
Defaults to using tn[verts] as the local network but can be overriden
"""
function calculate_contraction(
function approx_network_region(
tn::ITensorNetwork,
mts::DataGraph,
verts::Vector;
verts_tensors=ITensor[tn[v] for v in verts],
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
verts_tn=ITensorNetwork([tn[v] for v in verts]),
)
environment_tensors = get_environment(tn, mts, verts)
tensors_to_contract = vcat(environment_tensors, verts_tensors)
return contract(tensors_to_contract; sequence=contraction_sequence(tensors_to_contract))
end
environment_tn = get_environment(tn, mts, verts)

"""
Simulaneously initialise and update message tensors of a tensornetwork
"""
function compute_message_tensors(
tn::ITensorNetwork;
niters=10,
nvertices_per_partition=nothing,
npartitions=nothing,
vertex_groups=nothing,
kwargs...,
)
Z = partition(tn; nvertices_per_partition, npartitions, subgraph_vertices=vertex_groups)

mts = construct_initial_mts(tn, Z; kwargs...)
mts = update_all_mts(tn, mts, niters)
return mts
return environment_tn verts_tn
end
Loading

0 comments on commit ee90168

Please sign in to comment.