Skip to content

Commit

Permalink
Merge pull request #1775 from cmu-phil/vbc-05-23
Browse files Browse the repository at this point in the history
Introduce plot data collection for different confusion statistics
  • Loading branch information
jdramsey authored May 23, 2024
2 parents 7795487 + de7202a commit 9bc1256
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 31 deletions.
174 changes: 167 additions & 7 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import org.apache.commons.math3.util.Pair;
import org.jetbrains.annotations.NotNull;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
Expand Down Expand Up @@ -313,24 +316,166 @@ public Double checkAgainstAndersonDarlingTest(List<Double> pValues) {
* @return A list containing two lists: the first list contains the accepted nodes and the second list contains the
* rejected nodes.
*/
public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) {
public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold, Double shuffleThreshold) {
// When calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
List<Node> accepts = new ArrayList<>();
List<Node> rejects = new ArrayList<>();
List<Node> allNodes = graph.getNodes();
for (Node x : allNodes) {
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> localPValues = getLocalPValues(independenceTest, localIndependenceFacts);
Double ADTest = checkAgainstAndersonDarlingTest(localPValues);
if (ADTest <= threshold) {
rejects.add(x);
} else {
accepts.add(x);
// All local nodes' p-values for node x
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold);
for (List<Double> localPValues: shuffledlocalPValues) {
Double ADTest = checkAgainstAndersonDarlingTest(localPValues); // P value obtained from AD test
if (ADTest <= threshold) {
rejects.add(x);
} else {
accepts.add(x);
}
}
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
return accepts_rejects;
}

public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) {
// When calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
List<Node> accepts = new ArrayList<>();
List<Node> rejects = new ArrayList<>();
List<Node> allNodes = graph.getNodes();

// Confusion stats lists for data processing.
Map<String, String> fileContentMap = new HashMap<>();

List<List<Double>> accepts_AdjP_ADTestP = new ArrayList<>();
List<List<Double>> accepts_AdjR_ADTestP = new ArrayList<>();
List<List<Double>> accepts_AHP_ADTestP = new ArrayList<>();
List<List<Double>> accepts_AHR_ADTestP = new ArrayList<>();
fileContentMap.put("accepts_AdjP_ADTestP_data.csv", "");
fileContentMap.put("accepts_AdjR_ADTestP_data.csv", "");
fileContentMap.put("accepts_AHP_ADTestP_data.csv", "");
fileContentMap.put("accepts_AHR_ADTestP_data.csv", "");

List<List<Double>> rejects_AdjP_ADTestP = new ArrayList<>();
List<List<Double>> rejects_AdjR_ADTestP = new ArrayList<>();
List<List<Double>> rejects_AHP_ADTestP = new ArrayList<>();
List<List<Double>> rejects_AHR_ADTestP = new ArrayList<>();
fileContentMap.put("rejects_AdjP_ADTestP_data.csv", "");
fileContentMap.put("rejects_AdjR_ADTestP_data.csv", "");
fileContentMap.put("rejects_AHP_ADTestP_data.csv", "");
fileContentMap.put("rejects_AHR_ADTestP_data.csv", "");

NumberFormat nf = new DecimalFormat("0.00");
// Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly.
for (Node x : allNodes) {
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> ap_ar_ahp_ahr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData(x, estimatedCpdag, trueGraph);
Double ap = ap_ar_ahp_ahr.get(0);
Double ar = ap_ar_ahp_ahr.get(1);
Double ahp = ap_ar_ahp_ahr.get(2);
Double ahr = ap_ar_ahp_ahr.get(3);
// All local nodes' p-values for node x.
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTest = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTest <= threshold) {
rejects.add(x);
if (!Double.isNaN(ap)) {
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTest));
}
if (!Double.isNaN(ar)) {
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTest));
}
if (!Double.isNaN(ahp)) {
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTest));
}
if (!Double.isNaN(ahr)) {
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTest));
}
} else {
accepts.add(x);
if (!Double.isNaN(ap)) {
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTest));
}
if (!Double.isNaN(ar)) {
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTest));
}
if (!Double.isNaN(ahp)) {
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTest));
}
if (!Double.isNaN(ahr)) {
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTest));
}
}
}
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
// Write into data files.
for (Map.Entry<String, String> entry : fileContentMap.entrySet()) {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) {
writer.write(entry.getValue());
switch (entry.getKey()) {
case "acceptsAdjP_ADTestP_data.csv":
for (List<Double> AdjP_ADTestP_pair : accepts_AdjP_ADTestP) {
writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n");
}
break;

case "accepts_AdjR_ADTestP_data.csv":
for (List<Double> AdjR_ADTestP_pair : accepts_AdjR_ADTestP) {
writer.write(nf.format(AdjR_ADTestP_pair.get(0)) + "," + nf.format(AdjR_ADTestP_pair.get(1)) + "\n");
}
break;

case "accepts_AHP_ADTestP_data.csv":
for (List<Double> AHP_ADTestP_pair : accepts_AHP_ADTestP) {
writer.write(nf.format(AHP_ADTestP_pair.get(0)) + "," + nf.format(AHP_ADTestP_pair.get(1)) + "\n");
}
break;

case "accepts_AHR_ADTestP_data.csv":
for (List<Double> AHR_ADTestP_pair : accepts_AHR_ADTestP) {
writer.write(nf.format(AHR_ADTestP_pair.get(0)) + "," + nf.format(AHR_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_AdjP_ADTestP_data.csv":
for (List<Double> AdjP_ADTestP_pair : rejects_AdjP_ADTestP) {
writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_AdjR_ADTestP_data.csv":
for (List<Double> AdjR_ADTestP_pair : rejects_AdjR_ADTestP) {
writer.write(nf.format(AdjR_ADTestP_pair.get(0)) + "," + nf.format(AdjR_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_AHP_ADTestP_data.csv":
for (List<Double> AHP_ADTestP_pair : rejects_AHP_ADTestP) {
writer.write(nf.format(AHP_ADTestP_pair.get(0)) + "," + nf.format(AHP_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_AHR_ADTestP_data.csv":
for (List<Double> AHR_ADTestP_pair : rejects_AHR_ADTestP) {
writer.write(nf.format(AHR_ADTestP_pair.get(0)) + "," + nf.format(AHR_ADTestP_pair.get(1)) + "\n");
}
break;
default:
break;
}
System.out.println("Successfully written to " + entry.getKey());
} catch (IOException e) {
e.printStackTrace();
}
}
return accepts_rejects;
}

Expand Down Expand Up @@ -362,6 +507,21 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGra
" ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr));
}

public List<Double> getPrecisionAndRecallOnMarkovBlanketGraphPlotData(Node x, Graph estimatedGraph, Graph trueGraph) {
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());
Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x);
System.out.println("xMBLookupGraph:" + xMBLookupGraph);
Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x);
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph);

double ap = new AdjacencyPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double ar = new AdjacencyRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double ahp = new ArrowheadPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double ahr = new ArrowheadRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
return Arrays.asList(ap, ar, ahp, ahr);
}

/**
* Calculates the precision and recall using LocalGraphConfusion
* (which calculates the combination of Adjacency and ArrowHead) on the Markov Blanket graph for a given node.
Expand Down
Loading

0 comments on commit 9bc1256

Please sign in to comment.