Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed shuffle function and combine all shuffled independence test p vals for a target node into a flat list to feed into Anderson Darling Test #1781

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 78 additions & 74 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Array;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.stream.Collectors;

/**
* Checks whether a graph is Markov given a data set. First, a list of m-separation predictions are made for each pair
Expand Down Expand Up @@ -271,27 +273,27 @@ public List<Double> getLocalPValues(IndependenceTest independenceTest, List<Inde
* @return
*/
public List<List<Double>> getLocalPValues(IndependenceTest independenceTest, List<IndependenceFact> facts, Double shuffleThreshold) {
// Call pvalue function on each item, only include the non-null ones.
// Shuffle to generate more data from the same graph.
int shuffleTimes = (int) Math.ceil(1 / shuffleThreshold);
// pVals is a list of lists of the p values for each shuffled results.
List<List<Double>> pVals_list = new ArrayList<>();
for (IndependenceFact f : facts) {
Double pV;
// For now, check if the test is FisherZ test.
if (independenceTest instanceof IndTestFisherZ) {
// Shuffle to generate more data from the same graph.
int shuffleTimes = (int) Math.ceil(1 / shuffleThreshold);
List<Double> pVals = new ArrayList<>();
for (int i = 0; i < shuffleTimes; i++) {
List<Integer> rows = getSubsampleRows(shuffleThreshold); // Default as 0.5
((RowsSettable) independenceTest).setRows(rows); // FisherZ will only calc pvalues to those rows
for (int i = 0; i < shuffleTimes; i++) {
List<Integer> rows = getSubsampleRows(shuffleThreshold); // Default as 0.5
((RowsSettable) independenceTest).setRows(rows); // the test will only calc pvalues to those rows
// call pvalue function on each item, only include the non-null ones
List<Double> pVals = new ArrayList<>();
for (IndependenceFact f : facts) {
Double pV;
// For now, check if the test is FisherZ test.
if (independenceTest instanceof IndTestFisherZ) {
pV = ((IndTestFisherZ) independenceTest).getPValue(f.getX(), f.getY(), f.getZ());
pVals.add(pV);
} else if (independenceTest instanceof IndTestChiSquare) {
pV = ((IndTestChiSquare) independenceTest).getPValue(f.getX(), f.getY(), f.getZ());
if (pV != null) pVals.add(pV);
}
pVals_list.add(pVals);
} else if (independenceTest instanceof IndTestChiSquare) {
pV = ((IndTestChiSquare) independenceTest).getPValue(f.getX(), f.getY(), f.getZ());
if (pV != null) pVals_list.add(Arrays.asList(pV));
}
pVals_list.add(pVals);
}
return pVals_list;
}
Expand Down Expand Up @@ -326,13 +328,15 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
// All local nodes' p-values for node x
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold);
for (List<Double> localPValues: shuffledlocalPValues) {
Double ADTest = checkAgainstAndersonDarlingTest(localPValues); // P value obtained from AD test
if (ADTest <= threshold) {
rejects.add(x);
} else {
accepts.add(x);
}
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
List<Double> flatList = shuffledlocalPValues.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList);
if (ADTestPValue <= threshold) {
rejects.add(x);
} else {
accepts.add(x);
}
}
accepts_rejects.add(accepts);
Expand Down Expand Up @@ -390,38 +394,38 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
Double ahr = ap_ar_ahp_ahr.get(3);
// All local nodes' p-values for node x.
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(ap)) {
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(ap)) {
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
List<Double> flatList = shuffledlocalPValues.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(ap)) {
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(ap)) {
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
}
}
Expand Down Expand Up @@ -531,26 +535,26 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
Double lgr = lgp_lgr.get(1);
// All local nodes' p-values for node x.
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(lgp)) {
rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(lgp)) {
accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
List<Double> flatList = shuffledlocalPValues.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
Double ADTestPValue = checkAgainstAndersonDarlingTest(flatList);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(lgp)) {
rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(lgp)) {
accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ public void test2() {

@Test
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// TODO VBC: Also check different dense graph.
// Graph trueGraph = RandomGraph.randomDag(20, 0, 40, 100, 100, 100, false);
Graph trueGraph = RandomGraph.randomDag(20, 0, 40, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

Expand All @@ -133,7 +133,7 @@ public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
// List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.5);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
Expand Down