Skip to content

Commit

Permalink
Merge pull request #1521 from cmu-phil/parameter_disappearing_issue
Browse files Browse the repository at this point in the history
Fixed problem of SEM and Bayes simultion parameters going missing after simulation and re-opening the Simulation box.
  • Loading branch information
jdramsey authored Feb 14, 2023
2 parents 5da6ae6 + 6889e8e commit e1f308c
Show file tree
Hide file tree
Showing 24 changed files with 1,803 additions and 217 deletions.
2 changes: 1 addition & 1 deletion docs/manual/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -7246,7 +7246,7 @@ <h3 id="zSRiskBound" class="parameter_description">zSRiskBound</h3>
<li>Long Description: <span id="zSRiskBound_long_desc">
This is the probability of getting the true model if a correct model is discovered. Could underfit.</span>
</li>
<li>Default Value: <span id="zSRiskBound_default_value">0.001</span></li>
<li>Default Value: <span id="zSRiskBound_default_value">0.1</span></li>
<li>Lower Bound: <span id="zSRiskBound_lower_bound">0</span></li>
<li>Upper Bound: <span id="zSRiskBound_upper_bound">1</span></li>
<li>Value Type: <span id="zSRiskBound_value_type">Double</span></li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,15 @@ private List<Statistic> statistics() {


// Greg table
statistics.add(new AncestorPrecision());
statistics.add(new AncestorRecall());
statistics.add(new AncestorF1());
statistics.add(new SemidirectedPrecision());
statistics.add(new SemidirectedRecall());
statistics.add(new SemidirectedPathF1());
statistics.add(new NoSemidirectedPrecision());
statistics.add(new NoSemidirectedRecall());
statistics.add(new NoSemidirectedF1());
// statistics.add(new AncestorPrecision());
// statistics.add(new AncestorRecall());
// statistics.add(new AncestorF1());
// statistics.add(new SemidirectedPrecision());
// statistics.add(new SemidirectedRecall());
// statistics.add(new SemidirectedPathF1());
// statistics.add(new NoSemidirectedPrecision());
// statistics.add(new NoSemidirectedRecall());
// statistics.add(new NoSemidirectedF1());

// statistics.add(new LegalPag());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ private void showParameters() {
this.parameterBox.removeAll();
if (this.simulation.getSimulation() != null) {
Set<String> params = new LinkedHashSet<>(this.simulation.getSimulation().getParameters());

if (params.isEmpty()) {
this.parameterBox.add(ParameterTab.NO_PARAM_LBL, BorderLayout.NORTH);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public void execute() {
}
}

this.fges.setInitialGraph(this.externalGraph);
this.fges.setExternalGraph(this.externalGraph);
this.fges.setKnowledge((Knowledge) getParams().get("knowledge", new Knowledge()));
this.fges.setVerbose(true);
Graph graph = this.fges.search();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public Graph search(List<DataModel> dataModels, Parameters parameters) {
}

if (initial != null) {
search.setInitialGraph(initial);
search.setExternalGraph(initial);
}

return search.search();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public Graph search(DataModel dataModel, Parameters parameters) {
boss.setAlgType(Boss.AlgType.BOSS1);
} else if (parameters.getInt(Params.BOSS_ALG) == 2) {
boss.setAlgType(Boss.AlgType.BOSS2);
} else if (parameters.getInt(Params.BOSS_ALG) == 3) {
boss.setAlgType(Boss.AlgType.BOSS3);
} else {
throw new IllegalArgumentException("Unrecognized boss algorithm type.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class Fges implements Algorithm, HasKnowledge, UsesScoreWrapper {
private ScoreWrapper score;
private Knowledge knowledge = new Knowledge();

private Graph initialGraph = null;
private Graph wxternalGraph = null;

public Fges() {

Expand Down Expand Up @@ -68,7 +68,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {

edu.cmu.tetrad.search.Fges search
= new edu.cmu.tetrad.search.Fges(score);
// search.setInitialGraph(initialGraph);
search.setExternalGraph(wxternalGraph);
search.setKnowledge(this.knowledge);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));
search.setMeekVerbose(parameters.getBoolean(Params.MEEK_VERBOSE));
Expand Down Expand Up @@ -150,7 +150,7 @@ public void setScoreWrapper(ScoreWrapper score) {
this.score = score;
}

public void setInitialGraph(Graph initialGraph) {
this.initialGraph = initialGraph;
public void setExternalGraph(Graph externalGraph) {
this.wxternalGraph = externalGraph;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class PC implements Algorithm, HasKnowledge, TakesIndependenceWrapper {
private IndependenceWrapper test;
private Knowledge knowledge = new Knowledge();

private Graph initialGraph = null;
private Graph externalGraph = null;

public PC() {
}
Expand Down Expand Up @@ -102,7 +102,7 @@ public Graph search(DataModel dataModel, Parameters parameters) {
search.setUseHeuristic(parameters.getBoolean(Params.USE_MAX_P_ORIENTATION_HEURISTIC));
search.setMaxPathLength(parameters.getInt(Params.MAX_P_ORIENTATION_MAX_PATH_LENGTH));
search.setMaxPathLength(parameters.getInt(Params.MAX_P_ORIENTATION_MAX_PATH_LENGTH));
search.setInitialGraph(initialGraph);
search.setExternalGraph(externalGraph);
search.setVerbose(parameters.getBoolean(Params.VERBOSE));

return search.search();
Expand Down Expand Up @@ -171,7 +171,7 @@ public void setIndependenceWrapper(IndependenceWrapper test) {
this.test = test;
}

public void setInitialGraph(Graph initialGraph) {
this.initialGraph = initialGraph;
public void setExternalGraph(Graph externalGraph) {
this.externalGraph = externalGraph;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import edu.cmu.tetrad.algcomparison.graph.RandomForward;
import edu.cmu.tetrad.algcomparison.independence.FisherZ;
import edu.cmu.tetrad.algcomparison.score.SemBicScore;
import edu.cmu.tetrad.algcomparison.score.ZhangShenBoundScore;
import edu.cmu.tetrad.algcomparison.simulation.SemSimulation;
import edu.cmu.tetrad.algcomparison.simulation.Simulations;
import edu.cmu.tetrad.algcomparison.statistic.*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ public List<String> getParameters() {
parameters.addAll(this.randomGraph.getParameters());
}

if (this.pm == null) {
parameters.addAll(BayesPm.getParameterNames());
}
// if (this.pm == null) {
parameters.addAll(BayesPm.getParameterNames());
// }

if (this.im == null) {
parameters.addAll(MlBayesIm.getParameterNames());
}
// if (this.im == null) {
parameters.addAll(MlBayesIm.getParameterNames());
// }

parameters.add(Params.NUM_RUNS);
parameters.add(Params.DIFFERENT_GRAPHS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ public List<String> getParameters() {
parameters.addAll(this.randomGraph.getParameters());
}

if (this.im == null) {
parameters.addAll(SemIm.getParameterNames());
}
// if (this.im == null) {
parameters.addAll(SemIm.getParameterNames());
// }

parameters.add(Params.MEASUREMENT_VARIANCE);
parameters.add(Params.NUM_RUNS);
Expand Down Expand Up @@ -175,7 +175,7 @@ private DataSet simulate(Graph graph, Parameters parameters) {
// Not setting this im messes up algcomparison. -JR 20230206

// if (this.im == null) {
this.im = im;
this.im = im;
// }

// Need this in case the SEM IM is given externally.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public boolean takesKnowledge(Class clazz) {
return clazz != null && HasKnowledge.class.isAssignableFrom(clazz);
}

public boolean takesInitialGraph(Class clazz) {
public boolean takesExternalGraph(Class clazz) {
return clazz != null && TakesExternalGraph.class.isAssignableFrom(clazz);
}

Expand Down
2 changes: 1 addition & 1 deletion tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public Paths(Graph graph) {
}

/**
* Returns a valid causal order for either a DAG or a CPDAG.
* Returns a valid causal order for either a DAG or a CPDAG. (bryanandrews)
* @param initialOrder Variables in the order will be kept as close to this
* initial order as possible, either the forward order
* or the reverse order, depending on the next parameter.
Expand Down
43 changes: 14 additions & 29 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/Boss.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class Boss {
private int numStarts = 1;
private AlgType algType = AlgType.BOSS1;
private boolean caching = true;
private double epsilon = 1e-10;

public Boss(@NotNull IndependenceTest test, Score score) {
this.test = test;
Expand Down Expand Up @@ -89,7 +90,7 @@ public List<Node> bestOrder(@NotNull List<Node> order) {
shuffle(order);
}

this.start = MillisecondTimes.timeMillis();
this.start = MillisecondTimes.timeMillis();

makeValidKnowledgeOrder(order);

Expand Down Expand Up @@ -131,7 +132,7 @@ public List<Node> bestOrder(@NotNull List<Node> order) {

this.scorer.score(bestPerm);

this.stop = MillisecondTimes.timeMillis();
this.stop = MillisecondTimes.timeMillis();

if (this.verbose) {
TetradLogger.getInstance().forceLogMessage("\nFinal " + algType + " order = " + this.scorer.getPi());
Expand Down Expand Up @@ -225,7 +226,7 @@ public void betterMutation1(@NotNull TeyssierScorer scorer) {

tuck(x, j, scorer, range);

if (scorer.score() > bestScore || violatesKnowledge(scorer.getPi())) {
if (scorer.score() > bestScore + epsilon || violatesKnowledge(scorer.getPi())) {
for (int l = range[0]; l <= range[1]; l++) {
introns2.add(scorer.get(l));
}
Expand All @@ -244,63 +245,47 @@ public void betterMutation1(@NotNull TeyssierScorer scorer) {
if (verbose) {
System.out.println();
}
} while (bestScore > originalScore);
} while (bestScore > originalScore + epsilon);
}



public void betterMutation2(@NotNull TeyssierScorer scorer) {
scorer.bookmark();
double s1, s2;

Set<Node> introns1;
Set<Node> introns2;

introns2 = new HashSet<>(scorer.getPi());

do {
s1 = scorer.score();
scorer.bookmark(1);

introns1 = introns2;
introns2 = new HashSet<>();

for (Node k : scorer.getPi()) {
double _sp = NEGATIVE_INFINITY;
scorer.bookmark();

if (!introns1.contains(k)) continue;

for (int j = 0; j < scorer.size(); j++) {
scorer.moveTo(k, j);

if (scorer.score() >= _sp) {
if (scorer.score() >= _sp + epsilon) {
if (!violatesKnowledge(scorer.getPi())) {
_sp = scorer.score();
scorer.bookmark();

if (scorer.index(k) <= j) {
for (int m = scorer.index(k); m <= j; m++) {
introns2.add(scorer.get(m));
}
} else if (scorer.index(k) > j) {
for (int m = j; m <= scorer.index(k); m++) {
introns2.add(scorer.get(m));
}
if (verbose) {
System.out.print("\rIndex = " + (scorer.index(k) + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s"));
}
}
}
}

if (verbose) {
System.out.print("\rIndex = " + (j + 1) + " Score = " + scorer.score() + " (betterMutation2)" + " Elapsed " + ((MillisecondTimes.timeMillis() - start) / 1000.0 + " s"));
}
}

scorer.goToBookmark();
}

if (verbose) {
System.out.println();
}

s2 = scorer.score();
} while (s2 > s1);
} while (s2 > s1 + epsilon);

scorer.goToBookmark(1);
}
Expand Down
Loading

0 comments on commit e1f308c

Please sign in to comment.