From eb0bb38ca153e36d3dd3586b61740676d414b14a Mon Sep 17 00:00:00 2001 From: Joseph Tindall <51231103+JoeyT1994@users.noreply.github.com> Date: Fri, 24 Feb 2023 11:08:45 -0500 Subject: [PATCH] "Full update" gate application --- src/apply.jl | 219 ++++++++++++++++++++++++++++++++++++++-- src/tebd.jl | 12 ++- test/test_fullupdate.jl | 72 +++++++++++++ 3 files changed, 291 insertions(+), 12 deletions(-) create mode 100644 test/test_fullupdate.jl diff --git a/src/apply.jl b/src/apply.jl index b140a93d..4b153d15 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -5,6 +5,10 @@ function ITensors.apply( maxdim=nothing, normalize=false, ortho=false, + envs=ITensor[], + nfullupdatesweeps=10, + print_fidelity_loss=true, + envisposdef=false, ) ψ = copy(ψ) v⃗ = neighbor_vertices(ψ, o) @@ -25,16 +29,52 @@ function ITensors.apply( if ortho ψ = orthogonalize(ψ, v⃗[1]) end - oψᵥ = apply(o, ψ[v⃗[1]] * ψ[v⃗[2]]) - ψᵥ₁, ψᵥ₂ = factorize( - oψᵥ, inds(ψ[v⃗[1]]); cutoff, maxdim, tags=ITensorNetworks.edge_tag(e) - ) + + outer_dim_v1, outer_dim_v2 = dim(uniqueinds(ψ[v⃗[1]], o, ψ[v⃗[2]])), + dim(uniqueinds(ψ[v⃗[2]], o, ψ[v⃗[1]])) + dim_shared = dim(commoninds(ψ[v⃗[1]], ψ[v⃗[2]])) + d1, d2 = dim(commoninds(ψ[v⃗[1]], o)), dim(commoninds(ψ[v⃗[2]], o)) + if outer_dim_v1 * outer_dim_v2 <= dim_shared * dim_shared * d1 * d2 + Qᵥ₁, Rᵥ₁ = ITensor(true), copy(ψ[v⃗[1]]) + Qᵥ₂, Rᵥ₂ = ITensor(true), copy(ψ[v⃗[2]]) + else + Qᵥ₁, Rᵥ₁ = factorize( + ψ[v⃗[1]], uniqueinds(uniqueinds(ψ[v⃗[1]], ψ[v⃗[2]]), uniqueinds(ψ, v⃗[1])) + ) + Qᵥ₂, Rᵥ₂ = factorize( + ψ[v⃗[2]], uniqueinds(uniqueinds(ψ[v⃗[2]], ψ[v⃗[1]]), uniqueinds(ψ, v⃗[2])) + ) + end + + if !isempty(envs) + extended_envs = vcat(envs, Qᵥ₁, prime(dag(Qᵥ₁)), Qᵥ₂, prime(dag(Qᵥ₂))) + Rᵥ₁, Rᵥ₂ = optimise_p_q( + Rᵥ₁, + Rᵥ₂, + extended_envs, + o; + nfullupdatesweeps, + maxdim, + print_fidelity_loss, + envisposdef, + ) + else + Rᵥ₁, Rᵥ₂ = factorize( + apply(o, Rᵥ₁ * Rᵥ₂), inds(Rᵥ₁); cutoff, maxdim, tags=ITensorNetworks.edge_tag(e) + ) + end + + ψᵥ₁ = Qᵥ₁ * Rᵥ₁ + ψᵥ₂ = Qᵥ₂ * Rᵥ₂ + if normalize ψᵥ₁ ./= norm(ψᵥ₁) ψᵥ₂ ./= norm(ψᵥ₂) end + ψ[v⃗[1]] = ψᵥ₁ ψ[v⃗[2]] = ψᵥ₂ + elseif length(v⃗) < 1 error("Gate being applied does not share indices with tensor network.") elseif length(v⃗) > 2 @@ -50,6 +90,7 @@ function ITensors.apply( maxdim=typemax(Int), normalize=false, ortho=false, + kwargs..., ) o⃗ψ = ψ for oᵢ in o⃗ @@ -59,24 +100,182 @@ function ITensors.apply( end function ITensors.apply( - o⃗::Scaled, ψ::AbstractITensorNetwork; cutoff, maxdim, normalize=false, ortho=false + o⃗::Scaled, + ψ::AbstractITensorNetwork; + cutoff, + maxdim, + normalize=false, + ortho=false, + kwargs..., ) return maybe_real(Ops.coefficient(o⃗)) * - apply(Ops.argument(o⃗), ψ; cutoff, maxdim, normalize, ortho) + apply(Ops.argument(o⃗), ψ; cutoff, maxdim, normalize, ortho, kwargs...) end function ITensors.apply( - o⃗::Prod, ψ::AbstractITensorNetwork; cutoff, maxdim, normalize=false, ortho=false + o⃗::Prod, + ψ::AbstractITensorNetwork; + cutoff, + maxdim, + normalize=false, + ortho=false, + kwargs..., ) o⃗ψ = ψ for oᵢ in o⃗ - o⃗ψ = apply(oᵢ, o⃗ψ; cutoff, maxdim, normalize, ortho) + o⃗ψ = apply(oᵢ, o⃗ψ; cutoff, maxdim, normalize, ortho, kwargs...) end return o⃗ψ end function ITensors.apply( - o::Op, ψ::AbstractITensorNetwork; cutoff, maxdim, normalize=false, ortho=false + o::Op, ψ::AbstractITensorNetwork; cutoff, maxdim, normalize=false, ortho=false, kwargs... +) + return apply(ITensor(o, siteinds(ψ)), ψ; cutoff, maxdim, normalize, ortho, kwargs...) +end + +### Full Update Routines ### + +"""Calculate the overlap of the gate acting on the previous p and q versus the new p and q in the presence of environments. This is the cost function that optimise_p_q will minimise""" +function fidelity( + envs::Vector{ITensor}, + p_cur::ITensor, + q_cur::ITensor, + p_prev::ITensor, + q_prev::ITensor, + gate::ITensor, ) - return apply(ITensor(o, siteinds(ψ)), ψ; cutoff, maxdim, normalize, ortho) + p_sind, q_sind = commonind(p_cur, gate), commonind(q_cur, gate) + p_sind_sim, q_sind_sim = sim(p_sind), sim(q_sind) + gate_sq = + gate * replaceinds(dag(gate), Index[p_sind, q_sind], Index[p_sind_sim, q_sind_sim]) + term1_tns = vcat( + [ + p_prev, + q_prev, + replaceind(prime(dag(p_prev)), prime(p_sind), p_sind_sim), + replaceind(prime(dag(q_prev)), prime(q_sind), q_sind_sim), + gate_sq, + ], + envs, + ) + term1 = ITensors.contract( + term1_tns; sequence=ITensors.optimal_contraction_sequence(term1_tns) + ) + + term2_tns = vcat( + [ + p_cur, + q_cur, + replaceind(prime(dag(p_cur)), prime(p_sind), p_sind), + replaceind(prime(dag(q_cur)), prime(q_sind), q_sind), + ], + envs, + ) + term2 = ITensors.contract( + term2_tns; sequence=ITensors.optimal_contraction_sequence(term2_tns) + ) + term3_tns = vcat([p_prev, q_prev, prime(dag(p_cur)), prime(dag(q_cur)), gate], envs) + term3 = ITensors.contract( + term3_tns; sequence=ITensors.optimal_contraction_sequence(term3_tns) + ) + + f = term3[] / sqrt(term1[] * term2[]) + return f * conj(f) end + +"""Do Full Update Sweeping, Optimising the tensors p and q in the presence of the environments envs, +Specifically this functions find the p_cur and q_cur which optimise envs*gate*p*q*dag(prime(p_cur))*dag(prime(q_cur))""" +function optimise_p_q( + p::ITensor, + q::ITensor, + envs::Vector{ITensor}, + o::ITensor; + nfullupdatesweeps=10, + maxdim=nothing, + print_fidelity_loss=false, + envisposdef=true, +) + p_cur, q_cur = factorize(apply(o, p * q), inds(p); maxdim, tags=tags(commonind(p, q))) + + fstart = print_fidelity_loss ? fidelity(envs, p_cur, q_cur, p, q, o) : 0 + + qs_ind = setdiff(inds(q_cur), collect(Iterators.flatten(inds.(vcat(envs, p_cur))))) + ps_ind = setdiff(inds(p_cur), collect(Iterators.flatten(inds.(vcat(envs, q_cur))))) + + opt_b_seq = ITensors.optimal_contraction_sequence( + vcat(ITensor[p, q, o, dag(prime(q_cur))], envs) + ) + opt_b_tilde_seq = ITensors.optimal_contraction_sequence( + vcat(ITensor[p, q, o, dag(prime(p_cur))], envs) + ) + opt_M_seq = ITensors.optimal_contraction_sequence( + vcat(ITensor[q_cur, replaceinds(prime(dag(q_cur)), prime(qs_ind), qs_ind), p_cur], envs) + ) + opt_M_tilde_seq = ITensors.optimal_contraction_sequence( + vcat(ITensor[p_cur, replaceinds(prime(dag(p_cur)), prime(ps_ind), ps_ind), q_cur], envs) + ) + + function b( + p::ITensor, + q::ITensor, + o::ITensor, + envs::Vector{ITensor}, + r::ITensor; + opt_sequence=nothing, + ) + return noprime( + ITensors.contract(vcat(ITensor[p, q, o, dag(prime(r))], envs); sequence=opt_sequence) + ) + end + + function M_p( + envs::Vector{ITensor}, + p_q_tensor::ITensor, + s_ind, + apply_tensor::ITensor; + opt_sequence=nothing, + ) + return noprime( + ITensors.contract( + vcat( + ITensor[ + p_q_tensor, + replaceinds(prime(dag(p_q_tensor)), prime(s_ind), s_ind), + apply_tensor, + ], + envs, + ); + sequence=opt_sequence, + ), + ) + end + for i in 1:nfullupdatesweeps + b_vec = b(p, q, o, envs, q_cur; opt_sequence=opt_b_seq) + M_p_partial = partial(M_p, envs, q_cur, qs_ind; opt_sequence=opt_M_seq) + + p_cur, info = linsolve( + M_p_partial, b_vec, p_cur; isposdef=envisposdef, ishermitian=false + ) + + b_tilde_vec = b(p, q, o, envs, p_cur; opt_sequence=opt_b_tilde_seq) + M_p_tilde_partial = partial(M_p, envs, p_cur, ps_ind; opt_sequence=opt_M_tilde_seq) + + q_cur, info = linsolve( + M_p_tilde_partial, b_tilde_vec, q_cur; isposdef=envisposdef, ishermitian=false + ) + end + + fend = print_fidelity_loss ? fidelity(envs, p_cur, q_cur, p, q, o) : 0 + + diff = real(fend - fstart) + if print_fidelity_loss && diff < -eps(diff) && nfullupdatesweeps >= 1 + println( + "Warning: Krylov Solver Didn't Find a Better Solution by Sweeping. Something might be amiss.", + ) + end + + return p_cur, q_cur +end + +partial = (f, a...; c...) -> (b...) -> f(a..., b...; c...) diff --git a/src/tebd.jl b/src/tebd.jl index 155ec941..ebc1bc0d 100644 --- a/src/tebd.jl +++ b/src/tebd.jl @@ -1,5 +1,13 @@ function tebd( - ℋ::Sum, ψ::AbstractITensorNetwork; β, Δβ, maxdim, cutoff, print_frequency=10, ortho=false + ℋ::Sum, + ψ::AbstractITensorNetwork; + β, + Δβ, + maxdim, + cutoff, + print_frequency=10, + ortho=false, + kwargs..., ) 𝒰 = exp(-Δβ * ℋ; alg=Trotter{2}()) # Imaginary time evolution terms @@ -11,7 +19,7 @@ function tebd( @show step, (step - 1) * Δβ, β end ψ = insert_links(ψ) - ψ = apply(u⃗, ψ; cutoff, maxdim, normalize=true, ortho) + ψ = apply(u⃗, ψ; cutoff, maxdim, normalize=true, ortho, kwargs...) if ortho for v in vertices(ψ) ψ = orthogonalize(ψ, v) diff --git a/test/test_fullupdate.jl b/test/test_fullupdate.jl new file mode 100644 index 00000000..f7e50a9c --- /dev/null +++ b/test/test_fullupdate.jl @@ -0,0 +1,72 @@ +using ITensorNetworks +using ITensorNetworks: + compute_message_tensors, get_environment, nested_graph_leaf_vertices, contract_inner +using Test +using Compat +using ITensors +using Metis +using NamedGraphs +using Random +using LinearAlgebra +using SplitApplyCombine + +@testset "full_update" begin + Random.seed!(5623) + dims = (2, 3) + n = prod(dims) + g = named_grid(dims) + s = siteinds("S=1/2", g) + χ = 2 + ψ = randomITensorNetwork(s; link_space=χ) + v1, v2 = (2, 2), (1, 2) + + ψψ = ψ ⊗ prime(dag(ψ); sites=[]) + + #Simple Belief Propagation Grouping + vertex_groupsSBP = nested_graph_leaf_vertices( + partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=1) + ) + mtsSBP = compute_message_tensors(ψψ; vertex_groups=vertex_groupsSBP) + envsSBP = get_environment(ψψ, mtsSBP, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) + + #This grouping will correspond to calculating the environments exactly (each column of the grid is a partition) + vertex_groupsGBP = nested_graph_leaf_vertices( + partition(partition(ψψ, group(v -> v[1][1], vertices(ψψ))); nvertices_per_partition=1) + ) + mtsGBP = compute_message_tensors(ψψ; vertex_groups=vertex_groupsGBP) + envsGBP = get_environment(ψψ, mtsGBP, [(v1, 1), (v1, 2), (v2, 1), (v2, 2)]) + + ngates = 5 + + for i in 1:ngates + o = ITensors.op("RandomUnitary", s[v1]..., s[v2]...) + + ψOexact = apply(o, ψ; maxdim=4 * χ) + ψOSBP = apply( + o, + ψ; + envs=envsSBP, + maxdim=χ, + normalize=true, + print_fidelity_loss=true, + envisposdef=true, + ) + ψOGBP = apply( + o, + ψ; + envs=envsGBP, + maxdim=χ, + normalize=true, + print_fidelity_loss=true, + envisposdef=true, + ) + fSBP = + contract_inner(ψOSBP, ψOexact) / + sqrt(contract_inner(ψOexact, ψOexact) * contract_inner(ψOSBP, ψOSBP)) + fGBP = + contract_inner(ψOGBP, ψOexact) / + sqrt(contract_inner(ψOexact, ψOexact) * contract_inner(ψOGBP, ψOGBP)) + + @test real(fGBP * conj(fGBP)) >= real(fSBP * conj(fSBP)) + end +end