Skip to content

Commit

Permalink
Markov Check Test on same graph for different confusion matrix (Adj, …
Browse files Browse the repository at this point in the history
…AH, LG) for Gaussain DAG on Markov Blanket
  • Loading branch information
vbcwonderland committed Jun 18, 2024
1 parent 54d2aa6 commit 9c92055
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
NumberFormat nf = new DecimalFormat("0.00");
// Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly.
for (Node x : allNodes) {
System.out.println("Target Node: " + x);
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> ap_ar_ahp_ahr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData(x, estimatedCpdag, trueGraph);
Double ap = ap_ar_ahp_ahr.get(0);
Expand Down Expand Up @@ -436,6 +437,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
accepts_AHR_ADTestP.add(Arrays.asList(ahr, ADTestPValue));
}
}
System.out.println("-----------------------------");
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
Expand Down Expand Up @@ -540,6 +542,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
NumberFormat nf = new DecimalFormat("0.00");
// Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly.
for (Node x : allNodes) {
System.out.println("Target Node: " + x);
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> lgp_lgr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(x, estimatedCpdag, trueGraph);
Double lgp = lgp_lgr.get(0);
Expand All @@ -549,6 +552,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
List<Double> flatList = shuffledlocalPValues.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
System.out.println("# p values feed into ADTest: " + flatList.size() );
Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
Expand All @@ -568,6 +572,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
}
System.out.println("-----------------------------");
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
Expand Down
66 changes: 35 additions & 31 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package edu.cmu.tetrad.test;

import edu.cmu.tetrad.algcomparison.statistic.*;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.*;
import edu.cmu.tetrad.search.*;
Expand Down Expand Up @@ -113,26 +114,55 @@ public void test2() {

@Test
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
// TODO VBC: Also check different dense graph.
Graph trueGraph = RandomGraph.randomDag(20, 0, 80, 100, 100, 100, false);
Graph trueGraph = RandomGraph.randomDag(100, 0, 400, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
DataSet data = im.simulateData(10000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
// TODO VBC: Next check different search algo to generate estimated graph. e.g. PC
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");
double whole_ap = new AdjacencyPrecision().getValue(trueGraph, estimatedCpdag, null);
double whole_ar = new AdjacencyRecall().getValue(trueGraph, estimatedCpdag, null);
double whole_ahp = new ArrowheadPrecision().getValue(trueGraph, estimatedCpdag, null);
double whole_ahr = new ArrowheadRecall().getValue(trueGraph, estimatedCpdag, null);
double whole_lgp = new LocalGraphPrecision().getValue(trueGraph, estimatedCpdag, null);
double whole_lgr = new LocalGraphRecall().getValue(trueGraph, estimatedCpdag, null);

testGaussianDAGPrecisionRecallForLocalOnMarkovBlanketUsingAdjAHConfusionMatrix(data, trueGraph, estimatedCpdag);
testGaussianDAGPrecisionRecallForLocalOnMarkovBlanketUsingLGConfusionMatrix(data, trueGraph, estimatedCpdag);
System.out.println("~~~~~~~~~~~~~Full Graph~~~~~~~~~~~~~~~");
System.out.println("whole_ap: " + whole_ap);
System.out.println("whole_ar: " + whole_ar );
System.out.println("whole_ahp: " + whole_ahp);
System.out.println("whole_ahr: " + whole_ahr);
System.out.println("whole_lgp: " + whole_lgp);
System.out.println("whole_lgr: " + whole_lgr);
}

public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanketUsingAdjAHConfusionMatrix(DataSet data, Graph trueGraph, Graph estimatedCpdag) {
IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
// List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.5);
// Using Adj, AH confusion matrix
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 1.0);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());
}

public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanketUsingLGConfusionMatrix(DataSet data, Graph trueGraph, Graph estimatedCpdag) {
IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
// Using Local Graph (LG) confusion matrix
// ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(fisherZTest, estimatedCpdag, trueGraph, 0.05, 1.0);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
Expand Down Expand Up @@ -406,32 +436,6 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
}
}

@Test
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket2() {
Graph trueGraph = RandomGraph.randomDag(20, 0, 80, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
// ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5
// List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3);

List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());
}

@Test
public void testGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket2() {
Expand Down

0 comments on commit 9c92055

Please sign in to comment.