Skip to content

Commit

Permalink
Environments (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Mar 21, 2024
1 parent a4f3592 commit 0477e9d
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 63 deletions.
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ include(joinpath("solvers", "contract.jl"))
include(joinpath("solvers", "linsolve.jl"))
include(joinpath("solvers", "sweep_plans", "sweep_plans.jl"))
include("apply.jl")
include("environment.jl")

include("exports.jl")

Expand Down
29 changes: 19 additions & 10 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
default_message(inds_e) = ITensor[denseblocks(delta(inds_e))]
default_messages(ptn::PartitionedGraph) = Dictionary()
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
return contract_exact(contract_list; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
updated_messages /= norm(updated_messages)
return ITensor[updated_messages]
end
default_message_update_kwargs() = (; normalize=true, contraction_sequence_alg="optimal")
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioned_vertices::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
return 0.5 *
Expand All @@ -27,11 +32,15 @@ function BeliefPropagationCache(
return BeliefPropagationCache(ptn, messages, default_message)
end

function BeliefPropagationCache(tn::ITensorNetwork, partitioned_vertices; kwargs...)
function BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
ptn = PartitionedGraph(tn, partitioned_vertices)
return BeliefPropagationCache(ptn; kwargs...)
end

function BeliefPropagationCache(tn; kwargs...)
return BeliefPropagationCache(tn, default_partitioning(tn); kwargs...)
end

function partitioned_itensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_itensornetwork
end
Expand Down Expand Up @@ -92,7 +101,7 @@ function set_messages(cache::BeliefPropagationCache, messages)
)
end

function incoming_messages(
function environment(
bp_cache::BeliefPropagationCache,
partition_vertices::Vector{<:PartitionVertex};
ignore_edges=PartitionEdge[],
Expand All @@ -102,15 +111,15 @@ function incoming_messages(
return reduce(vcat, ms; init=[])
end

function incoming_messages(
function environment(
bp_cache::BeliefPropagationCache, partition_vertex::PartitionVertex; kwargs...
)
return incoming_messages(bp_cache, [partition_vertex]; kwargs...)
return environment(bp_cache, [partition_vertex]; kwargs...)
end

function incoming_messages(bp_cache::BeliefPropagationCache, verts::Vector)
function environment(bp_cache::BeliefPropagationCache, verts::Vector)
partition_verts = partitionvertices(bp_cache, verts)
messages = incoming_messages(bp_cache, partition_verts)
messages = environment(bp_cache, partition_verts)
central_tensors = ITensor[
tensornetwork(bp_cache)[v] for v in setdiff(vertices(bp_cache, partition_verts), verts)
]
Expand All @@ -129,10 +138,10 @@ function update_message(
bp_cache::BeliefPropagationCache,
edge::PartitionEdge;
message_update=default_message_update,
message_update_kwargs=default_message_update_kwargs(),
message_update_kwargs=(;),
)
vertex = src(edge)
messages = incoming_messages(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
messages = environment(bp_cache, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bp_cache, vertex)

return message_update(ITensor[messages; state]; message_update_kwargs...)
Expand Down
14 changes: 0 additions & 14 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,3 @@ function contract_density_matrix(
end
return out
end

function contract_exact(
contract_list::Vector{ITensor};
contraction_sequence_alg="optimal",
normalize=true,
contractor_kwargs...,
)
seq = contraction_sequence(contract_list; alg=contraction_sequence_alg)
out = ITensors.contract(contract_list; sequence=seq, contractor_kwargs...)
if normalize
normalize!(out)
end
return ITensor[out]
end
42 changes: 42 additions & 0 deletions src/environment.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
default_environment_algorithm() = "exact"

function environment(
ψ::AbstractITensorNetwork,
vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
return environment(Algorithm(alg), ψ, vertices; kwargs...)
end

function environment(
::Algorithm"exact",
ψ::AbstractITensorNetwork,
verts::Vector;
contraction_sequence_alg="optimal",
kwargs...,
)
ψ_reduced = Vector{ITensor}(subgraph(ψ, setdiff(vertices(ψ), verts)))
sequence = contraction_sequence(ψ_reduced; alg=contraction_sequence_alg)
return ITensor[contract(ψ_reduced; sequence, kwargs...)]
end

function environment(
::Algorithm"bp",
ψ::AbstractITensorNetwork,
vertices::Vector;
(cache!)=nothing,
partitioned_vertices=default_partitioned_vertices(ψ),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(ψ, partitioned_vertices))
end

if update_cache
cache![] = update(cache![]; cache_update_kwargs...)
end

return environment(cache![], vertices)
end
46 changes: 29 additions & 17 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ function ket_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == ket_vertex_suffix(f), vertices(f))
end

function bra_ket_vertices(f::AbstractFormNetwork)
return vcat(bra_vertices(f), ket_vertices(f))
function bra_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [bra_vertex_map(f)(osv) for osv in original_state_vertices]
end

function bra_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [bra_vertex_map(f)(sv) for sv in state_vertices]
function ket_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return [ket_vertex_map(f)(osv) for osv in original_state_vertices]
end

function ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return [ket_vertex_map(f)(sv) for sv in state_vertices]
function state_vertices(f::AbstractFormNetwork)
return vcat(bra_vertices(f), ket_vertices(f))
end

function bra_ket_vertices(f::AbstractFormNetwork, state_vertices::Vector)
return vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
function state_vertices(f::AbstractFormNetwork, original_state_vertices::Vector)
return vcat(
bra_vertices(f, original_state_vertices), ket_vertices(f, original_state_vertices)
)
end

function Graphs.induced_subgraph(f::AbstractFormNetwork, vertices::Vector)
Expand All @@ -57,18 +59,28 @@ function operator_network(f::AbstractFormNetwork)
)
end

function derivative(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
tn_vertices = derivative_vertices(f, state_vertices)
return derivative(tensornetwork(f), tn_vertices; kwargs...)
end

function derivative_vertices(f::AbstractFormNetwork, state_vertices::Vector; kwargs...)
return setdiff(
vertices(f), vcat(bra_vertices(f, state_vertices), ket_vertices(f, state_vertices))
)
function environment(
f::AbstractFormNetwork,
original_state_vertices::Vector;
alg=default_environment_algorithm(),
kwargs...,
)
form_vertices = state_vertices(f, original_state_vertices)
if alg == "bp"
partitioned_vertices = group(v -> original_state_vertex(f, v), vertices(f))
return environment(
tensornetwork(f), form_vertices; alg, partitioned_vertices, kwargs...
)
else
return environment(tensornetwork(f), form_vertices; alg, kwargs...)
end
end

operator_vertex_map(f::AbstractFormNetwork) = v -> (v, operator_vertex_suffix(f))
bra_vertex_map(f::AbstractFormNetwork) = v -> (v, bra_vertex_suffix(f))
ket_vertex_map(f::AbstractFormNetwork) = v -> (v, ket_vertex_suffix(f))
inv_vertex_map(f::AbstractFormNetwork) = v -> first(v)
operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v)
bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v)
ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v)
original_state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v)
10 changes: 7 additions & 3 deletions src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ function BilinearFormNetwork(
end

function update(
blf::BilinearFormNetwork, state_vertex, bra_state::ITensor, ket_state::ITensor
blf::BilinearFormNetwork, original_state_vertex, bra_state::ITensor, ket_state::ITensor
)
blf = copy(blf)
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(tensornetwork(blf), bra_state, bra_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(tensornetwork(blf), ket_state, ket_vertex_map(blf)(state_vertex))
setindex_preserve_graph!(
tensornetwork(blf), bra_state, bra_vertex(blf, original_state_vertex)
)
setindex_preserve_graph!(
tensornetwork(blf), ket_state, ket_vertex(blf, original_state_vertex)
)
return blf
end
4 changes: 2 additions & 2 deletions src/formnetworks/quadraticformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ function QuadraticFormNetwork(
return QuadraticFormNetwork(blf, dual_index_map, dual_inv_index_map)
end

function update(qf::QuadraticFormNetwork, state_vertex, ket_state::ITensor)
function update(qf::QuadraticFormNetwork, original_state_vertex, ket_state::ITensor)
state_inds = inds(ket_state)
bra_state = replaceinds(dag(ket_state), state_inds, dual_index_map(qf).(state_inds))
new_blf = update(bilinear_formnetwork(qf), state_vertex, bra_state, ket_state)
new_blf = update(bilinear_formnetwork(qf), original_state_vertex, bra_state, ket_state)
return QuadraticFormNetwork(new_blf, dual_index_map(qf), dual_index_map(qf))
end
1 change: 0 additions & 1 deletion src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ function default_norm_cache(ψ::ITensorNetwork)
ψψ = norm_network(ψ)
return BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)

function ITensorNetwork(
ψ_vidal::VidalITensorNetwork; (cache!)=nothing, update_gauge=false, update_kwargs...
Expand Down
6 changes: 3 additions & 3 deletions test/test_apply.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ITensorNetworks
using ITensorNetworks:
incoming_messages,
environment,
update,
contract_inner,
norm_network,
Expand Down Expand Up @@ -29,14 +29,14 @@ using SplitApplyCombine
#Simple Belief Propagation Grouping
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter=20)
envsSBP = incoming_messages(bp_cache, PartitionVertex.([v1, v2]))
envsSBP = environment(bp_cache, PartitionVertex.([v1, v2]))

ψv = VidalITensorNetwork(ψ)

#This grouping will correspond to calculating the environments exactly (each column of the grid is a partition)
bp_cache = BeliefPropagationCache(ψψ, group(v -> v[1][1], vertices(ψψ)))
bp_cache = update(bp_cache; maxiter=20)
envsGBP = incoming_messages(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])
envsGBP = environment(bp_cache, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)])

ngates = 5

Expand Down
13 changes: 7 additions & 6 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
incoming_messages
environment,
contract
using Test
using Compat
using ITensors
Expand Down Expand Up @@ -40,7 +41,7 @@ ITensors.disable_warn_order()

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = incoming_messages(bpc, [PartitionVertex(v)])
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

Expand Down Expand Up @@ -70,7 +71,7 @@ ITensors.disable_warn_order()

bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc)
env_tensors = incoming_messages(bpc, [PartitionVertex(v)])
env_tensors = environment(bpc, [PartitionVertex(v)])
numerator = contract(vcat(env_tensors, ITensor[ψ[v], op("Sz", s[v]), dag(prime(ψ[v]))]))[]
denominator = contract(vcat(env_tensors, ITensor[ψ[v], op("I", s[v]), dag(prime(ψ[v]))]))[]

Expand All @@ -93,7 +94,7 @@ ITensors.disable_warn_order()
bpc = BeliefPropagationCache(ψψ, group(v -> v[1], vertices(ψψ)))
bpc = update(bpc; maxiter=20)

env_tensors = incoming_messages(bpc, vs)
env_tensors = environment(bpc, vs)
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v] for v in vs]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v] for v in vs]))[]

Expand All @@ -112,7 +113,7 @@ ITensors.disable_warn_order()
bpc = update(bpc; maxiter=20)

ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = incoming_messages(bpc, [(v, 2) for v in vs])
env_tensors = environment(bpc, [(v, 2) for v in vs])
rdm = ITensors.contract(
vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]])
)
Expand Down Expand Up @@ -148,7 +149,7 @@ ITensors.disable_warn_order()
message_update_kwargs=(; cutoff=1e-6, maxdim=4),
)

env_tensors = incoming_messages(bpc, [v])
env_tensors = environment(bpc, [v])
numerator = contract(vcat(env_tensors, ITensor[ψOψ[v]]))[]
denominator = contract(vcat(env_tensors, ITensor[ψψ[v]]))[]

Expand Down
31 changes: 24 additions & 7 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
using ITensors
using Graphs
using Graphs: nv
using NamedGraphs
using ITensorNetworks
using ITensorNetworks:
delta_network,
update,
tensornetwork,
bra_vertex_map,
ket_vertex_map,
bra_vertex,
ket_vertex,
dual_index_map,
bra_network,
ket_network,
operator_network
operator_network,
environment,
BeliefPropagationCache
using Test
using Random
using SplitApplyCombine

@testset "FormNetworkss" begin
@testset "FormNetworks" begin
g = named_grid((1, 4))
s_ket = siteinds("S=1/2", g)
s_bra = prime(s_ket; links=[])
Expand All @@ -42,10 +45,24 @@ using Random
new_tensor = randomITensor(inds(ψket[v]))
qf_updated = update(qf, v, copy(new_tensor))

@test tensornetwork(qf_updated)[bra_vertex_map(qf_updated)(v)]
@test tensornetwork(qf_updated)[bra_vertex(qf_updated, v)]
dual_index_map(qf_updated)(dag(new_tensor))
@test tensornetwork(qf_updated)[ket_vertex_map(qf_updated)(v)] new_tensor
@test tensornetwork(qf_updated)[ket_vertex(qf_updated, v)] new_tensor

@test underlying_graph(ket_network(qf)) == underlying_graph(ψket)
@test underlying_graph(operator_network(qf)) == underlying_graph(A)

∂qf_∂v = only(environment(qf, [v]))
@test (∂qf_∂v) * (qf[ket_vertex(qf, v)] * qf[bra_vertex(qf, v)]) contract(qf)

∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=false)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
∂qf_∂v /= norm(∂qf_∂v)
@test ∂qf_∂v_bp != ∂qf_∂v

∂qf_∂v_bp = environment(qf, [v]; alg="bp", update_cache=true)
∂qf_∂v_bp = contract(∂qf_∂v_bp)
∂qf_∂v_bp /= norm(∂qf_∂v_bp)
@test ∂qf_∂v_bp ∂qf_∂v
end

0 comments on commit 0477e9d

Please sign in to comment.