From c7c662383913fdcedffe2879046b11b816f2bead Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Thu, 23 May 2024 16:34:47 -0400 Subject: [PATCH 1/4] Introduce Plot data functions for Local Graph Precision and Recal --- .../edu/cmu/tetrad/search/MarkovCheck.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 2181efe713..ab12d3c551 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -479,6 +479,101 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot return accepts_rejects; } + public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) { + // When calling, default reject null as <=0.05 + List> accepts_rejects = new ArrayList<>(); + List accepts = new ArrayList<>(); + List rejects = new ArrayList<>(); + List allNodes = graph.getNodes(); + + // Confusion stats lists for data processing. + Map fileContentMap = new HashMap<>(); + + // Local Graph Precision and Recall + List> accepts_LGP_ADTestP = new ArrayList<>(); + List> accepts_LGR_ADTestP = new ArrayList<>(); + fileContentMap.put("accepts_LGP_ADTestP_data.csv", ""); + fileContentMap.put("accepts_LGR_ADTestP_data.csv", ""); + + List> rejects_LGP_ADTestP = new ArrayList<>(); + List> rejects_LGR_ADTestP = new ArrayList<>(); + fileContentMap.put("rejects_LGP_ADTestP_data.csv", ""); + fileContentMap.put("rejects_LGR_ADTestP_data.csv", ""); + + NumberFormat nf = new DecimalFormat("0.00"); + // Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly. + for (Node x : allNodes) { + List localIndependenceFacts = getLocalIndependenceFacts(x); + List lgp_lgr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(x, estimatedCpdag, trueGraph); + Double lgp = lgp_lgr.get(0); + Double lgr = lgp_lgr.get(1); + // All local nodes' p-values for node x. + List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 + for (List localPValues: shuffledlocalPValues) { + // P value obtained from AD test + Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? + if (ADTest <= threshold) { + rejects.add(x); + if (!Double.isNaN(lgp)) { + rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTest)); + } + if (!Double.isNaN(lgr)) { + rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTest)); + } + } else { + accepts.add(x); + if (!Double.isNaN(lgp)) { + accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTest)); + } + if (!Double.isNaN(lgr)) { + accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTest)); + } + } + } + } + accepts_rejects.add(accepts); + accepts_rejects.add(rejects); + // Write into data files. + for (Map.Entry entry : fileContentMap.entrySet()) { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) { + writer.write(entry.getValue()); + switch (entry.getKey()) { + case "accepts_LGP_ADTestP_data.csv": + for (List LGP_ADTestP_pair : accepts_LGP_ADTestP) { + writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "accepts_LGR_ADTestP_data.csv": + for (List LGR_ADTestP_pair : accepts_LGR_ADTestP) { + writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_LGP_ADTestP_data.csv": + for (List LGP_ADTestP_pair : rejects_LGP_ADTestP) { + writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_LGR_ADTestP_data.csv": + for (List LGR_ADTestP_pair : rejects_LGR_ADTestP) { + writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n"); + } + break; + + default: + break; + } + System.out.println("Successfully written to " + entry.getKey()); + } catch (IOException e) { + e.printStackTrace(); + } + } + return accepts_rejects; + } + /** * Calculates the precision and recall on the Markov Blanket graph for a given node. Prints the statistics to the * console. @@ -547,6 +642,19 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGr " LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n"); } + public List getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(Node x, Graph estimatedGraph, Graph trueGraph) { + // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); + Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x); + System.out.println("xMBLookupGraph:" + xMBLookupGraph); + Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x); + System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); + + double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + return Arrays.asList(lgp, lgr); + } + /** * Returns the variables of the independence test. * From baf1377f7d35c78d24132c46f06203b3550158be Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Thu, 23 May 2024 16:35:53 -0400 Subject: [PATCH 2/4] update test when using lgp and lgr --- .../edu/cmu/tetrad/test/TestCheckMarkov.java | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 7e07a6cf6f..27584cd642 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -444,7 +444,7 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() { } @Test - public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() { + public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket2() { Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); System.out.println("Test True Graph: " + trueGraph); System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); @@ -461,25 +461,27 @@ public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); + // List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3); + List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); System.out.println("Rejects size: " + rejects.size()); - List acceptsPrecision = new ArrayList<>(); - List acceptsRecall = new ArrayList<>(); - for(Node a: accepts) { - System.out.println("====================="); - markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); - System.out.println("====================="); - - } - for (Node a: rejects) { - System.out.println("====================="); - markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); - System.out.println("====================="); - } +// List acceptsPrecision = new ArrayList<>(); +// List acceptsRecall = new ArrayList<>(); +// for(Node a: accepts) { +// System.out.println("====================="); +// markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); +// System.out.println("====================="); +// +// } +// for (Node a: rejects) { +// System.out.println("====================="); +// markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); +// System.out.println("====================="); +// } } } From 4c5b5120ab6c6b6d802da8e5a7a294cba0aad54d Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Thu, 23 May 2024 16:44:55 -0400 Subject: [PATCH 3/4] nit --- .../edu/cmu/tetrad/search/MarkovCheck.java | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index ab12d3c551..e0f40149a7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -381,35 +381,35 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 for (List localPValues: shuffledlocalPValues) { // P value obtained from AD test - Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues); // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? - if (ADTest <= threshold) { + if (ADTestPValue <= threshold) { rejects.add(x); if (!Double.isNaN(ap)) { - rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ar)) { - rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahp)) { - rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahr)) { - rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } } else { accepts.add(x); if (!Double.isNaN(ap)) { - accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ar)) { - accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahp)) { - accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahr)) { - accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } } } @@ -421,7 +421,7 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) { writer.write(entry.getValue()); switch (entry.getKey()) { - case "acceptsAdjP_ADTestP_data.csv": + case "accepts_AdjP_ADTestP_data.csv": for (List AdjP_ADTestP_pair : accepts_AdjP_ADTestP) { writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n"); } @@ -489,7 +489,7 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot // Confusion stats lists for data processing. Map fileContentMap = new HashMap<>(); - // Local Graph Precision and Recall + // Using Local Graph Precision and Recall to calculate Confusion statistics. List> accepts_LGP_ADTestP = new ArrayList<>(); List> accepts_LGR_ADTestP = new ArrayList<>(); fileContentMap.put("accepts_LGP_ADTestP_data.csv", ""); @@ -511,23 +511,23 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 for (List localPValues: shuffledlocalPValues) { // P value obtained from AD test - Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues); // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? - if (ADTest <= threshold) { + if (ADTestPValue <= threshold) { rejects.add(x); if (!Double.isNaN(lgp)) { - rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTest)); + rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); } if (!Double.isNaN(lgr)) { - rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTest)); + rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); } } else { accepts.add(x); if (!Double.isNaN(lgp)) { - accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTest)); + accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); } if (!Double.isNaN(lgr)) { - accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTest)); + accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); } } } From 58f01edf16d6daa04bb17a4ffabaa372f4da043a Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Thu, 23 May 2024 16:48:25 -0400 Subject: [PATCH 4/4] method documentation --- .../edu/cmu/tetrad/search/MarkovCheck.java | 22 +++++++++++++++++++ .../edu/cmu/tetrad/test/TestCheckMarkov.java | 14 ------------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index e0f40149a7..11ef450002 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -340,6 +340,17 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind return accepts_rejects; } + /** + * Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics. + * + * Confusion statistics were calculated using Adjacency (AdjacencyPrecision, AdjacencyRecall) and Arrowhead (ArrowheadPrecision, ArrowheadRecall) + * @param independenceTest + * @param estimatedCpdag + * @param trueGraph + * @param threshold + * @param shuffleThreshold + * @return + */ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) { // When calling, default reject null as <=0.05 List> accepts_rejects = new ArrayList<>(); @@ -479,6 +490,17 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot return accepts_rejects; } + /** + * Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics. + * + * Confusion statistics were calculated using Local Graph Precision and Recall (LocalGraphPrecision, LocalGraphRecall). + * @param independenceTest + * @param estimatedCpdag + * @param trueGraph + * @param threshold + * @param shuffleThreshold + * @return + */ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) { // When calling, default reject null as <=0.05 List> accepts_rejects = new ArrayList<>(); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 27584cd642..eb896081cf 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -468,20 +468,6 @@ public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket2() { List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); System.out.println("Rejects size: " + rejects.size()); - -// List acceptsPrecision = new ArrayList<>(); -// List acceptsRecall = new ArrayList<>(); -// for(Node a: accepts) { -// System.out.println("====================="); -// markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); -// System.out.println("====================="); -// -// } -// for (Node a: rejects) { -// System.out.println("====================="); -// markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); -// System.out.println("====================="); -// } } }