Skip to content

Commit

Permalink
add(::ITensorNetwork, ::ITensorNetwork) with directsum backend (#110
Browse files Browse the repository at this point in the history
)
  • Loading branch information
JoeyT1994 authored Aug 11, 2023
1 parent 77ec3cb commit fd19eef
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 6 deletions.
77 changes: 73 additions & 4 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,27 +591,34 @@ function neighbor_vertices(ψ::AbstractITensorNetwork, T::ITensor)
return first.(v⃗)
end

function linkinds_combiners(tn::AbstractITensorNetwork)
function linkinds_combiners(tn::AbstractITensorNetwork; edges=edges(tn))
combiners = DataGraph(directed_graph(underlying_graph(tn)), ITensor, ITensor)
for e in edges(tn)
for e in edges
C = combiner(linkinds(tn, e); tags=edge_tag(e))
combiners[e] = C
combiners[reverse(e)] = dag(C)
end
return combiners
end

function combine_linkinds(tn::AbstractITensorNetwork, combiners=linkinds_combiners(tn))
function combine_linkinds(tn::AbstractITensorNetwork, combiners)
combined_tn = copy(tn)
for e in edges(tn)
if !isempty(linkinds(tn, e))
if !isempty(linkinds(tn, e)) && haskey(edge_data(combiners), e)
combined_tn[src(e)] = combined_tn[src(e)] * combiners[e]
combined_tn[dst(e)] = combined_tn[dst(e)] * combiners[reverse(e)]
end
end
return combined_tn
end

function combine_linkinds(
tn::AbstractITensorNetwork; edges::Vector{<:Union{Pair,AbstractEdge}}=edges(tn)
)
combiners = linkinds_combiners(tn; edges)
return combine_linkinds(tn, combiners)
end

function split_index(
tn::AbstractITensorNetwork,
edges_to_split;
Expand Down Expand Up @@ -870,6 +877,68 @@ function ITensors.commoninds(tn1::AbstractITensorNetwork, tn2::AbstractITensorNe
return inds
end

"""Check if the edge of an itensornetwork has multiple indices"""
is_multi_edge(tn::AbstractITensorNetwork, e) = length(linkinds(tn, e)) > 1
is_multi_edge(tn::AbstractITensorNetwork) = Base.Fix1(is_multi_edge, tn)

"""Add two itensornetworks together by growing the bond dimension. The network structures need to be have the same vertex names, same site index on each vertex """
function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
@assert issetequal(vertices(tn1), vertices(tn2))

tn1 = combine_linkinds(tn1; edges=filter(is_multi_edge(tn1), edges(tn1)))
tn2 = combine_linkinds(tn2; edges=filter(is_multi_edge(tn2), edges(tn2)))

edges_tn1, edges_tn2 = edges(tn1), edges(tn2)

if !issetequal(edges_tn1, edges_tn2)
new_edges = union(edges_tn1, edges_tn2)
tn1 = insert_missing_internal_inds(tn1, new_edges)
tn2 = insert_missing_internal_inds(tn2, new_edges)
end

edges_tn1, edges_tn2 = edges(tn1), edges(tn2)
@assert issetequal(edges_tn1, edges_tn2)

tn12 = copy(tn1)
new_edge_indices = Dict(
zip(
edges_tn1,
[
Index(
dim(only(linkinds(tn1, e))) + dim(only(linkinds(tn2, e))),
tags(only(linkinds(tn1, e))),
) for e in edges_tn1
],
),
)

#Create vertices of tn12 as direct sum of tn1[v] and tn2[v]. Work out the matching indices by matching edges. Make index tags those of tn1[v]
for v in vertices(tn1)
@assert siteinds(tn1, v) == siteinds(tn2, v)

e1_v = filter(x -> src(x) == v || dst(x) == v, edges_tn1)
e2_v = filter(x -> src(x) == v || dst(x) == v, edges_tn2)

@assert issetequal(e1_v, e2_v)
tn1v_linkinds = Index[only(linkinds(tn1, e)) for e in e1_v]
tn2v_linkinds = Index[only(linkinds(tn2, e)) for e in e1_v]
tn12v_linkinds = Index[new_edge_indices[e] for e in e1_v]

@assert length(tn1v_linkinds) == length(tn2v_linkinds)

tn12[v] = ITensors.directsum(
tn12v_linkinds,
tn1[v] => Tuple(tn1v_linkinds),
tn2[v] => Tuple(tn2v_linkinds);
tags=tags.(Tuple(tn1v_linkinds)),
)
end

return tn12
end

+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)

## # TODO: should this make sure that internal indices
## # don't clash?
## function hvncat(
Expand Down
7 changes: 5 additions & 2 deletions src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import Base:
isapprox,
isassigned,
iterate,
union
union,
+

import NamedGraphs:
vertextype,
Expand Down Expand Up @@ -92,7 +93,9 @@ import ITensors:
nsite,
# promotion and conversion
promote_itensor_eltype,
scalartype
scalartype,
#adding
add

using ITensors.ContractionSequenceOptimization: deepmap

Expand Down
77 changes: 77 additions & 0 deletions test/test_additensornetworks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using ITensorNetworks
using ITensorNetworks: inner_network
using Test
using Compat
using ITensors
using Metis
using NamedGraphs
using NamedGraphs: hexagonal_lattice_graph, rem_edge!
using Random
using LinearAlgebra
using SplitApplyCombine

using Random

@testset "add_itensornetworks" begin
Random.seed!(5623)
g = named_grid((2, 3))
s = siteinds("S=1/2", g)
ψ1 = ITensorNetwork(s, v -> "")
ψ2 = ITensorNetwork(s, v -> "")

ψ_GHZ = ψ1 + ψ2

v = (2, 2)
Oψ_GHZ = copy(ψ_GHZ)
Oψ_GHZ[v] = apply(op("Sz", s[v]), Oψ_GHZ[v])

ψψ_GHZ = inner_network(ψ_GHZ, ψ_GHZ)
ψOψ_GHZ = inner_network(ψ_GHZ, Oψ_GHZ)

@test ITensors.contract(ψOψ_GHZ)[] / ITensors.contract(ψψ_GHZ)[] == 0.0

χ = 3
g = hexagonal_lattice_graph(1, 2)

s1 = siteinds("S=1/2", g)
s2 = copy(s1)
rem_edge!(s2, NamedEdge((1, 1) => (1, 2)))

v = rand(vertices(g))
ψ1 = randomITensorNetwork(s1; link_space=χ)
ψ2 = randomITensorNetwork(s2; link_space=χ)

ψ12 = ψ1 + ψ2

Oψ12 = copy(ψ12)
Oψ12[v] = apply(op("Sz", s1[v]), Oψ12[v])

Oψ1 = copy(ψ1)
Oψ1[v] = apply(op("Sz", s1[v]), Oψ1[v])

Oψ2 = copy(ψ2)
Oψ2[v] = apply(op("Sz", s2[v]), Oψ2[v])

ψψ_12 = inner_network(ψ12, ψ12)
ψOψ_12 = inner_network(ψ12, Oψ12)

ψ1ψ2 = inner_network(ψ1, ψ2)
ψ1Oψ2 = inner_network(ψ1, Oψ2)

ψψ_2 = inner_network(ψ2, ψ2)
ψOψ_2 = inner_network(ψ2, Oψ2)

ψψ_1 = inner_network(ψ1, ψ1)
ψOψ_1 = inner_network(ψ1, Oψ1)

expec_method1 =
(
ITensors.contract(ψOψ_1)[] +
ITensors.contract(ψOψ_2)[] +
2 * ITensors.contract(ψ1Oψ2)[]
) /
(ITensors.contract(ψψ_1)[] + ITensors.contract(ψψ_2)[] + 2 * ITensors.contract(ψ1ψ2)[])
expec_method2 = ITensors.contract(ψOψ_12)[] / ITensors.contract(ψψ_12)[]

@test expec_method1 expec_method2
end

0 comments on commit fd19eef

Please sign in to comment.