diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 2181efe713..11ef450002 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -340,6 +340,17 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind return accepts_rejects; } + /** + * Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics. + * + * Confusion statistics were calculated using Adjacency (AdjacencyPrecision, AdjacencyRecall) and Arrowhead (ArrowheadPrecision, ArrowheadRecall) + * @param independenceTest + * @param estimatedCpdag + * @param trueGraph + * @param threshold + * @param shuffleThreshold + * @return + */ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) { // When calling, default reject null as <=0.05 List> accepts_rejects = new ArrayList<>(); @@ -381,35 +392,35 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 for (List localPValues: shuffledlocalPValues) { // P value obtained from AD test - Double ADTest = checkAgainstAndersonDarlingTest(localPValues); + Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues); // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? - if (ADTest <= threshold) { + if (ADTestPValue <= threshold) { rejects.add(x); if (!Double.isNaN(ap)) { - rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ar)) { - rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahp)) { - rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahr)) { - rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTest)); + rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } } else { accepts.add(x); if (!Double.isNaN(ap)) { - accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ar)) { - accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahp)) { - accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } if (!Double.isNaN(ahr)) { - accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTest)); + accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } } } @@ -421,7 +432,7 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) { writer.write(entry.getValue()); switch (entry.getKey()) { - case "acceptsAdjP_ADTestP_data.csv": + case "accepts_AdjP_ADTestP_data.csv": for (List AdjP_ADTestP_pair : accepts_AdjP_ADTestP) { writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n"); } @@ -479,6 +490,112 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot return accepts_rejects; } + /** + * Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics. + * + * Confusion statistics were calculated using Local Graph Precision and Recall (LocalGraphPrecision, LocalGraphRecall). + * @param independenceTest + * @param estimatedCpdag + * @param trueGraph + * @param threshold + * @param shuffleThreshold + * @return + */ + public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) { + // When calling, default reject null as <=0.05 + List> accepts_rejects = new ArrayList<>(); + List accepts = new ArrayList<>(); + List rejects = new ArrayList<>(); + List allNodes = graph.getNodes(); + + // Confusion stats lists for data processing. + Map fileContentMap = new HashMap<>(); + + // Using Local Graph Precision and Recall to calculate Confusion statistics. + List> accepts_LGP_ADTestP = new ArrayList<>(); + List> accepts_LGR_ADTestP = new ArrayList<>(); + fileContentMap.put("accepts_LGP_ADTestP_data.csv", ""); + fileContentMap.put("accepts_LGR_ADTestP_data.csv", ""); + + List> rejects_LGP_ADTestP = new ArrayList<>(); + List> rejects_LGR_ADTestP = new ArrayList<>(); + fileContentMap.put("rejects_LGP_ADTestP_data.csv", ""); + fileContentMap.put("rejects_LGR_ADTestP_data.csv", ""); + + NumberFormat nf = new DecimalFormat("0.00"); + // Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly. + for (Node x : allNodes) { + List localIndependenceFacts = getLocalIndependenceFacts(x); + List lgp_lgr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(x, estimatedCpdag, trueGraph); + Double lgp = lgp_lgr.get(0); + Double lgr = lgp_lgr.get(1); + // All local nodes' p-values for node x. + List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 + for (List localPValues: shuffledlocalPValues) { + // P value obtained from AD test + Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues); + // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? + if (ADTestPValue <= threshold) { + rejects.add(x); + if (!Double.isNaN(lgp)) { + rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); + } + if (!Double.isNaN(lgr)) { + rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); + } + } else { + accepts.add(x); + if (!Double.isNaN(lgp)) { + accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); + } + if (!Double.isNaN(lgr)) { + accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); + } + } + } + } + accepts_rejects.add(accepts); + accepts_rejects.add(rejects); + // Write into data files. + for (Map.Entry entry : fileContentMap.entrySet()) { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) { + writer.write(entry.getValue()); + switch (entry.getKey()) { + case "accepts_LGP_ADTestP_data.csv": + for (List LGP_ADTestP_pair : accepts_LGP_ADTestP) { + writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "accepts_LGR_ADTestP_data.csv": + for (List LGR_ADTestP_pair : accepts_LGR_ADTestP) { + writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_LGP_ADTestP_data.csv": + for (List LGP_ADTestP_pair : rejects_LGP_ADTestP) { + writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n"); + } + break; + + case "rejects_LGR_ADTestP_data.csv": + for (List LGR_ADTestP_pair : rejects_LGR_ADTestP) { + writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n"); + } + break; + + default: + break; + } + System.out.println("Successfully written to " + entry.getKey()); + } catch (IOException e) { + e.printStackTrace(); + } + } + return accepts_rejects; + } + /** * Calculates the precision and recall on the Markov Blanket graph for a given node. Prints the statistics to the * console. @@ -547,6 +664,19 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGr " LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n"); } + public List getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(Node x, Graph estimatedGraph, Graph trueGraph) { + // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); + Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x); + System.out.println("xMBLookupGraph:" + xMBLookupGraph); + Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x); + System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); + + double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + return Arrays.asList(lgp, lgr); + } + /** * Returns the variables of the independence test. * diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 7e07a6cf6f..eb896081cf 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -444,7 +444,7 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() { } @Test - public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() { + public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket2() { Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); System.out.println("Test True Graph: " + trueGraph); System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); @@ -461,25 +461,13 @@ public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); // ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5 - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); + // List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5); + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3); + List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size()); System.out.println("Rejects size: " + rejects.size()); - - List acceptsPrecision = new ArrayList<>(); - List acceptsRecall = new ArrayList<>(); - for(Node a: accepts) { - System.out.println("====================="); - markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); - System.out.println("====================="); - - } - for (Node a: rejects) { - System.out.println("====================="); - markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); - System.out.println("====================="); - } } }