Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescale tensors and choose correct contraction order optimizer in tests #16

Merged
merged 6 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Manifest.toml
*.jl.cov
*.jl.mem
/docs/build/
.vscode/
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ julia = "1.3"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Documenter"]
test = ["Test", "Documenter", "KaHyPar"]
49 changes: 26 additions & 23 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ Probabilistic modeling with a tensor network.
* `tensors` is the tensors fed into the tensor network.
* `fixedvertices` is a dictionary to specifiy degree of freedoms fixed to certain values.
"""
struct TensorNetworkModeling{LT,ET,MT<:AbstractArray}
struct TensorNetworkModel{LT,ET,MT<:AbstractArray}
vars::Vector{LT}
code::ET
tensors::Vector{MT}
fixedvertices::Dict{LT,Int}
end

function Base.show(io::IO, tn::TensorNetworkModeling)
function Base.show(io::IO, tn::TensorNetworkModel)
open = getiyv(tn.code)
variables = join([string_var(var, open, tn.fixedvertices) for var in tn.vars], ", ")
tc, sc, rw = timespacereadwrite_complexity(tn)
println(io, "$(typeof(tn))")
println(io, "variables: $variables")
print_tcscrw(io, tc, sc, rw)
end
Base.show(io::IO, ::MIME"text/plain", tn::TensorNetworkModeling) = Base.show(io, tn)
Base.show(io::IO, ::MIME"text/plain", tn::TensorNetworkModel) = Base.show(io, tn)
function string_var(var, open ,fixedvertices)
if var ∈ open && haskey(fixedvertices, var)
"$var (open, fixed to $(fixedvertices[var]))"
Expand All @@ -83,61 +83,62 @@ end
"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModeling(instance::UAIInstance; openvertices=(), optimizer=GreedyMethod(), simplifier=nothing)::TensorNetworkModeling
return TensorNetworkModeling(1:instance.nvars, instance.factors; fixedvertices=Dict(zip(instance.obsvars, instance.obsvals .- 1)), optimizer, simplifier, openvertices)
function TensorNetworkModel(instance::UAIInstance; openvertices=(), optimizer=GreedyMethod(), simplifier=nothing)::TensorNetworkModel
return TensorNetworkModel(1:instance.nvars, instance.cards, instance.factors; fixedvertices=Dict(zip(instance.obsvars, instance.obsvals .- 1)), optimizer, simplifier, openvertices)
end

"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModeling(vars::AbstractVector{LT}, factors::Vector{<:Factor{T}}; openvertices=(), fixedvertices=Dict{LT,Int}(), optimizer=GreedyMethod(), simplifier=nothing)::TensorNetworkModeling where {T,LT}
function TensorNetworkModel(vars::AbstractVector{LT}, cards::AbstractVector{Int}, factors::Vector{<:Factor{T}}; openvertices=(), fixedvertices=Dict{LT,Int}(), optimizer=GreedyMethod(), simplifier=nothing)::TensorNetworkModel where {T,LT}
# The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
# The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
# e.g.
# `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
rawcode = EinCode([[[var] for var in vars]..., [[factor.vars...] for factor in factors]...], collect(LT, openvertices)) # labels for vertex tensors (unity tensors) and edge tensors
tensors = [[ones(T, 2) for _=1:length(vars)]..., getfield.(factors, :vals)...]
return TensorNetworkModeling(collect(LT, vars), rawcode, tensors; fixedvertices, optimizer, simplifier)
tensors = Array{T}[[ones(T, cards[i]) for i=1:length(vars)]..., [t.vals for t in factors]...]
mroavi marked this conversation as resolved.
Show resolved Hide resolved
return TensorNetworkModel(collect(LT, vars), rawcode, tensors; fixedvertices, optimizer, simplifier)
end
"""
$(TYPEDSIGNATURES)
"""
function TensorNetworkModeling(vars::AbstractVector{LT}, rawcode::EinCode, tensors::Vector{<:AbstractArray}; fixedvertices=Dict{LT,Int}(), optimizer=GreedyMethod(), simplifier=nothing)::TensorNetworkModeling where LT
function TensorNetworkModel(vars::AbstractVector{LT}, rawcode::EinCode, tensors::Vector{<:AbstractArray}; fixedvertices=Dict{LT,Int}(), optimizer=GreedyMethod(), simplifier=nothing)::TensorNetworkModel where LT
# `optimize_code` optimizes the contraction order of a raw tensor network without a contraction order specified.
# The 1st argument is the contraction pattern to be optimized (without contraction order).
# The 2nd arugment is the size dictionary, which is a label-integer dictionary.
# The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
code = optimize_code(rawcode, OMEinsum.get_size_dict(getixsv(rawcode), tensors), optimizer, simplifier)
TensorNetworkModeling(collect(LT, vars), code, tensors, fixedvertices)
size_dict = OMEinsum.get_size_dict(getixsv(rawcode), tensors)
code = optimize_code(rawcode, size_dict, optimizer, simplifier)
TensorNetworkModel(collect(LT, vars), code, tensors, fixedvertices)
end

"""
$(TYPEDSIGNATURES)

Get the variables in this tensor network, they are also known as legs, labels, or degree of freedoms.
"""
get_vars(tn::TensorNetworkModeling)::Vector = tn.vars
get_vars(tn::TensorNetworkModel)::Vector = tn.vars

"""
$(TYPEDSIGNATURES)

Get the cardinalities of variables in this tensor network.
"""
function get_cards(tn::TensorNetworkModeling; fixedisone=false)::Vector
function get_cards(tn::TensorNetworkModel; fixedisone=false)::Vector
vars = get_vars(tn)
[fixedisone && haskey(tn.fixedvertices, vars[k]) ? 1 : length(tn.tensors[k]) for k=1:length(vars)]
end

chfixedvertices(tn::TensorNetworkModeling, fixedvertices) = TensorNetworkModeling(tn.vars, tn.code, tn.tensors, fixedvertices)
chfixedvertices(tn::TensorNetworkModel, fixedvertices) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, fixedvertices)

"""
$(TYPEDSIGNATURES)

Evaluate the probability of `config`.
Evaluate the log probability of `config`.
"""
function probability(tn::TensorNetworkModeling, config)::Real
assign = Dict(zip(get_vars(tn), config .+ 1))
return mapreduce(x->x[2][getindex.(Ref(assign), x[1])...], *, zip(getixsv(tn.code), tn.tensors))
function log_probability(tn::TensorNetworkModel, config::Union{Dict, AbstractVector})::Real
assign = config isa AbstractVector ? Dict(zip(get_vars(tn), config)) : config
return sum(x->log(x[2][(getindex.(Ref(assign), x[1]) .+ 1)...]), zip(getixsv(tn.code), tn.tensors))
end

"""
Expand All @@ -146,11 +147,13 @@ $(TYPEDSIGNATURES)
Contract the tensor network and return a probability array with its rank specified in the contraction code `tn.code`.
The returned array may not be l1-normalized even if the total probability is l1-normalized, because the evidence `tn.fixedvertices` may not be empty.
"""
function probability(tn::TensorNetworkModeling; usecuda=false)::AbstractArray
return tn.code(generate_tensors(tn; usecuda)...)
function probability(tn::TensorNetworkModel; usecuda=false, rescale=true)::AbstractArray
return tn.code(adapt_tensors(tn; usecuda, rescale)...)
end

function OMEinsum.timespacereadwrite_complexity(tn::TensorNetworkModeling)
return timespacereadwrite_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone=true))))
function OMEinsum.contraction_complexity(tn::TensorNetworkModel)
return contraction_complexity(tn.code, Dict(zip(get_vars(tn), get_cards(tn; fixedisone=true))))
end
OMEinsum.timespace_complexity(tn::TensorNetworkModeling) = timespacereadwrite_complexity(tn)[1:2]

# adapt array type with the target array type
match_arraytype(::Type{<:Array{T, N}}, target::AbstractArray{T, N}) where {T, N} = Array(target)
47 changes: 47 additions & 0 deletions src/RescaledArray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
$(TYPEDEF)
RescaledArray(α, T) -> RescaledArray

An array data type with a log-prefactor, and a l∞-normalized storage, i.e. the maximum element in a tensor is 1.
This tensor type can avoid the potential underflow/overflow of numbers in a tensor network.
The constructor `RescaledArray(α, T)` creates a rescaled array that equal to `exp(α) * T`.
"""
struct RescaledArray{T, N, AT<:AbstractArray{T, N}} <: AbstractArray{T, N}
log_factor::T
normalized_value::AT
end
Base.show(io::IO, c::RescaledArray) = print(io, "exp($(c.log_factor)) * $(c.normalized_value)")
Base.show(io::IO, ::MIME"text/plain", c::RescaledArray) = Base.show(io, c)
Base.Array(c::RescaledArray) = rmul!(Array(c.normalized_value), exp(c.log_factor))
Base.copy(c::RescaledArray) = RescaledArray(c.log_factor, copy(c.normalized_value))

"""
$(TYPEDSIGNATURES)

Returns a rescaled array that equivalent to the input tensor.
"""
function rescale_array(tensor::AbstractArray{T})::RescaledArray where T
maxf = maximum(tensor)
if iszero(maxf)
@warn("The maximum value of the array to rescale is 0!")
return RescaledArray(zero(T), tensor)
end
return RescaledArray(log(maxf), OMEinsum.asarray(tensor ./ maxf, tensor))
end

for CT in [:DynamicEinCode, :StaticEinCode]
@eval function OMEinsum.einsum(code::$CT, @nospecialize(xs::NTuple{N,RescaledArray}), size_dict::Dict) where N
# The following equality holds
# einsum(code, exp(α) * A, exp(β) * B, ...) = exp(α * β * ...) * einsum(code, A, B, ...)
# Hence the einsum is performed on the normalized values, and the factors are added later.
res = einsum(code, getfield.(xs, :normalized_value), size_dict)
rescaled = rescale_array(res)
# a new rescaled array, its factor is
return RescaledArray(sum(x->x.log_factor, xs) + rescaled.log_factor, rescaled.normalized_value)
end
end
mroavi marked this conversation as resolved.
Show resolved Hide resolved

Base.size(arr::RescaledArray) = size(arr.normalized_value)
Base.size(arr::RescaledArray, i::Int) = size(arr.normalized_value, i)

match_arraytype(::Type{<:RescaledArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = rescale_array(target)
6 changes: 4 additions & 2 deletions src/TensorInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@ using DocStringExtensions, TropicalNumbers
using Artifacts

# reexport OMEinsum functions
export RescaledArray
export timespace_complexity, timespacereadwrite_complexity, TreeSA, GreedyMethod, KaHyParBipartite, SABipartite, MergeGreedy, MergeVectors

# read and load uai files
export read_uai_file, read_td_file, read_uai_evid_file, read_uai_mar_file, read_uai_problem

# marginals
export TensorNetworkModeling, get_vars, get_cards, probability, marginals
export TensorNetworkModel, get_vars, get_cards, log_probability, probability, marginals

# MAP
export most_probable_config, maximum_logp

# MMAP
export MMAPModeling
export MMAPModel

include("Core.jl")
include("RescaledArray.jl")
include("utils.jl")
include("inference.jl")
include("maxprob.jl")
Expand Down
5 changes: 4 additions & 1 deletion src/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ function onehot_like(A::CuArray, j)
mask = zero(A)
CUDA.@allowscalar mask[j] = one(eltype(mask))
return mask
end
end

# NOTE: this interface should be in OMEinsum
match_arraytype(::Type{<:CuArray{T, N}}, target::AbstractArray{T, N}) where {T, N} = CuArray(target)
27 changes: 17 additions & 10 deletions src/inference.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# generate tensors based on which vertices are fixed.
generate_tensors(gp::TensorNetworkModeling; usecuda) = generate_tensors(gp.code, gp.tensors, gp.fixedvertices; usecuda)
function generate_tensors(code, tensors, fixedvertices; usecuda)
isempty(fixedvertices) && return tensors
adapt_tensors(gp::TensorNetworkModel; usecuda, rescale) = adapt_tensors(gp.code, gp.tensors, gp.fixedvertices; usecuda, rescale)
function adapt_tensors(code, tensors, fixedvertices; usecuda, rescale)
ixs = getixsv(code)
# `ix` is the vector of labels (or a degree of freedoms) for a tensor,
# if a label in `ix` is fixed to a value, do the slicing to the tensor it associates to.
map(tensors, ixs) do t, ix
dims = map(ixi->ixi ∉ keys(fixedvertices) ? Colon() : (fixedvertices[ixi]+1:fixedvertices[ixi]+1), ix)
usecuda ? CuArray(t[dims...]) : t[dims...]
t2 = t[dims...]
t3 = usecuda ? CuArray(t2) : t2
rescale ? rescale_array(t3) : t3
end
end

# ######### Inference by back propagation ############
# `CacheTree` stores intermediate `NestedEinsum` contraction results.
# It is a tree structure that isomorphic to the contraction tree,
# `siblings` are the siblings of current node.
# `content` is the cached intermediate contraction result.
# `siblings` are the siblings of current node.
struct CacheTree{T}
content::AbstractArray{T}
siblings::Vector{CacheTree{T}}
Expand Down Expand Up @@ -60,7 +61,7 @@ function generate_gradient_tree(code::NestedEinsum, cache::CacheTree{T}, dy::Abs
if OMEinsum.isleaf(code)
return CacheTree(dy, CacheTree{T}[])
else
xs = (getfield.(cache.siblings, :content)...,)
xs = ntuple(i->cache.siblings[i].content, length(cache.siblings))
# `einsum_grad` is the back-propagation rule for einsum function.
# If the forward pass is `y = einsum(EinCode(inputs_labels, output_labels), (A, B, ...), size_dict)`
# Then the back-propagation pass is
Expand All @@ -87,7 +88,7 @@ function gradient_tree(code, xs)
# forward compute and cache intermediate results.
cache = cached_einsum(code, xs, size_dict)
# initialize `y̅` as `1`. Note we always start from `L̅ := 1`.
dy = fill!(similar(cache.content), one(eltype(cache.content)))
dy = match_arraytype(typeof(cache.content), ones(eltype(cache.content), size(cache.content)))
# back-propagate
return copy(cache.content), generate_gradient_tree(code, cache, dy, size_dict)
end
Expand Down Expand Up @@ -125,8 +126,14 @@ $(TYPEDSIGNATURES)
Returns the marginal probability distribution of variables.
One can use `get_vars(tn)` to get the full list of variables in this tensor network.
"""
function marginals(tn::TensorNetworkModeling; usecuda=false)::Vector
function marginals(tn::TensorNetworkModel; usecuda=false, rescale=true)::Vector
vars = get_vars(tn)
_, grads = cost_and_gradient(tn.code, generate_tensors(tn; usecuda))
return LinearAlgebra.normalize!.(grads[1:length(vars)], 1)
# sometimes, the cost can overflow, then we need to rescale the tensors during contraction.
cost, grads = cost_and_gradient(tn.code, adapt_tensors(tn; usecuda, rescale))
@debug "cost = $cost"
if rescale
return LinearAlgebra.normalize!.(getfield.(grads[1:length(vars)], :normalized_value), 1)
else
return LinearAlgebra.normalize!.(grads[1:length(vars)], 1)
end
end
8 changes: 4 additions & 4 deletions src/maxprob.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ $(TYPEDSIGNATURES)

Returns the largest log-probability and the most probable configuration.
"""
function most_probable_config(tn::TensorNetworkModeling; usecuda=false)::Tuple{Tropical,Vector}
function most_probable_config(tn::TensorNetworkModel; usecuda=false)::Tuple{Tropical,Vector}
vars = get_vars(tn)
tensors = map(t->Tropical.(log.(t)), generate_tensors(tn; usecuda))
tensors = map(t->Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale=false))
logp, grads = cost_and_gradient(tn.code, tensors)
# use Array to convert CuArray to CPU arrays
return Array(logp)[], map(k->haskey(tn.fixedvertices, vars[k]) ? tn.fixedvertices[vars[k]] : argmax(grads[k]) - 1, 1:length(vars))
Expand All @@ -57,8 +57,8 @@ $(TYPEDSIGNATURES)

Returns an output array containing largest log-probabilities.
"""
function maximum_logp(tn::TensorNetworkModeling; usecuda=false)::AbstractArray{<:Tropical}
function maximum_logp(tn::TensorNetworkModel; usecuda=false)::AbstractArray{<:Tropical}
# generate tropical tensors with its elements being log(p).
tensors = map(t->Tropical.(log.(t)), generate_tensors(tn; usecuda))
tensors = map(t->Tropical.(log.(t)), adapt_tensors(tn; usecuda, rescale=false))
return tn.code(tensors...)
end
Loading