Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include Non Gaussian cases for Local Precision and Recall tests for DAG and CPDAG respectively #1770

Merged
merged 1 commit into from
May 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 188 additions & 4 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import org.junit.Test;

import java.util.ArrayList;
Expand Down Expand Up @@ -111,12 +112,13 @@ public void test2() {
}

@Test
public void testDAGPrecisionRecallForLocalOnMarkovBlanket() {
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -149,14 +151,15 @@ public void testDAGPrecisionRecallForLocalOnMarkovBlanket() {
}

@Test
public void testCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
public void testGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -188,12 +191,104 @@ public void testCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
}

@Test
public void testDAGPrecisionRecallForLocalOnParents() {
public void testNonGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);

Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

@Test
public void testNonGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);

Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

// Compare the Est CPDAG with True graph's CPDAG.
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");
}
}



@Test
public void testGaussianDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -225,14 +320,15 @@ public void testDAGPrecisionRecallForLocalOnParents() {
}

@Test
public void testCPDAGPrecisionRecallForLocalOnParents() {
public void testGaussianCPDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -262,4 +358,92 @@ public void testCPDAGPrecisionRecallForLocalOnParents() {
System.out.println("=====================");
}
}

@Test
public void testNonGaussianDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
// TODO VBC: confirm on the choice of ConditioningSetType.
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

@Test
public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);

Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

// Compare the Est CPDAG with True graph's CPDAG.
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");
}
}

}