diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java index 80cdf4c9b4..a6b4c2e2bf 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/FgesRunner.java @@ -200,7 +200,7 @@ public void execute() { } } - this.fges.setExternalGraph(this.externalGraph); + this.fges.setBoundGraph(this.externalGraph); this.fges.setKnowledge((Knowledge) getParams().get("knowledge", new Knowledge())); this.fges.setVerbose(true); Graph graph = this.fges.search(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java index 008ef6d179..d67ccb29de 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java @@ -75,7 +75,7 @@ public Graph search(List dataModels, Parameters parameters) { } if (initial != null) { - search.setExternalGraph(initial); + search.setBoundGraph(initial); } return search.search(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java index aa6866191c..fc65a24161 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fges.java @@ -3,6 +3,7 @@ import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; +import edu.cmu.tetrad.algcomparison.utils.TakesExternalGraph; import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper; import edu.cmu.tetrad.annotation.AlgType; import edu.cmu.tetrad.annotation.Bootstrapping; @@ -33,14 +34,15 @@ algoType = AlgType.forbid_latent_common_causes ) @Bootstrapping -public class Fges implements Algorithm, HasKnowledge, UsesScoreWrapper { +public class Fges implements Algorithm, HasKnowledge, UsesScoreWrapper, TakesExternalGraph { static final long serialVersionUID = 23L; private ScoreWrapper score; private Knowledge knowledge = new Knowledge(); - private Graph wxternalGraph = null; + private Graph externalGraph = null; + private Algorithm algorithm = null; public Fges() { @@ -63,12 +65,21 @@ public Graph search(DataModel dataModel, Parameters parameters) { knowledge = timeSeries.getKnowledge(); } + if (this.algorithm != null) { + Graph _graph = this.algorithm.search(dataModel, parameters); + + if (_graph != null) { + this.externalGraph = _graph; + } + } + Score score = this.score.getScore(dataModel, parameters); Graph graph; edu.cmu.tetrad.search.Fges search = new edu.cmu.tetrad.search.Fges(score); - search.setExternalGraph(wxternalGraph); +// search.setInitialGraph(externalGraph); + search.setBoundGraph(externalGraph); search.setKnowledge(this.knowledge); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); search.setMeekVerbose(parameters.getBoolean(Params.MEEK_VERBOSE)); @@ -87,6 +98,9 @@ public Graph search(DataModel dataModel, Parameters parameters) { return graph; } else { Fges fges = new Fges(this.score); + if (this.externalGraph != null) { + fges.setExternalGraph(this.externalGraph); + } DataSet data = (DataSet) dataModel; GeneralResamplingTest search = new GeneralResamplingTest( @@ -137,7 +151,7 @@ public Knowledge getKnowledge() { @Override public void setKnowledge(Knowledge knowledge) { - this.knowledge = new Knowledge((Knowledge) knowledge); + this.knowledge = new Knowledge(knowledge); } @Override @@ -150,7 +164,24 @@ public void setScoreWrapper(ScoreWrapper score) { this.score = score; } + + @Override + public Graph getExternalGraph() { + return this.externalGraph; + } + + @Override public void setExternalGraph(Graph externalGraph) { - this.wxternalGraph = externalGraph; + this.externalGraph = externalGraph; + } + + @Override + public void setExternalGraph(Algorithm algorithm) { + if (algorithm == null) { + throw new IllegalArgumentException("This EB algorithm needs both data and a graph source as inputs; it \n" + + "will orient the edges in the input graph using the data."); + } + + this.algorithm = algorithm; } } \ No newline at end of file diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java index 53152b8932..caf1f2922f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Bridges2.java @@ -138,7 +138,7 @@ public Graph search() { g.removeEdge(edge); g.addEdge(reversed); - fges.setExternalGraph(g); + fges.setInitialGraph(g); Graph g1 = fges.search(); double s1 = fges.getModelScore(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java index dbc1629682..7c57109b09 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BridgesOld.java @@ -66,7 +66,7 @@ public Graph search() { meeks.orientImplied(g); - ges.setExternalGraph(g); + ges.setInitialGraph(g); Graph g1 = ges.search(); double s1 = ges.getModelScore(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java index a2af42d835..b82b3bc2df 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java @@ -89,7 +89,7 @@ public final class Fges implements GraphSearch, GraphScorer { /** * An initial graph to start from. */ - private Graph externalGraph; + private Graph initialGraph; /** * If non-null, edges not adjacent in this graph will not be added. */ @@ -192,8 +192,8 @@ public Graph search() { boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables()); } - if (externalGraph != null) { - graph = new EdgeListGraph(externalGraph); + if (initialGraph != null) { + graph = new EdgeListGraph(initialGraph); graph = GraphUtils.replaceNodes(graph, getVariables()); } @@ -268,24 +268,24 @@ public LinkedList getTopGraphs() { /** * Sets the initial graph. */ - public void setExternalGraph(Graph externalGraph) { - if (externalGraph == null) { - this.externalGraph = null; + public void setInitialGraph(Graph initialGraph) { + if (initialGraph == null) { + this.initialGraph = initialGraph; return; } - externalGraph = GraphUtils.replaceNodes(externalGraph, variables); + initialGraph = GraphUtils.replaceNodes(initialGraph, variables); if (verbose) { - out.println("External graph variables: " + externalGraph.getNodes()); + out.println("External graph variables: " + initialGraph.getNodes()); out.println("Data set variables: " + variables); } - if (!new HashSet<>(externalGraph.getNodes()).equals(new HashSet<>(variables))) { + if (!new HashSet<>(initialGraph.getNodes()).equals(new HashSet<>(variables))) { throw new IllegalArgumentException("Variables aren't the same."); } - this.externalGraph = externalGraph; + this.initialGraph = initialGraph; } /** @@ -322,7 +322,11 @@ public void setOut(PrintStream out) { * If non-null, edges not adjacent in this graph will not be added. */ public void setBoundGraph(Graph boundGraph) { - this.boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables()); + if (boundGraph == null) { + this.boundGraph = null; + } else { + this.boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables()); + } } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java index 4f09f2c460..64aaad98a3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SearchGraphUtils.java @@ -1964,47 +1964,6 @@ public static int[][] graphComparison(Graph trueCpdag, Graph estCpdag, PrintStre return counts; } - public static Graph reorient(Graph graph, DataModel dataModel, Knowledge knowledge) { - if (dataModel instanceof DataModelList) { - DataModelList list = (DataModelList) dataModel; - List dataSets = new ArrayList<>(list); - Fges images = new Fges(new SemBicScoreImages(dataSets)); - - images.setBoundGraph(graph); - images.setKnowledge(knowledge); - return images.search(); - } else if (dataModel instanceof DataSet) { - DataSet dataSet = (DataSet) dataModel; - - Score score; - - if (dataModel.isContinuous()) { - score = new SemBicScore(new CovarianceMatrix(dataSet)); - } else if (dataSet.isDiscrete()) { - score = new BDeuScore(dataSet); - } else { - throw new NullPointerException(); - } - - Fges ges = new Fges(score); - - ges.setBoundGraph(graph); - ges.setKnowledge(knowledge); - return ges.search(); - } else if (dataModel instanceof CovarianceMatrix) { - ICovarianceMatrix cov = (CovarianceMatrix) dataModel; - Score score = new SemBicScore(cov); - - Fges ges = new Fges(score); - - ges.setBoundGraph(graph); - ges.setKnowledge(knowledge); - return ges.search(); - } - - throw new IllegalStateException("Can do that that reorientation."); - } - @NotNull public static Graph dagToPag(Graph trueGraph) { return new DagToPag(trueGraph).convert();