Skip to content

Commit

Permalink
Merge pull request #1771 from cmu-phil/vbc-05-16-2
Browse files Browse the repository at this point in the history
Introducing LocalGraphConfusion and its corresponding Precision and Recall classes
  • Loading branch information
jdramsey authored May 17, 2024
2 parents ea4507d + 9f1adaa commit ae16a35
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package edu.cmu.tetrad.algcomparison.statistic;

import edu.cmu.tetrad.algcomparison.statistic.utils.LocalGraphConfusion;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.graph.Graph;

public class LocalGraphPrecision implements Statistic {
@Override
public String getAbbreviation() {
return "LGP";
}

@Override
public String getDescription() {
return "Local Graph Precision";
}

@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
LocalGraphConfusion lgConfusion = new LocalGraphConfusion(trueGraph, estGraph);
int lgTp = lgConfusion.getTp();
int lgFp = lgConfusion.getFp();
return lgTp / (double) (lgTp + lgFp);
}

@Override
public double getNormValue(double value) {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package edu.cmu.tetrad.algcomparison.statistic;

import edu.cmu.tetrad.algcomparison.statistic.utils.LocalGraphConfusion;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.graph.Graph;

public class LocalGraphRecall implements Statistic {
@Override
public String getAbbreviation() {
return "LGR";
}

@Override
public String getDescription() {
return "Local Graph Recall";
}

@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
LocalGraphConfusion lgConfusion = new LocalGraphConfusion(trueGraph, estGraph);
int lgTp = lgConfusion.getTp();
int lgFn = lgConfusion.getFn();
return lgTp / (double) (lgTp + lgFn);
}

@Override
public double getNormValue(double value) {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package edu.cmu.tetrad.algcomparison.statistic.utils;

import edu.cmu.tetrad.graph.*;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* A confusion matrix for local graph accuracy check --i.e. TP, FP, TN, FN for counts of a combination of
* arrowhead and precision.
*/
public class LocalGraphConfusion {
/**
* The true positive count.
*/
private int tp;

/**
* The true negative count.
*/
private int tn;

/**
* The false positive count.
*/
private int fp;

/**
* The false positive count.
*/
private int fn;

/**
* Constructs a new LocalGraphConfusion object from the given graphs.
* @param trueGraph The true graph
*
* @param estGraph The estimated graph
*/
public LocalGraphConfusion(Graph trueGraph, Graph estGraph) {
this.tp = 0;
this.tn = 0;
this.fp = 0;
this.fn = 0;

// STEP0: Create lookups for both true graph and estimated graph.
// trueGraphLookup is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph trueGraphLookup = GraphUtils.replaceNodes(trueGraph, estGraph.getNodes());
// estGraphLookup is the same structure as estGraph's structure but node objects replaced by true graph nodes.
Graph estGraphLookup = GraphUtils.replaceNodes(estGraph, trueGraph.getNodes());

// STEP1: Check for Adjacency.
/**
* True
* Y N
* ---------------------
* Y | TP FP
* Est | --------------------
* N | FN TN
* -----------------------
*/
// STEP 1.1: Create allUnoriented base on trueGraphLookup and estimatedGraph
Set<Edge> allUnoriented = new HashSet<>();
for (Edge edge: trueGraphLookup.getEdges()) {
allUnoriented.add(Edges.undirectedEdge(edge.getNode1(), edge.getNode2()));
}
for (Edge edge: estGraph.getEdges()) {
allUnoriented.add(Edges.undirectedEdge(edge.getNode1(), edge.getNode2()));
}
// STEP 1.2: Iterate through allUnoriented to record confusion metrix
for (Edge u: allUnoriented) {
Node node1 = u.getNode1();
Node node2 = u.getNode2();
if (estGraph.isAdjacentTo(node1, node2)) { // Est: Y
if (trueGraphLookup.isAdjacentTo(node1, node2)) { // True: Y
this.tp++;
} else { // True: N
this.fp++;
}
} else { // Est: N
if (trueGraphLookup.isAdjacentTo(node1, node2)) { // True: Y
this.fn++;
} else { // True: N
this.tn++;
}
}
}

// STEP2: Check for Orientation(i.e. Arrowhead), so we need to check both endpoints of an edge.
/**
* True
* -> <- ...(None)
* ---------------------------
* -> | TP FP,FN / (Do not repeat count, as we checked for it in Adj step)
* Est | --------------------------
* <- | FP, FN TP /
* | --------------------------
* -- | FN FN /
* | --------------------------
* ...| / / /
* -----------------------------
*
*/
// STEP2.1: Check through the true graph
for (Edge tle: trueGraphLookup.getEdges()) {
// STEP2.1.1: Get corresponding endpoint in Est graph lookup
List<Edge> estGraphLookupEdges = estGraphLookup.getEdges(tle.getNode1(), tle.getNode2());
Edge ele; // estimated lookup graph edge
if (estGraphLookupEdges.size() == 1) {
ele = estGraphLookupEdges.iterator().next();
} else {
ele = estGraphLookup.getDirectedEdge(tle.getNode1(), tle.getNode2());
}
Endpoint ep1Est = null;
Endpoint ep2Est = null;
if (ele != null) {
ep1Est = ele.getProximalEndpoint(tle.getNode1());
ep2Est = ele.getProximalEndpoint(tle.getNode2());
}

// STEP2.1.2: Get corresponding endpoint in true graph lookup
List<Edge> trueGraphLookupEdges = trueGraphLookup.getEdges(tle.getNode1(), tle.getNode1());
Edge tle2;
if (trueGraphLookupEdges.size() == 1) {
tle2 = trueGraphLookupEdges.iterator().next();
} else {
tle2 = trueGraphLookup.getDirectedEdge(tle.getNode1(), tle.getNode2());
}
Endpoint ep1True = null;
Endpoint ep2True = null;
if (tle2 != null) {
ep1True = tle2.getProximalEndpoint(tle.getNode1());
ep2True = tle2.getProximalEndpoint(tle.getNode2());
}

// STEP2.1.3: Compare the endpoints
// we only care the case when the edge exist.
boolean connected = trueGraph.isAdjacentTo(tle.getNode1(), tle.getNode2())
&& estGraph.isAdjacentTo(tle.getNode1(), tle.getNode2());
if (connected) {
if (ep1True == Endpoint.TAIL && ep2True == Endpoint.ARROW) { // True: ->
if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: ->
this.tp++;
} else if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <-
// this.fp++;
this.fn++;
} else if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.TAIL) { // Est: --
this.fn++;
}
} else if (ep1True == Endpoint.ARROW && ep2True == Endpoint.TAIL) { // True: <-
if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: ->
// this.fp++;
this.fn++;
} else if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <-
this.tp++;
} else if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.TAIL) { // Est: --
this.fn++;
}
}
}
}
// STEP2: Check through the est graph
// because est graph can have extra arrowhead that was not in true graph, which should be count as fp.
for (Edge ele: estGraphLookup.getEdges()) {
List<Edge> estGraphLookupEdges = estGraphLookup.getEdges(ele.getNode1(), ele.getNode2());
Edge ele2;
if (estGraphLookupEdges.size() == 1) {
ele2 = estGraphLookupEdges.iterator().next();
} else {
ele2 = estGraphLookup.getDirectedEdge(ele.getNode1(), ele.getNode2());
}
Endpoint ep1Est = null;
Endpoint ep2Est = null;
if (ele2 != null) {
ep1Est = ele2.getProximalEndpoint(ele.getNode1());
ep2Est = ele2.getProximalEndpoint(ele.getNode2());
}

List<Edge> trueGraphLookupEdges = trueGraphLookup.getEdges(ele.getNode1(), ele.getNode1());
Edge tle;
if (trueGraphLookupEdges.size() == 1) {
tle = trueGraphLookupEdges.iterator().next();
} else {
tle = trueGraphLookup.getDirectedEdge(ele.getNode1(), ele.getNode2());
}
Endpoint ep1True = null;
Endpoint ep2True = null;
if (tle != null) {
ep1True = tle.getProximalEndpoint(ele.getNode1());
ep2True = tle.getProximalEndpoint(ele.getNode2());
}

boolean connected = trueGraph.isAdjacentTo(ele.getNode1(), ele.getNode2());
if (connected) {
if (ep1True == Endpoint.TAIL && ep2True == Endpoint.ARROW) { // True: ->
if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <-
this.fp++;
}
// TODO VBC: Question: seems we wont encounter <-> case, is it?
} else if (ep1True == Endpoint.ARROW && ep2True == Endpoint.TAIL) { // True: <-
if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: ->
this.fp++;
}
}
}
}
}

public int getTp() {
return tp;
}

public int getTn() {
return tn;
}

public int getFp() {
return fp;
}

public int getFn() {
return fn;
}
}
30 changes: 26 additions & 4 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.algcomparison.statistic.AdjacencyPrecision;
import edu.cmu.tetrad.algcomparison.statistic.AdjacencyRecall;
import edu.cmu.tetrad.algcomparison.statistic.ArrowheadPrecision;
import edu.cmu.tetrad.algcomparison.statistic.ArrowheadRecall;
import edu.cmu.tetrad.algcomparison.statistic.*;
import edu.cmu.tetrad.data.GeneralAndersonDarlingTest;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
Expand Down Expand Up @@ -332,6 +329,31 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGra
" ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr));
}

/**
* Calculates the precision and recall using LocalGraphConfusion
* (which calculates the combination of Adjacency and ArrowHead) on the Markov Blanket graph for a given node.
* Prints the statistics to the console.
*
* @param x The target node.
* @param estimatedGraph The estimated graph.
* @param trueGraph The true graph.
*/
public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGraph, Graph trueGraph) {
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());
Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x);
System.out.println("xMBLookupGraph:" + xMBLookupGraph);
Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x);
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph);

double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null);

NumberFormat nf = new DecimalFormat("0.00");
System.out.println("Node " + x + "'s statistics: " + " \n" +
" LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n");
}

/**
* Returns the variables of the independence test.
*
Expand Down
38 changes: 38 additions & 0 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -446,4 +446,42 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
}
}

@Test
public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

}

0 comments on commit ae16a35

Please sign in to comment.