Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling configurations #28

Merged
merged 10 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.8'
- '1'
#- 'nightly'
- 'nightly'
os:
- ubuntu-latest
arch:
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TropicalGEMM = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"

[compat]
Artifacts = "1"
CUDA = "4"
DocStringExtensions = "0.8.6, 0.9"
OMEinsum = "0.7"
Requires = "1"
PrecompileTools = "1"
Requires = "1"
StatsBase = "0.34"
TropicalGEMM = "0.1"
TropicalNumbers = "0.5.4"
julia = "1.3"
6 changes: 6 additions & 0 deletions example/asia/asia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ probability(tnet)
# Get the marginal probabilities (MAR)
marginals(tnet) .|> first

# The corresponding variables are
get_vars(tnet)

# Set the evidence variables "X-ray" (7) to be positive.
set_evidence!(instance, 7=>0)

Expand All @@ -19,6 +22,9 @@ tnet = TensorNetworkModel(instance)
# Get the maximum log-probabilities (MAP)
maximum_logp(tnet)

# To sample from the probability model
sample(tnet, 10)

# Get not only the maximum log-probability, but also the most probable conifguration
# In the most probable configuration, the most probable one is the patient smoke (3) and has lung cancer (4)
logp, cfg = most_probable_config(tnet)
Expand Down
7 changes: 6 additions & 1 deletion src/TensorInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using DocStringExtensions, TropicalNumbers
using Artifacts
# The Tropical GEMM support
using TropicalGEMM
using StatsBase

# reexport OMEinsum functions
export RescaledArray
Expand All @@ -20,6 +21,9 @@ export TensorNetworkModel, get_vars, get_cards, log_probability, probability, ma
# MAP
export most_probable_config, maximum_logp

# sampling
export sample

# MMAP
export MMAPModel

Expand All @@ -29,6 +33,7 @@ include("utils.jl")
include("inference.jl")
include("maxprob.jl")
include("mmap.jl")
include("sampling.jl")

using Requires
function __init__()
Expand All @@ -40,7 +45,7 @@ PrecompileTools.@setup_workload begin
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# precompile file and potentially make loading faster.
#PrecompileTools.@compile_workload begin
#include("../example/asia/asia.jl")
# include("../example/asia/asia.jl")
#end
end

Expand Down
113 changes: 113 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
############ Sampling ############
"""
$TYPEDEF

### Fields
$TYPEDFIELDS

The sampled configurations are stored in `samples`, which is a vector of vector.
`labels` is a vector of variable names for labeling configurations.
The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
"""
struct Samples{L}
samples::Vector{Vector{Int}}
labels::Vector{L}
setmask::BitVector
end
function setmask!(samples::Samples, eliminated_variables)
for var in eliminated_variables
loc = findfirst(==(var), samples.labels)
samples.setmask[loc] && error("varaible `$var` is already eliminated.")
samples.setmask[loc] = true
end
return samples
end

idx4labels(totalset, labels) = map(v->findfirst(==(v), totalset), labels)

"""
$(TYPEDSIGNATURES)

The backward process for sampling configurations.

* `ixs` and `xs` are labels and tensor data for input tensors,
* `iy` and `y` are labels and tensor data for the output tensor,
* `samples` is the samples generated for eliminated variables,
* `size_dict` is a key-value map from tensor label to dimension size.
"""
function backward_sampling!(ixs, @nospecialize(xs::Tuple), iy, @nospecialize(y), samples::Samples, size_dict)
eliminated_variables = setdiff(vcat(ixs...), iy)
eliminated_locs = idx4labels(samples.labels, eliminated_variables)
setmask!(samples, eliminated_variables)

# the contraction code to get probability
newiy = eliminated_variables
iy_in_sample = idx4labels(samples.labels, iy)
slice_y_dim = collect(1:length(iy))
newixs = map(ix->setdiff(ix, iy), ixs)
ix_in_sample = map(ix->idx4labels(samples.labels, ix ∩ iy), ixs)
slice_xs_dim = map(ix->idx4labels(ix, ix ∩ iy), ixs)
code = DynamicEinCode(newixs, newiy)

totalset = CartesianIndices((map(x->size_dict[x], eliminated_variables)...,))
for (i, sample) in enumerate(samples.samples)
newxs = [get_slice(x, dimx, sample[ixloc]) for (x, dimx, ixloc) in zip(xs, slice_xs_dim, ix_in_sample)]
newy = get_element(y, slice_y_dim, sample[iy_in_sample])
probabilities = einsum(code, (newxs...,), size_dict) / newy
config = StatsBase.sample(totalset, Weights(vec(probabilities)))
# update the samples
samples.samples[i][eliminated_locs] .= config.I .- 1
end
return samples
end

# type unstable
function get_slice(x, dim, config)
asarray(x[[i ∈ dim ? config[findfirst(==(i), dim)]+1 : Colon() for i in 1:ndims(x)]...], x)
end
function get_element(x, dim, config)
x[[config[findfirst(==(i), dim)]+1 for i in 1:ndims(x)]...]
end

"""
$(TYPEDSIGNATURES)

Generate samples from a tensor network based probabilistic model.
Returns a vector of vector, each element being a configurations defined on `get_vars(tn)`.

### Arguments
* `tn` is the tensor network model.
* `n` is the number of samples to be returned.
"""
function sample(tn::TensorNetworkModel, n::Int; usecuda = false)::Vector{Vector{Int}}
# generate tropical tensors with its elements being log(p).
xs = adapt_tensors(tn; usecuda, rescale = false)
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}())
# forward compute and cache intermediate results.
cache = cached_einsum(tn.code, xs, size_dict)
# initialize `y̅` as the initial batch of samples.
labels = get_vars(tn)
iy = getiyv(tn.code)
setmask = falses(length(labels))
idx = map(l->findfirst(==(l), labels), iy)
setmask[idx] .= true
indices = StatsBase.sample(CartesianIndices(size(cache.content)), Weights(normalize!(vec(LinearAlgebra.normalize!(cache.content)))), n)
configs = map(indices) do ind
c=zeros(Int, length(labels))
c[idx] .= ind.I .- 1
c
end
samples = Samples(configs, labels, setmask)
# back-propagate
generate_samples(tn.code, cache, samples, size_dict)
return samples.samples
end

function generate_samples(code::NestedEinsum, cache::CacheTree{T}, samples, size_dict::Dict) where {T}
if !OMEinsum.isleaf(code)
xs = ntuple(i -> cache.siblings[i].content, length(cache.siblings))
backward_sampling!(OMEinsum.getixs(code.eins), xs, OMEinsum.getiy(code.eins), cache.content, samples, size_dict)
generate_samples.(code.args, cache.siblings, Ref(samples), Ref(size_dict))
end
end
14 changes: 11 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ The UAI file formats are defined in:
https://personal.utdallas.edu/~vibhav.gogate/uai16-evaluation/uaiformat.html
"""
function read_uai_file(uai_filepath; factor_eltype = Float64)

# Read the uai file into an array of lines
rawlines = open(uai_filepath) do file
readlines(file)
str = open(uai_filepath) do file
read(file, String)
end
return read_uai_string(str; factor_eltype)
end

function read_uai_string(str; factor_eltype = Float64)
rawlines = split(str, "\n")
# Filter out empty lines
lines = filter(!isempty, rawlines)

Expand Down Expand Up @@ -193,5 +196,10 @@ function uai_problem_from_file(uai_filepath::String; uai_evid_filepath="", uai_m
return UAIInstance(nvars, ncliques, cards, factors, obsvars, obsvals, reference_marginals)
end

function uai_problem_from_string(uai::String; eltype=Float64)::UAIInstance
nvars, cards, ncliques, factors = read_uai_string(uai; factor_eltype = eltype)
return UAIInstance(nvars, ncliques, cards, factors, Int[], Int[], Vector{eltype}[])
end

# patch to get content by broadcasting into array, while keep array size unchanged.
broadcasted_content(x) = asarray(content.(x), x)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ end
@testset "MMAP" begin
include("mmap.jl")
end
@testset "MMAP" begin
include("sampling.jl")
end

using CUDA
if CUDA.functional()
Expand Down
55 changes: 55 additions & 0 deletions test/sampling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using TensorInference, Test

@testset "sampling" begin
instance = TensorInference.uai_problem_from_string("""MARKOV
8
2 2 2 2 2 2 2 2
8
1 0
2 1 0
1 2
2 3 2
2 4 2
3 5 3 1
2 6 5
3 7 5 4

2
0.01
0.99

4
0.05 0.01
0.95 0.99

2
0.5
0.5

4
0.1 0.01
0.9 0.99

4
0.6 0.3
0.4 0.7

8
1 1 1 0
0 0 0 1

4
0.98 0.05
0.02 0.95

8
0.9 0.7 0.8 0.1
0.1 0.3 0.2 0.9
""")
n = 10000
tnet = TensorNetworkModel(instance)
samples = sample(tnet, n)
mars = getindex.(marginals(tnet), 2)
mars_sample = [count(s->s[k]==(1), samples) for k=1:8] ./ n
@test isapprox(mars, mars_sample, atol=0.05)
end