From dd5b89b4940b6589c5030f5d9679e5d338483807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Wien=C3=B6bst?= Date: Wed, 25 Jan 2023 12:37:27 +0100 Subject: [PATCH 1/3] initial commit for mec algorithms --- src/cpdag.jl | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/cpdag.jl b/src/cpdag.jl index 79fa1dac..16fc30e2 100644 --- a/src/cpdag.jl +++ b/src/cpdag.jl @@ -81,6 +81,59 @@ function alt_cpdag(g) return nvertexdigraphfromedgelist(nv(g), edgelist) end +# TODOs: +# - add function for checking whether a graph is a CDPAG +# - add function for checking whether a PDAG has a DAG extension +function hasdirectedcycle(G) + # TODO: implement + return true +end + +function ismaximallyoriented(G) + # TODO: implement + return true +end + +function checkundirectedcomps(G) + # TODO: implement + return true, true +end + +function isstronglyprotected(G) + # TODO: implement + return true +end + +""" + classifygraph(G) +Classifies a graph efficiently into the following classes: +- "cyclic": has a directed cycle (the graph can also satisfy some of the other criteria) +- "not maximally oriented": there is an undirected edge whose direction would follow from one of the Meek rules (Meek 1995) +- "not extendable": there is no consistent DAG extension of this graph +- "CPDAG": the graph is a CPDAG +- "MPDAG, counting works" and "MPDAG, counting not yet implemented": in both cases the graph is an MPDAG, however, in the second case it has not enough structure for the counting algorithm to work (TODO: add more details) +""" + +function classifygraph(G) + hasdirectedcycle(G) && return "cyclic" + ismaximallyoriented(G) && return "not maximally oriented" + areinduced, arechordal = checkundirectedcomps(G) + !arechordal && return "not extendable" + if areinduced + if isstronglyprotected(G) + return "CPDAG" + else + return "MPDAG, counting works" + end + else + return "MPDAG, counting not yet implemented" + end +end + +function consistent_extension(G) + # TODO: implement Dor-Tarsi algorithm +end + """ ordered_edges(dag) From 76e0d1020486647f2864a18fa43b1a48305fc92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Wien=C3=B6bst?= Date: Mon, 7 Aug 2023 12:50:17 +0200 Subject: [PATCH 2/3] rebase --- Project.toml | 14 +++- src/CausalInference.jl | 2 + src/chordal.jl | 165 +++++++++++++++++++++++++++++++++++++++++ src/cpdag.jl | 125 +++++++++++++++++++++++++++++++ 4 files changed, 302 insertions(+), 4 deletions(-) create mode 100644 src/chordal.jl diff --git a/Project.toml b/Project.toml index f3ac5e42..b17e0328 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" +LinkedLists = "70f5e60a-1556-5f34-a19e-a48b3e4aaee9" MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -22,6 +23,15 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TabularDisplay = "3eeacb1d-13c2-54cc-9b18-30c86af3cadb" ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" +[weakdeps] +GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742" + +[extensions] +GraphRecipesExt = ["GraphRecipes", "Plots"] +TikzGraphsExt = "TikzGraphs" + [compat] Combinatorics = "1.0" DelimitedFiles = "1.6, 1.7, 1.8, 1.9" @@ -43,10 +53,6 @@ ThreadsX = "0.1" TikzGraphs = "1.3, 1.4" julia = "1.6, 1.7, 1.8, 1.9" -[extensions] -GraphRecipesExt = ["GraphRecipes", "Plots"] -TikzGraphsExt = "TikzGraphs" - [extras] DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" diff --git a/src/CausalInference.jl b/src/CausalInference.jl index 1d4b70ea..29456bc6 100644 --- a/src/CausalInference.jl +++ b/src/CausalInference.jl @@ -4,6 +4,7 @@ using Graphs using Graphs.SimpleGraphs using Combinatorics using Base.Iterators +using LinkedLists using Memoization, LRUCache using ThreadsX @@ -44,6 +45,7 @@ include("backdoor.jl") include("ges.jl") include("gensearch.jl") include("workloads.jl") +include("chordal.jl") # Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695) const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) diff --git a/src/chordal.jl b/src/chordal.jl new file mode 100644 index 00000000..1008c682 --- /dev/null +++ b/src/chordal.jl @@ -0,0 +1,165 @@ +import LinkedLists + +""" + ischordal(g) + +Return true if the given graph is chordal +""" +function ischordal(G) + mcsorder, invmcsorder, _ = mcs(G, Set()) + + n = length(mcsorder) + + f = zeros(Int, n) + index = zeros(Int, n) + for i=n:-1:1 + w = mcsorder[i] + f[w] = w + index[w] = i + for v in neighbors(G, w) + if invmcsorder[v] > i + index[v] = i + if f[v] == v + f[v] = w + end + end + end + for v in neighbors(G, w) + if invmcsorder[v] > i + if index[f[v]] > i + return false + end + end + end + end + return true +end + +function cliquetreefrommcs(G, mcsorder, invmcsorder) + n = nv(G) + # data structures for the algorithm + K = Vector{Set}() + push!(K, Set()) + s = 1 + edgelist = Set{Edge}() + visited = falses(n) + clique = zeros(Int, n) + + for i = 1:n + x = mcsorder[i] + S = Set{Int}() + for w in inneighbors(G, x) + if visited[w] + push!(S, w) + end + end + + # if necessary create new maximal clique + if K[s] != S + s += 1 + push!(K, S) + k, _ = findmax(map(x -> invmcsorder[x], collect(S))) + p = clique[mcsorder[k]] + push!(edgelist, Edge(p, s)) + end + + union!(K[s], x) + clique[x] = s + visited[x] = true; + end + + T = SimpleGraphFromIterator(edgelist) + # ensure graph is not empty + nv(T) == 0 && add_vertices!(T,1) + return K, T +end + +@inline function vispush!(l::LinkedList, pointers, x, vis) + if vis + pointers[x] = push!(l,x) + else + pointers[x] = pushfirst!(l,x) + end +end + +# TODO: separate mcs and mcs plus cgk?? +# Returns the visit order of the vertices, its inverse and the subgraphs C_G(K) (see Def. 1 in [1,2]). If K is empty a normal MCS is performed. +function mcs(G, K) + n = nv(G) + copy_K = copy(K) + + # data structures for MCS + sets = [LinkedList{Int}() for _ = 1:n+1] + pointers = Vector(undef,n) + size = Vector{Int}(undef, n) + visited = falses(n) + + # output data structures + mcsorder = Vector{Int}(undef, n) + invmcsorder = Vector{Int}(undef, n) + subgraphs = Array[] + + # init + visited[collect(copy_K)] .= true + for v in vertices(G) + size[v] = 1 + vispush!(sets[1], pointers, v, visited[v]) + end + maxcard = 1 + + for i = 1:n + # first, the vertices in K are chosen + # they are always in the set of maximum cardinality vertices + if !isempty(copy_K) + v = pop!(copy_K) + # afterwards, the algorithm chooses any vertex from maxcard + else + v = first(sets[maxcard]) + end + # v is the ith vertex in the mcsorder + mcsorder[i] = v + invmcsorder[v] = i + size[v] = -1 + + # immediately append possible subproblems to the output + if !visited[v] + vertexset = Vector{Int}() + for x in sets[maxcard] + visited[x] && break + visited[x] = true + push!(vertexset, x) + end + sg = induced_subgraph(G, vertexset) + subgraphs = vcat(subgraphs, (map(x -> sg[2][x], connected_components(sg[1])))) + end + + deleteat!(sets[maxcard], pointers[v]) + + # update the neighbors + for w in inneighbors(G, v) + if size[w] >= 1 + deleteat!(sets[size[w]], pointers[w]) + size[w] += 1 + vispush!(sets[size[w]], pointers, w, visited[w]) + end + end + maxcard += 1 + while maxcard >= 1 && isempty(sets[maxcard]) + maxcard -= 1 + end + end + + return mcsorder, invmcsorder, subgraphs +end + +""" + cliquetree(G) + +Computes a clique tree of a graph G. A vector K of maximal cliques and a tree T on 1,2,...,|K| is returned. + +""" +function cliquetree(G) + mcsorder, invmcsorder, _ = mcs(G, Set()) + K, T = cliquetreefrommcs(G, mcsorder, invmcsorder) + return K, T +end diff --git a/src/cpdag.jl b/src/cpdag.jl index 16fc30e2..e8e1f861 100644 --- a/src/cpdag.jl +++ b/src/cpdag.jl @@ -1,5 +1,130 @@ import Base: iterate, length +# replace by struct TODO +function fac(n, fmemo) + fmemo[n] != 0 && return fmemo[n] + n == 1 && return BigInt(1) + res = fac(n-1, fmemo) * n + return fmemo[n] = res +end + +function phi(cliquesize, i, fp, fmemo, pmemo) + pmemo[i] != 0 && return pmemo[i] + sum = fac(cliquesize-fp[i], fmemo) + for j = (i+1):length(fp) + sum -= fac(fp[j]-fp[i], fmemo) * phi(cliquesize, j, fp, fmemo, pmemo) + end + return pmemo[i] = sum +end + +# TODO: rename +function subproblems(G, K) + _, _, subgraphs = mcs(G, K) + return subgraphs +end + +function count(cc, memo, fmemo) + G = cc[1] # graph + mapping = cc[2] # mapping to original vertex numbers + n = nv(G) + + # check memoization table + mapG = Set(map(x -> mapping[x], vertices(G))) + haskey(memo, mapG) && return memo[mapG] + + # do bfs over the clique tree + # TODO: this is dfs! + K, T = cliquetree(G) + sum = BigInt(0) + Q = [1] + vis = falses(nv(T)) + vis[1] = true + pred = -1 * ones(Int, nv(T)) + while !isempty(Q) + v = pop!(Q) + for x in inneighbors(T, v) + if !vis[x] + push!(Q, x) + vis[x] = true + pred[x] = v + end + end + + # product of #AMOs for the subproblems + prod = BigInt(1) + for H in subproblems(G, K[v]) + HH = induced_subgraph(G, H) + prod *= count((HH[1], map(x -> mapping[x], HH[2])), memo, fmemo) + end + + # compute correction term phi + FP = [] + curr = v + curr_succ = -1 + intersect_pred = -1 + while pred[curr] != -1 + curr = pred[curr] + intersect_v = length(intersect(K[v], K[curr])) + if curr_succ != -1 + intersect_pred = length(intersect(K[curr], K[curr_succ])) + end + curr_succ = curr + if intersect_v == 0 + break + end + #if lastcut were strictly greater, v is not in bouquet + # defined by cut between curr and curr_succ + if intersect_v >= intersect_pred && (isempty(FP) || intersect_v < FP[end]) + push!(FP, intersect_v) + end + end + push!(FP, 0) + pmemo = zeros(BigInt, length(FP)) + sum += prod * phi(length(K[v]), 1, reverse(FP), fmemo, pmemo) + end + return memo[mapG] = sum +end + +""" + MECsize(G) + +Return the number of Markov equivalent DAGs in the class represented by CPDAG G. + +# Examples +```julia-repl +julia> G = readgraph("example.in", true) +{6, 22} directed simple Int64 graph +julia> MECsize(G) +54 +``` +""" +function MECsize(G) + n = nv(G) + memo = Dict{Set, BigInt}() #mapping set of vertices -> AMO sum + fmemo = zeros(BigInt, n) + U = copy(G) + U.ne = 0 + for i = 1:n + filter!(j->has_edge(G, j, i), U.fadjlist[i]) + filter!(j->has_edge(G, i, j), U.badjlist[i]) + U.ne += length(U.fadjlist[i]) + end + tres = 1 + for component in connected_components(U) + cc = induced_subgraph(U, component) + if !ischordal(cc[1]) + println("Undirected connected components are NOT chordal...Abort") + println("Are you sure the graph is a CPDAG?") + # is there anything more clever than just returning? + return + end + tres *= count(cc, memo, fmemo) + end + + return tres +end + + # REMARKS: # - implemented own topological sort temporarily as topological_sort_by_dfs from Julia Graphs appears to have an issue (quadratic run-time for sparse graphs) # - cpdag(g) potentially has O(m * sqrt(m)) worst-case run-time due to iterating over all w -> x, not only compelled ones. However, in contrast to topological_sort_by_dfs, quadratic run-time behaviour for sparse graphs (e.g. O(n*n*)/O(m*m)) does not seem to appear. From c551035ce0182eae80df87321275ea14b0bbcd1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20Wien=C3=B6bst?= Date: Mon, 7 Aug 2023 12:51:50 +0200 Subject: [PATCH 3/3] rebase --- src/CausalInference.jl | 2 +- src/cpdag.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/CausalInference.jl b/src/CausalInference.jl index 29456bc6..b9b268c3 100644 --- a/src/CausalInference.jl +++ b/src/CausalInference.jl @@ -13,7 +13,7 @@ import Base: ==, show export ancestors, descendants, alt_test_dsep, test_covariate_adjustment, alt_test_backdoor, find_dsep, find_min_dsep, find_covariate_adjustment, find_backdoor_adjustment, find_frontdoor_adjustment, find_min_covariate_adjustment, find_min_backdoor_adjustment, find_min_frontdoor_adjustment, list_dseps, list_covariate_adjustment, list_backdoor_adjustment, list_frontdoor_adjustment export dsep, skeleton, gausscitest, dseporacle, partialcor export unshielded, pcalg, vskel, vskel!, alt_vskel -export cpdag, alt_cpdag, meek_rules! +export cpdag, alt_cpdag, meek_rules!, MECsize export digraph, vpairs, skel_oracle, pc_oracle, randdag export cmitest, kl_entropy, kl_renyi, kl_mutual_information export kl_cond_mi, kl_perm_mi_test, kl_perm_cond_mi_test diff --git a/src/cpdag.jl b/src/cpdag.jl index e8e1f861..70cf3c30 100644 --- a/src/cpdag.jl +++ b/src/cpdag.jl @@ -1,6 +1,6 @@ import Base: iterate, length -# replace by struct TODO +# replace by struct? function fac(n, fmemo) fmemo[n] != 0 && return fmemo[n] n == 1 && return BigInt(1)