From 06baa48df9d358aeb280e07d0592d5b9d3605062 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Fri, 10 Feb 2023 03:20:11 -0500 Subject: [PATCH 1/9] 1. Put intron reasoning back in BOSS 2. Do BES for BOSS only if better mutation improves the score. 3. Took out isLegalPag() calls in Edgewise comparison; this hangs the interface for large graphs. 4. Removed path statistics from the Stats List comparison, as they hang the interface for large graph. 5. Changing default for ZSB risk bound to 0.1. 6. For BOSS1 and BOSS2, added an epsilon to score comparisons (default 1e-10). --- docs/manual/index.html | 2 +- .../cmu/tetradapp/editor/StatsListEditor.java | 18 ++-- .../algcomparison/examples/TestBoss.java | 12 +-- .../main/java/edu/cmu/tetrad/search/Boss.java | 30 ++++-- .../cmu/tetrad/search/SearchGraphUtils.java | 95 ++++++++++--------- 5 files changed, 84 insertions(+), 73 deletions(-) diff --git a/docs/manual/index.html b/docs/manual/index.html index 6fa5dcabf4..9667fc1f6f 100755 --- a/docs/manual/index.html +++ b/docs/manual/index.html @@ -7246,7 +7246,7 @@

zSRiskBound

  • Long Description: This is the probability of getting the true model if a correct model is discovered. Could underfit.
  • -
  • Default Value: 0.001
  • +
  • Default Value: 0.1
  • Lower Bound: 0
  • Upper Bound: 1
  • Value Type: Double
  • diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/StatsListEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/StatsListEditor.java index 2c94694f20..e9570a3726 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/StatsListEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/StatsListEditor.java @@ -172,15 +172,15 @@ private List statistics() { // Greg table - statistics.add(new AncestorPrecision()); - statistics.add(new AncestorRecall()); - statistics.add(new AncestorF1()); - statistics.add(new SemidirectedPrecision()); - statistics.add(new SemidirectedRecall()); - statistics.add(new SemidirectedPathF1()); - statistics.add(new NoSemidirectedPrecision()); - statistics.add(new NoSemidirectedRecall()); - statistics.add(new NoSemidirectedF1()); +// statistics.add(new AncestorPrecision()); +// statistics.add(new AncestorRecall()); +// statistics.add(new AncestorF1()); +// statistics.add(new SemidirectedPrecision()); +// statistics.add(new SemidirectedRecall()); +// statistics.add(new SemidirectedPathF1()); +// statistics.add(new NoSemidirectedPrecision()); +// statistics.add(new NoSemidirectedRecall()); +// statistics.add(new NoSemidirectedF1()); // statistics.add(new LegalPag()); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java index 192101af04..0ca19aa618 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java @@ -44,10 +44,10 @@ public class TestBoss { public static void main(String... args) { Parameters parameters = new Parameters(); - parameters.set(Params.NUM_RUNS, 10); + parameters.set(Params.NUM_RUNS, 1); parameters.set(Params.DIFFERENT_GRAPHS, true); - parameters.set(Params.NUM_MEASURES, 60); - parameters.set(Params.AVG_DEGREE, 6); + parameters.set(Params.NUM_MEASURES, 400); + parameters.set(Params.AVG_DEGREE, 20); parameters.set(Params.SAMPLE_SIZE, 1000); parameters.set(Params.COEF_LOW, 0); parameters.set(Params.COEF_HIGH, 1); @@ -58,7 +58,7 @@ public static void main(String... args) { parameters.set(Params.SEM_BIC_STRUCTURE_PRIOR, 0); parameters.set(Params.ALPHA, 1e-2); - parameters.set("verbose", false); + parameters.set("verbose", true); Statistics statistics = new Statistics(); statistics.add(new AdjacencyPrecision()); @@ -68,9 +68,9 @@ public static void main(String... args) { statistics.add(new ElapsedCpuTime()); Algorithms algorithms = new Algorithms(); - algorithms.add(new Fges(new SemBicScore())); +// algorithms.add(new Fges(new SemBicScore())); algorithms.add(new BDCE(new SemBicScore())); - algorithms.add(new BOSSDC(new SemBicScore())); +// algorithms.add(new BOSSDC(new SemBicScore())); Simulations simulations = new Simulations(); simulations.add(new SemSimulation(new RandomForward())); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java index 8f2e309f32..61555b516b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java @@ -36,6 +36,7 @@ public class Boss { private int numStarts = 1; private AlgType algType = AlgType.BOSS1; private boolean caching = true; + private double epsilon = 1e-10; public Boss(@NotNull IndependenceTest test, Score score) { this.test = test; @@ -103,17 +104,26 @@ public List bestOrder(@NotNull List order) { if (algType == AlgType.BOSS1) { betterMutation1(scorer); - besMutation(scorer); + + if (scorer.score() > s1 + epsilon) { + besMutation(scorer); + } } else if (algType == AlgType.BOSS2) { betterMutation2(scorer); - besMutation(scorer); + + if (scorer.score() > s1 + epsilon) { + besMutation(scorer); + } } else if (algType == AlgType.BOSS3) { betterMutationBryan(scorer); - besMutation(scorer); + + if (scorer.score() > s1 + epsilon) { + besMutation(scorer); + } } s2 = scorer.score(); - } while (s2 > s1 || (++count <= 5)); + } while (s2 > s1 + epsilon || (ensureMinimumCount && ++count <= 5)); if (this.scorer.score() > best) { best = this.scorer.score(); @@ -210,14 +220,14 @@ public void betterMutation1(@NotNull TeyssierScorer scorer) { for (int i = 1; i < scorer.size(); i++) { Node x = scorer.get(i); -// if (!introns1.contains(x)) continue; + if (!introns1.contains(x)) continue; for (int j = i - 1; j >= 0; j--) { if (!scorer.adjacent(scorer.get(j), x)) continue; tuck(x, j, scorer, range); - if (scorer.score() > bestScore || violatesKnowledge(scorer.getPi())) { + if (scorer.score() > bestScore + epsilon || violatesKnowledge(scorer.getPi())) { for (int l = range[0]; l <= range[1]; l++) { introns2.add(scorer.get(l)); } @@ -236,7 +246,7 @@ public void betterMutation1(@NotNull TeyssierScorer scorer) { if (verbose) { System.out.println(); } - } while (bestScore > originalScore); + } while (bestScore > originalScore + epsilon); } @@ -261,12 +271,12 @@ public void betterMutation2(@NotNull TeyssierScorer scorer) { double _sp = NEGATIVE_INFINITY; scorer.bookmark(); -// if (!introns1.contains(k)) continue; + if (!introns1.contains(k)) continue; for (int j = 0; j < scorer.size(); j++) { scorer.moveTo(k, j); - if (scorer.score() >= _sp) { + if (scorer.score() >= _sp + epsilon) { if (!violatesKnowledge(scorer.getPi())) { _sp = scorer.score(); scorer.bookmark(); @@ -292,7 +302,7 @@ public void betterMutation2(@NotNull TeyssierScorer scorer) { } s2 = scorer.score(); - } while (s2 > s1); + } while (s2 > s1 + epsilon); scorer.goToBookmark(1); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java index dcb12b600a..3efff08019 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java @@ -1818,56 +1818,57 @@ public static String graphComparisonString(String trueGraphName, Graph trueGraph if (!edge1.equals(edge2)) { incorrect.add(adj); - if (SearchGraphUtils.isLegalPag(trueGraph).isLegalPag() && SearchGraphUtils.isLegalPag(targetGraph).isLegalPag()) { - GraphUtils.addPagColoring(trueGraph); - GraphUtils.addPagColoring(targetGraph); - - if (edge2 == null) continue; - - if (GraphUtils.compatible(edge1, edge2)) { - compatible.add(edge1); - } else { - incompatible.add(edge1); - } - } +// if (SearchGraphUtils.isLegalPag(trueGraph).isLegalPag() && SearchGraphUtils.isLegalPag(targetGraph).isLegalPag()) { +// GraphUtils.addPagColoring(trueGraph); +// GraphUtils.addPagColoring(targetGraph); +// +// if (edge2 == null) continue; +// +// if (GraphUtils.compatible(edge1, edge2)) { +// compatible.add(edge1); +// } else { +// incompatible.add(edge1); +// } +// } } } - if (SearchGraphUtils.isLegalPag(trueGraph).isLegalPag() && SearchGraphUtils.isLegalPag(targetGraph).isLegalPag()) { - builder.append("\n\n" + "Edges incorrectly oriented (incompatible)"); - - if (incompatible.isEmpty()) { - builder.append("\n --NONE--"); - } else { - sort(incompatible); - - int j1 = 0; - - for (Edge adj : incompatible) { - Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2()); - Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2()); - if (edge1 == null || edge2 == null) continue; - builder.append("\n").append(++j1).append(". ").append(edge1).append(" ====> ").append(edge2); - } - } - - builder.append("\n\n" + "Edges incorrectly oriented (compatible)"); - - sort(compatible); - - if (compatible.isEmpty()) { - builder.append("\n --NONE--"); - } else { - int j1 = 0; - - for (Edge adj : compatible) { - Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2()); - Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2()); - if (edge1 == null || edge2 == null) continue; - builder.append("\n").append(++j1).append(". ").append(edge1).append(" ====> ").append(edge2); - } - } - } else { +// if (SearchGraphUtils.isLegalPag(trueGraph).isLegalPag() && SearchGraphUtils.isLegalPag(targetGraph).isLegalPag()) { +// builder.append("\n\n" + "Edges incorrectly oriented (incompatible)"); +// +// if (incompatible.isEmpty()) { +// builder.append("\n --NONE--"); +// } else { +// sort(incompatible); +// +// int j1 = 0; +// +// for (Edge adj : incompatible) { +// Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2()); +// Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2()); +// if (edge1 == null || edge2 == null) continue; +// builder.append("\n").append(++j1).append(". ").append(edge1).append(" ====> ").append(edge2); +// } +// } +// +// builder.append("\n\n" + "Edges incorrectly oriented (compatible)"); +// +// sort(compatible); +// +// if (compatible.isEmpty()) { +// builder.append("\n --NONE--"); +// } else { +// int j1 = 0; +// +// for (Edge adj : compatible) { +// Edge edge1 = trueGraph.getEdge(adj.getNode1(), adj.getNode2()); +// Edge edge2 = targetGraph.getEdge(adj.getNode1(), adj.getNode2()); +// if (edge1 == null || edge2 == null) continue; +// builder.append("\n").append(++j1).append(". ").append(edge1).append(" ====> ").append(edge2); +// } +// } +// } else + { builder.append("\n\n" + "Edges incorrectly oriented"); if (incorrect.isEmpty()) { From b1409c05be0b28ab44cb8e10ce73096715bdb7ea Mon Sep 17 00:00:00 2001 From: jdramsey Date: Fri, 10 Feb 2023 08:27:40 -0500 Subject: [PATCH 2/9] Removing intron reasoning from BOSS2, as it's not doing anything. --- .../main/java/edu/cmu/tetrad/search/Boss.java | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java index 61555b516b..b20fadea4e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java @@ -255,24 +255,14 @@ public void betterMutation2(@NotNull TeyssierScorer scorer) { scorer.bookmark(); double s1, s2; - Set introns1; - Set introns2; - - introns2 = new HashSet<>(scorer.getPi()); - do { s1 = scorer.score(); scorer.bookmark(1); - introns1 = introns2; - introns2 = new HashSet<>(); - for (Node k : scorer.getPi()) { double _sp = NEGATIVE_INFINITY; scorer.bookmark(); - if (!introns1.contains(k)) continue; - for (int j = 0; j < scorer.size(); j++) { scorer.moveTo(k, j); @@ -280,27 +270,18 @@ public void betterMutation2(@NotNull TeyssierScorer scorer) { if (!violatesKnowledge(scorer.getPi())) { _sp = scorer.score(); scorer.bookmark(); - - if (scorer.index(k) <= j) { - for (int m = scorer.index(k); m <= j; m++) { - introns2.add(scorer.get(m)); - } - } else if (scorer.index(k) > j) { - for (int m = j; m <= scorer.index(k); m++) { - introns2.add(scorer.get(m)); - } - } } } if (verbose) { - System.out.print("\rIndex = " + (j + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s")); + System.out.print("\rIndex = " + (scorer.index(k) + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s")); } } scorer.goToBookmark(); } + System.out.println(); s2 = scorer.score(); } while (s2 > s1 + epsilon); From 26a9dedfc05b792e86e25c43985375a42579dc9a Mon Sep 17 00:00:00 2001 From: jdramsey Date: Fri, 10 Feb 2023 09:32:02 -0500 Subject: [PATCH 3/9] Removing intron reasoning from BOSS2, as it's not doing anything. --- .../src/main/java/edu/cmu/tetrad/graph/Paths.java | 2 +- .../src/main/java/edu/cmu/tetrad/search/Boss.java | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java index 14700e1dba..a21704b564 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java @@ -20,7 +20,7 @@ public Paths(Graph graph) { } /** - * Returns a valid causal order for either a DAG or a CPDAG. + * Returns a valid causal order for either a DAG or a CPDAG. (bryanandrews) * @param initialOrder Variables in the order will be kept as close to this * initial order as possible, either the forward order * or the reverse order, depending on the next parameter. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java index b20fadea4e..0f6cda7907 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java @@ -90,7 +90,7 @@ public List bestOrder(@NotNull List order) { shuffle(order); } - this.start = MillisecondTimes.timeMillis(); + this.start = MillisecondTimes.timeMillis(); makeValidKnowledgeOrder(order); @@ -133,7 +133,7 @@ public List bestOrder(@NotNull List order) { this.scorer.score(bestPerm); - this.stop = MillisecondTimes.timeMillis(); + this.stop = MillisecondTimes.timeMillis(); if (this.verbose) { TetradLogger.getInstance().forceLogMessage("\nFinal " + algType + " order = " + this.scorer.getPi()); @@ -250,7 +250,6 @@ public void betterMutation1(@NotNull TeyssierScorer scorer) { } - public void betterMutation2(@NotNull TeyssierScorer scorer) { scorer.bookmark(); double s1, s2; @@ -274,14 +273,17 @@ public void betterMutation2(@NotNull TeyssierScorer scorer) { } if (verbose) { - System.out.print("\rIndex = " + (scorer.index(k) + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s")); + TetradLogger.getInstance().forceLogMessage("\rIndex = " + (scorer.index(k) + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s")); } } scorer.goToBookmark(); } - System.out.println(); + if (verbose) { + TetradLogger.getInstance().forceLogMessage("\n"); + } + s2 = scorer.score(); } while (s2 > s1 + epsilon); From a85958fc11b0bbeb2b97ff591c6e7f9081d823e2 Mon Sep 17 00:00:00 2001 From: bja43 Date: Fri, 10 Feb 2023 13:03:02 -0600 Subject: [PATCH 4/9] Updated BDC methods --- .../algorithm/oracle/cpdag/BOSSDC.java | 2 + .../java/edu/cmu/tetrad/search/BossDC.java | 64 +++++++++++++++++-- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BOSSDC.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BOSSDC.java index 8e6bfc9688..6355515eea 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BOSSDC.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/BOSSDC.java @@ -54,6 +54,8 @@ public Graph search(DataModel dataModel, Parameters parameters) { boss.setAlgType(Boss.AlgType.BOSS1); } else if (parameters.getInt(Params.BOSS_ALG) == 2) { boss.setAlgType(Boss.AlgType.BOSS2); + } else if (parameters.getInt(Params.BOSS_ALG) == 3) { + boss.setAlgType(Boss.AlgType.BOSS3); } else { throw new IllegalArgumentException("Unrecognized boss algorithm type."); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossDC.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossDC.java index 627665cd82..f9f3c7d1e7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossDC.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BossDC.java @@ -97,37 +97,41 @@ public void divide(@NotNull TeyssierScorer scorer, int a, int b, int c) { conquerRTL(scorer, a, b, c); } else if (algType == Boss.AlgType.BOSS2){ conquerLTR(scorer, a, b, c); + } else if (algType == Boss.AlgType.BOSS3){ + conquerMT(scorer, a, b, c); } } public void conquerRTL(@NotNull TeyssierScorer scorer, int a, int b, int c) { - double currentScore = scorer.score();; + double currentScore = scorer.score(); double bestScore = currentScore; scorer.bookmark(); for (int i = b; i < c; i++) { Node x = scorer.get(i); + Set ancestors = scorer.getAncestors(x); for (int j = (b-1); j >= a; j--) { if (!scorer.adjacent(scorer.get(j), x)) continue; - tuck(x, j, scorer); + tuck(x, j, scorer, ancestors); currentScore = scorer.score(); - if (currentScore > bestScore) { + if (currentScore > bestScore + 1e-10) { bestScore = currentScore; + ancestors = scorer.getAncestors(x); scorer.bookmark(); } } - if (currentScore < bestScore) { + if (currentScore < bestScore + 1e-10) { scorer.goToBookmark(); } } } public void conquerLTR(@NotNull TeyssierScorer scorer, int a, int b, int c) { - double currentScore = scorer.score();; + double currentScore = scorer.score(); double bestScore = currentScore; scorer.bookmark(); @@ -140,7 +144,7 @@ public void conquerLTR(@NotNull TeyssierScorer scorer, int a, int b, int c) { tuck(x, j, scorer); currentScore = scorer.score(); - if (currentScore > bestScore){ + if (currentScore > bestScore + 1e-10){ bestScore = currentScore; scorer.bookmark(); break; @@ -151,6 +155,42 @@ public void conquerLTR(@NotNull TeyssierScorer scorer, int a, int b, int c) { } } + public void conquerMT(@NotNull TeyssierScorer scorer, int a, int b, int c) { + double currentScore = scorer.score(); + double bestScore = currentScore; + scorer.bookmark(); + + for (int i = a; i < b; i++) { + Node x = scorer.get(i); + for (int j = (c-1); j >= b; j--) { + scorer.moveTo(x, j); + currentScore = scorer.score(); + if (currentScore > bestScore + 1e-10) { + bestScore = currentScore; + scorer.bookmark(); + } + } + if (currentScore < bestScore + 1e-10) { + scorer.goToBookmark(); + } + } + + for (int i = b; i < c; i++) { + Node x = scorer.get(i); + for (int j = (b-1); j >= a; j--) { + scorer.moveTo(x, j); + currentScore = scorer.score(); + if (currentScore > bestScore + 1e-10) { + bestScore = currentScore; + scorer.bookmark(); + } + } + if (currentScore < bestScore + 1e-10) { + scorer.goToBookmark(); + } + } + } + private void tuck(Node k, int j, TeyssierScorer scorer) { if (scorer.index(k) < j) return; Set ancestors = scorer.getAncestors(k); @@ -162,6 +202,16 @@ private void tuck(Node k, int j, TeyssierScorer scorer) { } } + private void tuck(Node k, int j, TeyssierScorer scorer, Set ancestors) { + if (scorer.index(k) < j) return; + + for (int i = j + 1; i <= scorer.index(k); i++) { + if (ancestors.contains(scorer.get(i))) { + scorer.moveTo(scorer.get(i), j++); + } + } + } + public void besMutation(TeyssierScorer scorer) { Graph graph = scorer.getGraph(true); Bes bes = new Bes(score); @@ -218,5 +268,5 @@ public void setCaching(boolean caching) { this.caching = caching; } - public enum AlgType {BOSS1, BOSS2} + public enum AlgType {BOSS1, BOSS2, BOSS3} } \ No newline at end of file From c9c170cbb8d2a952dc420cc0dc497755132f911d Mon Sep 17 00:00:00 2001 From: bja43 Date: Sat, 11 Feb 2023 01:42:02 -0600 Subject: [PATCH 5/9] Adding GST (but not applying it to any of the algorithms) --- .../edu/cmu/tetrad/search/GrowShrinkTree.java | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java new file mode 100644 index 0000000000..f5a0b73552 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java @@ -0,0 +1,142 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.graph.Node; +import org.jetbrains.annotations.NotNull; + +import java.util.*; + +public class GrowShrinkTree { + private static Score score; + private static List variables; + private final Map roots; + + public GrowShrinkTree(Score score) { + GrowShrinkTree.score = score; + GrowShrinkTree.variables = score.getVariables(); + this.roots = new HashMap<>(); + for (Node node : GrowShrinkTree.variables) { + this.roots.put(node, new GSTNode(node)); + } + } + + public double GrowShrink(Node node, Set prefix, LinkedHashSet parents) { + return this.roots.get(node).GrowShrink(node, prefix, parents); + } + + private static class GSTNode implements Comparable { + private final Node add; + private boolean grow; + private boolean shrink; + private final double growScore; + private double shrinkScore; + private List branches; + private Set remove; + + private GSTNode(Node node) { + this.add = null; + this.grow = false; + this.shrink = false; + + int y = GrowShrinkTree.variables.indexOf(node); + this.growScore = GrowShrinkTree.score.localScore(y); + } + + private GSTNode(Node node, Node add, Set parents) { + this.add = add; + this.grow = false; + this.shrink = false; + + int y = GrowShrinkTree.variables.indexOf(node); + int[] X = new int[parents.size() + 1]; + + int i = 0; + for (Node parent : parents) X[i++] = GrowShrinkTree.variables.indexOf(parent); + X[i] = GrowShrinkTree.variables.indexOf(add); + + this.growScore = GrowShrinkTree.score.localScore(y, X); + } + + public double GrowShrink(Node node, Set prefix, LinkedHashSet parents) { + + if (!this.grow) { + this.grow = true; + this.branches = new ArrayList<>(); + + for (Node add : GrowShrinkTree.variables) { + if (parents.contains(add) || add == node) continue; + GSTNode branch = new GSTNode(node, add, parents); + if (this.compareTo(branch) < 0) this.branches.add(branch); + } + this.branches.sort(Collections.reverseOrder()); + } + + for (GSTNode branch : this.branches) { + Node add = branch.getAdd(); + if (prefix.contains(add)) { + prefix.remove(add); + parents.add(add); + return branch.GrowShrink(node, prefix, parents); + } + } + + if (!this.shrink) { + this.shrink = true; + this.remove = new HashSet<>(); + this.shrinkScore = this.growScore; + + if (parents.isEmpty()) return this.shrinkScore; + + int y = GrowShrinkTree.variables.indexOf(node); + int[] X = new int[parents.size() - 1]; + + int i = 0; + Iterator itr = parents.iterator(); + itr.next(); + while (itr.hasNext()) X[i++] = GrowShrinkTree.variables.indexOf(itr.next()); + Node best; + + do { + i = 0; + itr = parents.iterator(); + Node remove = itr.next(); + best = null; + + do { + double s = GrowShrinkTree.score.localScore(y, X); + if (s > this.shrinkScore) { + this.shrinkScore = s; + best = remove; + } + + if (i < parents.size() - 1) { + remove = itr.next(); + X[i++] = GrowShrinkTree.variables.indexOf(remove); + } + } while (i < parents.size() - 1); + + if (best != null) { + parents.remove(best); + this.remove.add(best); + } + + } while (best != null); + } + parents.removeAll(this.remove); + return this.shrinkScore; + } + + public Node getAdd() { + return this.add; + } + + public double getGrowScore() { + return this.growScore; + } + + @Override + public int compareTo(@NotNull GrowShrinkTree.GSTNode branch) { + return Double.compare(this.growScore, branch.getGrowScore()); + } + } +} + From aae7a4959980197dcc5ad8159b6af0b2a39a4613 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sat, 11 Feb 2023 12:23:33 -0500 Subject: [PATCH 6/9] Adjustment to Boss2 --- .../src/main/java/edu/cmu/tetrad/search/Boss.java | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java index 0f6cda7907..e5831f596f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java @@ -269,19 +269,20 @@ public void betterMutation2(@NotNull TeyssierScorer scorer) { if (!violatesKnowledge(scorer.getPi())) { _sp = scorer.score(); scorer.bookmark(); + + if (verbose) { + System.out.print("\rIndex = " + (scorer.index(k) + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s")); + } } - } + } - if (verbose) { - TetradLogger.getInstance().forceLogMessage("\rIndex = " + (scorer.index(k) + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s")); - } } scorer.goToBookmark(); } if (verbose) { - TetradLogger.getInstance().forceLogMessage("\n"); + System.out.println(); } s2 = scorer.score(); From 4b8eb1a049eb13e5a6b24b4121f4d55bbfa4bb04 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sat, 11 Feb 2023 15:14:19 -0500 Subject: [PATCH 7/9] Adjustment to Boss2 --- .../edu/cmu/tetrad/algcomparison/examples/TestBoss.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java index 0ca19aa618..be54579c92 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/examples/TestBoss.java @@ -30,6 +30,7 @@ import edu.cmu.tetrad.algcomparison.graph.RandomForward; import edu.cmu.tetrad.algcomparison.independence.FisherZ; import edu.cmu.tetrad.algcomparison.score.SemBicScore; +import edu.cmu.tetrad.algcomparison.score.ZhangShenBoundScore; import edu.cmu.tetrad.algcomparison.simulation.SemSimulation; import edu.cmu.tetrad.algcomparison.simulation.Simulations; import edu.cmu.tetrad.algcomparison.statistic.*; @@ -46,7 +47,7 @@ public static void main(String... args) { Parameters parameters = new Parameters(); parameters.set(Params.NUM_RUNS, 1); parameters.set(Params.DIFFERENT_GRAPHS, true); - parameters.set(Params.NUM_MEASURES, 400); + parameters.set(Params.NUM_MEASURES, 100); parameters.set(Params.AVG_DEGREE, 20); parameters.set(Params.SAMPLE_SIZE, 1000); parameters.set(Params.COEF_LOW, 0); @@ -58,7 +59,7 @@ public static void main(String... args) { parameters.set(Params.SEM_BIC_STRUCTURE_PRIOR, 0); parameters.set(Params.ALPHA, 1e-2); - parameters.set("verbose", true); + parameters.set(Params.VERBOSE, true); Statistics statistics = new Statistics(); statistics.add(new AdjacencyPrecision()); @@ -69,7 +70,7 @@ public static void main(String... args) { Algorithms algorithms = new Algorithms(); // algorithms.add(new Fges(new SemBicScore())); - algorithms.add(new BDCE(new SemBicScore())); + algorithms.add(new BOSS(new FisherZ(), new ZhangShenBoundScore())); // algorithms.add(new BOSSDC(new SemBicScore())); Simulations simulations = new Simulations(); @@ -81,7 +82,7 @@ public static void main(String... args) { comparison.setShowSimulationIndices(true); comparison.setSortByUtility(false); comparison.setShowUtilities(false); - comparison.setParallelized(true); + comparison.setParallelized(false); comparison.setComparisonGraph(Comparison.ComparisonGraph.CPDAG_of_the_true_DAG); From 1ccc21e7ff5014511b0c87a83f313be47097a350 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sat, 11 Feb 2023 23:20:32 -0500 Subject: [PATCH 8/9] Changed initialGraph to externalGraph for this branch too... --- .../edu/cmu/tetradapp/model/FgesRunner.java | 2 +- .../algorithm/multi/FgesConcatenated.java | 2 +- .../algorithm/oracle/cpdag/Fges.java | 8 +- .../algorithm/oracle/cpdag/PC.java | 8 +- .../annotation/AlgorithmAnnotations.java | 2 +- .../java/edu/cmu/tetrad/search/Bridges.java | 2 +- .../java/edu/cmu/tetrad/search/Bridges2.java | 2 +- .../edu/cmu/tetrad/search/BridgesOld.java | 2 +- .../main/java/edu/cmu/tetrad/search/Fges.java | 15 +- .../edu/cmu/tetrad/search/GrowShrinkTree.java | 38 +- .../java/edu/cmu/tetrad/search/PcAll.java | 6 +- .../edu/cmu/tetrad/search/TeyssierScorer.java | 193 +-- .../cmu/tetrad/search/TeyssierScorer3.java | 1384 +++++++++++++++++ 13 files changed, 1534 insertions(+), 130 deletions(-) create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer3.java diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java index 6a4292352d..80cdf4c9b4 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java @@ -200,7 +200,7 @@ public void execute() { } } - this.fges.setInitialGraph(this.externalGraph); + this.fges.setExternalGraph(this.externalGraph); this.fges.setKnowledge((Knowledge) getParams().get("knowledge", new Knowledge())); this.fges.setVerbose(true); Graph graph = this.fges.search(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java index 8cc215ee31..008ef6d179 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java @@ -75,7 +75,7 @@ public Graph search(List dataModels, Parameters parameters) { } if (initial != null) { - search.setInitialGraph(initial); + search.setExternalGraph(initial); } return search.search(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java index 563117a6bc..aa6866191c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java @@ -40,7 +40,7 @@ public class Fges implements Algorithm, HasKnowledge, UsesScoreWrapper { private ScoreWrapper score; private Knowledge knowledge = new Knowledge(); - private Graph initialGraph = null; + private Graph wxternalGraph = null; public Fges() { @@ -68,7 +68,7 @@ public Graph search(DataModel dataModel, Parameters parameters) { edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(score); -// search.setInitialGraph(initialGraph); + search.setExternalGraph(wxternalGraph); search.setKnowledge(this.knowledge); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); search.setMeekVerbose(parameters.getBoolean(Params.MEEK_VERBOSE)); @@ -150,7 +150,7 @@ public void setScoreWrapper(ScoreWrapper score) { this.score = score; } - public void setInitialGraph(Graph initialGraph) { - this.initialGraph = initialGraph; + public void setExternalGraph(Graph externalGraph) { + this.wxternalGraph = externalGraph; } } \ No newline at end of file diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PC.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PC.java index 7dfe86b5dc..c8e5a8def9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PC.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/PC.java @@ -39,7 +39,7 @@ public class PC implements Algorithm, HasKnowledge, TakesIndependenceWrapper { private IndependenceWrapper test; private Knowledge knowledge = new Knowledge(); - private Graph initialGraph = null; + private Graph externalGraph = null; public PC() { } @@ -102,7 +102,7 @@ public Graph search(DataModel dataModel, Parameters parameters) { search.setUseHeuristic(parameters.getBoolean(Params.USE_MAX_P_ORIENTATION_HEURISTIC)); search.setMaxPathLength(parameters.getInt(Params.MAX_P_ORIENTATION_MAX_PATH_LENGTH)); search.setMaxPathLength(parameters.getInt(Params.MAX_P_ORIENTATION_MAX_PATH_LENGTH)); - search.setInitialGraph(initialGraph); + search.setExternalGraph(externalGraph); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); return search.search(); @@ -171,7 +171,7 @@ public void setIndependenceWrapper(IndependenceWrapper test) { this.test = test; } - public void setInitialGraph(Graph initialGraph) { - this.initialGraph = initialGraph; + public void setExternalGraph(Graph externalGraph) { + this.externalGraph = externalGraph; } } \ No newline at end of file diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AlgorithmAnnotations.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AlgorithmAnnotations.java index 43ad2cee49..7a5d667ff0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AlgorithmAnnotations.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/annotation/AlgorithmAnnotations.java @@ -54,7 +54,7 @@ public boolean takesKnowledge(Class clazz) { return clazz != null && HasKnowledge.class.isAssignableFrom(clazz); } - public boolean takesInitialGraph(Class clazz) { + public boolean takesExternalGraph(Class clazz) { return clazz != null && TakesExternalGraph.class.isAssignableFrom(clazz); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges.java index 057fd8d9e7..b8a576b0d4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges.java @@ -90,7 +90,7 @@ public final class Bridges implements GraphSearch, GraphScorer { /** * An initial graph to start from. */ - private Graph initialGraph; + private Graph externalGraph; /** * If non-null, edges not adjacent in this graph will not be added. */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java index caf1f2922f..53152b8932 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java @@ -138,7 +138,7 @@ public Graph search() { g.removeEdge(edge); g.addEdge(reversed); - fges.setInitialGraph(g); + fges.setExternalGraph(g); Graph g1 = fges.search(); double s1 = fges.getModelScore(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java index 7c57109b09..dbc1629682 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java @@ -66,7 +66,7 @@ public Graph search() { meeks.orientImplied(g); - ges.setInitialGraph(g); + ges.setExternalGraph(g); Graph g1 = ges.search(); double s1 = ges.getModelScore(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java index b07a6d3706..a2af42d835 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java @@ -89,7 +89,7 @@ public final class Fges implements GraphSearch, GraphScorer { /** * An initial graph to start from. */ - private Graph initialGraph; + private Graph externalGraph; /** * If non-null, edges not adjacent in this graph will not be added. */ @@ -192,8 +192,8 @@ public Graph search() { boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables()); } - if (initialGraph != null) { - graph = new EdgeListGraph(initialGraph); + if (externalGraph != null) { + graph = new EdgeListGraph(externalGraph); graph = GraphUtils.replaceNodes(graph, getVariables()); } @@ -268,7 +268,12 @@ public LinkedList getTopGraphs() { /** * Sets the initial graph. */ - public void setInitialGraph(Graph externalGraph) { + public void setExternalGraph(Graph externalGraph) { + if (externalGraph == null) { + this.externalGraph = null; + return; + } + externalGraph = GraphUtils.replaceNodes(externalGraph, variables); if (verbose) { @@ -280,7 +285,7 @@ public void setInitialGraph(Graph externalGraph) { throw new IllegalArgumentException("Variables aren't the same."); } - this.initialGraph = externalGraph; + this.externalGraph = externalGraph; } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java index f5a0b73552..ab56d1b1de 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GrowShrinkTree.java @@ -7,14 +7,18 @@ public class GrowShrinkTree { private static Score score; - private static List variables; + private static HashMap index; private final Map roots; public GrowShrinkTree(Score score) { GrowShrinkTree.score = score; - GrowShrinkTree.variables = score.getVariables(); + GrowShrinkTree.index = new HashMap<>(); + + int i = 0; + for (Node node : score.getVariables()) GrowShrinkTree.index.put(node, i++); + this.roots = new HashMap<>(); - for (Node node : GrowShrinkTree.variables) { + for (Node node : GrowShrinkTree.score.getVariables()) { this.roots.put(node, new GSTNode(node)); } } @@ -37,7 +41,7 @@ private GSTNode(Node node) { this.grow = false; this.shrink = false; - int y = GrowShrinkTree.variables.indexOf(node); + int y = GrowShrinkTree.index.get(node); this.growScore = GrowShrinkTree.score.localScore(y); } @@ -46,12 +50,12 @@ private GSTNode(Node node, Node add, Set parents) { this.grow = false; this.shrink = false; - int y = GrowShrinkTree.variables.indexOf(node); + int y = GrowShrinkTree.index.get(node); int[] X = new int[parents.size() + 1]; int i = 0; - for (Node parent : parents) X[i++] = GrowShrinkTree.variables.indexOf(parent); - X[i] = GrowShrinkTree.variables.indexOf(add); + for (Node parent : parents) X[i++] = GrowShrinkTree.index.get(parent); + X[i] = GrowShrinkTree.index.get(add); this.growScore = GrowShrinkTree.score.localScore(y, X); } @@ -62,7 +66,7 @@ public double GrowShrink(Node node, Set prefix, LinkedHashSet parent this.grow = true; this.branches = new ArrayList<>(); - for (Node add : GrowShrinkTree.variables) { + for (Node add : GrowShrinkTree.score.getVariables()) { if (parents.contains(add) || add == node) continue; GSTNode branch = new GSTNode(node, add, parents); if (this.compareTo(branch) < 0) this.branches.add(branch); @@ -86,17 +90,17 @@ public double GrowShrink(Node node, Set prefix, LinkedHashSet parent if (parents.isEmpty()) return this.shrinkScore; - int y = GrowShrinkTree.variables.indexOf(node); - int[] X = new int[parents.size() - 1]; - - int i = 0; - Iterator itr = parents.iterator(); - itr.next(); - while (itr.hasNext()) X[i++] = GrowShrinkTree.variables.indexOf(itr.next()); + int y = GrowShrinkTree.index.get(node); Node best; do { - i = 0; + int[] X = new int[parents.size() - 1]; + + int i = 0; + Iterator itr = parents.iterator(); + itr.next(); + while (itr.hasNext()) X[i++] = GrowShrinkTree.index.get(itr.next()); + itr = parents.iterator(); Node remove = itr.next(); best = null; @@ -110,7 +114,7 @@ public double GrowShrink(Node node, Set prefix, LinkedHashSet parent if (i < parents.size() - 1) { remove = itr.next(); - X[i++] = GrowShrinkTree.variables.indexOf(remove); + X[i++] = GrowShrinkTree.index.get(remove); } } while (i < parents.size() - 1); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcAll.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcAll.java index 979a94fcca..90d82a6c2c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcAll.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PcAll.java @@ -95,7 +95,7 @@ public final class PcAll implements GraphSearch { private Concurrent concurrent = Concurrent.YES; private ColliderDiscovery colliderDiscovery = ColliderDiscovery.FAS_SEPSETS; private ConflictRule conflictRule = ConflictRule.OVERWRITE; - private Graph initialGraph = null; + private Graph externalGraph = null; /** * Constructs a CPC algorithm that uses the given independence test as oracle. This does not make a copy of the @@ -597,8 +597,8 @@ private void orientCollidersUsingSepsets(SepsetMap set, Knowledge knowledge, Gra } } - public void setInitialGraph(Graph initialGraph) { - this.initialGraph = initialGraph; + public void setExternalGraph(Graph externalGraph) { + this.externalGraph = externalGraph; } public enum FasType {REGULAR, STABLE} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer.java index 346b5a6657..2f49c3dd3d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer.java @@ -47,12 +47,16 @@ public class TeyssierScorer { private double runningScore = 0f; private Graph mag = null; + private GrowShrinkTree GST; + public TeyssierScorer(IndependenceTest test, Score score) { NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); this.score = score; this.test = test; + this.GST = new GrowShrinkTree(score); + if (score != null) { this.variables = score.getVariables(); this.pi = new ArrayList<>(this.variables); @@ -1102,120 +1106,127 @@ private boolean lastMoveSame(int i1, int i2) { private Pair getGrowShrinkScore(int p) { Node n = this.pi.get(p); - Set parents = new HashSet<>(); - boolean changed = true; - - double sMax = score(n, new HashSet<>()); - List prefix = new ArrayList<>(getPrefix(p)); - - // Backward scoring only from the prefix variables - if (this.useBackwardScoring) { - parents.addAll(prefix); - sMax = score(n, parents); - changed = false; - } - - // Grow-shrink - while (changed) { - changed = false; - - // Let z be the node that maximizes the score... - Node z = null; - - for (Node z0 : prefix) { - if (parents.contains(z0)) continue; - - if (!knowledge.isEmpty() && this.knowledge.isForbidden(z0.getName(), n.getName())) continue; - - parents.add(z0); - - if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; - - double s2 = score(n, parents); - - if (s2 >= sMax) { - sMax = s2; - z = z0; - } - - parents.remove(z0); - } - - if (z != null) { - parents.add(z); - changed = true; - } - - } - - boolean changed2 = true; - - while (changed2) { - changed2 = false; - - Node w = null; +// Set parents = new HashSet<>(); +// boolean changed = true; - for (Node z0 : new HashSet<>(parents)) { - if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +// double sMax = score(n, new HashSet<>()); +// List prefix = new ArrayList<>(getPrefix(p)); - parents.remove(z0); + Set prefix = new HashSet<>(getPrefix(p)); + LinkedHashSet parents = new LinkedHashSet<>(); + double sMax = GST.GrowShrink(n, prefix, parents); - double s2 = score(n, parents); + return new Pair(parents, Double.isNaN(sMax) ? Double.NEGATIVE_INFINITY : sMax); - if (s2 > sMax) { - sMax = s2; - w = z0; - } - - parents.add(z0); - } - if (w != null) { - parents.remove(w); - changed2 = true; - } - } - -// while (changed2) { -// changed2 = false; +// // Backward scoring only from the prefix variables +// if (this.useBackwardScoring) { +// parents.addAll(prefix); +// sMax = score(n, parents); +// changed = false; +// } // -// List aaa = null; +// // Grow-shrink +// while (changed) { +// changed = false; // -// List pp = new ArrayList<>(parents); +// // Let z be the node that maximizes the score... +// Node z = null; // -// SublistGenerator gen = new SublistGenerator(parents.size(), 2); -// int[] choice; +// for (Node z0 : prefix) { +// if (parents.contains(z0)) continue; // -// while ((choice = gen.next()) != null) { -// List aa = GraphUtils.asList(choice, pp); +// if (!knowledge.isEmpty() && this.knowledge.isForbidden(z0.getName(), n.getName())) continue; // -//// for (Node z0 : new HashSet<>(parents)) { -//// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; -//// -// aa.forEach(parents::remove); +// parents.add(z0); +// +// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +// +// double s2 = score(n, parents); +// +// if (s2 >= sMax) { +// sMax = s2; +// z = z0; +// } +// +// parents.remove(z0); +// } +// +// if (z != null) { +// parents.add(z); +// changed = true; +// } +// +// } +// +// boolean changed2 = true; +// +// while (changed2) { +// changed2 = false; +// +// Node w = null; +// +// for (Node z0 : new HashSet<>(parents)) { +// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +// +// parents.remove(z0); // // double s2 = score(n, parents); // // if (s2 > sMax) { // sMax = s2; -// aaa = aa; +// w = z0; // } // -// parents.addAll(aa); +// parents.add(z0); // } // -// if (aaa != null) { -// aaa.forEach(parents::remove); +// if (w != null) { +// parents.remove(w); // changed2 = true; // } +// } +// +//// while (changed2) { +//// changed2 = false; +//// +//// List aaa = null; +//// +//// List pp = new ArrayList<>(parents); +//// +//// SublistGenerator gen = new SublistGenerator(parents.size(), 2); +//// int[] choice; +//// +//// while ((choice = gen.next()) != null) { +//// List aa = GraphUtils.asList(choice, pp); +//// +////// for (Node z0 : new HashSet<>(parents)) { +////// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +////// +//// aa.forEach(parents::remove); +//// +//// double s2 = score(n, parents); +//// +//// if (s2 > sMax) { +//// sMax = s2; +//// aaa = aa; +//// } +//// +//// parents.addAll(aa); +//// } +//// +//// if (aaa != null) { +//// aaa.forEach(parents::remove); +//// changed2 = true; +//// } +//// +//// } // +// if (this.useScore) { +// return new Pair(parents, Double.isNaN(sMax) ? Double.NEGATIVE_INFINITY : sMax); +// } else { +// return new Pair(parents, -parents.size()); // } - - if (this.useScore) { - return new Pair(parents, Double.isNaN(sMax) ? Double.NEGATIVE_INFINITY : sMax); - } else { - return new Pair(parents, -parents.size()); - } } private Pair getGrowShrinkIndependent(int p) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer3.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer3.java new file mode 100644 index 0000000000..b6c6bd25c1 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/TeyssierScorer3.java @@ -0,0 +1,1384 @@ +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.Knowledge; +import edu.cmu.tetrad.data.KnowledgeEdge; +import edu.cmu.tetrad.graph.*; +import org.jetbrains.annotations.NotNull; + +import java.util.*; +import java.util.concurrent.Callable; + +import static edu.cmu.tetrad.util.RandomUtil.shuffle; +import static java.util.Collections.sort; +import static org.apache.commons.math3.util.FastMath.floor; + + +/** + * Implements a scorer extending Teyssier, M., and Koller, D. (2012). Ordering-based search: A simple and effective + * algorithm for learning Bayesian networks. arXiv preprint arXiv:1207.1429. You give it a score function + * and a variable ordering, and it computes the score. You can move any variable left or right, and it will + * keep track of the score using the Teyssier and Kohler method. You can move a variable to a new position, + * and you can bookmark a state and come back to it. + * + * @author josephramsey + * @author bryanandrews + */ +public class TeyssierScorer3 { + private final List variables; + private final Map variablesHash; + private final Score score; + private final IndependenceTest test; + private int maxIndegree; + private Map> bookmarkedOrders = new HashMap<>(); + private Map> bookmarkedScores = new HashMap<>(); + private Map> bookmarkedOrderHashes = new HashMap<>(); + private Map bookmarkedRunningScores = new HashMap<>(); + private Map, Double>> cache = new HashMap<>(); + private Map orderHash; + private ArrayList pi; // The current permutation. + private ArrayList scores; + private Knowledge knowledge = new Knowledge(); + private ArrayList> prefixes; + + private boolean useScore = true; + private boolean useRaskuttiUhler; + private boolean useBackwardScoring; + private boolean cachingScores = true; + private double runningScore = 0f; + private Graph mag = null; + + private GrowShrinkTree GST; + + public TeyssierScorer3(IndependenceTest test, Score score) { + NodeEqualityMode.setEqualityMode(NodeEqualityMode.Type.OBJECT); + + this.score = score; + this.test = test; + + this.GST = new GrowShrinkTree(score); + + if (score != null) { + this.variables = score.getVariables(); + this.pi = new ArrayList<>(this.variables); + } else if (test != null) { + this.variables = test.getVariables(); + this.pi = new ArrayList<>(this.variables); + } else { + throw new IllegalArgumentException("Need both a score and a test,"); + } + + this.orderHash = new HashMap<>(); + nodesHash(this.orderHash, this.pi); + + this.variablesHash = new HashMap<>(); + nodesHash(this.variablesHash, this.variables); + + if (score instanceof GraphScore) { + this.useScore = false; + } + } + + public TeyssierScorer3(TeyssierScorer3 scorer) { + this.variables = new ArrayList<>(scorer.variables); + this.variablesHash = new HashMap<>(); + + for (Node key : scorer.variablesHash.keySet()) { + this.variablesHash.put(key, scorer.variablesHash.get(key)); + } + + this.score = scorer.score; + this.test = scorer.test; + + this.bookmarkedOrders = new HashMap<>(); + + for (Object key : scorer.bookmarkedOrders.keySet()) { + this.bookmarkedOrders.put(key, scorer.bookmarkedOrders.get(key)); + } + + this.bookmarkedScores = new HashMap<>(); + + for (Object key : scorer.bookmarkedScores.keySet()) { + this.bookmarkedScores.put(key, new ArrayList<>(scorer.bookmarkedScores.get(key))); + } + + this.bookmarkedOrderHashes = new HashMap<>(); + + for (Object key : scorer.bookmarkedOrderHashes.keySet()) { + this.bookmarkedOrderHashes.put(key, new HashMap<>(scorer.bookmarkedOrderHashes.get(key))); + } + + this.bookmarkedRunningScores = new HashMap<>(scorer.bookmarkedRunningScores); + + this.orderHash = new HashMap<>(scorer.orderHash); + + this.pi = new ArrayList<>(scorer.pi); + + this.scores = new ArrayList<>(scorer.scores); + this.knowledge = scorer.knowledge; + this.useScore = scorer.useScore; + this.useRaskuttiUhler = scorer.useRaskuttiUhler; + this.useBackwardScoring = scorer.useBackwardScoring; + this.cachingScores = scorer.cachingScores; + this.runningScore = scorer.runningScore; + this.maxIndegree = scorer.maxIndegree; + + this.prefixes = new ArrayList<>(scorer.prefixes); + } + + public void moveToEnd(Node z) { + moveTo(z, size() - 1); + } + + /** + * @param useScore True if the score should be used; false if the test should be used. + */ + public void setUseScore(boolean useScore) { + if (!(this.score instanceof GraphScore)) { + this.useScore = useScore; + } + } + + /** + * @param cachingScores True if scores should be cached (potentially expensive for memory); + * false if not (potentially expensive for time). + */ + public void setCachingScores(boolean cachingScores) { + this.cachingScores = cachingScores; + } + + /** + * @param knowledge Knowledge of forbidden edges. + */ + public void setKnowledge(Knowledge knowledge) { + this.knowledge = knowledge; + } + + /** + * @param useRaskuttiUhler True if Pearl's method for building a DAG should be used. + */ + public void setUseRaskuttiUhler(boolean useRaskuttiUhler) { + this.useRaskuttiUhler = useRaskuttiUhler; + this.useScore = false; + } + + public void setUseBackwardScoring(boolean useBackwardScoring) { + this.useBackwardScoring = useBackwardScoring; + } + + public double score(List order, Graph mag) { + this.mag = mag; + return score(order); + } + + /** + * Scores the given permutation. This needs to be done initially before any move or tuck + * operations are performed. + * + * @param order The permutation to score. + * @return The score of it. + */ + public double score(List order) { + this.pi = new ArrayList<>(order); + this.scores = new ArrayList<>(); + + for (int i1 = 0; i1 < order.size(); i1++) { + this.scores.add(null); + } + + this.prefixes = new ArrayList<>(); + for (int i1 = 0; i1 < order.size(); i1++) this.prefixes.add(null); + initializeScores(); + return score(); + } + + /** + * @return The score of the current permutation. + */ + public double score() { + return sum(); +// return runningScore; + } + + private double sum() { + double score = 0; + + for (int i = 0; i < this.pi.size(); i++) { + if (this.scores.get(i) == null) { + recalculate(i); + } + double score1 = this.scores.get(i).getScore(); + score += score1; + } + + return score; + } + + /** + * Performs a tuck operation. + */ + public boolean swaptuck(Node x, Node y, Node z, boolean doZ) { + boolean moved = false; + + if (index(y) < index(x)) { + moveTo(x, index(y)); + moved = true; + } + + if (doZ) { + if (index(y) < index(z)) { + moveTo(z, index(y)); + moved = true; + } + } + + return moved; + } + + public boolean swaptuck(Node x, Node y) { + if (index(x) < index(y)) { +// moveTo(y, index(x)); + return false; + } else if (index(y) < index(x)) { + moveTo(x, index(y)); + return true; + } + + return false; + } + + public void tuckWithoutMovingAncestors(Node x, Node y) { + if (index(x) > index(y)) { + moveTo(x, index(y)); + } + } + + public boolean tuck(Node k, Node j) { + return tuck(k, index(j)); + } + + public boolean tuck(Node k, int j) { + if (adjacent(k, get(j))) return false; +// if (scorer.coveredEdge(k, scorer.get(j))) return false; + if (j >= index(k)) return false; + + Set ancestors = getAncestors(k); + for (int i = j + 1; i <= index(k); i++) { + if (ancestors.contains(get(i))) { + moveTo(get(i), j++); + } + } + + return true; + } + + /** + * Moves v to a new index. + * + * @param v The variable to move. + * @param toIndex The index to move v to. + */ + public void moveTo(Node v, int toIndex) { + int vIndex = index(v); + if (vIndex == toIndex) return; + if (lastMoveSame(vIndex, toIndex)) return; + + this.pi.remove(v); + this.pi.add(toIndex, v); + + if (toIndex < vIndex) { + updateScores(toIndex, vIndex); + } else { + updateScores(vIndex, toIndex); + } + } + + /** + * Swaps m and n in the permutation. + * + * @param m The first variable. + * @param n The second variable. + * @return True iff the swap was done. + */ + public boolean swap(Node m, Node n) { + int i = this.orderHash.get(m); + int j = this.orderHash.get(n); + + this.pi.set(i, n); + this.pi.set(j, m); + + if (violatesKnowledge(this.pi)) { + this.pi.set(i, m); + this.pi.set(j, n); + return false; + } + + if (i < j) { + updateScores(i, j); + } else { + updateScores(j, i); + } + + return true; + } + + /** + * Returns true iff x->y or y->x is a covered edge. x->y is a covered edge if + * parents(x) = parents(y) \ {x} + * + * @param x The first variable. + * @param y The second variable. + * @return True iff x->y or y->x is a covered edge. + */ + public boolean coveredEdge(Node x, Node y) { +// if (!adjacent(x, y)) return false; + Set px = getParents(x); + Set py = getParents(y); + px.remove(y); + py.remove(x); + return px.equals(py); + } + + /** + * @return A copy of the current permutation. + */ + public List getPi() { + return new ArrayList<>(this.pi); + } + + /** + * Returns the current permutation without making a copy. Could be dangerous! + * + * @return the current permutation. + */ + public List getOrderShallow() { + return this.pi; + } + + /** + * Return the index of v in the current permutation. + * + * @param v The variable. + * @return Its index. + */ + public int index(Node v) { + Integer integer = this.orderHash.get(v); + + if (integer == null) + throw new IllegalArgumentException("First 'evaluate' a permutation containing variable " + + v + "."); + + return integer; + } + + /** + * Returns the parents of the node at index p. + * + * @param p The index of the node. + * @return Its parents. + */ + public Set getParents(int p) { + if (this.scores.get(p) == null) recalculate(p); + return new HashSet<>(this.scores.get(p).getParents()); + } + + public Set getChildren(int p) { + Set adj = getAdjacentNodes(get(p)); + adj.removeAll(getParents(p)); + return adj; + } + + /** + * Returns the parents of a node v. + * + * @param v The variable. + * @return Its parents. + */ + public Set getParents(Node v) { + return getParents(index(v)); + } + + public Set getChildren(Node v) { + return getChildren(index(v)); + } + + /** + * Returns the nodes adjacent to v. + * + * @param v The variable. + * @return Its adjacent nodes. + */ + public Set getAdjacentNodes(Node v) { + Set adj = new HashSet<>(); + + for (Node w : this.pi) { + if (getParents(v).contains(w) || getParents(w).contains(v)) { + adj.add(w); + } + } + + return adj; + } + + public Set getAncestralNodes(Node v) { + Set adj = new HashSet<>(); + + for (Node w : this.pi) { + if (getAncestors(v).contains(w) || getAncestors(w).contains(v)) { + adj.add(w); + } + } + + return adj; + } + + /** + * Returns the DAG build for the current permutation, or its CPDAG. + * + * @param cpDag True iff the CPDAG should be returned, False if the DAG. + * @return This graph. + */ + public Graph getGraph(boolean cpDag) { + if (cpDag) { + List order = getPi(); + Graph G1 = new EdgeListGraph(this.variables); + + for (int p = 0; p < order.size(); p++) { + for (Node z : getParents(p)) { + G1.addDirectedEdge(z, order.get(p)); + } + } + + GraphUtils.replaceNodes(G1, this.variables); + + MeekRules rules = new MeekRules(); + rules.setKnowledge(knowledge); + rules.orientImplied(G1); + + return G1; + +// return findCompelled(); + } else { + List order = getPi(); + Graph G1 = new EdgeListGraph(this.variables); + + for (int p = 0; p < order.size(); p++) { + for (Node z : getParents(p)) { + G1.addDirectedEdge(z, order.get(p)); + } + } + + GraphUtils.replaceNodes(G1, this.variables); + + return G1; + } + } + + public void orientbk(Knowledge bk, Graph graph, List variables) { + for (Iterator it = bk.forbiddenEdgesIterator(); it.hasNext(); ) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + KnowledgeEdge edge = it.next(); + + //match strings to variables in the graph. + Node from = SearchGraphUtils.translate(edge.getFrom(), variables); + Node to = SearchGraphUtils.translate(edge.getTo(), variables); + + if (from == null || to == null) { + continue; + } + + if (graph.getEdge(from, to) == null) { + continue; + } + + // Orient to*->from + graph.setEndpoint(to, from, Endpoint.ARROW); +// graph.setEndpoint(from, to, Endpoint.CIRCLE); +// this.changeFlag = true; +// this.logger.forceLogMessage(SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); + } + + for (Iterator it = bk.requiredEdgesIterator(); it.hasNext(); ) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + KnowledgeEdge edge = it.next(); + + //match strings to variables in the graph. + Node from = SearchGraphUtils.translate(edge.getFrom(), variables); + Node to = SearchGraphUtils.translate(edge.getTo(), variables); + + if (from == null || to == null) { + continue; + } + + if (graph.getEdge(from, to) == null) { + continue; + } + + // Orient to*->from + graph.setEndpoint(from, to, Endpoint.ARROW); +// graph.setEndpoint(from, to, Endpoint.CIRCLE); +// this.changeFlag = true; +// this.logger.forceLogMessage(SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); + } + } + + +// public Graph getGraph(boolean cpDag) { +// +// if(cpDag) { +// return findCompelled(); +// } +// +// List order = getPi(); +// Graph G1 = new EdgeListGraph(this.variables); +// +// for (int p = 0; p < order.size(); p++) { +// for (Node z : getParents(p)) { +// G1.addDirectedEdge(z, order.get(p)); +// } +// } +// +// GraphUtils.replaceNodes(G1, this.variables); +// +// if (cpDag) { +// return SearchGraphUtils.cpdagForDag(G1); +// } else { +// return G1; +// } +// } + + private List orderEdges() { + + List orderedEdges = new ArrayList<>(); + + for (int i = this.pi.size(); i-- > 0; ) { + Node y = this.pi.get(i); + Set pa = this.getParents(i); + for (int j = 0; j < i; j++) { + Node x = this.pi.get(j); + if (pa.contains(x)) { + Edge e = new Edge(x, y, Endpoint.TAIL, Endpoint.ARROW); + orderedEdges.add(e); + pa.remove(x); + if (pa.isEmpty()) { + break; + } + } + } + } + return orderedEdges; + } + + public Graph findCompelled() { + + Graph G = new EdgeListGraph(this.variables); + + List orderedEdges = orderEdges(); + + Node remainderCompelled = null; + Node remainderReversible = null; + + EDGES: + while (!orderedEdges.isEmpty()) { + + Edge e = orderedEdges.remove(0); + Node x = e.getNode1(); + Node y = e.getNode2(); + + if (remainderCompelled != null) { + if (remainderCompelled == y) { + G.addEdge(e); + continue; + } else { + remainderCompelled = null; + } + } + + if (remainderReversible != null) { + if (remainderReversible == y) { + G.addUndirectedEdge(x, y); + continue; + } else { + remainderReversible = null; + } + } + + if (G.isParentOf(x, y)) { + continue; + } + + List compelled = G.getParents(x); + Set xPa = getParents(x); + Set yPa = getParents(y); + + for (Node w : compelled) { + if (yPa.contains(w)) { + G.addEdge(e); + continue EDGES; + } else { + G.addDirectedEdge(w, y); + } + } + + yPa.remove(x); + for (Node z : yPa) { + if (!xPa.contains(z)) { + G.addEdge(e); + remainderCompelled = y; + continue EDGES; + } + } + + G.addUndirectedEdge(x, y); + remainderReversible = y; + } + + return G; + } + + /** + * Returns a list of adjacent node pairs in the current graph. + * + * @return This list. + */ + public List getAdjacencies() { + List order = getPi(); + Set pairs = new HashSet<>(); + + for (int i = 0; i < order.size(); i++) { + for (int j = 0; j < i; j++) { + Node x = order.get(i); + Node y = order.get(j); + + if (adjacent(x, y)) { + pairs.add(new NodePair(x, y)); + } + } + } + + return new ArrayList<>(pairs); + } + + public Map> getAdjMap() { + Map> adjMap = new HashMap<>(); + for (Node node1 : getPi()) { + if (!adjMap.containsKey(node1)) { + adjMap.put(node1, new HashSet<>()); + } + for (Node node2 : getParents(node1)) { + if (!adjMap.containsKey(node2)) { + adjMap.put(node2, new HashSet<>()); + } + adjMap.get(node1).add(node2); + adjMap.get(node2).add(node1); + } + } + return adjMap; + } + + + public Map> getChildMap() { + Map> childMap = new HashMap<>(); + for (Node node1 : getPi()) { + for (Node node2 : getParents(node1)) { + if (!childMap.containsKey(node2)) { + childMap.put(node2, new HashSet<>()); + } + childMap.get(node2).add(node1); + } + } + return childMap; + } + + public Set getAncestors(Node node) { + Set ancestors = new HashSet<>(); + collectAncestorsVisit(node, ancestors); + + return ancestors; + } + + public Set getDescendants(Node node) { + Set descendants = new HashSet<>(); + collectDescendantVisit(node, descendants); + return descendants; + } + + private void collectAncestorsVisit(Node node, Set ancestors) { + if (ancestors.contains(node)) { + return; + } + + ancestors.add(node); + Set parents = getParents(node); + + if (!parents.isEmpty()) { + for (Node parent : parents) { + collectAncestorsVisit(parent, ancestors); + } + } + } + + private void collectDescendantVisit(Node node, Set ancestors) { + if (ancestors.contains(node)) { + return; + } + + ancestors.add(node); + Set children = getChildren(node); + + if (!children.isEmpty()) { + for (Node parent : children) { + collectDescendantVisit(parent, ancestors); + } + } + } + + /** + * Returns a list of edges for the current graph as a list of ordered pairs. + * + * @return This list. + */ + public List> getEdges() { + List order = getPi(); + List> edges = new ArrayList<>(); + + for (Node y : order) { + for (Node x : getParents(y)) { + edges.add(new OrderedPair<>(x, y)); + } + } + + return edges; + } + + /** + * @return The number of edges in the current graph. + */ + public int getNumEdges() { + int numEdges = 0; + + for (int p = 0; p < this.pi.size(); p++) { + numEdges += getParents(p).size(); + } + + return numEdges; + } + + /** + * Returns the node at index j in pi. + * + * @param j The index. + * @return The node at that index. + */ + public Node get(int j) { + return this.pi.get(j); + } + + /** + * Bookmarks the current pi as index key. + * + * @param key This bookmark may be retrieved using the index 'key', an integer. + * This bookmark will be stored until it is retrieved and then removed. + */ + public void bookmark(int key) { + try { + this.bookmarkedOrders.put(key, new ArrayList<>(this.pi)); + this.bookmarkedScores.put(key, new ArrayList<>(this.scores)); + this.bookmarkedOrderHashes.put(key, new HashMap<>(this.orderHash)); + this.bookmarkedRunningScores.put(key, runningScore); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Bookmarks the current pi with index Integer.MIN_VALUE. + */ + public void bookmark() { + bookmark(Integer.MIN_VALUE); + } + + /** + * Retrieves the bookmarked state for index 'key' and removes that bookmark. + * + * @param key The integer key for this bookmark. + */ + public void goToBookmark(int key) { + if (!this.bookmarkedOrders.containsKey(key)) { +// bookmark(key); +// return; + throw new IllegalArgumentException("That key was not bookmarked: " + key); + } + + this.pi = new ArrayList<>(this.bookmarkedOrders.get(key)); + this.scores = new ArrayList<>(this.bookmarkedScores.get(key)); + this.orderHash = new HashMap<>(this.bookmarkedOrderHashes.get(key)); + this.runningScore = this.bookmarkedRunningScores.get(key); + + } + + /** + * Retries the bookmark with key = Integer.MIN_VALUE and removes the bookmark. + */ + public void goToBookmark() { + goToBookmark(Integer.MIN_VALUE); + } + + /** + * Clears all bookmarks. + */ + public void clearBookmarks() { + this.bookmarkedOrders.clear(); + this.bookmarkedScores.clear(); + this.bookmarkedOrderHashes.clear(); + this.bookmarkedRunningScores.clear(); + } + + /** + * @return The size of pi, the current permutation. + */ + public int size() { + return this.pi.size(); + } + + /** + * Shuffles the current permutation and rescores it. + */ + public void shuffleVariables() { + this.pi = new ArrayList<>(this.pi); + shuffle(this.pi); + score(this.pi); + } + + public List getShuffledVariables() { + List variables = getPi(); + shuffle(variables); + return variables; + } + + /** + * Returns True iff a is adjacent to b in the current graph. + * + * @param a The first node. + * @param b The second node. + * @return True iff adj(a, b). + */ + public boolean adjacent(Node a, Node b) { + if (a == b) return false; + return getParents(a).contains(b) || getParents(b).contains(a); + } + + public boolean ancestorAdjacent(Node a, Node b) { + if (a == b) return false; + return getAncestors(a).contains(b) || getAncestors(b).contains(a); + } + + /** + * Returns true iff [a, b, c] is a collider. + * + * @param a The first node. + * @param b The second node. + * @param c The third node. + * @return True iff a->b<-c in the current DAG. + */ + public boolean collider(Node a, Node b, Node c) { + return getParents(b).contains(a) && getParents(b).contains(c); + } + + public boolean ancestralCollider(Node a, Node b, Node c) { + return getAncestors(b).contains(a) && getAncestors(b).contains(c); + } + + /** + * Returns true iff [a, b, c] is a triangle. + * + * @param a The first node. + * @param b The second node. + * @param c The third node. + * @return True iff adj(a, b) and adj(b, c) and adj(a, c). + */ + public boolean triangle(Node a, Node b, Node c) { + return adjacent(a, b) && adjacent(b, c) && adjacent(a, c); + } + + /** + * True iff the nodes in W form a clique in the current DAG. + * + * @param W The nodes. + * @return True iff these nodes form a clique. + */ + public boolean clique(List W) { + for (int i = 0; i < W.size(); i++) { + for (int j = i + 1; j < W.size(); j++) { + if (!adjacent(W.get(i), W.get(j))) { + return false; + } + } + } + + return true; + } + +// /** +// * A convenience method to reset the score cache if it becomes larger than a certain +// * size. +// * +// * @param maxSize The maximum size of the score cache; it the if the score cache is +// * larger than this it will be cleared. +// */ +// public void resetCacheIfTooBig(int maxSize) { +// if (this.cache.size() > maxSize) { +// this.cache = new HashMap<>(); +// System.out.println("Clearing cacche..."); +// System.gc(); +// } +// } + + private boolean violatesKnowledge(List order) { + if (!knowledge.isEmpty()) { + for (int i = 0; i < order.size(); i++) { + for (int j = i + 1; j < order.size(); j++) { + if (this.knowledge.isForbidden(order.get(i).getName(), order.get(j).getName())) { + return true; + } + } + } + } + + return false; + } + + private void initializeScores() { + for (int i1 = 0; i1 < this.pi.size(); i1++) this.prefixes.set(i1, null); + updateScores(0, this.pi.size() - 1); + } + + private void updateScores(int i1, int i2) { + for (int i = i1; i <= i2; i++) { + this.orderHash.put(this.pi.get(i), i); + this.scores.set(i, null); +// recalculate(i); + } + +// for (int i = i1; i <= i2; i++) { +// this.orderHash.put(this.pi.get(i), i); +// } +// +// int chunk = getChunkSize(i2 - i1 + 1); +// List tasks = new ArrayList<>(); +// +// for (int w = 0; w < size(); w += chunk) { +// tasks.add(new MyTask(pi, this, chunk, orderHash, w, w + chunk)); +// } +// +// ForkJoinPool.commonPool().invokeAll(tasks); + } + +// private int getChunkSize(int n) { +// int chunk = n / Runtime.getRuntime().availableProcessors(); +// if (chunk < 100) chunk = 100; +// return chunk; +// } + + private double score(Node n, Set pi) { + if (this.cachingScores) { + this.cache.computeIfAbsent(n, w -> new HashMap<>()); + Double score = this.cache.get(n).get(pi); + + if (score != null) { + return score; + } + } + + int[] parentIndices = new int[pi.size()]; + + int k = 0; + + for (Node p : pi) { + parentIndices[k++] = this.variablesHash.get(p); + } + + if (mag != null) { + ((edu.cmu.tetrad.search.MagSemBicScore) score).setMag(mag); + } + + double v = (double) this.score.localScore(this.variablesHash.get(n), parentIndices); + + if (this.cachingScores) { + this.cache.computeIfAbsent(n, w -> new HashMap<>()); + this.cache.get(n).put(new HashSet<>(pi), v); + } + + return v; + } + + public Set getPrefix(int i) { + Set prefix = new HashSet<>(); + + for (int j = 0; j < i; j++) { + prefix.add(this.pi.get(j)); + } + + return prefix; + } + + public Score getScoreObject() { + return this.score; + } + + public IndependenceTest getTestObject() { + return this.test; + } + + class MyTask implements Callable { + final List pi; + final Map orderHash; + TeyssierScorer scorer; + int chunk; + private final int from; + private final int to; + + MyTask(List pi, TeyssierScorer scorer, int chunk, Map orderHash, + int from, int to) { + this.pi = pi; + this.scorer = scorer; + this.chunk = chunk; + this.orderHash = orderHash; + this.from = from; + this.to = to; + } + + @Override + public Boolean call() throws InterruptedException { + for (int i = from; i <= to; i++) { + if (Thread.currentThread().isInterrupted()) throw new InterruptedException(); + recalculate(i); + } + + return true; + } + } + + private void recalculate(int p) { + if (this.prefixes.get(p) == null || !this.prefixes.get(p).containsAll(getPrefix(p))) { + Pair p2 = getParentsInternal(p); + if (scores.get(p) == null) { + this.runningScore += p2.score; + } else { + this.runningScore += p2.score - scores.get(p).score; + } + this.scores.set(p, p2); + } + } + + private void nodesHash(Map nodesHash, List variables) { + for (int i = 0; i < variables.size(); i++) { + nodesHash.put(variables.get(i), i); + } + } + + private boolean lastMoveSame(int i1, int i2) { + if (i1 <= i2) { + Set prefix0 = getPrefix(i1); + + for (int i = i1; i <= i2; i++) { + prefix0.add(get(i)); + if (!prefix0.equals(this.prefixes.get(i))) return false; + } + } else { + Set prefix0 = getPrefix(i1); + + for (int i = i2; i <= i1; i++) { + prefix0.add(get(i)); + if (!prefix0.equals(this.prefixes.get(i))) return false; + } + } + + return true; + } + + @NotNull + private Pair getGrowShrinkScore(int p) { + Node n = this.pi.get(p); + +// Set parents = new HashSet<>(); +// boolean changed = true; + +// double sMax = score(n, new HashSet<>()); +// List prefix = new ArrayList<>(getPrefix(p)); + + Set prefix = new HashSet<>(getPrefix(p)); + LinkedHashSet parents = new LinkedHashSet<>(); + double sMax = GST.GrowShrink(n, prefix, parents); + + return new Pair(parents, Double.isNaN(sMax) ? Double.NEGATIVE_INFINITY : sMax); + + +// // Backward scoring only from the prefix variables +// if (this.useBackwardScoring) { +// parents.addAll(prefix); +// sMax = score(n, parents); +// changed = false; +// } +// +// // Grow-shrink +// while (changed) { +// changed = false; +// +// // Let z be the node that maximizes the score... +// Node z = null; +// +// for (Node z0 : prefix) { +// if (parents.contains(z0)) continue; +// +// if (!knowledge.isEmpty() && this.knowledge.isForbidden(z0.getName(), n.getName())) continue; +// +// parents.add(z0); +// +// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +// +// double s2 = score(n, parents); +// +// if (s2 >= sMax) { +// sMax = s2; +// z = z0; +// } +// +// parents.remove(z0); +// } +// +// if (z != null) { +// parents.add(z); +// changed = true; +// } +// +// } +// +// boolean changed2 = true; +// +// while (changed2) { +// changed2 = false; +// +// Node w = null; +// +// for (Node z0 : new HashSet<>(parents)) { +// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +// +// parents.remove(z0); +// +// double s2 = score(n, parents); +// +// if (s2 > sMax) { +// sMax = s2; +// w = z0; +// } +// +// parents.add(z0); +// } +// +// if (w != null) { +// parents.remove(w); +// changed2 = true; +// } +// } +// +//// while (changed2) { +//// changed2 = false; +//// +//// List aaa = null; +//// +//// List pp = new ArrayList<>(parents); +//// +//// SublistGenerator gen = new SublistGenerator(parents.size(), 2); +//// int[] choice; +//// +//// while ((choice = gen.next()) != null) { +//// List aa = GraphUtils.asList(choice, pp); +//// +////// for (Node z0 : new HashSet<>(parents)) { +////// if (!knowledge.isEmpty() && knowledge.isRequired(z0.getName(), n.getName())) continue; +////// +//// aa.forEach(parents::remove); +//// +//// double s2 = score(n, parents); +//// +//// if (s2 > sMax) { +//// sMax = s2; +//// aaa = aa; +//// } +//// +//// parents.addAll(aa); +//// } +//// +//// if (aaa != null) { +//// aaa.forEach(parents::remove); +//// changed2 = true; +//// } +//// +//// } +// +// if (this.useScore) { +// return new Pair(parents, Double.isNaN(sMax) ? Double.NEGATIVE_INFINITY : sMax); +// } else { +// return new Pair(parents, -parents.size()); +// } + } + + private Pair getGrowShrinkIndependent(int p) { + Node n = this.pi.get(p); + + Set parents = new HashSet<>(); + + Set prefix = getPrefix(p); + + boolean changed1 = true; + + while (changed1) { + changed1 = false; + + for (Node z0 : prefix) { + if (parents.contains(z0)) continue; + if (!knowledge.isEmpty() && this.knowledge.isForbidden(z0.getName(), n.getName())) continue; + + if (!knowledge.isEmpty() && this.knowledge.isRequired(z0.getName(), n.getName())) { + parents.add(z0); + continue; + } + + if (this.test.checkIndependence(n, z0, new ArrayList<>(parents)).dependent()) { + parents.add(z0); + changed1 = true; + } + } + + for (Node z1 : new HashSet<>(parents)) { + if (!knowledge.isEmpty() && this.knowledge.isRequired(z1.getName(), n.getName())) { + continue; + } + + parents.remove(z1); + + if (this.test.checkIndependence(n, z1, new ArrayList<>(parents)).dependent()) { + parents.add(z1); + } else { + changed1 = true; + } + } + } + + return new Pair(parents, -parents.size()); + } + + private Pair getParentsInternal(int p) { + if (this.useRaskuttiUhler) { + return getRaskuttiUhlerParents(p); + } else { + if (this.useScore) { + return getGrowShrinkScore(p); + } else { + return getGrowShrinkIndependent(p); + } + } + } + + + /** + * Returns the parents of the node at index p, calculated using Pearl's method. + * + * @param p The index. + * @return The parents, as a Pair object (parents + score). + */ + private Pair getRaskuttiUhlerParents(int p) { + Node x = this.pi.get(p); + Set parents = new HashSet<>(); + Set prefix = getPrefix(p); + + for (Node y : prefix) { + Set minus = new HashSet<>(prefix); + minus.remove(y); + ArrayList z = new ArrayList<>(minus); + sort(z); + + if (this.test.checkIndependence(x, y, z).dependent()) { + parents.add(y); + } + } + + return new Pair(parents, -parents.size()); + } + + public Set> getSkeleton() { + List order = getPi(); + Set> skeleton = new HashSet<>(); + + for (Node y : order) { + for (Node x : getParents(y)) { + Set adj = new HashSet<>(); + adj.add(x); + adj.add(y); + skeleton.add(adj); + } + } + + return skeleton; + } + + +// public void moveToNoUpdate(Node v, int toIndex) { +// bookmark(-55); +// +// if (!this.pi.contains(v)) return; +// +// int vIndex = index(v); +// +// if (vIndex == toIndex) return; +// +// if (lastMoveSame(vIndex, toIndex)) return; +// +// this.pi.remove(v); +// this.pi.add(toIndex, v); +// +// if (violatesKnowledge(this.pi)) { +// goToBookmark(-55); +// } +// +// } + + public boolean parent(Node k, Node j) { + return getParents(j).contains(k); + } + + private static class Pair { + private final Set parents; + private final double score; + + private Pair(Set parents, double score) { + this.parents = parents; + this.score = score; + } + + public Set getParents() { + return this.parents; + } + + public double getScore() { + return this.score; + } + + public int hashCode() { + return this.parents.hashCode() + (int) floor(10000D * this.score); + } + + public boolean equals(Object o) { + if (o == null) return false; + if (!(o instanceof Pair)) return false; + Pair thatPair = (Pair) o; + return this.parents.equals(thatPair.parents) && this.score == thatPair.score; + } + } +} From 36083e67b36c37c1c12d3cb3867e1b857e6fccba Mon Sep 17 00:00:00 2001 From: jdramsey Date: Tue, 14 Feb 2023 12:15:10 -0500 Subject: [PATCH 9/9] Fixed problem of SEM and Bayes simulation parameters going missing on restarting the Simulation box after simulating data. --- .../tetradapp/editor/simulation/ParameterTab.java | 1 + .../algcomparison/simulation/BayesNetSimulation.java | 12 ++++++------ .../algcomparison/simulation/SemSimulation.java | 8 ++++---- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java index 38633aa56b..9494ec0546 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java @@ -217,6 +217,7 @@ private void showParameters() { this.parameterBox.removeAll(); if (this.simulation.getSimulation() != null) { Set params = new LinkedHashSet<>(this.simulation.getSimulation().getParameters()); + if (params.isEmpty()) { this.parameterBox.add(ParameterTab.NO_PARAM_LBL, BorderLayout.NORTH); } else { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java index 87a311478e..b65893c0f9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java @@ -107,13 +107,13 @@ public List getParameters() { parameters.addAll(this.randomGraph.getParameters()); } - if (this.pm == null) { - parameters.addAll(BayesPm.getParameterNames()); - } +// if (this.pm == null) { + parameters.addAll(BayesPm.getParameterNames()); +// } - if (this.im == null) { - parameters.addAll(MlBayesIm.getParameterNames()); - } +// if (this.im == null) { + parameters.addAll(MlBayesIm.getParameterNames()); +// } parameters.add(Params.NUM_RUNS); parameters.add(Params.DIFFERENT_GRAPHS); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java index 6a048323f4..77ed568ed4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java @@ -125,9 +125,9 @@ public List getParameters() { parameters.addAll(this.randomGraph.getParameters()); } - if (this.im == null) { - parameters.addAll(SemIm.getParameterNames()); - } +// if (this.im == null) { + parameters.addAll(SemIm.getParameterNames()); +// } parameters.add(Params.MEASUREMENT_VARIANCE); parameters.add(Params.NUM_RUNS); @@ -175,7 +175,7 @@ private DataSet simulate(Graph graph, Parameters parameters) { // Not setting this im messes up algcomparison. -JR 20230206 // if (this.im == null) { - this.im = im; + this.im = im; // } // Need this in case the SEM IM is given externally.