diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 11ef450002..1889577375 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -20,6 +20,7 @@ import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; +import java.lang.reflect.Array; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.*; @@ -27,6 +28,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; +import java.util.stream.Collectors; /** * Checks whether a graph is Markov given a data set. First, a list of m-separation predictions are made for each pair @@ -271,27 +273,27 @@ public List getLocalPValues(IndependenceTest independenceTest, List> getLocalPValues(IndependenceTest independenceTest, List facts, Double shuffleThreshold) { - // Call pvalue function on each item, only include the non-null ones. + // Shuffle to generate more data from the same graph. + int shuffleTimes = (int) Math.ceil(1 / shuffleThreshold); // pVals is a list of lists of the p values for each shuffled results. List> pVals_list = new ArrayList<>(); - for (IndependenceFact f : facts) { - Double pV; - // For now, check if the test is FisherZ test. - if (independenceTest instanceof IndTestFisherZ) { - // Shuffle to generate more data from the same graph. - int shuffleTimes = (int) Math.ceil(1 / shuffleThreshold); - List pVals = new ArrayList<>(); - for (int i = 0; i < shuffleTimes; i++) { - List rows = getSubsampleRows(shuffleThreshold); // Default as 0.5 - ((RowsSettable) independenceTest).setRows(rows); // FisherZ will only calc pvalues to those rows + for (int i = 0; i < shuffleTimes; i++) { + List rows = getSubsampleRows(shuffleThreshold); // Default as 0.5 + ((RowsSettable) independenceTest).setRows(rows); // the test will only calc pvalues to those rows + // call pvalue function on each item, only include the non-null ones + List pVals = new ArrayList<>(); + for (IndependenceFact f : facts) { + Double pV; + // For now, check if the test is FisherZ test. + if (independenceTest instanceof IndTestFisherZ) { pV = ((IndTestFisherZ) independenceTest).getPValue(f.getX(), f.getY(), f.getZ()); pVals.add(pV); + } else if (independenceTest instanceof IndTestChiSquare) { + pV = ((IndTestChiSquare) independenceTest).getPValue(f.getX(), f.getY(), f.getZ()); + if (pV != null) pVals.add(pV); } - pVals_list.add(pVals); - } else if (independenceTest instanceof IndTestChiSquare) { - pV = ((IndTestChiSquare) independenceTest).getPValue(f.getX(), f.getY(), f.getZ()); - if (pV != null) pVals_list.add(Arrays.asList(pV)); } + pVals_list.add(pVals); } return pVals_list; } @@ -326,13 +328,15 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind List localIndependenceFacts = getLocalIndependenceFacts(x); // All local nodes' p-values for node x List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); - for (List localPValues: shuffledlocalPValues) { - Double ADTest = checkAgainstAndersonDarlingTest(localPValues); // P value obtained from AD test - if (ADTest <= threshold) { - rejects.add(x); - } else { - accepts.add(x); - } + // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? + List flatList = shuffledlocalPValues.stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList); + if (ADTestPValue <= threshold) { + rejects.add(x); + } else { + accepts.add(x); } } accepts_rejects.add(accepts); @@ -390,38 +394,38 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot Double ahr = ap_ar_ahp_ahr.get(3); // All local nodes' p-values for node x. List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 - for (List localPValues: shuffledlocalPValues) { - // P value obtained from AD test - Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues); - // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? - if (ADTestPValue <= threshold) { - rejects.add(x); - if (!Double.isNaN(ap)) { - rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - if (!Double.isNaN(ar)) { - rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - if (!Double.isNaN(ahp)) { - rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - if (!Double.isNaN(ahr)) { - rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - } else { - accepts.add(x); - if (!Double.isNaN(ap)) { - accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - if (!Double.isNaN(ar)) { - accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - if (!Double.isNaN(ahp)) { - accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } - if (!Double.isNaN(ahr)) { - accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); - } + List flatList = shuffledlocalPValues.stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList); + // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? + if (ADTestPValue <= threshold) { + rejects.add(x); + if (!Double.isNaN(ap)) { + rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + if (!Double.isNaN(ar)) { + rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + if (!Double.isNaN(ahp)) { + rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + if (!Double.isNaN(ahr)) { + rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + } else { + accepts.add(x); + if (!Double.isNaN(ap)) { + accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + if (!Double.isNaN(ar)) { + accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + if (!Double.isNaN(ahp)) { + accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue)); + } + if (!Double.isNaN(ahr)) { + accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue)); } } } @@ -531,26 +535,26 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot Double lgr = lgp_lgr.get(1); // All local nodes' p-values for node x. List> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5 - for (List localPValues: shuffledlocalPValues) { - // P value obtained from AD test - Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues); - // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? - if (ADTestPValue <= threshold) { - rejects.add(x); - if (!Double.isNaN(lgp)) { - rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); - } - if (!Double.isNaN(lgr)) { - rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); - } - } else { - accepts.add(x); - if (!Double.isNaN(lgp)) { - accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); - } - if (!Double.isNaN(lgr)) { - accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); - } + List flatList = shuffledlocalPValues.stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList); + // TODO VBC: what should we do for cases when ADTest is NaN and ∞ ? + if (ADTestPValue <= threshold) { + rejects.add(x); + if (!Double.isNaN(lgp)) { + rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); + } + if (!Double.isNaN(lgr)) { + rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); + } + } else { + accepts.add(x); + if (!Double.isNaN(lgp)) { + accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue)); + } + if (!Double.isNaN(lgr)) { + accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue)); } } } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 58b7fbd227..2742a758bf 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -113,9 +113,9 @@ public void test2() { @Test public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() { - Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); +// Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); // TODO VBC: Also check different dense graph. -// Graph trueGraph = RandomGraph.randomDag(20, 0, 40, 100, 100, 100, false); + Graph trueGraph = RandomGraph.randomDag(20, 0, 40, 100, 100, 100, false); System.out.println("Test True Graph: " + trueGraph); System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); @@ -133,7 +133,7 @@ public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() { IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); // List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); - List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3); + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.5); List accepts = accepts_rejects.get(0); List rejects = accepts_rejects.get(1); System.out.println("Accepts size: " + accepts.size());