Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update 'Wasserstein_distances.jl' #20

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Eirene.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Eirene. If not, see <http://www.gnu.org/licenses/>.
# along with Eirene. If not, swasserstein_distance([0 1], [3 5; 7 9], p=Inf, q=1)ee <http://www.gnu.org/licenses/>.
#
# PLEASE HELP US DOCUMENT Eirene's recent work! Bibtex entries and
# contact information for teaching and outreach can be found at the
Expand Down Expand Up @@ -7512,7 +7512,7 @@ function unittest()

numits = 5
maxdim = 2
x = Array{Any}(undef,22)
x = Array{Any}(undef,24)

x[1] = eirenevrVperseusvr() # correct answer: empty
x[2] = eirenevrVeirenepc(numits,maxdim) # correct answer: empty
Expand All @@ -7536,6 +7536,8 @@ function unittest()
x[20] = wd_test_1() # correct answer: empty
x[21] = wd_test_2() # correct answer: empty
x[22] = wd_test_3() # correct answer: empty
x[23] = wd_test_4() # correct answer: empty
x[24] = wd_test_5() # correct answer: empty

for p = 1:length(x)
if !isempty(x[p])
Expand Down
217 changes: 116 additions & 101 deletions src/wasserstein_distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function pad(u1,u2)
@assert size(u1)[2] == size(u2)[2] == 2

#need transpose as sometimes a 1D vector
n1 = size(u1)[1]
n1 = size(u1)[1]
n2 = size(u2)[1]
# note total n = n1 + n2
v1 = vcat(u1, zeros(n2,2))
Expand All @@ -31,16 +31,16 @@ function pad(u1,u2)
z = (v2[i,1]+v2[i,2])/2
v1[n1+i,1] = z
v1[n1+i,2] = z
end
end
################
for i = 1:n1
for i = 1:n1

z = (v1[i,1]+v1[i,2])/2
z = (v1[i,1]+v1[i,2])/2

v2[n2+i,1] = z
v2[n2+i,2] = z
v2[n2+i,1] = z
v2[n2+i,2] = z

end
end

return v1,v2,n1,n2
end
Expand All @@ -61,9 +61,9 @@ function dist_mat(v1,v2,n1,n2; p = 2)
#if l1 compute here in faster way.
if p == 1
for i = 1:n
for j in 1:n
cost[i,j] = abs(v1[i,1]-v2[j,1]) + abs(v1[i,2] - v2[j,2])
end
for j in 1:n
cost[i,j] = abs(v1[i,1]-v2[j,1]) + abs(v1[i,2] - v2[j,2])
end
end

elseif p == Inf
Expand All @@ -74,23 +74,23 @@ function dist_mat(v1,v2,n1,n2; p = 2)
end
else
for i = 1:n
for j in 1:n
cost[i,j] = ((abs(v1[i,1]-v)^p)+ abs(v1[i,2]-v2[j,2])^p)^(1/p)
for j in 1:n
cost[i,j] = ((abs(v1[i,1]-v2[j,1])^p)+ abs(v1[i,2]-v2[j,2])^p)^(1/p)
end
end
end

end


#set distance between diagonal points to be 0.
#this could just not be calculated if not using broadcast.
cost[(n-n2+1):n,(n-n1+1):n] = zeros(n2,n1)

cost[(n-n2+1):n,(n-n1+1):n] = zeros(n2,n1)
#print("cost matrix is ", cost,"\n")

return cost

end

function dist_inf(v1,v2,p=2)
function dist_inf(v1,v2)
#= else

takes in two vectors with all y points at infinity.
Expand All @@ -107,112 +107,106 @@ function dist_inf(v1,v2,p=2)
n = size(v1)[1]
cost = zeros(n,n)

if p == 1
for i = 1:n
for j in 1:n
cost[i,j] = abs(v1[i,1]-v2[j,1]) + abs(v1[i,2] - v2[j,2])
for j in 1:n
cost[i,j] = abs(v1[i,1]-v2[j,1]) + abs(v1[i,2] - v2[j,2])
end
end

elseif p == Inf
for i = 1:n
for j in 1:n
cost[i,j] = maximum(broadcast(abs,v1[i,:]-v2[j,:]))
end
end
else
for i = 1:n
for j in 1:n
cost[i,j] = ((abs(v1[i,1]-v)^p)+ abs(v1[i,2]-v2[j,2])^p)^(1/p)
end
end

#calculate cost matrix for only x co-ordinates
#for i = 1:n
# cost[i,:] = broadcast(abs,broadcast(-, v1[im1], v2[:,1]))
#end

return cost, hungarian(cost)[1]

end
#end
# elseif p == Inf
# for i = 1:n
# for j in 1:n
# cost[i,j] = maximum(broadcast(abs,v1[i,:]-v2[j,:]))
# end
# end
# else
# for i = 1:n
# for j in 1:n
# cost[i,j] = ((abs(v1[i,1]-v2[i,1])^p) )^(1/p)
# end
# end
# end
return cost, hungarian(cost)[1]
end
end
############# Main function #############

function wasserstein_distance(dgm1,dgm2; p = 2,q=p)

u1 = transpose(dgm1)
u2 = transpose(dgm2)

#=
takes two (possibly unequal size) vectors and calculates the W_(q,p)distance between their persistence diagrams. The default is that q=p=2
Can calculate lp distance between diagrams, l1 should be the fastest.
u1 = dgm1
u2 = dgm2
#=
takes two (possibly unequal size) vectors and calculates the W_(q,p)distance between their persistence diagrams. The default is that q=p=2
Can calculate lp distance between diagrams, l1 should be the fastest.
Can handle values of Inf in vectors.
=#

#if no Inf is present in either vector calculate as normal.
if all(i->(i!=Inf), u1) && all(i->(i!=Inf), u2)
=#
#if no Inf is present in either vector calculate as normal.
if all(i->(i!=Inf), u1) && all(i->(i!=Inf), u2)
v1,v2,n1,n2 = pad(u1,u2)

cost = dist_mat(v1,v2,n1,n2,p=p)
assignment = hungarian(cost)[1]

if q == Inf
values = [cost[i, assignment[i]] for i in 1:(n1+n2)]
distance = maximum(values)
return distance

else
distance = 0
for i in 1:length(assignment)
distance += cost[i, assignment[i]]^(q)
# I forgot that whilst Eirene processeprs points as nd x np matrices, the barcodes are np x 2
distance += cost[i, assignment[i]]^(q)
end
return distance^(1/q)

end




#if there are equal amounts of infinity calculate possibly finite distance.
elseif sum(u1[:,2] .== Inf) == sum(u2[:,2] .== Inf)

#get the number of infinities.
N_inf = sum(u1[:, 2] .== Inf)
#sort vectors by incresum(broadcast(abs,broadcast(-, v1[:,i], v2)),dims = 1)asing amount in y component.
u_sort_1 = u1[:, sortperm(u1[:,2], rev = true)]
u_sort_2 = u2[:, sortperm(u2[:,2], rev = true)]
#split into infinity part and finite part
u_sort_1_1 = u_sort_1[1:N_inf,:]
u_sort_2_1 = u_sort_2[1:N_inf,:]
u_sort_1_2 = u_sort_1[(1+N_inf):end,:]
u_sort_2_2 = u_sort_2[(1+N_inf):end,:]

#calculate infinite cost.
#if there are equal amounts of infinity calculate possibly finite distance.
elseif sum(u1[:,2] .== Inf) == sum(u2[:,2] .== Inf)
#get the number of infinities.
N_inf = sum(u1[:, 2] .== Inf)
#sort vectors by incresum(broadcast(abs,broadcast(-, v1[:,i], v2)),dims = 1)asing amount in y component.
u_sort_1 = u1[:, sortperm(u1[:,2], rev = true)]
u_sort_2 = u2[:, sortperm(u2[:,2], rev = true)]
#split into infinity part and finite part
u_sort_1_1 = u_sort_1[1:N_inf,:]
u_sort_2_1 = u_sort_2[1:N_inf,:]
u_sort_1_2 = u_sort_1[(1+N_inf):end,:]
u_sort_2_2 = u_sort_2[(1+N_inf):end,:]
#calculate infinite cost.
cost, assignment_inf = dist_inf(u_sort_1_1,u_sort_2_1)

if q == Inf
costs = [cost[i, assignment_inf[i]] for i in 1:(N_inf)]
cost_inf = maximum(costs)
else
cost_inf = 0
for i in 1:N_inf
cost_inf += cost[i, assignment_inf[i]]^(q)
end
for i in 1:N_inf
cost_inf += cost[i, assignment_inf[i]]^(q)
end
end
#calculate finite cost with self-reference.
cost_h = wasserstein_distance(u_sort_1_2,u_sort_2_2,p=p, q=q)


return cost_h + cost_inf

#unequal infinity return infinity.
else
return Inf

end

#calculate finite cost with self-reference.
cost_h = wasserstein_distance(u_sort_1_2,u_sort_2_2,p=p, q=q)

if q == Inf
return maximum(cost_h, cost_inf)
else
return (cost_h^q + cost_inf)^(1/q)
end

#unequal infinity return infinity.
else
return Inf

end

end


Expand All @@ -221,34 +215,55 @@ end
#

function wd_test_1()
val = wasserstein_distance([1,1], [1,1])
val = wasserstein_distance([1 2], [1 2])

if val == 0
return []
else
print("Error: wd_test_1, value = ",val)
print("Error: wd_test_1, value = ")
return val
end
end

function wd_test_2()
val = wasserstein_distance([1,2],[3,4], p=Inf )
val = wasserstein_distance([1 2],[3 4], p=Inf )

if val == 1.25
if val == 0.5
return []
else
print("Error: wd_test_2, value = ",val)
print("Error: wd_test_2, value = ")
return val
end
end

function wd_test_3()
val = wasserstein_distance([1,2],[3,3.5],p=1,q=2 )
val = wasserstein_distance([1 2],[3 3.5],p=1,q=Inf )

if val == 1
return []
else
print("Error: wd_test_3, value = ")
return val
end
end

if val == 2.125
function wd_test_4()
val = wasserstein_distance([0 1], [3 5; 7 9], p=Inf, q=1)

if val == 2.5
return []
else print("Error: wd_test_4, value = ")
return val
end
end

function wd_test_5()
val = wasserstein_distance([1 1], [2 2])

if val == 0
return []
else
print("Error: wd_test_3, value = ",val)
print("Error: wd_test_5, value = ")
return val
end
end