From a9f4dde0bc4aaa7dbb194b51e97aed2cb0a028d3 Mon Sep 17 00:00:00 2001
From: jdramsey
Date: Thu, 7 Jan 2021 15:27:24 -0500
Subject: [PATCH 1/4] Mainly fixing a knowledge bug but also fixing issues in
IndTestFisherZ and SemBicScore--these were using testwise deletion when they
didn't need to. To pull over the SEM BIC code from another branch, I also
needed to pull over a modest revision to the FGES code, where I shore up some
of the reasoning.
---
docs/manual/index.html | 34 +
tetrad-lib/pom.xml | 6 +
.../algcomparison/score/SemBicScore.java | 20 +-
.../java/edu/cmu/tetrad/data/Knowledge2.java | 26 +-
.../main/java/edu/cmu/tetrad/search/Fges.java | 631 ++++++++----------
.../edu/cmu/tetrad/search/IndTestFisherZ.java | 18 +
.../edu/cmu/tetrad/search/SemBicScore.java | 263 ++++----
.../main/java/edu/cmu/tetrad/util/Params.java | 2 +
.../java/edu/cmu/tetrad/test/TestGFci.java | 9 +-
9 files changed, 525 insertions(+), 484 deletions(-)
diff --git a/docs/manual/index.html b/docs/manual/index.html
index ef9c3ac1d6..cddd68a64e 100755
--- a/docs/manual/index.html
+++ b/docs/manual/index.html
@@ -4789,6 +4789,40 @@ selfLoopCoef
Value Type: Double
+ semBicRule
+
+ - Short Description: 1 = Chickering, 2 = Nandy, 3 = High Dimensional
+
+ - Long Description: The rule used for calculating a score.
+ The Chickering Rule is the local scoring consistency criterion in Chickering's formulation of
+ GES, though we allow a multiplier on the penalty term called "penalty discount".
+ The Nandy et al. rule is a reformultion of the Chickering rule using a single calculation
+ of a partial correlation in place of the likelihood difference. For the high-dimensional score
+ we use the formulation in Gao and Chen, described on the Wikipedia page for Bayesian Information
+ Criterion, which is a pseudo-BIC formulation. It has two parameters, gamma (> 0) and omega (>= 1).
+ We fix gamma at 0.25 and allow omega to be adjusted using the "penalty discount" parameter.
+ In all cases we include a structure prior, which amounts to a prior on the number of
+ parents in any scoring, a binomial formula.
+
+ - Default Value: 1
+ - Lower Bound: 1
+ - Upper Bound: 4
+ - Value Type: Integer
+
+
+ semBicStructurePrior
+
+ - Short Description: Structure Prior for SEM BIC (default 0)
+
+ - Long Description: Structure prior; default is 0 (turned off); may be
+ any positive number otherwise
+
+ - Default Value: 0
+ - Lower Bound: 0
+ - Upper Bound: Infinity
+ - Value Type: Double
+
+
skipNumRecords
- Short Description: Number of records that should be skipped between recordings (min = 0)
diff --git a/tetrad-lib/pom.xml b/tetrad-lib/pom.xml
index 8a7515b2c5..797ab1cd15 100644
--- a/tetrad-lib/pom.xml
+++ b/tetrad-lib/pom.xml
@@ -146,6 +146,12 @@
data-reader
${project.version}
+
+ org.jetbrains
+ annotations
+ RELEASE
+ compile
+
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java
index d4a748f628..eef7f35730 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/score/SemBicScore.java
@@ -42,7 +42,22 @@ public Score getScore(DataModel dataSet, Parameters parameters) {
}
semBicScore.setPenaltyDiscount(parameters.getDouble(Params.PENALTY_DISCOUNT));
- semBicScore.setStructurePrior(parameters.getDouble(Params.STRUCTURE_PRIOR));
+ semBicScore.setStructurePrior(parameters.getDouble(Params.SEM_BIC_STRUCTURE_PRIOR));
+
+ switch (parameters.getInt(Params.SEM_BIC_RULE)) {
+ case 1:
+ semBicScore.setRuleType(edu.cmu.tetrad.search.SemBicScore.RuleType.CHICKERING);
+ break;
+ case 2:
+ semBicScore.setRuleType(edu.cmu.tetrad.search.SemBicScore.RuleType.NANDY);
+ break;
+ case 3:
+ semBicScore.setRuleType(edu.cmu.tetrad.search.SemBicScore.RuleType.HIGH_DIMENSIONAL);
+ break;
+ default:
+ throw new IllegalStateException("Expecting 1, 2, or 3: " + parameters.getInt(Params.SEM_BIC_RULE));
+ }
+
return semBicScore;
}
@@ -60,7 +75,8 @@ public DataType getDataType() {
public List getParameters() {
List parameters = new ArrayList<>();
parameters.add(Params.PENALTY_DISCOUNT);
- parameters.add(Params.STRUCTURE_PRIOR);
+ parameters.add(Params.SEM_BIC_STRUCTURE_PRIOR);
+ parameters.add(Params.SEM_BIC_RULE);
return parameters;
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge2.java
index dcf84f0f61..acf928f605 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge2.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/Knowledge2.java
@@ -75,8 +75,8 @@ public final class Knowledge2 implements TetradSerializable, IKnowledge {
private boolean defaultToKnowledgeLayout;
private final Set variables;
- private final Set>> forbiddenRulesSpecs;
- private final Set>> requiredRulesSpecs;
+ private final List>> forbiddenRulesSpecs;
+ private final List>> requiredRulesSpecs;
private final List> tierSpecs;
// Legacy.
@@ -85,8 +85,8 @@ public final class Knowledge2 implements TetradSerializable, IKnowledge {
public Knowledge2() {
this.variables = new HashSet<>();
- this.forbiddenRulesSpecs = new HashSet<>();
- this.requiredRulesSpecs = new HashSet<>();
+ this.forbiddenRulesSpecs = new ArrayList<>();
+ this.requiredRulesSpecs = new ArrayList<>();
this.tierSpecs = new ArrayList<>();
this.knowledgeGroups = new LinkedList<>();
this.knowledgeGroupRules = new HashMap<>();
@@ -108,8 +108,8 @@ public Knowledge2(Knowledge2 knowledge) {
this.defaultToKnowledgeLayout = knowledge.defaultToKnowledgeLayout;
this.variables = new HashSet<>(knowledge.variables);
- this.forbiddenRulesSpecs = new HashSet<>(knowledge.forbiddenRulesSpecs);
- this.requiredRulesSpecs = new HashSet<>(knowledge.requiredRulesSpecs);
+ this.forbiddenRulesSpecs = new ArrayList<>(knowledge.forbiddenRulesSpecs);
+ this.requiredRulesSpecs = new ArrayList<>(knowledge.requiredRulesSpecs);
this.tierSpecs = new ArrayList<>(knowledge.tierSpecs);
this.knowledgeGroups = knowledge.knowledgeGroups;
@@ -420,8 +420,8 @@ public boolean isDefaultToKnowledgeLayout() {
private boolean isForbiddenByRules(String var1, String var2) {
return forbiddenRulesSpecs.stream()
.anyMatch(rule -> !var1.equals(var2)
- && rule.getFirst().contains(var1)
- && rule.getSecond().contains(var2));
+ && rule.getFirst().contains(var1)
+ && rule.getSecond().contains(var2));
}
/**
@@ -456,7 +456,7 @@ public boolean isForbiddenByGroups(String var1, String var2) {
return s.stream()
.anyMatch(rule -> rule.getFirst().contains(var1)
- && rule.getSecond().contains(var2));
+ && rule.getSecond().contains(var2));
}
/**
@@ -471,7 +471,7 @@ public boolean isForbiddenByGroups(String var1, String var2) {
public boolean isForbiddenByTiers(String var1, String var2) {
return forbiddenTierRules().stream()
.anyMatch(rule -> rule.getFirst().contains(var1)
- && rule.getSecond().contains(var2));
+ && rule.getSecond().contains(var2));
}
/**
@@ -485,8 +485,8 @@ public boolean isForbiddenByTiers(String var1, String var2) {
public boolean isRequired(String var1, String var2) {
return requiredRulesSpecs.stream()
.anyMatch(rule -> !var1.equals(var2)
- && rule.getFirst().contains(var1)
- && rule.getSecond().contains(var2));
+ && rule.getFirst().contains(var1)
+ && rule.getSecond().contains(var2));
}
/**
@@ -505,7 +505,7 @@ public boolean isRequiredByGroups(String var1, String var2) {
return s.stream()
.anyMatch(rule -> rule.getFirst().contains(var1)
- && rule.getSecond().contains(var2));
+ && rule.getSecond().contains(var2));
}
/**
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
index be6cf104c8..dd480e3e7e 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
@@ -20,9 +20,14 @@
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search;
-import edu.cmu.tetrad.data.*;
+import edu.cmu.tetrad.data.IKnowledge;
+import edu.cmu.tetrad.data.Knowledge2;
+import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.*;
-import edu.cmu.tetrad.util.*;
+import edu.cmu.tetrad.util.DepthChoiceGenerator;
+import edu.cmu.tetrad.util.TaskManager;
+import edu.cmu.tetrad.util.TetradLogger;
+import org.jetbrains.annotations.NotNull;
import java.io.PrintStream;
import java.text.DecimalFormat;
@@ -53,10 +58,20 @@
*/
public final class Fges implements GraphSearch, GraphScorer {
+ private IndependenceTest graphScore = null;
+ private boolean turning = false;
+
+ public boolean isTurning() {
+ return turning;
+ }
+
+ public void setTurning(boolean turning) {
+ this.turning = turning;
+ }
+
/**
* Internal.
*/
-
private enum Mode {
allowUnfaithfulness, heuristicSpeedup, coverNoncolliders
}
@@ -106,12 +121,12 @@ private enum Mode {
/**
* The logger for this class. The config needs to be set.
*/
- private TetradLogger logger = TetradLogger.getInstance();
+ private final TetradLogger logger = TetradLogger.getInstance();
/**
* The top n graphs found by the algorithm, where n is numPatternsToStore.
*/
- private LinkedList topGraphs = new LinkedList<>();
+ private final LinkedList topGraphs = new LinkedList<>();
/**
* True if verbose output should be printed.
@@ -119,10 +134,7 @@ private enum Mode {
private boolean verbose = false;
// Potential arrows sorted by bump high to low. The first one is a candidate for adding to the graph.
- private SortedSet sortedArrows = null;
-
- // Arrows added to sortedArrows for each .
- private Map, Set> lookupArrows = null;
+ private ConcurrentSkipListSet sortedArrows = null;
// A utility map to help with orientation.
private Map> neighbors = null;
@@ -170,6 +182,9 @@ private enum Mode {
// The maximum number of threads to use.
private final int maxThreads;
+ // Max number of nodes reoriented in any step.
+ private int tdepth = -1;
+
//===========================CONSTRUCTORS=============================//
/**
@@ -193,11 +208,15 @@ public Fges(Score score, int parallelism) {
setScore(score);
this.maxThreads = parallelism;
this.pool = new ForkJoinPool(parallelism);
- this.graph = new EdgeListGraphSingleConnections(getVariables());
+ this.graph = new EdgeListGraph(getVariables());
}
//==========================PUBLIC METHODS==========================//
+ public void setTDepth(int tdepth) {
+ this.tdepth = tdepth;
+ }
+
/**
* Set to true if it is assumed that all path pairs with one length 1 path
* do not cancel.
@@ -225,16 +244,15 @@ public Graph search() {
long start = System.currentTimeMillis();
topGraphs.clear();
- lookupArrows = new ConcurrentHashMap<>();
final List nodes = new ArrayList<>(variables);
- graph = new EdgeListGraphSingleConnections(nodes);
+ graph = new EdgeListGraph(nodes);
if (adjacencies != null) {
adjacencies = GraphUtils.replaceNodes(adjacencies, nodes);
}
if (initialGraph != null) {
- graph = new EdgeListGraphSingleConnections(initialGraph);
+ graph = new EdgeListGraph(initialGraph);
graph = GraphUtils.replaceNodes(graph, nodes);
}
@@ -245,25 +263,36 @@ public Graph search() {
// Do forward search.
this.mode = Mode.heuristicSpeedup;
+
fes();
+ turning();
bes();
this.mode = Mode.coverNoncolliders;
initializeTwoStepEdges(getVariables());
+
fes();
+ turning();
bes();
+
+ turning();
} else {
initializeForwardEdgesFromEmptyGraph(getVariables());
// Do forward search.
this.mode = Mode.heuristicSpeedup;
+
fes();
+ turning();
bes();
this.mode = Mode.allowUnfaithfulness;
initializeForwardEdgesFromExistingGraph(getVariables());
+
fes();
+ turning();
bes();
+
}
this.modelScore = scoreDag(SearchGraphUtils.dagFromPattern(graph), true);
@@ -282,6 +311,38 @@ public Graph search() {
return graph;
}
+ private void turning() {
+ if (!turning) return;
+
+ {
+ int count = 1;
+
+ do {
+ Graph ref = new EdgeListGraph(graph);
+
+ for (Edge edge : ref.getEdges()) {
+ Node x = edge.getNode1();
+ Node y = edge.getNode2();
+
+ if (graph.isAdjacentTo(x, y)) {
+ Edge _edge = graph.getEdge(x, y);
+
+ graph.removeEdge(_edge);
+
+ sortedArrows.clear();
+ calculateArrowsForward(x, y);
+ calculateArrowsForward(y, x);
+ fes();
+ bes();
+ }
+ }
+
+ if (graph.equals(ref)) break;
+ System.out.println("Turning " + count);
+ } while (count++ <= 20);
+ }
+ }
+
/**
* @return the background knowledge.
*/
@@ -311,7 +372,9 @@ public long getElapsedTime() {
* the true edges.
*/
public void setTrueGraph(Graph trueGraph) {
+ trueGraph = GraphUtils.replaceNodes(trueGraph, variables);
this.trueGraph = trueGraph;
+ this.graphScore = new IndTestDSep(trueGraph);
}
/**
@@ -344,8 +407,8 @@ public void setInitialGraph(Graph initialGraph) {
if (initialGraph != null) {
if (verbose) {
- out.println("Initial graph variables: " + initialGraph.getNodes());
- out.println("Data set variables: " + variables);
+ TetradLogger.getInstance().forceLogMessage("Initial graph variables: " + initialGraph.getNodes());
+ TetradLogger.getInstance().forceLogMessage("Data set variables: " + variables);
}
if (!new HashSet<>(initialGraph.getNodes()).equals(new HashSet<>(variables))) {
@@ -421,10 +484,10 @@ public void setBoundGraph(Graph boundGraph) {
* @deprecated Use the getters on the individual scores instead.
*/
public double getPenaltyDiscount() {
- if (score instanceof ISemBicScore) {
- return ((ISemBicScore) score).getPenaltyDiscount();
+ if (score instanceof SemBicScore) {
+ return ((SemBicScore) score).getPenaltyDiscount();
} else {
- return 2.0;
+ return 1.0;
}
}
@@ -537,7 +600,7 @@ public Boolean call() {
for (int i = from; i < to; i++) {
if ((i + 1) % 1000 == 0) {
count[0] += 1000;
- out.println("Initializing effect edges: " + (count[0]));
+ TetradLogger.getInstance().forceLogMessage("Initializing effect edges: " + (count[0]));
}
Node y = nodes.get(i);
@@ -563,122 +626,97 @@ public Boolean call() {
// start: changed by Fattaneh
int child = hashIndices.get(y);
int parent = hashIndices.get(x);
- double bump = 0.0, bump2 = 0.0;
+ double bump, bump2 = 0.0;
// if the initial graph graph is empty, proceed as usual
- if (initialGraph == null){
+ if (initialGraph == null) {
bump = score.localScoreDiff(parent, child);
- }
- else{
+ } else {
// if x or y has no adjacency in the initial graph, then proceed as if initial graph is empty
if (initialGraph.getAdjacentNodes(x).isEmpty() && initialGraph.getAdjacentNodes(y).isEmpty()) {
bump = score.localScoreDiff(parent, child);
}
// if x or y has adjacencies in the initial graph, then that should be considered in scoring
- else{
+ else {
int[] parentIndicesY;
Set parentsY = new HashSet<>(initialGraph.getParents(y));
parentIndicesY = new int[parentsY.size()];
- int c = 0;
+ int c = 0;
for (Node p : parentsY) {
parentIndicesY[c++] = hashIndices.get(p);
}
- bump = score.localScoreDiff(parent, child, parentIndicesY);
-
-// if (verbose2){
-// System.out.println("bump: " + bump);
-// System.out.println("bump w/o parents y: " + score.localScoreDiff(parent, child));
-// }
+ bump = score.localScoreDiff(parent, child, parentIndicesY);
}
}
// computing the bump of an edge from y (child) --> x (parent)
if (symmetricFirstStep) {
- if (initialGraph == null){
+ if (initialGraph == null) {
bump2 = score.localScoreDiff(child, parent);
- }
- else{
+ } else {
// if x or y has no adjacency, then proceed as an empty initial graph
if (initialGraph.getAdjacentNodes(x).isEmpty() && initialGraph.getAdjacentNodes(y).isEmpty()) {
bump2 = score.localScoreDiff(child, parent);
- }
- else{
+ } else {
int[] parentIndicesX;
Set parentsX = new HashSet<>(initialGraph.getParents(x));
parentIndicesX = new int[parentsX.size()];
- int c = 0;
+ int c = 0;
for (Node p : parentsX) {
parentIndicesX[c++] = hashIndices.get(p);
}
- bump2 = score.localScoreDiff(child, parent, parentIndicesX);
-
-// if (verbose2){
-// System.out.println("bump2: " + bump2);
-// System.out.println("bump2 w/o parents y: " + score.localScoreDiff(child, parent));
-// }
+ bump2 = score.localScoreDiff(child, parent, parentIndicesX);
}
}
- bump = bump > bump2 ? bump : bump2;
+ bump = Math.max(bump, bump2);
}
-// if (symmetricFirstStep) {
-// double bump2 = score.localScoreDiff(child, parent);
-// bump = bump > bump2 ? bump : bump2;
-// }
-
if (boundGraph != null && !boundGraph.isAdjacentTo(x, y)) {
continue;
}
- if (bump > 0) {
+ if (bump > 0.0) {
final Edge edge = Edges.undirectedEdge(x, y);
effectEdgesGraph.addEdge(edge);
}
- if (bump > 0) {
- if (initialGraph == null ){
- addArrow(x, y, emptySet, emptySet, emptySet, bump);
+ if (bump > 0.0) {
+ if (initialGraph == null) {
+ addArrow(x, y, emptySet, emptySet, emptySet, bump, new HashSet<>());
- if (!symmetricFirstStep){
- addArrow(y, x, emptySet, emptySet, emptySet, bump2);
+ if (!symmetricFirstStep) {
+ addArrow(y, x, emptySet, emptySet, emptySet, bump2, new HashSet<>());
}
- }
- else{
- if( initialGraph.getAdjacentNodes(x).isEmpty() && initialGraph.getAdjacentNodes(y).isEmpty()){
- addArrow(x, y, emptySet, emptySet, emptySet, bump);
+ } else {
+ if (initialGraph.getAdjacentNodes(x).isEmpty() && initialGraph.getAdjacentNodes(y).isEmpty()) {
+ addArrow(x, y, emptySet, emptySet, emptySet, bump, new HashSet<>());
- if (!symmetricFirstStep){
- addArrow(y, x, emptySet, emptySet, emptySet, bump2);
+ if (!symmetricFirstStep) {
+ addArrow(y, x, emptySet, emptySet, emptySet, bump2, new HashSet<>());
}
- }
- else{
-// System.out.println("x: " + x+ ", y: " + y);
-// System.out.println("sortedArrows before calculateArrowsForward: " + sortedArrows);
+ } else {
calculateArrowsForward(x, y);
-// System.out.println("sortedArrows after calculateArrowsForward: " + sortedArrows);
calculateArrowsForward(y, x);
-// System.out.println("sortedArrows after calculateArrowsForward IN REVERSE : " + sortedArrows);
}
}
}
- if (symmetricFirstStep){
+ if (symmetricFirstStep) {
if (bump2 > 0) {
- if (initialGraph == null ){
- addArrow(y, x, emptySet, emptySet, emptySet, bump2);
+ if (initialGraph == null) {
+ addArrow(y, x, emptySet, emptySet, emptySet, bump2, new HashSet<>());
- }
- else{
- if( initialGraph.getAdjacentNodes(x).isEmpty() && initialGraph.getAdjacentNodes(y).isEmpty()){
- addArrow(y, x, emptySet, emptySet, emptySet, bump2);
+ } else {
+ if (initialGraph.getAdjacentNodes(x).isEmpty() && initialGraph.getAdjacentNodes(y).isEmpty()) {
+ addArrow(y, x, emptySet, emptySet, emptySet, bump2, new HashSet<>());
}
}
@@ -693,12 +731,11 @@ public Boolean call() {
private void initializeForwardEdgesFromEmptyGraph(final List nodes) {
sortedArrows = new ConcurrentSkipListSet<>();
- lookupArrows = new ConcurrentHashMap<>();
neighbors = new ConcurrentHashMap<>();
final Set emptySet = new HashSet<>();
long start = System.currentTimeMillis();
- this.effectEdgesGraph = new EdgeListGraphSingleConnections(nodes);
+ this.effectEdgesGraph = new EdgeListGraph(nodes);
List> tasks = new ArrayList<>();
@@ -715,7 +752,7 @@ private void initializeForwardEdgesFromEmptyGraph(final List nodes) {
long stop = System.currentTimeMillis();
if (verbose) {
- out.println("Elapsed initializeForwardEdgesFromEmptyGraph = " + (stop - start) + " ms");
+ TetradLogger.getInstance().forceLogMessage("Elapsed initializeForwardEdgesFromEmptyGraph = " + (stop - start) + " ms");
}
}
@@ -723,7 +760,6 @@ private void initializeTwoStepEdges(final List nodes) {
count[0] = 0;
sortedArrows = new ConcurrentSkipListSet<>();
- lookupArrows = new ConcurrentHashMap<>();
neighbors = new ConcurrentHashMap<>();
if (this.effectEdgesGraph == null) {
@@ -742,9 +778,9 @@ private void initializeTwoStepEdges(final List nodes) {
class InitializeFromExistingGraphTask extends RecursiveTask {
- private int chunk;
- private int from;
- private int to;
+ private final int chunk;
+ private final int from;
+ private final int to;
private InitializeFromExistingGraphTask(int chunk, int from, int to) {
this.chunk = chunk;
@@ -762,7 +798,7 @@ protected Boolean compute() {
for (int i = from; i < to && !Thread.currentThread().isInterrupted(); i++) {
if ((i + 1) % 1000 == 0) {
count[0] += 1000;
- out.println("Initializing effect edges: " + (count[0]));
+ TetradLogger.getInstance().forceLogMessage("Initializing effect edges: " + (count[0]));
}
Node y = nodes.get(i);
@@ -818,7 +854,6 @@ protected Boolean compute() {
}
}
- return true;
} else {
int mid = (to + from) / 2;
@@ -829,8 +864,8 @@ protected Boolean compute() {
right.compute();
left.join();
- return true;
}
+ return true;
}
}
@@ -841,7 +876,6 @@ private void initializeForwardEdgesFromExistingGraph(final List nodes) {
count[0] = 0;
sortedArrows = new ConcurrentSkipListSet<>();
- lookupArrows = new ConcurrentHashMap<>();
neighbors = new ConcurrentHashMap<>();
if (this.effectEdgesGraph == null) {
@@ -864,9 +898,9 @@ private void initializeForwardEdgesFromExistingGraph(final List nodes) {
class InitializeFromExistingGraphTask extends RecursiveTask {
- private int chunk;
- private int from;
- private int to;
+ private final int chunk;
+ private final int from;
+ private final int to;
private InitializeFromExistingGraphTask(int chunk, int from, int to) {
this.chunk = chunk;
@@ -884,7 +918,7 @@ protected Boolean compute() {
for (int i = from; i < to && !Thread.currentThread().isInterrupted(); i++) {
if ((i + 1) % 1000 == 0) {
count[0] += 1000;
- out.println("Initializing effect edges: " + (count[0]));
+ TetradLogger.getInstance().forceLogMessage("Initializing effect edges: " + (count[0]));
}
// We want to recapture the variables that would have been effect edges if paths hadn't
@@ -893,7 +927,7 @@ protected Boolean compute() {
Node y = nodes.get(i);
Set D = new HashSet<>(getUnconditionallyDconnectedVars(y, graph));
D.remove(y);
- D.removeAll(effectEdgesGraph.getAdjacentNodes(y));
+// D.removeAll(effectEdgesGraph.getAdjacentNodes(y));
for (Node x : D) {
if (Thread.currentThread().isInterrupted()) {
@@ -918,7 +952,6 @@ protected Boolean compute() {
}
}
- return true;
} else {
int mid = (to + from) / 2;
@@ -929,8 +962,8 @@ protected Boolean compute() {
right.compute();
left.join();
- return true;
}
+ return true;
}
}
@@ -940,7 +973,6 @@ protected Boolean compute() {
private void fes() {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("** FORWARD EQUIVALENCE SEARCH");
- out.println("** FORWARD EQUIVALENCE SEARCH");
}
int maxDegree = this.maxDegree == -1 ? 1000 : this.maxDegree;
@@ -968,7 +1000,11 @@ private void fes() {
continue;
}
- if (!new HashSet<>(getTNeighbors(x, y)).equals(arrow.getTNeighbors())) {
+ if (!getTNeighbors(x, y).equals(arrow.getTNeighbors())) {
+ continue;
+ }
+
+ if (!new HashSet<>(graph.getParents(y)).equals(arrow.getParents())) {
continue;
}
@@ -985,14 +1021,13 @@ private void fes() {
Set visited = reapplyOrientation(x, y, null);
Set toProcess = new HashSet<>();
- for (Node node : visited) {
+ visited.forEach(node -> {
final Set neighbors1 = getNeighbors(node);
final Set storedNeighbors = this.neighbors.get(node);
-
if (!(neighbors1.equals(storedNeighbors))) {
toProcess.add(node);
}
- }
+ });
toProcess.add(x);
toProcess.add(y);
@@ -1010,18 +1045,20 @@ private Set getCommonAdjacents(Node x, Node y) {
private void bes() {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("** BACKWARD EQUIVALENCE SEARCH");
- out.println("** BACKWARD EQUIVALENCE SEARCH");
}
sortedArrows = new ConcurrentSkipListSet<>();
- lookupArrows = new ConcurrentHashMap<>();
neighbors = new ConcurrentHashMap<>();
initializeArrowsBackward();
while (!sortedArrows.isEmpty()) {
- Arrow arrow = sortedArrows.first();
- sortedArrows.remove(arrow);
+ Arrow arrow;
+
+ synchronized (this) {
+ arrow = sortedArrows.first();
+ sortedArrows.remove(arrow);
+ }
Node x = arrow.getA();
Node y = arrow.getB();
@@ -1040,11 +1077,21 @@ private void bes() {
continue;
}
- if (!validDelete(x, y, arrow.getHOrT(), arrow.getNaYX())) {
+ List parents = graph.getParents(y);
+ Set parents1 = arrow.getParents();
+
+ parents.remove(x);
+ parents1.remove(x);
+
+ if (!new HashSet<>(parents).equals(parents1)) {
continue;
}
- boolean deleted = delete(x, y, arrow.getHOrT(), arrow.getBump(), arrow.getNaYX());
+ boolean deleted = false;
+
+ if (validDelete(x, y, arrow.getHOrT(), arrow.getNaYX())) {
+ deleted = delete(x, y, arrow.getHOrT(), arrow.getBump(), arrow.getNaYX());
+ }
if (!deleted) {
continue;
@@ -1054,14 +1101,13 @@ private void bes() {
Set toProcess = new HashSet<>();
- for (Node node : visited) {
+ visited.forEach(node -> {
final Set neighbors1 = getNeighbors(node);
final Set storedNeighbors = this.neighbors.get(node);
-
if (!(neighbors1.equals(storedNeighbors))) {
toProcess.add(node);
}
- }
+ });
toProcess.add(x);
toProcess.add(y);
@@ -1119,8 +1165,8 @@ private void reevaluateForward(final Set nodes) {
class AdjTask implements Callable {
private final List nodes;
- private int from;
- private int to;
+ private final int from;
+ private final int to;
private AdjTask(List nodes, int from, int to) {
this.nodes = nodes;
@@ -1194,93 +1240,55 @@ public Boolean call() {
pool.invokeAll(tasks);
}
- // Calculates the new arrows for an a->b edge.
- private void calculateArrowsForward(Node a, Node b) {
- if (mode == Mode.heuristicSpeedup && !effectEdgesGraph.isAdjacentTo(a, b)) {
+ // Calculates the new arrows for an x->y edge.
+ private void calculateArrowsForward(Node x, Node y) {
+ if (mode == Mode.heuristicSpeedup && !effectEdgesGraph.isAdjacentTo(x, y)) {
return;
}
- if (adjacencies != null && !adjacencies.isAdjacentTo(a, b)) {
+ if (adjacencies != null && !adjacencies.isAdjacentTo(x, y)) {
return;
}
- this.neighbors.put(b, getNeighbors(b));
-
- clearArrow(a, b);
+ this.neighbors.put(y, getNeighbors(y));
- if (a == b) {
+ if (x == y) {
throw new IllegalArgumentException();
}
if (existsKnowledge()) {
- if (getKnowledge().isForbidden(a.getName(), b.getName())) {
+ if (getKnowledge().isForbidden(x.getName(), y.getName())) {
return;
}
}
- Set naYX = getNaYX(a, b);
- if (!isClique(naYX)) {
- return;
- }
-
- List TNeighbors = getTNeighbors(a, b);
-
- Set> previousCliques = new HashSet<>();
- previousCliques.add(new HashSet<>());
- Set> newCliques = new HashSet<>();
-
- Set _T = null;
- double _bump = Double.NEGATIVE_INFINITY;
-
- FOR:
- for (int i = 0; i <= TNeighbors.size(); i++) {
- final ChoiceGenerator gen = new ChoiceGenerator(TNeighbors.size(), i);
- int[] choice;
-
- while ((choice = gen.next()) != null) {
- Set T = GraphUtils.asSet(choice, TNeighbors);
-
- Set union = new HashSet<>(naYX);
- union.addAll(T);
-
- boolean foundAPreviousClique = false;
-
- for (Set clique : previousCliques) {
- if (union.containsAll(clique)) {
- foundAPreviousClique = true;
- break;
- }
- }
-
- if (!foundAPreviousClique) {
- break FOR;
- }
+ Set naYX = getNaYX(x, y);
+ Set tNeighbors = getTNeighbors(x, y);
+ List TNeighbors = new ArrayList<>(tNeighbors);
+ HashSet parents = new HashSet<>(graph.getParents(y));
- if (!isClique(union)) {
- continue;
- }
- newCliques.add(union);
+ int depth = this.tdepth == -1 ? Integer.MAX_VALUE : this.tdepth;
+ depth = Math.min(depth, TNeighbors.size());
- double bump = insertEval(a, b, T, naYX, hashIndices);
+ final DepthChoiceGenerator gen = new DepthChoiceGenerator(TNeighbors.size(), depth);// TNeighbors.size());
+ int[] choice;
- if (bump > 0) {
- _T = T;
- _bump = bump;
-// addArrow(a, b, TNeighbors, naYX, bump);
- }
- }
+ while ((choice = gen.next()) != null) {
+ Set T = GraphUtils.asSet(choice, TNeighbors);
+ double bump = insertEval(x, y, T, naYX, hashIndices);
- if (_bump > Double.NEGATIVE_INFINITY) {
- addArrow(a, b, _T, new HashSet<>(TNeighbors), naYX, _bump);
+ if (bump > 0.0) {
+ addArrow(x, y, T, tNeighbors, naYX, bump, parents);
}
-
- previousCliques = newCliques;
- newCliques = new HashSet<>();
}
}
- private void addArrow(Node a, Node b, Set hOrT, Set TNeighbors, Set naYX, double bump) {
- Arrow arrow = new Arrow(bump, a, b, hOrT, TNeighbors, naYX, arrowIndex++);
+ private synchronized void addArrow(Node a, Node b, Set hOrT, Set TNeighbors, Set naYX, double bump,
+ Set parents) {
+ Arrow arrow = new Arrow(bump, a, b, hOrT, TNeighbors, naYX, arrowIndex++, parents);
sortedArrows.add(arrow);
- addLookupArrow(a, b, arrow);
+
+ synchronized (this) {
+ sortedArrows.add(arrow);
+ }
}
// Reevaluates arrows after removing an edge from the graph.
@@ -1288,11 +1296,11 @@ private void reevaluateBackward(Set toProcess) {
class BackwardTask extends RecursiveTask {
private final Node r;
- private List adj;
- private Map hashIndices;
- private int chunk;
- private int from;
- private int to;
+ private final List adj;
+ private final Map hashIndices;
+ private final int chunk;
+ private final int from;
+ private final int to;
private BackwardTask(Node r, List adj, int chunk, int from, int to,
Map hashIndices) {
@@ -1323,7 +1331,6 @@ protected Boolean compute() {
}
}
- return true;
} else {
int mid = (to - from) / 2;
@@ -1334,8 +1341,8 @@ protected Boolean compute() {
invokeAll(tasks);
- return true;
}
+ return true;
}
}
@@ -1348,47 +1355,39 @@ protected Boolean compute() {
}
// Calculates the arrows for the removal in the backward direction.
- private void calculateArrowsBackward(Node a, Node b) {
+ private void calculateArrowsBackward(Node x, Node y) {
if (existsKnowledge()) {
- if (!getKnowledge().noEdgeRequired(a.getName(), b.getName())) {
+ if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) {
return;
}
}
- clearArrow(a, b);
-
- Set naYX = getNaYX(a, b);
-
+ Set naYX = getNaYX(x, y);
List _naYX = new ArrayList<>(naYX);
+ HashSet parents = new HashSet<>(graph.getParents(y));
+ parents.remove(x);
- final int _depth = _naYX.size();
-
- Set _h = null;
- double _bump = Double.NEGATIVE_INFINITY;
+ int depth = this.tdepth == -1 ? Integer.MAX_VALUE : this.tdepth;
+ depth = Math.min(depth, naYX.size());
- final DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _depth);
+ final DepthChoiceGenerator gen = new DepthChoiceGenerator(naYX.size(), depth);
int[] choice;
while ((choice = gen.next()) != null) {
Set h = GraphUtils.asSet(choice, _naYX);
if (existsKnowledge()) {
- if (invalidSetByKnowledge(b, h)) {
+ if (invalidSetByKnowledge(y, h)) {
continue;
}
}
- double bump = deleteEval(a, b, h, naYX, hashIndices);
+ double bump = deleteEval(x, y, h, naYX, hashIndices);
if (bump >= 0.0) {
- _h = h;
- _bump = bump;
+ addArrow(x, y, h, null, naYX, bump, parents);
}
}
-
- if (_bump > Double.NEGATIVE_INFINITY) {
- addArrow(a, b, _h, null, naYX, _bump);
- }
}
// Basic data structure for an arrow a->b considered for addition or removal from the graph, together with
@@ -1398,15 +1397,16 @@ private void calculateArrowsBackward(Node a, Node b) {
// as the "bump".
private static class Arrow implements Comparable {
- private double bump;
- private Node a;
- private Node b;
- private Set hOrT;
+ private final double bump;
+ private final Node a;
+ private final Node b;
+ private final Set hOrT;
private Set TNeighbors;
- private Set naYX;
- private int index;
+ private final Set naYX;
+ private final int index;
+ private final Set parents;
- Arrow(double bump, Node a, Node b, Set hOrT, Set capTorH, Set naYX, int index) {
+ Arrow(double bump, Node a, Node b, Set hOrT, Set capTorH, Set naYX, int index, Set parents) {
this.bump = bump;
this.a = a;
this.b = b;
@@ -1414,6 +1414,7 @@ private static class Arrow implements Comparable {
this.hOrT = hOrT;
this.naYX = naYX;
this.index = index;
+ this.parents = parents;
}
public double getBump() {
@@ -1442,11 +1443,7 @@ Set getNaYX() {
// The fastest way to do this is using a hash code, though it's still possible for two Arrows to have the
// same hash code but not be equal. If we're paranoid, in this case we calculate a determinate comparison
// not equal to zero by keeping a list. This last part is commened out by default.
- public int compareTo(Arrow arrow) {
- if (arrow == null) {
- throw new NullPointerException();
- }
-
+ public int compareTo(@NotNull Arrow arrow) {
final int compare = Double.compare(arrow.getBump(), getBump());
if (compare == 0) {
@@ -1464,36 +1461,24 @@ public int getIndex() {
return index;
}
- Set getTNeighbors() {
+ public Set getTNeighbors() {
return TNeighbors;
}
- void setTNeighbors(Set TNeighbors) {
+ public void setTNeighbors(Set TNeighbors) {
this.TNeighbors = TNeighbors;
}
+ public Set getParents() {
+ return parents;
+ }
}
// Get all adj that are connected to Y by an undirected edge and not adjacent to X.
- private List getTNeighbors(Node x, Node y) {
- List yEdges = graph.getEdges(y);
- List tNeighbors = new ArrayList<>();
-
- for (Edge edge : yEdges) {
- if (!Edges.isUndirectedEdge(edge)) {
- continue;
- }
-
- Node z = edge.getDistalNode(y);
-
- if (graph.isAdjacentTo(z, x)) {
- continue;
- }
-
- tNeighbors.add(z);
- }
-
- return tNeighbors;
+ private Set getTNeighbors(Node x, Node y) {
+ Set nb = getNeighbors(y);
+ nb.removeAll(graph.getAdjacentNodes(x));
+ return nb;
}
// Get all adj that are connected to Y.
@@ -1514,31 +1499,10 @@ private Set getNeighbors(Node y) {
return neighbors;
}
- // Evaluate the Insert(X, Y, TNeighbors) operator (Definition 12 from Chickering, 2002).
- private double insertEval(Node x, Node y, Set t, Set naYX,
- Map hashIndices) {
- if (x == y) {
- throw new IllegalArgumentException();
- }
- Set set = new HashSet<>(naYX);
- set.addAll(t);
- set.addAll(graph.getParents(y));
- return scoreGraphChange(x, y, set, hashIndices);
- }
-
- // Evaluate the Delete(X, Y, TNeighbors) operator (Definition 12 from Chickering, 2002).
- private double deleteEval(Node x, Node y, Set h, Set naYX,
- Map hashIndices) {
- Set set = new HashSet<>(naYX);
- set.removeAll(h);
- final List parents = graph.getParents(y);
- parents.remove(x);
- set.addAll(parents);
- return -scoreGraphChange(x, y, set, hashIndices);
- }
-
// Do an actual insertion. (Definition 12 from Chickering, 2002).
private boolean insert(Node x, Node y, Set T, double bump) {
+
+
if (graph.isAdjacentTo(x, y)) {
return false; // The initial graph may already have put this edge in the graph.
}
@@ -1557,18 +1521,19 @@ private boolean insert(Node x, Node y, Set T, double bump) {
graph.addDirectedEdge(x, y);
-// if (verbose) {
-// String label = trueGraph != null && trueEdge != null ? "*" : "";
-// TetradLogger.getInstance().forceLogMessage(graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y)
-// + " " + T + " " + bump + " " + label);
-//// out.println(graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y)
-//// + " " + T + " " + bump + " " + label);
-// }
+ if (verbose) {
+ Set set = new HashSet<>(getNaYX(x, y));
+ set.addAll(T);
+ set.addAll(graph.getParents(y));
+ String label = trueGraph != null && trueEdge != null ? "*" : "";
+ TetradLogger.getInstance().forceLogMessage("graph.getNumEdges()" + ". INSERT " + graph.getEdge(x, y)
+ + " T = " + T + " conditioning on = " + set + " " + bump + " " + label);
+ }
int numEdges = graph.getNumEdges();
if (numEdges % 1000 == 0) {
- out.println("Num edges added: " + numEdges);
+ TetradLogger.getInstance().forceLogMessage("Num edges added: " + numEdges);
}
if (verbose) {
@@ -1578,12 +1543,18 @@ private boolean insert(Node x, Node y, Set T, double bump) {
+ " degree = " + GraphUtils.getDegree(graph)
+ " indegree = " + GraphUtils.getIndegree(graph);
TetradLogger.getInstance().forceLogMessage(message);
-// out.println(message);
+
+ if (trueGraph != null) {
+ Set set = new HashSet<>(getNaYX(x, y));
+ set.addAll(T);
+ set.addAll(graph.getParents(y));
+ if (dseparated(x, y, new ArrayList<>(set))) {
+ TetradLogger.getInstance().forceLogMessage("...d-separated but judged dependent");
+ }
+ }
}
for (Node _t : T) {
- if (graph.isAdjacentTo(x, _t)) continue;
-
graph.removeEdge(_t, y);
if (boundGraph != null && !boundGraph.isAdjacentTo(_t, y)) {
continue;
@@ -1594,7 +1565,6 @@ private boolean insert(Node x, Node y, Set T, double bump) {
if (verbose) {
String message = "--- Directing " + graph.getEdge(_t, y);
TetradLogger.getInstance().forceLogMessage(message);
-// out.println(message);
}
}
@@ -1612,23 +1582,28 @@ private boolean delete(Node x, Node y, Set H, double bump, Set naYX)
}
Edge oldxy = graph.getEdge(x, y);
-
- Set diff = new HashSet<>(naYX);
- diff.removeAll(H);
-
graph.removeEdge(oldxy);
int numEdges = graph.getNumEdges();
if (numEdges % 1000 == 0) {
- out.println("Num edges (backwards) = " + numEdges);
+ TetradLogger.getInstance().forceLogMessage("Num edges (backwards) = " + numEdges);
}
if (verbose) {
+ Set set = new HashSet<>(naYX);
+ set.removeAll(H);
+ set.addAll(graph.getParents(y));
+
String label = trueGraph != null && trueEdge != null ? "*" : "";
String message = (graph.getNumEdges()) + ". DELETE " + x + "-->" + y
- + " H = " + H + " NaYX = " + naYX + " diff = " + diff + " (" + bump + ") " + label;
+ + " H = " + H + " conditioning on = " + set + " (" + bump + ") " + label;
TetradLogger.getInstance().forceLogMessage(message);
-// out.println(message);
+
+ if (trueGraph != null) {
+ if (!dseparated(x, y, new ArrayList<>(set))) {
+ TetradLogger.getInstance().forceLogMessage("...d-connected but judged independent");
+ }
+ }
}
for (Node h : H) {
@@ -1639,26 +1614,22 @@ private boolean delete(Node x, Node y, Set H, double bump, Set naYX)
Edge oldyh = graph.getEdge(y, h);
graph.removeEdge(oldyh);
-
graph.addEdge(Edges.directedEdge(y, h));
if (verbose) {
TetradLogger.getInstance().forceLogMessage("--- Directing " + oldyh + " to "
+ graph.getEdge(y, h));
- out.println("--- Directing " + oldyh + " to " + graph.getEdge(y, h));
}
Edge oldxh = graph.getEdge(x, h);
if (Edges.isUndirectedEdge(oldxh)) {
graph.removeEdge(oldxh);
-
graph.addEdge(Edges.directedEdge(x, h));
if (verbose) {
TetradLogger.getInstance().forceLogMessage("--- Directing " + oldxh + " to "
+ graph.getEdge(x, h));
-// out.println("--- Directing " + oldxh + " to " + graph.getEdge(x, h));
}
}
}
@@ -1728,7 +1699,6 @@ private void addRequiredEdges(Graph graph) {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
- out.println("Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
}
}
}
@@ -1754,7 +1724,6 @@ private void addRequiredEdges(Graph graph) {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
-// out.println("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
@@ -1766,7 +1735,6 @@ private void addRequiredEdges(Graph graph) {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
-// out.println("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
@@ -1784,7 +1752,6 @@ private void addRequiredEdges(Graph graph) {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
-// out.println("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
@@ -1795,7 +1762,6 @@ private void addRequiredEdges(Graph graph) {
if (verbose) {
TetradLogger.getInstance().forceLogMessage("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
-// out.println("Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
}
}
}
@@ -1818,24 +1784,9 @@ private boolean invalidSetByKnowledge(Node y, Set subset) {
// Find all adj that are connected to Y by an undirected edge that are adjacent to X (that is, by undirected or
// directed edge).
private Set getNaYX(Node x, Node y) {
- List adj = graph.getAdjacentNodes(y);
- Set nayx = new HashSet<>();
-
- for (Node z : adj) {
- if (z == x) {
- continue;
- }
- Edge yz = graph.getEdge(y, z);
- if (!Edges.isUndirectedEdge(yz)) {
- continue;
- }
- if (!graph.isAdjacentTo(z, x)) {
- continue;
- }
- nayx.add(z);
- }
-
- return nayx;
+ Set nb = getNeighbors(y);
+ nb.retainAll(graph.getAdjacentNodes(x));
+ return nb;
}
// Returns true iif the given set forms a clique in the given graph.
@@ -1938,32 +1889,6 @@ private void buildIndexing(List nodes) {
}
}
- // Removes information associated with an edge x->y.
- private synchronized void clearArrow(Node x, Node y) {
- final OrderedPair pair = new OrderedPair<>(x, y);
- final Set lookupArrows = this.lookupArrows.get(pair);
-
- if (lookupArrows != null) {
- sortedArrows.removeAll(lookupArrows);
- }
-
- this.lookupArrows.remove(pair);
- }
-
- // Adds the given arrow for the adjacency i->j. These all are for i->j but may have
- // different TNeighbors or H or NaYX sets, and so different bumps.
- private void addLookupArrow(Node i, Node j, Arrow arrow) {
- OrderedPair pair = new OrderedPair<>(i, j);
- Set arrows = lookupArrows.get(pair);
-
- if (arrows == null) {
- arrows = new ConcurrentSkipListSet<>();
- lookupArrows.put(pair, arrows);
- }
-
- arrows.add(arrow);
- }
-
//===========================SCORING METHODS===================//
private double scoreDag(Graph dag, boolean recordScores) {
@@ -1971,20 +1896,6 @@ private double scoreDag(Graph dag, boolean recordScores) {
Score score2 = score;
- if (score instanceof SemBicScore) {
- DataSet dataSet = ((SemBicScore) score).getDataSet();
-
- if (dataSet != null) {
- score2 = new SemBicScore(dataSet);
- } else {
- ICovarianceMatrix cov = ((SemBicScore) score).getCovariances();
-
- if (cov != null) {
- score2 = new SemBicScore(cov);
- }
- }
- }
-
dag = GraphUtils.replaceNodes(dag, getVariables());
if (dag == null) throw new NullPointerException("DAG not specified.");
@@ -2020,6 +1931,29 @@ private double scoreDag(Graph dag, boolean recordScores) {
return _score;
}
+ // Evaluate the Insert(X, Y, TNeighbors) operator (Definition 12 from Chickering, 2002).
+ private double insertEval(Node x, Node y, Set t, Set naYX,
+ Map hashIndices) {
+ Set set = new HashSet<>(naYX);
+ set.addAll(t);
+ set.addAll(graph.getParents(y));
+ if (set.contains(x)) return 0;
+ set.remove(x);
+ return scoreGraphChange(x, y, set, hashIndices);
+ }
+
+ // Evaluate the Delete(X, Y, TNeighbors) operator (Definition 12 from Chickering, 2002).
+ private double deleteEval(Node x, Node y, Set h, Set naYX,
+ Map hashIndices) {
+ Set set = new HashSet<>(naYX);
+ set.removeAll(h);
+ final List parents = graph.getParents(y);
+// if (parents.contains(x)) return NaN;
+ parents.remove(x);
+ set.addAll(parents);
+ return -(scoreGraphChange(x, y, set, hashIndices));// + 0.05 * score.getSampleSize();
+ }
+
private double scoreGraphChange(Node x, Node y, Set parents,
Map hashIndices) {
int yIndex = hashIndices.get(y);
@@ -2031,6 +1965,8 @@ private double scoreGraphChange(Node x, Node y, Set parents,
throw new IllegalArgumentException();
}
+ if (parents.contains(x)) return 0;
+
int[] parentIndices = new int[parents.size()];
int count = 0;
@@ -2093,9 +2029,8 @@ private static Set getUnconditionallyDconnectedVars(Node x, Graph graph) {
Set Y = new HashSet<>();
class EdgeNode {
-
- private Edge edge;
- private Node node;
+ private final Edge edge;
+ private final Node node;
private EdgeNode(Edge edge, Node node) {
this.edge = edge;
@@ -2161,4 +2096,12 @@ private static boolean reachable(Edge e1, Edge e2, Node a) {
return !collider;
}
+
+ private boolean dseparated(Node x, Node y, List z) {
+ if (trueGraph != null) {
+ return graphScore.isIndependent(x, y, z);
+ }
+
+ throw new IllegalArgumentException("True graph not given.");
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestFisherZ.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestFisherZ.java
index 561ca82a41..44244d89ef 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestFisherZ.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IndTestFisherZ.java
@@ -90,6 +90,24 @@ public IndTestFisherZ(DataSet dataSet, double alpha) {
throw new IllegalArgumentException("Data set must be continuous.");
}
+ if (!dataSet.existsMissingValue()) {
+ this.cor = new CorrelationMatrix(dataSet);
+ this.variables = cor.getVariables();
+ this.indexMap = indexMap(variables);
+ this.nameMap = nameMap(variables);
+ setAlpha(alpha);
+
+ Map nodesHash = new HashMap<>();
+
+ for (int j = 0; j < variables.size(); j++) {
+ nodesHash.put(variables.get(j), j);
+ }
+
+ this.nodesHash = nodesHash;
+
+ return;
+ }
+
if (!(alpha >= 0 && alpha <= 1)) {
throw new IllegalArgumentException("Alpha mut be in [0, 1]");
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java
index 6ed29025f5..133049f14a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SemBicScore.java
@@ -21,23 +21,18 @@
package edu.cmu.tetrad.search;
+import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.Matrix;
-import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.StatUtils;
-import edu.cmu.tetrad.util.Vector;
+import org.jetbrains.annotations.NotNull;
-import java.io.PrintStream;
-import java.text.DecimalFormat;
-import java.text.NumberFormat;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+import static edu.cmu.tetrad.util.MatrixUtils.convertCovToCorr;
import static java.lang.Double.NaN;
import static java.lang.Math.*;
@@ -51,7 +46,7 @@ public class SemBicScore implements Score {
// The dataset.
private DataSet dataSet;
- // The covariances.
+ // The correlation matrix.
private ICovarianceMatrix covariances;
// The variables of the covariance matrix.
@@ -60,9 +55,6 @@ public class SemBicScore implements Score {
// The sample size of the covariance matrix.
private final int sampleSize;
- // The printstream output should be sent to.
- private PrintStream out = System.out;
-
// True if verbose output should be sent to out.
private boolean verbose = false;
@@ -75,6 +67,12 @@ public class SemBicScore implements Score {
// The structure prior, 0 for standard BIC.
private double structurePrior = 0.0;
+ // Equivalent sample size
+ private Matrix matrix;
+
+ // The rule type to use.
+ private RuleType ruleType = RuleType.CHICKERING;
+
/**
* Constructs the score using a covariance matrix.
*/
@@ -83,7 +81,7 @@ public SemBicScore(ICovarianceMatrix covariances) {
throw new NullPointerException();
}
- setCovariances(new CovarianceMatrix(covariances));
+ setCovariances(covariances);
this.variables = covariances.getVariables();
this.sampleSize = covariances.getSampleSize();
this.indexMap = indexMap(this.variables);
@@ -97,6 +95,17 @@ public SemBicScore(DataSet dataSet) {
throw new NullPointerException();
}
+ if (!dataSet.existsMissingValue()) {
+// setCovariances(new CorrelationMatrix(new CovarianceMatrix(dataSet, false)));
+ setCovariances(new CovarianceMatrix(dataSet, false));
+ this.variables = covariances.getVariables();
+ this.sampleSize = covariances.getSampleSize();
+ this.indexMap = indexMap(this.variables);
+
+
+ return;
+ }
+
this.dataSet = dataSet;
this.variables = dataSet.getVariables();
@@ -106,6 +115,14 @@ public SemBicScore(DataSet dataSet) {
@Override
public double localScoreDiff(int x, int y, int[] z) {
+ if (ruleType == RuleType.NANDY) {
+ return nandyBic(x, y, z);
+ } else {
+ return localScore(y, append(z, x)) - localScore(y, z);
+ }
+ }
+
+ public double nandyBic(int x, int y, int[] z) {
double sp1 = getStructurePrior(z.length + 1);
double sp2 = getStructurePrior(z.length);
@@ -113,24 +130,18 @@ public double localScoreDiff(int x, int y, int[] z) {
Node _y = variables.get(y);
List _z = getVariableList(z);
- int n;
- double r;
-
- if (covariances == null) {
- List rows = getRows(x, z);
- rows.retainAll(getRows(y, z));
+ List rows = getRows(x, z);
- n = rows.size();
- r = partialCorrelation(_x, _y, _z, rows);
- } else {
- n = covariances.getSampleSize();
- r = partialCorrelation(_x, _y, _z, null);
+ if (rows != null) {
+ rows.retainAll(Objects.requireNonNull(getRows(y, z)));
}
- // r could be NaN if the matrix is not invertible; this NaN will be returned.
- return -n * Math.log(1.0 - r * r) - getPenaltyDiscount() * log(n)
- + signum(getStructurePrior()) * (sp1 - sp2);
-// return (localScore(y, append(z, x)) - localScore(y, z));
+ double r = partialCorrelation(_x, _y, _z, rows);
+
+ double c = getPenaltyDiscount();
+
+ return -sampleSize * log(1.0 - r * r) - c * log(sampleSize)
+ - 2.0 * (sp1 - sp2);
}
@Override
@@ -139,43 +150,69 @@ public double localScoreDiff(int x, int y) {
}
public double localScore(int i, int... parents) {
- List rows = getRows(i, parents);
+ List rows = getRows(i, parents);
- try {
- final int p = parents.length;
- int k = p + 1;
- double n = sampleSize;
+ final int p = parents.length;
- int[] ii = {i};
- Matrix X = getCov(rows, parents, parents);
- Matrix Y = getCov(rows, parents, ii);
- double s2 = getCov(rows, ii, ii).get(0, 0);
+ int k = p + 1;
- Vector coefs = getCoefs(X, Y).getColumn(0);
+ int[] all = concat(i, parents);
- for (int q = 0; q < X.rows(); q++) {
- for (int r = 0; r < X.columns(); r++) {
- s2 -= coefs.get(q) * coefs.get(r) * X.get(r, q);
- }
- }
+ // Only do this once.
+ Matrix cov = getCov(rows, all, all);
- if (s2 <= 0) {
- if (isVerbose()) {
- out.println("Nonpositive residual varianceY: resVar / varianceY = " + (s2 / getCovariances().getValue(i, i)));
- }
- return NaN;
- }
+ double m = variables.size();
+ double n = sampleSize;
- return -n * log(s2) - getPenaltyDiscount() * k * log(n)
- + 2 * getStructurePrior(parents.length);
- } catch (Exception e) {
- e.printStackTrace();
- return NaN;
+ int[] pp = indexedParents(parents);
+
+ Matrix covxx = cov.getSelection(pp, pp);
+ Matrix covxy = cov.getSelection(pp, new int[]{0});
+
+ double varey;
+
+// if (false) {
+// // Ricardo's calculation.
+// Matrix b = covxx.inverse().times(covxy);
+// varey = cov.get(0, 0);
+// Vector _cxy = covxy.getColumn(0);
+// Vector _b = b.getColumn(0);
+// varey -= _cxy.dotProduct(_b);
+// } else {
+ // My calculation.
+ Matrix b = covxx.inverse().times(covxy);
+ Matrix b2 = adjustedCoefs(p, b);
+ Matrix times = b2.transpose().times(cov).times(b2);
+ varey = times.get(0, 0);
+// }
+
+ double c = getPenaltyDiscount();
+
+ if (ruleType == RuleType.CHICKERING || ruleType == RuleType.NANDY) {
+
+ // Standard BIC, with penalty discount and structure prior.
+ return -n * log(varey) - c * k * log(n) - 2 * getStructurePrior(p);
+
+ } else if (ruleType == RuleType.HIGH_DIMENSIONAL) {
+
+ // Pseudo-BIC formula, set Wikipedia page for BIC. With penalty discount and structure prior.
+
+ // We will just take 6 * omega * (1 + gamma) to be a number >= 6. To be compatible with other scores,
+ // we will use c + 5 for this value, where c is the penalty discount. So a penalty discount of 1 (the usual)
+ // will correspond to 6 * omega * (1 + gamma) of 6, the minimum.
+
+ return -n * log(varey) - c * k * 6 * log(m) - 2 * getStructurePrior(p);
+ } else {
+ throw new IllegalStateException("That rule type is not implemented: " + ruleType);
}
}
- private Matrix getCoefs(Matrix x, Matrix y) {
- return (x.inverse()).times(y);
+ @NotNull
+ public Matrix adjustedCoefs(int p, Matrix b) {
+ Matrix byx = new Matrix(p + 1, 1);
+ byx.set(0, 0, 1);
+ for (int j = 0; j < p; j++) byx.set(j + 1, 0, -b.get(j, 0));
+ return byx;
}
/**
@@ -192,10 +229,6 @@ public double localScore(int i) {
return localScore(i, new int[0]);
}
- public void setOut(PrintStream out) {
- this.out = out;
- }
-
public double getPenaltyDiscount() {
return penaltyDiscount;
}
@@ -283,15 +316,46 @@ public boolean determines(List z, Node y) {
private void setCovariances(ICovarianceMatrix covariances) {
this.covariances = covariances;
+ this.matrix = this.covariances.getMatrix();
+
+ Matrix cor = new CorrelationMatrix(covariances).getMatrix();
+
+ double m = covariances.getSize();
+ double n = covariances.getSampleSize();
+
+ double tr = cor.times(cor).trace();
+ double rho = sqrt((1. / m) * tr - 1.) / (n - 1.);
+ double ess = n / (1. + (n - 1.) * rho);
+
+ System.out.println("n = " + n + " ess = " + ess + " rho = " + rho);
+
+ }
+
+ private static int[] append(int[] z, int x) {
+ int[] _z = Arrays.copyOf(z, z.length + 1);
+ _z[z.length] = x;
+ return _z;
+ }
+
+ private static int[] indexedParents(int[] parents) {
+ int[] pp = new int[parents.length];
+ for (int j = 0; j < pp.length; j++) pp[j] = j + 1;
+ return pp;
+ }
+
+ private static int[] concat(int i, int[] parents) {
+ int[] all = new int[parents.length + 1];
+ all[0] = i;
+ System.arraycopy(parents, 0, all, 1, parents.length);
+ return all;
}
private double getStructurePrior(int parents) {
if (abs(getStructurePrior()) <= 0) {
return 0;
} else {
- int c = variables.size();
- double p = abs(getStructurePrior()) / (double) c;
- return (parents * Math.log(p) + (c - parents) * Math.log(1.0 - p));
+ double p = (getStructurePrior()) / (variables.size());
+ return -((parents) * Math.log(p) + (variables.size() - (parents)) * Math.log(1.0 - p));
}
}
@@ -313,54 +377,9 @@ private Map indexMap(List variables) {
return indexMap;
}
- // Prints a smallest subset of parents that causes a singular matrix exception.
-// private boolean printMinimalLinearlyDependentSet(int[] parents, ICovarianceMatrix cov) {
-// List _parents = new ArrayList<>();
-// for (int p : parents) _parents.add(variables.get(p));
-//
-// DepthChoiceGenerator gen = new DepthChoiceGenerator(_parents.size(), _parents.size());
-// int[] choice;
-//
-// while ((choice = gen.next()) != null) {
-// int[] sel = new int[choice.length];
-// List _sel = new ArrayList<>();
-// for (int m = 0; m < choice.length; m++) {
-// sel[m] = parents[m];
-// _sel.add(variables.get(sel[m]));
-// }
-//
-// Matrix m = cov.getSelection(sel, sel);
-//
-// try {
-// m.inverse();
-// } catch (Exception e2) {
-// out.println("### Linear dependence among variables: " + _sel);
-// out.println("### Removing " + _sel.get(0));
-// return true;
-// }
-// }
-//
-// return false;
-// }
-
-// private int[] append(int[] parents, int extra) {
-// int[] all = new int[parents.length + 1];
-// System.arraycopy(parents, 0, all, 0, parents.length);
-// all[parents.length] = extra;
-// return all;
-// }
-
- /**
- * @return a string representation of this score.
- */
- public String toString() {
- NumberFormat nf = new DecimalFormat("0.00");
- return "SEM BIC Score penalty " + nf.format(penaltyDiscount);
- }
-
private Matrix getCov(List rows, int[] _rows, int[] cols) {
- if (getCovariances() != null) {
- return getCovariances().getSelection(_rows, cols);
+ if (rows == null) {
+ return convertCovToCorr(getCovariances().getSelection(_rows, cols));
}
Matrix cov = new Matrix(_rows.length, cols.length);
@@ -389,17 +408,12 @@ private Matrix getCov(List rows, int[] _rows, int[] cols) {
}
}
- return cov;
+ return convertCovToCorr(cov);
}
private List getRows(int i, int[] parents) {
if (dataSet == null) {
- List rows = new ArrayList<>();
- for (int k = 0; k < getSampleSize(); k++) {
- rows.add(k);
- }
-
- return rows;
+ return null;
}
List rows = new ArrayList<>();
@@ -418,11 +432,10 @@ private List getRows(int i, int[] parents) {
return rows;
}
- private double partialCorrelation(Node x, Node y, List z, List rows) {
+ private double partialCorrelation(Node x, Node y, List z, List rows) {
try {
- return StatUtils.partialCorrelation(MatrixUtils.convertCovToCorr(getCov(rows, indices(x, y, z))));
+ return StatUtils.partialCorrelation(convertCovToCorr(getCov(rows, indices(x, y, z))));
} catch (Exception e) {
-// e.printStackTrace();
return NaN;
}
}
@@ -437,7 +450,7 @@ private int[] indices(Node x, Node y, List z) {
private Matrix getCov(List rows, int[] cols) {
if (dataSet == null) {
- return getCovariances().getMatrix().getSelection(cols, cols);
+ return matrix.getSelection(cols, cols);
}
Matrix cov = new Matrix(cols.length, cols.length);
@@ -488,6 +501,12 @@ private Matrix getCov(List rows, int[] cols) {
return cov;
}
+
+ public void setRuleType(RuleType ruleType) {
+ this.ruleType = ruleType;
+ }
+
+ public enum RuleType {CHICKERING, NANDY, HIGH_DIMENSIONAL}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java
index 344785765a..6b22ba0de4 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java
@@ -188,6 +188,8 @@ public final class Params {
// System prameters that are not supposed to put in the HTML manual documentation
public static final String PRINT_STREAM = "printStream";
+ public static final String SEM_BIC_RULE = "semBicRule";
+ public static final String SEM_BIC_STRUCTURE_PRIOR = "semBicStructurePrior";
// All parameters that are found in HTML manual documentation
private static final Set ALL_PARAMS_IN_HTML_MANUAL = new HashSet<>(Arrays.asList(
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java
index 2784d2e013..8d8dfa831e 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java
@@ -27,6 +27,8 @@
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.*;
import edu.cmu.tetrad.sem.LargeScaleSimulation;
+import edu.cmu.tetrad.sem.SemIm;
+import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.DataConvertUtils;
import edu.cmu.tetrad.util.DelimiterUtils;
import edu.cmu.tetrad.util.RandomUtil;
@@ -200,10 +202,10 @@ public void testFromData() {
Graph g = GraphUtils.randomGraphRandomForwardEdges(variables, numLatents, numEdges, 10, 10, 10, false, false);
- LargeScaleSimulation semSimulator = new LargeScaleSimulation(g);
- semSimulator.setErrorsNormal(true);
+ SemPm pm = new SemPm(g);
+ SemIm im = new SemIm(pm);
- DataSet data = semSimulator.simulateDataFisher(sampleSize);
+ DataSet data = im.simulateData(1000, false);
data = DataUtils.restrictToMeasured(data);
@@ -211,6 +213,7 @@ public void testFromData() {
IndependenceTest test = new IndTestFisherZ(data, 0.001);
SemBicScore score = new SemBicScore(data);
+ score.setRuleType(SemBicScore.RuleType.CHICKERING);
score.setPenaltyDiscount(2);
GFci gFci = new GFci(test, score);
gFci.setFaithfulnessAssumed(true);
From 1f2d1cdb5e3bd8105b078c6a2028e91ace7a06c4 Mon Sep 17 00:00:00 2001
From: jdramsey
Date: Fri, 8 Jan 2021 00:30:17 -0500
Subject: [PATCH 2/4] Also pulling over a fix for an annoying interface bug, in
the simulation editor, where if you click the simulate bug twice you don't
get a news simulation. You have to select a new model type. Fixed this.
Affected several files.
For FGES, decided to add the turning and tdepth parameters. Adjusted the manual.
---
docs/manual/index.html | 32 +++++-
.../editor/simulation/ParameterTab.java | 59 ++++++-----
.../edu/cmu/tetradapp/model/Simulation.java | 2 +-
.../cmu/tetrad/algcomparison/Comparison.java | 38 ++++---
.../algcomparison/TimeoutComparison.java | 52 +++++-----
.../algorithm/oracle/pattern/Fges.java | 8 +-
.../simulation/BayesNetSimulation.java | 4 +-
.../simulation/BooleanGlassSimulation.java | 5 +-
.../ConditionalGaussianSimulation.java | 4 +-
.../simulation/GeneralSemSimulation.java | 4 +-
.../GeneralSemSimulationSpecial1.java | 4 +-
.../simulation/LeeHastieSimulation.java | 4 +-
.../simulation/LinearFisherModel.java | 4 +-
.../simulation/LinearSineSimulation.java | 4 +-
.../simulation/SemSimulation.java | 4 +-
.../simulation/SemThenDiscretize.java | 2 +-
.../algcomparison/simulation/Simulation.java | 2 +-
.../simulation/StandardizedSemSimulation.java | 4 +-
.../simulation/TimeSeriesSemSimulation.java | 4 +-
.../LoadContinuousDataAndGraphs.java | 4 +-
.../LoadContinuousDataAndSingleGraph.java | 2 +-
.../LoadContinuousDataSmithSim.java | 3 +-
.../data/simulation/LoadDataAndGraphs.java | 3 +-
.../LoadDataFromFileWithoutGraph.java | 2 +-
.../main/java/edu/cmu/tetrad/util/Params.java | 2 +
.../LoadContinuousDataAndSingleGraph.java | 3 +-
.../LoadContinuousDataAndSingleGraphKun.java | 2 +-
.../test/LoadContinuousDataSmithSim.java | 2 +-
.../edu/cmu/tetrad/test/LoadMadelynData.java | 2 +-
.../edu/cmu/tetrad/test/SpecialDataClark.java | 7 +-
.../java/edu/cmu/tetrad/test/TestFges.java | 99 +++++++++++++++----
31 files changed, 250 insertions(+), 121 deletions(-)
diff --git a/docs/manual/index.html b/docs/manual/index.html
index cddd68a64e..23efff326b 100755
--- a/docs/manual/index.html
+++ b/docs/manual/index.html
@@ -2253,8 +2253,12 @@
Output Format
Parameters
samplePrior, structurePrior, penaltyDiscount,
- symmetricFirstStep, faithfulnessAssumed, maxDegree
+ symmetricFirstStep,
+ faithfulnessAssumed,
+ maxDegree
+ tDepth
+ turning
+
The IMaGES Discrete Algorithm (BDeu Score)
@@ -4892,6 +4896,19 @@ targetName
Value Type: String
+ tDepth
+
+ - Short Description: "T-Depth", the maximum number of neighbors considered in power set calculations
+
+ - Long Description: For FGES, this is the maximum number of T-neighbors or H-complement-neights
+ that are considered in any scoring step. Default is -1 (unlimited).
+
+ - Default Value: -1
+ - Lower Bound: -1
+ - Upper Bound: 2147483647
+ - Value Type: Integer
+
+
thr
- Short Description: THR parameter (GLASSO) (min = 0.0)
@@ -4948,6 +4965,17 @@ twoCycleAlpha
- Value Type: Double
+ turning
+
+ - Short Description: Yes, if the turning step should be included
+
+ - Long Description: Performs a turning step similar to (but not quite teh same as) that of PCALG's GES. The search in FGES repeats the sequence FES followed by BES, twice. The turning step is inserted between FGES and GES, if the user opts to use it. It is not used by default. The turning step is still somewhat experimental and may change in the future.
+ - Default Value: false
+ - Lower Bound:
g
+ - Upper Bound:
+ - Value Type: Boolean
+
+
upperBound
- Short Description: Upper bound cutoff threshold
diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java
index b1e837675b..767241ee19 100644
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/simulation/ParameterTab.java
@@ -43,6 +43,8 @@
import edu.cmu.tetradapp.ui.PaddingPanel;
import edu.cmu.tetradapp.util.ParameterComponents;
import edu.cmu.tetradapp.util.WatchedProcess;
+import org.jetbrains.annotations.NotNull;
+
import java.awt.BorderLayout;
import java.awt.Component;
import java.awt.Dimension;
@@ -68,20 +70,20 @@ public class ParameterTab extends JPanel {
private static final long serialVersionUID = 7074205549192562786L;
private static final String[] GRAPH_ITEMS = new String[]{
- GraphTypes.RANDOM_FOWARD_DAG,
- GraphTypes.SCALE_FREE_DAG,
- GraphTypes.CYCLIC_CONSTRUCTED_FROM_SMALL_LOOPS,
- GraphTypes.RANDOM_ONE_FACTOR_MIM,
- GraphTypes.RANDOM_TWO_FACTOR_MIM
+ GraphTypes.RANDOM_FOWARD_DAG,
+ GraphTypes.SCALE_FREE_DAG,
+ GraphTypes.CYCLIC_CONSTRUCTED_FROM_SMALL_LOOPS,
+ GraphTypes.RANDOM_ONE_FACTOR_MIM,
+ GraphTypes.RANDOM_TWO_FACTOR_MIM
};
private static final String[] SOURCE_GRAPH_ITEMS = {
- SimulationTypes.BAYS_NET,
- SimulationTypes.STRUCTURAL_EQUATION_MODEL,
- SimulationTypes.LINEAR_FISHER_MODEL,
- SimulationTypes.LEE_AND_HASTIE,
- SimulationTypes.CONDITIONAL_GAUSSIAN,
- SimulationTypes.TIME_SERIES
+ SimulationTypes.BAYS_NET,
+ SimulationTypes.STRUCTURAL_EQUATION_MODEL,
+ SimulationTypes.LINEAR_FISHER_MODEL,
+ SimulationTypes.LEE_AND_HASTIE,
+ SimulationTypes.CONDITIONAL_GAUSSIAN,
+ SimulationTypes.TIME_SERIES
};
private static final JLabel NO_PARAM_LBL = new JLabel("No parameters to edit");
@@ -123,6 +125,16 @@ private void initComponents() {
}
private void refreshParameters() {
+ RandomGraph randomGraph = newRandomGraph();
+ newSimulation(randomGraph);
+
+ showParameters();
+
+ firePropertyChange("refreshParameters", null, null);
+ }
+
+ @NotNull
+ private RandomGraph newRandomGraph() {
RandomGraph randomGraph = (simulation.getSourceGraph() == null)
? new SingleGraph(new EdgeListGraph())
: new SingleGraph(simulation.getSourceGraph());
@@ -151,7 +163,10 @@ private void refreshParameters() {
throw new IllegalArgumentException("Unrecognized simulation type: " + graphItem);
}
}
+ return randomGraph;
+ }
+ private void newSimulation(RandomGraph randomGraph) {
if (!simulation.isFixedSimulation()) {
String simulationItem = simulationsDropdown.getItemAt(simulationsDropdown.getSelectedIndex());
simulation.getParams().set("simulationsDropdownPreference", simulationItem);
@@ -209,10 +224,6 @@ private void refreshParameters() {
}
}
}
-
- showParameters();
-
- firePropertyChange("refreshParameters", null, null);
}
private void showParameters() {
@@ -289,7 +300,7 @@ private Box createSimulationOptionBox() {
simulation.getParams().getString("simulationsDropdownPreference", simulationItems[0]));
simulationsDropdown.addActionListener(e -> refreshParameters());
- simOptBox.add(createLabeledComponent("For a New Simulation, Select (or re-select) Type: ", simulationsDropdown));
+ simOptBox.add(createLabeledComponent("Type of Simulation: ", simulationsDropdown));
simOptBox.add(Box.createVerticalStrut(20));
return simOptBox;
@@ -300,7 +311,9 @@ private void simulate() {
@Override
public void watch() {
try {
- simulation.getSimulation().createData(simulation.getParams());
+ RandomGraph randomGraph = newRandomGraph();
+ newSimulation(randomGraph);
+ simulation.getSimulation().createData(simulation.getParams(), false);
firePropertyChange("modelChanged", null, null);
} catch (Exception exception) {
@@ -339,27 +352,27 @@ private String[] getSimulationItems(Simulation simulation) {
if (simulation.isFixedSimulation()) {
if (simulation.getSimulation() instanceof BayesNetSimulation) {
items = new String[]{
- SimulationTypes.BAYS_NET
+ SimulationTypes.BAYS_NET
};
} else if (simulation.getSimulation() instanceof SemSimulation) {
items = new String[]{
- SimulationTypes.STRUCTURAL_EQUATION_MODEL
+ SimulationTypes.STRUCTURAL_EQUATION_MODEL
};
} else if (simulation.getSimulation() instanceof LinearFisherModel) {
items = new String[]{
- SimulationTypes.LINEAR_FISHER_MODEL
+ SimulationTypes.LINEAR_FISHER_MODEL
};
} else if (simulation.getSimulation() instanceof StandardizedSemSimulation) {
items = new String[]{
- SimulationTypes.STANDARDIZED_STRUCTURAL_EQUATION_MODEL
+ SimulationTypes.STANDARDIZED_STRUCTURAL_EQUATION_MODEL
};
} else if (simulation.getSimulation() instanceof GeneralSemSimulation) {
items = new String[]{
- SimulationTypes.GENERAL_STRUCTURAL_EQUATION_MODEL
+ SimulationTypes.GENERAL_STRUCTURAL_EQUATION_MODEL
};
} else if (simulation.getSimulation() instanceof LoadContinuousDataAndGraphs) {
items = new String[]{
- SimulationTypes.LOADED_FROM_FILES
+ SimulationTypes.LOADED_FROM_FILES
};
} else {
throw new IllegalStateException("Not expecting that model type: "
diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java
index 27abeba16f..b5a35d7b9c 100644
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Simulation.java
@@ -282,7 +282,7 @@ public void createSimulation() {
// Every time the users click the Simulate button, new data needs to be created
// regardless of already created data - Zhou
//if (simulation.getNumDataModels() == 0) {
- simulation.createData(parameters);
+ simulation.createData(parameters, false);
//}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java
index 764516c52d..bde96e9bec 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java
@@ -112,6 +112,7 @@ public enum ComparisonGraph {
private String resultsPath = null;
private boolean parallelized = false;
private boolean savePatterns = false;
+ private boolean saveData = true;
private boolean savePags = false;
// private boolean saveTrueDags = false;
private ArrayList dirs = null;
@@ -264,7 +265,7 @@ public void compareFromSimulations(String resultsPath, Simulations simulations,
List wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
- wrapper.createData(wrapper.getSimulationSpecificParameters());
+ wrapper.createData(wrapper.getSimulationSpecificParameters(), true);
simulationWrappers.add(wrapper);
}
}
@@ -543,7 +544,7 @@ public void saveToFiles(String dataPath, Simulation simulation, Parameters param
parameters.set(param, simulationWrapper.getValue(param));
}
- simulationWrapper.createData(simulationWrapper.getSimulationSpecificParameters());
+ simulationWrapper.createData(simulationWrapper.getSimulationSpecificParameters(), false);
File subdir = dir;
if (simulationWrappers.size() > 1) {
@@ -585,11 +586,13 @@ public void saveToFiles(String dataPath, Simulation simulation, Parameters param
GraphUtils.saveGraph(graph, file2, false);
- File file = new File(dir2, "data." + (j + 1) + ".txt");
- Writer out = new FileWriter(file);
- DataModel dataModel = (DataModel) simulationWrapper.getDataModel(j);
- DataWriter.writeRectangularData((DataSet) dataModel, out, '\t');
- out.close();
+ if (isSaveData()) {
+ File file = new File(dir2, "data." + (j + 1) + ".txt");
+ Writer out = new FileWriter(file);
+ DataModel dataModel = (DataModel) simulationWrapper.getDataModel(j);
+ DataWriter.writeRectangularData((DataSet) dataModel, out, '\t');
+ out.close();
+ }
if (isSavePatterns()) {
File file3 = new File(dir3, "pattern." + (j + 1) + ".txt");
@@ -1148,6 +1151,20 @@ public void setSavePags(boolean savePags) {
// this.saveTrueDags = saveTrueDags;
// }
+ /**
+ * @return True if patterns should be saved out.
+ */
+ public boolean isSaveData() {
+ return this.saveData;
+ }
+
+ /**
+ * @return True if patterns should be saved out.
+ */
+ public void setSaveData(boolean saveData) {
+ this.saveData = saveData;
+ }
+
/**
* @return True iff tables should be tab delimited (e.g. for easy pasting
* into Excel).
@@ -1939,7 +1956,6 @@ public AlgorithmWrapper getAlgorithmWrapper() {
}
private class SimulationWrapper implements Simulation {
-
static final long serialVersionUID = 23L;
private Simulation simulation;
private List graphs;
@@ -1955,12 +1971,12 @@ public SimulationWrapper(Simulation simulation, Parameters parameters) {
}
@Override
- public void createData(Parameters parameters) {
- simulation.createData(parameters);
+ public void createData(Parameters parameters, boolean newModel) {
+ simulation.createData(parameters, newModel);
this.graphs = new ArrayList<>();
this.dataModels = new ArrayList<>();
for (int i = 0; i < simulation.getNumDataModels(); i++) {
- this.graphs.add(new EdgeListGraph(simulation.getTrueGraph(i)));
+ this.graphs.add(simulation.getTrueGraph(i));
this.dataModels.add(simulation.getDataModel(i));
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java
index 9930a60a0e..8a468e7204 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java
@@ -85,7 +85,7 @@ public enum ComparisonGraph {
private ComparisonGraph comparisonGraph = ComparisonGraph.true_DAG;
public void compareFromFiles(String filePath, Algorithms algorithms,
- Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
+ Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
compareFromFiles(filePath, filePath, algorithms, statistics, parameters, timeout, unit);
}
@@ -101,7 +101,7 @@ public void compareFromFiles(String filePath, Algorithms algorithms,
* @param parameters The list of parameters and their values.
*/
public void compareFromFiles(String dataPath, String resultsPath, Algorithms algorithms,
- Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
+ Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
for (Algorithm algorithm : algorithms.getAlgorithms()) {
if (algorithm instanceof ExternalAlgorithm) {
throw new IllegalArgumentException("Not expecting any implementations of ExternalAlgorithm here.");
@@ -141,13 +141,13 @@ public void compareFromFiles(String dataPath, String resultsPath, Algorithms alg
}
public void generateReportFromExternalAlgorithms(String dataPath, String resultsPath, Algorithms algorithms,
- Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
+ Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
generateReportFromExternalAlgorithms(dataPath, resultsPath, "Comparison.txt", algorithms,
statistics, parameters, timeout, unit);
}
public void generateReportFromExternalAlgorithms(String dataPath, String resultsPath, String outputFileName, Algorithms algorithms,
- Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
+ Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.saveGraphs = false;
this.dataPath = dataPath;
@@ -190,7 +190,7 @@ public void generateReportFromExternalAlgorithms(String dataPath, String results
}
public void compareFromSimulations(String resultsPath, Simulations simulations, Algorithms algorithms,
- Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
+ Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
compareFromSimulations(resultsPath, simulations, "Comparison.txt", algorithms, statistics, parameters, timeout, unit);
}
@@ -205,7 +205,7 @@ public void compareFromSimulations(String resultsPath, Simulations simulations,
* algorithm, and their utility weights.
*/
public void compareFromSimulations(String resultsPath, Simulations simulations, String outputFileName, Algorithms algorithms,
- Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
+ Statistics statistics, Parameters parameters, long timeout, TimeUnit unit) {
this.resultsPath = resultsPath;
// Create output file.
@@ -230,7 +230,7 @@ public void compareFromSimulations(String resultsPath, Simulations simulations,
List wrappers = getSimulationWrappers(simulation, parameters);
for (SimulationWrapper wrapper : wrappers) {
- wrapper.createData(wrapper.getSimulationSpecificParameters());
+ wrapper.createData(wrapper.getSimulationSpecificParameters(), false);
simulationWrappers.add(wrapper);
}
}
@@ -488,7 +488,7 @@ public void saveToFiles(String dataPath, Simulation simulation, Parameters param
parameters.set(param, simulationWrapper.getValue(param));
}
- simulationWrapper.createData(simulationWrapper.getSimulationSpecificParameters());
+ simulationWrapper.createData(simulationWrapper.getSimulationSpecificParameters(), false);
index++;
File subdir = new File(dir, "" + index);
@@ -743,7 +743,7 @@ public void configuration(String path) {
}
}
-// private void printParameters(HasParameters hasParameters, PrintStream out, Parameters allParams) {
+ // private void printParameters(HasParameters hasParameters, PrintStream out, Parameters allParams) {
// List paramDescriptions = new ArrayList<>(hasParameters.getParameters());
// if (paramDescriptions.isEmpty()) return;
// out.print("\tParameters: ");
@@ -827,8 +827,8 @@ private List getSimulationWrappers(Simulation simulation, Par
}
private double[][][][] calcStats(final List algorithmSimulationWrappers,
- List algorithmWrappers, List simulationWrappers,
- Statistics statistics, int numRuns, long timeout, TimeUnit unit) {
+ List algorithmWrappers, List simulationWrappers,
+ Statistics statistics, int numRuns, long timeout, TimeUnit unit) {
int numGraphTypes = 4;
graphTypeUsed = new boolean[4];
@@ -1051,8 +1051,8 @@ private class AlgorithmTask implements Callable {
private final Run run;
public AlgorithmTask(List algorithmSimulationWrappers,
- List algorithmWrappers, List simulationWrappers,
- Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
+ List algorithmWrappers, List simulationWrappers,
+ Statistics statistics, int numGraphTypes, double[][][][] allStats, Run run) {
this.algorithmSimulationWrappers = algorithmSimulationWrappers;
this.simulationWrappers = simulationWrappers;
this.algorithmWrappers = algorithmWrappers;
@@ -1116,9 +1116,9 @@ private void deleteFilesThenDirectory(File dir) {
}
private void doRun(List algorithmSimulationWrappers,
- List algorithmWrappers, List simulationWrappers,
- Statistics statistics,
- int numGraphTypes, double[][][][] allStats, Run run) {
+ List algorithmWrappers, List simulationWrappers,
+ Statistics statistics,
+ int numGraphTypes, double[][][][] allStats, Run run) {
System.out.println();
System.out.println("Run " + (run.getRunIndex() + 1));
System.out.println();
@@ -1267,7 +1267,7 @@ private void doRun(List algorithmSimulationWrappers,
}
private void saveGraph(String resultsPath, Graph graph, int i, int simIndex, int algIndex,
- AlgorithmWrapper algorithmWrapper, long elapsed) {
+ AlgorithmWrapper algorithmWrapper, long elapsed) {
if (!saveGraphs) {
return;
}
@@ -1332,7 +1332,7 @@ private String getHeader(int u) {
}
private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numTables,
- List wrappers, int numStats, Statistics statistics) {
+ List wrappers, int numStats, Statistics statistics) {
double[][][] statTables = new double[numTables][wrappers.size()][numStats + 1];
for (int u = 0; u < numTables; u++) {
@@ -1396,10 +1396,10 @@ private double[][][] calcStatTables(double[][][][] allStats, Mode mode, int numT
}
private void printStats(double[][][] statTables, Statistics statistics, Mode mode, int[] newOrder,
- List algorithmSimulationWrappers,
- List algorithmWrappers,
- List simulationWrappers, double[] utilities,
- Parameters parameters) {
+ List algorithmSimulationWrappers,
+ List algorithmWrappers,
+ List simulationWrappers, double[] utilities,
+ Parameters parameters) {
if (mode == Mode.Average) {
out.println("AVERAGE STATISTICS");
@@ -1519,7 +1519,7 @@ private void printStats(double[][][] statTables, Statistics statistics, Mode mod
}
private double[] calcUtilities(Statistics statistics, List wrappers,
- double[][] stats) {
+ double[][] stats) {
// Calculate utilities for the first table.
double[] utilities = new double[wrappers.size()];
@@ -1551,7 +1551,7 @@ private double[] calcUtilities(Statistics statistics, List algorithmSimulationWrappers,
- final double[] utilities) {
+ final double[] utilities) {
List order = new ArrayList<>();
for (int t = 0; t < algorithmSimulationWrappers.size(); t++) {
order.add(t);
@@ -1763,8 +1763,8 @@ public SimulationWrapper(Simulation simulation, Parameters parameters) {
}
@Override
- public void createData(Parameters parameters) {
- simulation.createData(parameters);
+ public void createData(Parameters parameters, boolean newModel) {
+ simulation.createData(parameters, false);
this.graphs = new ArrayList<>();
this.dataModels = new ArrayList<>();
for (int i = 0; i < simulation.getNumDataModels(); i++) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java
index 0e6d4e3500..bf70b2acb8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java
@@ -66,9 +66,9 @@ public Fges(ScoreWrapper score, Algorithm algorithm) {
@Override
public Graph search(DataModel dataSet, Parameters parameters) {
if (parameters.getInt(Params.NUMBER_RESAMPLING) < 1) {
- if (algorithm != null) {
+// if (algorithm != null) {
// initialGraph = algorithm.search(dataSet, parameters);
- }
+// }
edu.cmu.tetrad.search.Fges search
= new edu.cmu.tetrad.search.Fges(score.getScore(dataSet, parameters), Runtime.getRuntime().availableProcessors());
@@ -77,6 +77,8 @@ public Graph search(DataModel dataSet, Parameters parameters) {
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
search.setMaxDegree(parameters.getInt(Params.MAX_DEGREE));
search.setSymmetricFirstStep(parameters.getBoolean(Params.SYMMETRIC_FIRST_STEP));
+ search.setTDepth(parameters.getInt(Params.TDEPTH));
+ search.setTurning(parameters.getBoolean(Params.TURNING));
Object obj = parameters.get(Params.PRINT_STREAM);
if (obj instanceof PrintStream) {
@@ -145,6 +147,8 @@ public List getParameters() {
parameters.add(Params.FAITHFULNESS_ASSUMED);
parameters.add(Params.SYMMETRIC_FIRST_STEP);
parameters.add(Params.MAX_DEGREE);
+ parameters.add(Params.TDEPTH);
+ parameters.add(Params.TURNING);
parameters.add(Params.VERBOSE);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
index 8d924bf369..b4d5eb4116 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
@@ -47,8 +47,8 @@ public BayesNetSimulation(BayesIm im) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
Graph graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BooleanGlassSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BooleanGlassSimulation.java
index e11e0d3bce..6ba28d9837 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BooleanGlassSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BooleanGlassSimulation.java
@@ -13,6 +13,7 @@
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.TimeLagGraph;
import edu.cmu.tetrad.util.Parameters;
+
import java.util.*;
/**
@@ -33,8 +34,8 @@ public BooleanGlassSimulation(RandomGraph graph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
this.graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
index 2fd9cb395e..42a7ed2fd8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
@@ -37,8 +37,8 @@ public ConditionalGaussianSimulation(RandomGraph graph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
setVarLow(parameters.getDouble(Params.VAR_LOW));
setVarHigh(parameters.getDouble(Params.VAR_HIGH));
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java
index e8598b5654..a5be6cc8e3 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulation.java
@@ -54,8 +54,8 @@ public GeneralSemSimulation(GeneralizedSemIm im) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
Graph graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java
index cc9d4d7c59..5a118b6d94 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/GeneralSemSimulationSpecial1.java
@@ -33,8 +33,8 @@ public GeneralSemSimulationSpecial1(RandomGraph randomGraph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
Graph graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java
index 5f9d5ac27a..875ebf9ea1 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LeeHastieSimulation.java
@@ -36,8 +36,8 @@ public LeeHastieSimulation(RandomGraph graph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
double percentDiscrete = parameters.getDouble(Params.PERCENT_DISCRETE);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java
index 848de1e052..d04daa6bd1 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearFisherModel.java
@@ -54,8 +54,8 @@ public LinearFisherModel(RandomGraph graph, List shocks) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
boolean saveLatentVars = parameters.getBoolean(Params.SAVE_LATENT_VARS);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java
index 68dfcf724e..88258340ea 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/LinearSineSimulation.java
@@ -42,8 +42,8 @@ public LinearSineSimulation(RandomGraph graph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
setInterceptLow(parameters.getDouble("interceptLow"));
setInterceptHigh(parameters.getDouble("interceptHigh"));
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java
index 144b0a1ae7..a2480c31b6 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemSimulation.java
@@ -51,8 +51,8 @@ public SemSimulation(SemIm im) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
Graph graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java
index 397857f150..72ca0c5ff2 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/SemThenDiscretize.java
@@ -37,7 +37,7 @@ public SemThenDiscretize(RandomGraph randomGraph, DataType dataType) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
double percentDiscrete = parameters.getDouble(Params.PERCENT_DISCRETE);
boolean discrete = parameters.getString("dataType").equals("discrete");
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/Simulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/Simulation.java
index b07a3dde50..6b190d6797 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/Simulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/Simulation.java
@@ -20,7 +20,7 @@ public interface Simulation extends HasParameters, TetradSerializable {
/**
* Creates a data set and simulates data.
*/
- void createData(Parameters parameters);
+ void createData(Parameters parameters, boolean newModel);
/**
* @return The number of data sets to simulate.
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java
index 9aaebc861a..647ff0db72 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/StandardizedSemSimulation.java
@@ -44,8 +44,8 @@ public StandardizedSemSimulation(StandardizedSemIm im) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
Graph graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java
index c8b8ba2ffb..31376994ad 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/TimeSeriesSemSimulation.java
@@ -37,8 +37,8 @@ public TimeSeriesSemSimulation(RandomGraph randomGraph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
dataSets = new ArrayList<>();
graphs = new ArrayList<>();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndGraphs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndGraphs.java
index 159a18fb92..3fcd586a8b 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndGraphs.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndGraphs.java
@@ -30,8 +30,8 @@ public LoadContinuousDataAndGraphs(String path) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
this.dataSets = new ArrayList<>();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndSingleGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndSingleGraph.java
index 72eb649dc3..73b0a729c5 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndSingleGraph.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataAndSingleGraph.java
@@ -34,7 +34,7 @@ public LoadContinuousDataAndSingleGraph(String path) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
this.dataSets = new ArrayList<>();
File dir = new File(path + "/data_noise");
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataSmithSim.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataSmithSim.java
index 9321b1dc20..13d2cd8f70 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataSmithSim.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadContinuousDataSmithSim.java
@@ -34,7 +34,8 @@ public LoadContinuousDataSmithSim(String path) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
if (!dataSets.isEmpty()) return;
this.dataSets = new ArrayList<>();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataAndGraphs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataAndGraphs.java
index ecab6eee22..bc78037e91 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataAndGraphs.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataAndGraphs.java
@@ -38,7 +38,8 @@ public LoadDataAndGraphs(String path) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
if (!dataSets.isEmpty()) return;
this.dataSets = new ArrayList<>();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataFromFileWithoutGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataFromFileWithoutGraph.java
index 2a7dc6e579..2bacb7049d 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataFromFileWithoutGraph.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/data/simulation/LoadDataFromFileWithoutGraph.java
@@ -30,7 +30,7 @@ public LoadDataFromFileWithoutGraph(String path) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
try {
File file = new File(path);
System.out.println("Loading data from " + file.getAbsolutePath());
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java
index 6b22ba0de4..e31d7debb5 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java
@@ -167,11 +167,13 @@ public final class Params {
public static final String STRUCTURE_PRIOR = "structurePrior";
public static final String SYMMETRIC_FIRST_STEP = "symmetricFirstStep";
public static final String TARGET_NAME = "targetName";
+ public static final String TDEPTH = "tDepth";
public static final String TESTWISE_DELETION = "testwiseDeletion";
public static final String THR = "thr";
public static final String THRESHOLD_FOR_NUM_EIGENVALUES = "thresholdForNumEigenvalues";
public static final String THRESHOLD_NO_RANDOM_CONSTRAIN_SEARCH = "thresholdNoRandomConstrainSearch";
public static final String THRESHOLD_NO_RANDOM_DATA_SEARCH = "thresholdNoRandomDataSearch";
+ public static final String TURNING = "turning";
public static final String TWO_CYCLE_ALPHA = "twoCycleAlpha";
public static final String UPPER_BOUND = "upperBound";
public static final String USE_CORR_DIFF_ADJACENCIES = "useCorrDiffAdjacencies";
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraph.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraph.java
index 06d9eb384a..2717ef5056 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraph.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraph.java
@@ -34,7 +34,8 @@ public LoadContinuousDataAndSingleGraph(String path, String subdir) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
if (!dataSets.isEmpty()) return;
this.dataSets = new ArrayList<>();
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraphKun.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraphKun.java
index e6d65e412d..3306aba332 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraphKun.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataAndSingleGraphKun.java
@@ -30,7 +30,7 @@ public LoadContinuousDataAndSingleGraphKun(String path, String prefix) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
this.covs = new ArrayList<>();
File dir = new File(path);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataSmithSim.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataSmithSim.java
index dea3a3aa07..5aca023e53 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataSmithSim.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadContinuousDataSmithSim.java
@@ -34,7 +34,7 @@ public LoadContinuousDataSmithSim(String path, int index) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
if (!dataSets.isEmpty()) return;
this.dataSets = new ArrayList<>();
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadMadelynData.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadMadelynData.java
index bf276f4309..1164ab9289 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadMadelynData.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/LoadMadelynData.java
@@ -34,7 +34,7 @@ public LoadMadelynData(String directory, String suffix, int structure) {
}
@Override
- public void createData(Parameters parameters) {
+ public void createData(Parameters parameters, boolean newModel) {
this.dataSets = new ArrayList<>();
for (int run = 1; run <= 10; run++) {
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/SpecialDataClark.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/SpecialDataClark.java
index 5d0052d837..2e13f90e93 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/SpecialDataClark.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/SpecialDataClark.java
@@ -17,12 +17,9 @@
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.RandomUtil;
-import org.apache.commons.lang3.RandomUtils;
-import java.awt.*;
import java.util.ArrayList;
import java.util.List;
-import java.util.Random;
import static edu.cmu.tetrad.util.StatUtils.skewness;
import static java.lang.Math.abs;
@@ -44,8 +41,8 @@ public SpecialDataClark(RandomGraph graph) {
}
@Override
- public void createData(Parameters parameters) {
- if (!dataSets.isEmpty()) return;
+ public void createData(Parameters parameters, boolean newModel) {
+ if (!newModel && !dataSets.isEmpty()) return;
Graph graph = randomGraph.createGraph(parameters);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java
index cd76f9237e..33794a6a10 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java
@@ -29,7 +29,6 @@
import edu.cmu.tetrad.algcomparison.independence.SemBicTest;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.algcomparison.simulation.LinearFisherModel;
-import edu.cmu.tetrad.algcomparison.simulation.SemSimulation;
import edu.cmu.tetrad.algcomparison.simulation.Simulation;
import edu.cmu.tetrad.algcomparison.statistic.*;
import edu.cmu.tetrad.bayes.BayesIm;
@@ -42,16 +41,16 @@
import edu.cmu.tetrad.util.*;
import edu.pitt.csb.mgm.MGM;
import edu.pitt.csb.mgm.MixedUtils;
+
import java.io.*;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-import static java.lang.Math.log;
import static junit.framework.TestCase.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+
import org.junit.Test;
/**
@@ -301,10 +300,10 @@ public void testFgesMbFromGraph() {
RandomUtil.getInstance().setSeed(1450184147770L);
int numNodes = 20;
- int numIterations = 1;
+ int numIterations = 2;
for (int i = 0; i < numIterations; i++) {
-// System.out.println("Iteration " + (i + 1));
+ System.out.println("Iteration " + (i + 1));
Graph dag = GraphUtils.randomDag(numNodes, 0, numNodes, 10, 10, 10, false);
GraphScore fgesScore = new GraphScore(dag);
@@ -394,7 +393,7 @@ public void clarkTest() {
parameters.set(Params.ALPHA, 0.01);
- simulation.createData(parameters);
+ simulation.createData(parameters, false);
DataSet dataSet = (DataSet) simulation.getDataModel(0);
Graph trueGraph = simulation.getTrueGraph(0);
@@ -571,6 +570,7 @@ public void testCites() {
knowledge.addToTier(6, "CITES");
SemBicScore score = new SemBicScore(cov);
+ score.setRuleType(SemBicScore.RuleType.NANDY);
score.setPenaltyDiscount(1);
score.setStructurePrior(0);
Fges fges = new Fges(score);
@@ -583,7 +583,8 @@ public void testCites() {
System.out.println(pattern);
String trueString = "Graph Nodes:\n" +
- "ABILITY;GPQ;PREPROD;QFJ;SEX;CITES;PUBS\n" +
+ "Graph Nodes:\n" +
+ "Graph Nodes:;ABILITY;GPQ;PREPROD;QFJ;SEX;CITES;PUBS\n" +
"\n" +
"Graph Edges:\n" +
"1. ABILITY --> GPQ\n" +
@@ -706,21 +707,85 @@ public void testPcStable2() {
@Test
public void testFromGraph() {
- int numNodes = 10;
+ int numNodes = 20;
+ int aveDegree = 4;
int numIterations = 1;
for (int i = 0; i < numIterations; i++) {
-// System.out.println("Iteration " + (i + 1));
- Graph dag = GraphUtils.randomDag(numNodes, 0, 2 * numNodes, 10, 10, 10, false);
+ Graph dag = GraphUtils.randomDag(numNodes, 0, aveDegree * numNodes / 2, 10, 10, 10, false);
Fges fges = new Fges(new GraphScore(dag));
- fges.setFaithfulnessAssumed(false);
+ fges.setFaithfulnessAssumed(true);
+ fges.setVerbose(true);
+ fges.setTrueGraph(dag);
+ fges.setTDepth(1);
Graph pattern1 = fges.search();
Graph pattern2 = new Pc(new IndTestDSep(dag)).search();
-// System.out.println(pattern2);
assertEquals(pattern2, pattern1);
}
}
+ // @Test
+ public void testFromData() {
+ int numIterations = 1;
+
+ Parameters params = new Parameters();
+
+ int[] nodeOptions = {5, 10, 20, 30, 40, 50, 75, 100};
+ int[] avgDegreeOptions = {2, 4, 6};
+ int[] sampleSizeOptions = {100, 500, 1000, 10000, 100000};
+
+ int numRowsInTable = nodeOptions.length * avgDegreeOptions.length * sampleSizeOptions.length;
+
+ TextTable table = new TextTable(numRowsInTable + 1, 5);
+
+ table.setToken(0, 0, "# Nodes");
+ table.setToken(0, 1, "Avg Degree");
+ table.setToken(0, 2, "# Samples");
+ table.setToken(0, 3, "True # edges");
+ table.setToken(0, 4, "Est # Edges");
+
+ int count = 0;
+
+ for (int numNodes : nodeOptions) {
+ for (int avgDegree : avgDegreeOptions) {
+ for (int sampleSize : sampleSizeOptions) {
+ for (int q = 0; q < 1; q++) {
+ for (int i = 0; i < numIterations; i++) {
+ Graph dag = GraphUtils.randomDag(numNodes, 0,
+ (avgDegree * numNodes) / 2, 100, 100, 100, false);
+ SemPm pm = new SemPm(dag);
+ SemIm im = new SemIm(pm, params);
+ DataSet data = im.simulateData(sampleSize, false);
+ SemBicScore score = new SemBicScore(data);
+ score.setPenaltyDiscount(.5);
+ Fges fges = new Fges(score);
+ fges.setFaithfulnessAssumed(false);
+ fges.setVerbose(false);
+ fges.setTrueGraph(dag);
+ Graph pattern1 = fges.search();
+ System.out.println("num nodes = " + numNodes + " avg degree = " + avgDegree
+ + " sample size = " + sampleSize
+ + " true # edges = " + dag.getNumEdges()
+ + " est # edges = " + pattern1.getNumEdges());
+
+ count++;
+ table.setToken(count, 0, "" + numNodes);
+ table.setToken(count, 1, "" + avgDegree);
+ table.setToken(count, 2, "" + sampleSize);
+ table.setToken(count, 3, "" + dag.getNumEdges());
+ table.setToken(count, 4, "" + pattern1.getNumEdges());
+
+ }
+ }
+ }
+ }
+ }
+
+ System.out.println("\n==========================\n");
+ System.out.println(table);
+
+ }
+
@Test
public void testFromGraphWithForbiddenKnowledge() {
@@ -779,7 +844,7 @@ public void testFromGraphWithRequiredKnowledge() {
System.out.println("Knowledge violated: " + edge + " x = " + x + " y = " + y);
}
- assertTrue (pattern1.isParentOf(x, y));
+ assertTrue(pattern1.isParentOf(x, y));
}
}
}
@@ -1597,7 +1662,7 @@ public void test9() {
RandomGraph graph = new RandomForward();
LinearFisherModel sim = new LinearFisherModel(graph);
- sim.createData(parameters);
+ sim.createData(parameters, false);
Graph previous = null;
int prevDiff = Integer.MAX_VALUE;
@@ -1658,7 +1723,7 @@ public void testSemBicDiffs() {
final int N = 1000;
int numCond = 3;
- Graph graph = GraphUtils.randomGraph(10,0, 20, 100,
+ Graph graph = GraphUtils.randomGraph(10, 0, 20, 100,
100, 100, false);
final List nodes = graph.getNodes();
buildIndexing(nodes);
@@ -1682,7 +1747,7 @@ public void testSemBicDiffs() {
}
final boolean _dsep = dsep.isIndependent(x, y, new ArrayList<>(z));
- final double diff = scoreGraphChange(x, y, z, hashIndices, score) ;
+ final double diff = scoreGraphChange(x, y, z, hashIndices, score);
final boolean diffNegative = diff < 0;
if (!_dsep && _dsep != diffNegative) {
@@ -1763,7 +1828,7 @@ public static void main(String... args) {
RandomGraph graph = new RandomForward();
LinearFisherModel sim = new LinearFisherModel(graph);
- sim.createData(parameters);
+ sim.createData(parameters, false);
ScoreWrapper score = new edu.cmu.tetrad.algcomparison.score.SemBicScore();
Algorithm alg = new edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges(score);
From 4f71efbadb80e47ace04f8a13fd293e7fc51789d Mon Sep 17 00:00:00 2001
From: jdramsey
Date: Fri, 8 Jan 2021 09:41:10 -0500
Subject: [PATCH 3/4] Turing turning off.
---
.../cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java
index bf70b2acb8..35e5307cfc 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pattern/Fges.java
@@ -148,7 +148,7 @@ public List getParameters() {
parameters.add(Params.SYMMETRIC_FIRST_STEP);
parameters.add(Params.MAX_DEGREE);
parameters.add(Params.TDEPTH);
- parameters.add(Params.TURNING);
+// parameters.add(Params.TURNING);
parameters.add(Params.VERBOSE);
From b0129b4e4fd3acc1d75781e20e1aa272ff9e5da4 Mon Sep 17 00:00:00 2001
From: jdramsey
Date: Fri, 8 Jan 2021 09:47:14 -0500
Subject: [PATCH 4/4] Commented turning out of the manual.
---
docs/manual/index.html | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/manual/index.html b/docs/manual/index.html
index 23efff326b..a8e6647637 100755
--- a/docs/manual/index.html
+++ b/docs/manual/index.html
@@ -2257,7 +2257,7 @@ Parameters
faithfulnessAssumed,
maxDegree
tDepth
- turning
+