Skip to content

Commit

Permalink
Make sample type agnostic and GPU compatible (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
tipfom authored Nov 15, 2024
1 parent 8896bb1 commit f5dee09
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ function sample(rng::AbstractRNG, m::MPS)
error("sample: MPS is not normalized, norm=$(norm(m[1]))")
end

ElT = scalartype(m)

result = zeros(Int, N)
A = m[1]

Expand All @@ -664,16 +666,16 @@ function sample(rng::AbstractRNG, m::MPS)
# Compute the probability of each state
# one-by-one and stop when the random
# number r is below the total prob so far
pdisc = 0.0
pdisc = zero(real(ElT))
r = rand(rng)
# Will need n,An, and pn below
n = 1
An = ITensor()
pn = 0.0
pn = zero(real(ElT))
while n <= d
projn = ITensor(s)
projn[s => n] = 1.0
An = A * dag(projn)
projn[s => n] = one(ElT)
An = A * dag(adapt(datatype(A), projn))
pn = real(scalar(dag(An) * An))
pdisc += pn
(r < pdisc) && break
Expand All @@ -682,7 +684,7 @@ function sample(rng::AbstractRNG, m::MPS)
result[j] = n
if j < N
A = m[j + 1] * An
A *= (1.0 / sqrt(pn))
A *= (one(ElT) / sqrt(pn))
end
end
return result
Expand Down Expand Up @@ -749,7 +751,7 @@ function correlation_matrix(
end_site = last(sites)

N = length(psi)
ElT = promote_itensor_eltype(psi)
ElT = scalartype(psi)
s = siteinds(psi)

Op1 = _Op1 #make copies into which we can insert "F" string operators, and then restore.
Expand Down Expand Up @@ -983,7 +985,7 @@ updens, dndens = expect(psi, "Nup", "Ndn") # pass more than one operator
function expect(psi::MPS, ops; sites=1:length(psi), site_range=nothing)
psi = copy(psi)
N = length(psi)
ElT = promote_itensor_eltype(psi)
ElT = scalartype(psi)
s = siteinds(psi)

if !isnothing(site_range)
Expand Down

0 comments on commit f5dee09

Please sign in to comment.