Skip to content

Commit

Permalink
Merge pull request #1515 from cmu-phil/fix_initial_graph_usage_throug…
Browse files Browse the repository at this point in the history
…hout

Making the external graph API more consistent.
  • Loading branch information
jdramsey authored Feb 14, 2023
2 parents e1f308c + 0bc6a1e commit fd96cea
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void execute() {
}
}

this.fges.setExternalGraph(this.externalGraph);
this.fges.setBoundGraph(this.externalGraph);
this.fges.setKnowledge((Knowledge) getParams().get("knowledge", new Knowledge()));
this.fges.setVerbose(true);
Graph graph = this.fges.search();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Graph search(List<DataModel> dataModels, Parameters parameters) {
}

if (initial != null) {
search.setExternalGraph(initial);
search.setBoundGraph(initial);
}

return search.search();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.algcomparison.score.ScoreWrapper;
import edu.cmu.tetrad.algcomparison.utils.HasKnowledge;
import edu.cmu.tetrad.algcomparison.utils.TakesExternalGraph;
import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper;
import edu.cmu.tetrad.annotation.AlgType;
import edu.cmu.tetrad.annotation.Bootstrapping;
Expand Down Expand Up @@ -33,14 +34,15 @@
algoType = AlgType.forbid_latent_common_causes
)
@Bootstrapping
public class Fges implements Algorithm, HasKnowledge, UsesScoreWrapper {
public class Fges implements Algorithm, HasKnowledge, UsesScoreWrapper, TakesExternalGraph {

static final long serialVersionUID = 23L;

private ScoreWrapper score;
private Knowledge knowledge = new Knowledge();

private Graph wxternalGraph = null;
private Graph externalGraph = null;
private Algorithm algorithm = null;

public Fges() {

Expand All @@ -63,12 +65,21 @@ public Graph search(DataModel dataModel, Parameters parameters) {
knowledge = timeSeries.getKnowledge();
}

if (this.algorithm != null) {
Graph _graph = this.algorithm.search(dataModel, parameters);

if (_graph != null) {
this.externalGraph = _graph;
}
}

Score score = this.score.getScore(dataModel, parameters);
Graph graph;

edu.cmu.tetrad.search.Fges search
= new edu.cmu.tetrad.search.Fges(score);
search.setExternalGraph(wxternalGraph);
// search.setInitialGraph(externalGraph);
search.setBoundGraph(externalGraph);
search.setKnowledge(this.knowledge);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
search.setMeekVerbose(parameters.getBoolean(Params.MEEK_VERBOSE));
Expand All @@ -87,6 +98,9 @@ public Graph search(DataModel dataModel, Parameters parameters) {
return graph;
} else {
Fges fges = new Fges(this.score);
if (this.externalGraph != null) {
fges.setExternalGraph(this.externalGraph);
}

DataSet data = (DataSet) dataModel;
GeneralResamplingTest search = new GeneralResamplingTest(
Expand Down Expand Up @@ -137,7 +151,7 @@ public Knowledge getKnowledge() {

@Override
public void setKnowledge(Knowledge knowledge) {
this.knowledge = new Knowledge((Knowledge) knowledge);
this.knowledge = new Knowledge(knowledge);
}

@Override
Expand All @@ -150,7 +164,24 @@ public void setScoreWrapper(ScoreWrapper score) {
this.score = score;
}


@Override
public Graph getExternalGraph() {
return this.externalGraph;
}

@Override
public void setExternalGraph(Graph externalGraph) {
this.wxternalGraph = externalGraph;
this.externalGraph = externalGraph;
}

@Override
public void setExternalGraph(Algorithm algorithm) {
if (algorithm == null) {
throw new IllegalArgumentException("This EB algorithm needs both data and a graph source as inputs; it \n"
+ "will orient the edges in the input graph using the data.");
}

this.algorithm = algorithm;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public Graph search() {
g.removeEdge(edge);
g.addEdge(reversed);

fges.setExternalGraph(g);
fges.setInitialGraph(g);
Graph g1 = fges.search();
double s1 = fges.getModelScore();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public Graph search() {

meeks.orientImplied(g);

ges.setExternalGraph(g);
ges.setInitialGraph(g);
Graph g1 = ges.search();
double s1 = ges.getModelScore();

Expand Down
26 changes: 15 additions & 11 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fges.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public final class Fges implements GraphSearch, GraphScorer {
/**
* An initial graph to start from.
*/
private Graph externalGraph;
private Graph initialGraph;
/**
* If non-null, edges not adjacent in this graph will not be added.
*/
Expand Down Expand Up @@ -192,8 +192,8 @@ public Graph search() {
boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables());
}

if (externalGraph != null) {
graph = new EdgeListGraph(externalGraph);
if (initialGraph != null) {
graph = new EdgeListGraph(initialGraph);
graph = GraphUtils.replaceNodes(graph, getVariables());
}

Expand Down Expand Up @@ -268,24 +268,24 @@ public LinkedList<ScoredGraph> getTopGraphs() {
/**
* Sets the initial graph.
*/
public void setExternalGraph(Graph externalGraph) {
if (externalGraph == null) {
this.externalGraph = null;
public void setInitialGraph(Graph initialGraph) {
if (initialGraph == null) {
this.initialGraph = initialGraph;
return;
}

externalGraph = GraphUtils.replaceNodes(externalGraph, variables);
initialGraph = GraphUtils.replaceNodes(initialGraph, variables);

if (verbose) {
out.println("External graph variables: " + externalGraph.getNodes());
out.println("External graph variables: " + initialGraph.getNodes());
out.println("Data set variables: " + variables);
}

if (!new HashSet<>(externalGraph.getNodes()).equals(new HashSet<>(variables))) {
if (!new HashSet<>(initialGraph.getNodes()).equals(new HashSet<>(variables))) {
throw new IllegalArgumentException("Variables aren't the same.");
}

this.externalGraph = externalGraph;
this.initialGraph = initialGraph;
}

/**
Expand Down Expand Up @@ -322,7 +322,11 @@ public void setOut(PrintStream out) {
* If non-null, edges not adjacent in this graph will not be added.
*/
public void setBoundGraph(Graph boundGraph) {
this.boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables());
if (boundGraph == null) {
this.boundGraph = null;
} else {
this.boundGraph = GraphUtils.replaceNodes(boundGraph, getVariables());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1964,47 +1964,6 @@ public static int[][] graphComparison(Graph trueCpdag, Graph estCpdag, PrintStre
return counts;
}

public static Graph reorient(Graph graph, DataModel dataModel, Knowledge knowledge) {
if (dataModel instanceof DataModelList) {
DataModelList list = (DataModelList) dataModel;
List<DataModel> dataSets = new ArrayList<>(list);
Fges images = new Fges(new SemBicScoreImages(dataSets));

images.setBoundGraph(graph);
images.setKnowledge(knowledge);
return images.search();
} else if (dataModel instanceof DataSet) {
DataSet dataSet = (DataSet) dataModel;

Score score;

if (dataModel.isContinuous()) {
score = new SemBicScore(new CovarianceMatrix(dataSet));
} else if (dataSet.isDiscrete()) {
score = new BDeuScore(dataSet);
} else {
throw new NullPointerException();
}

Fges ges = new Fges(score);

ges.setBoundGraph(graph);
ges.setKnowledge(knowledge);
return ges.search();
} else if (dataModel instanceof CovarianceMatrix) {
ICovarianceMatrix cov = (CovarianceMatrix) dataModel;
Score score = new SemBicScore(cov);

Fges ges = new Fges(score);

ges.setBoundGraph(graph);
ges.setKnowledge(knowledge);
return ges.search();
}

throw new IllegalStateException("Can do that that reorientation.");
}

@NotNull
public static Graph dagToPag(Graph trueGraph) {
return new DagToPag(trueGraph).convert();
Expand Down

0 comments on commit fd96cea

Please sign in to comment.