From de7202a8fb0ced967a9ed0966c66c1e42a4cb2eb Mon Sep 17 00:00:00 2001 From: vbcwonderland Date: Thu, 23 May 2024 14:40:46 -0400 Subject: [PATCH] Introduce plot data collection for different confusion matrix stats from Anderson Darling Test result. --- .../edu/cmu/tetrad/search/MarkovCheck.java | 174 +++++++++++++++++- .../edu/cmu/tetrad/test/TestCheckMarkov.java | 46 +++-- 2 files changed, 189 insertions(+), 31 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index bec7c71c1b..2181efe713 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -17,6 +17,9 @@ import org.apache.commons.math3.util.Pair; import org.jetbrains.annotations.NotNull; +import java.io.BufferedWriter; +import java.io.FileWriter; +import java.io.IOException; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; @@ -313,7 +316,7 @@ public Double checkAgainstAndersonDarlingTest(List pValues) { * @return A list containing two lists: the first list contains the accepted nodes and the second list contains the * rejected nodes. */ - public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) { + public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold, Double shuffleThreshold) { // When calling, default reject null as <=0.05 List> accepts_rejects = new ArrayList<>(); List accepts = new ArrayList<>(); @@ -321,16 +324,158 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind List allNodes = graph.getNodes(); for (Node x : allNodes) { List localIndependenceFacts = getLocalIndependenceFacts(x); - List localPValues = getLocalPValues(independenceTest, localIndependenceFacts); - Double ADTest = checkAgainstAndersonDarlingTest(localPValues); - if (ADTest <= threshold) { - rejects.add(x); - } else { - accepts.add(x); + // All local nodes' p-values for node x + List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); + for (List localPValues: shuffledlocalPValues) { + Double ADTest = checkAgainstAndersonDarlingTest(localPValues); // P value obtained from AD test + if (ADTest <= threshold) { + rejects.add(x); + } else { + accepts.add(x); + } + } + } + accepts_rejects.add(accepts); + accepts_rejects.add(rejects); + return accepts_rejects; + } + + public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) { + // When calling, default reject null as <=0.05 + List> accepts_rejects = new ArrayList<>(); + List accepts = new ArrayList<>(); + List rejects = new ArrayList<>(); + List allNodes = graph.getNodes(); + + // Confusion stats lists for data processing. + Map fileContentMap = new HashMap<>(); + + List> accepts_AdjP_ADTestP = new ArrayList<>(); + List> accepts_AdjR_ADTestP = new ArrayList<>(); + List> accepts_AHP_ADTestP = new ArrayList<>(); + List> accepts_AHR_ADTestP = new ArrayList<>(); + fileContentMap.put("accepts_AdjP_ADTestP_data.csv", ""); + fileContentMap.put("accepts_AdjR_ADTestP_data.csv", ""); + fileContentMap.put("accepts_AHP_ADTestP_data.csv", ""); + fileContentMap.put("accepts_AHR_ADTestP_data.csv", ""); + + List> rejects_AdjP_ADTestP = new ArrayList<>(); + List> rejects_AdjR_ADTestP = new ArrayList<>(); + List> rejects_AHP_ADTestP = new ArrayList<>(); + List> rejects_AHR_ADTestP = new ArrayList<>(); + fileContentMap.put("rejects_AdjP_ADTestP_data.csv", ""); + fileContentMap.put("rejects_AdjR_ADTestP_data.csv", ""); + fileContentMap.put("rejects_AHP_ADTestP_data.csv", ""); + fileContentMap.put("rejects_AHR_ADTestP_data.csv", ""); + + NumberFormat nf = new DecimalFormat("0.00"); + // Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly. + for (Node x : allNodes) { + List localIndependenceFacts = getLocalIndependenceFacts(x); + List ap_ar_ahp_ahr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData(x, estimatedCpdag, trueGraph); + Double ap = ap_ar_ahp_ahr.get(0); + Double ar = ap_ar_ahp_ahr.get(1); + Double ahp = ap_ar_ahp_ahr.get(2); + Double ahr = ap_ar_ahp_ahr.get(3); + // All local nodes' p-values for node x. + List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 + for (List localPValues: shuffledlocalPValues) { + // P value obtained from AD test + Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? + if (ADTest <= threshold) { + rejects.add(x); + if (!Double.isNaN(ap)) { + rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTest)); + } + if (!Double.isNaN(ar)) { + rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTest)); + } + if (!Double.isNaN(ahp)) { + rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTest)); + } + if (!Double.isNaN(ahr)) { + rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTest)); + } + } else { + accepts.add(x); + if (!Double.isNaN(ap)) { + accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTest)); + } + if (!Double.isNaN(ar)) { + accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTest)); + } + if (!Double.isNaN(ahp)) { + accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTest)); + } + if (!Double.isNaN(ahr)) { + accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTest)); + } + } } } accepts_rejects.add(accepts); accepts_rejects.add(rejects); + // Write into data files. + for (Map.Entry entry : fileContentMap.entrySet()) { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) { + writer.write(entry.getValue()); + switch (entry.getKey()) { + case "acceptsAdjP_ADTestP_data.csv": + for (List AdjP_ADTestP_pair : accepts_AdjP_ADTestP) { + writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "accepts_AdjR_ADTestP_data.csv": + for (List AdjR_ADTestP_pair : accepts_AdjR_ADTestP) { + writer.write(nf.format(AdjR_ADTestP_pair.get(0)) + "," + nf.format(AdjR_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "accepts_AHP_ADTestP_data.csv": + for (List AHP_ADTestP_pair : accepts_AHP_ADTestP) { + writer.write(nf.format(AHP_ADTestP_pair.get(0)) + "," + nf.format(AHP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "accepts_AHR_ADTestP_data.csv": + for (List AHR_ADTestP_pair : accepts_AHR_ADTestP) { + writer.write(nf.format(AHR_ADTestP_pair.get(0)) + "," + nf.format(AHR_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_AdjP_ADTestP_data.csv": + for (List AdjP_ADTestP_pair : rejects_AdjP_ADTestP) { + writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_AdjR_ADTestP_data.csv": + for (List AdjR_ADTestP_pair : rejects_AdjR_ADTestP) { + writer.write(nf.format(AdjR_ADTestP_pair.get(0)) + "," + nf.format(AdjR_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_AHP_ADTestP_data.csv": + for (List AHP_ADTestP_pair : rejects_AHP_ADTestP) { + writer.write(nf.format(AHP_ADTestP_pair.get(0)) + "," + nf.format(AHP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_AHR_ADTestP_data.csv": + for (List AHR_ADTestP_pair : rejects_AHR_ADTestP) { + writer.write(nf.format(AHR_ADTestP_pair.get(0)) + "," + nf.format(AHR_ADTestP_pair.get(1)) + "\n"); + } + break; + default: + break; + } + System.out.println("Successfully written to " + entry.getKey()); + } catch (IOException e) { + e.printStackTrace(); + } + } return accepts_rejects; } @@ -362,6 +507,21 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGra " ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr)); } + public List getPrecisionAndRecallOnMarkovBlanketGraphPlotData(Node x, Graph estimatedGraph, Graph trueGraph) { + // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); + Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x); + System.out.println("xMBLookupGraph:" + xMBLookupGraph); + Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x); + System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); + + double ap = new AdjacencyPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + double ar = new AdjacencyRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + double ahp = new ArrowheadPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + double ahr = new ArrowheadRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + return Arrays.asList(ap, ar, ahp, ahr); + } + /** * Calculates the precision and recall using LocalGraphConfusion * (which calculates the combination of Adjacency and ArrowHead) on the Markov Blanket graph for a given node. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 27fdf9b703..7e07a6cf6f 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -113,7 +113,9 @@ public void test2() { @Test public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() { - Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); + Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); +// TODO VBC: Also check different dense graph. +// Graph trueGraph = RandomGraph.randomDag(20, 0, 40, 100, 100, 100, false); System.out.println("Test True Graph: " + trueGraph); System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); @@ -124,30 +126,18 @@ public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() { edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false); score.setPenaltyDiscount(2); Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search(); +// TODO VBC: Next check different search algo to generate estimated graph. e.g. PC System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag); System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~"); IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); +// List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); System.out.println("Rejects size: " + rejects.size()); - - List acceptsPrecision = new ArrayList<>(); - List acceptsRecall = new ArrayList<>(); - for(Node a: accepts) { - System.out.println("====================="); - markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph); - System.out.println("====================="); - - } - for (Node a: rejects) { - System.out.println("====================="); - markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph); - System.out.println("====================="); - } } @Test @@ -170,7 +160,8 @@ public void testGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -213,7 +204,8 @@ public void testNonGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -259,7 +251,8 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -300,7 +293,8 @@ public void testGaussianDAGPrecisionRecallForLocalOnParents() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); // TODO VBC: confirm on the choice of ConditioningSetType. MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -339,7 +333,8 @@ public void testGaussianCPDAGPrecisionRecallForLocalOnParents() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -382,7 +377,8 @@ public void testNonGaussianDAGPrecisionRecallForLocalOnParents() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); // TODO VBC: confirm on the choice of ConditioningSetType. MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -426,7 +422,8 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); @@ -463,7 +460,8 @@ public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size());