From 847853cfd399fc8d42b0d2b98f1b60775767cabd Mon Sep 17 00:00:00 2001 From: chirayukong Date: Wed, 15 Feb 2017 16:45:05 -0500 Subject: [PATCH] Updated tetrad-lib test files --- .../test/ExploreAutisticsNeurotypicals.java | 4 +- .../TestConditionalGaussianSimulation.java | 110 +++++++++++++ .../test/TestConditionalLikelihood.java | 73 --------- .../test/java/edu/cmu/tetrad/test/TestDM.java | 4 +- .../test/{TestFgs.java => TestFges.java} | 147 +++++++++--------- .../java/edu/cmu/tetrad/test/TestGFci.java | 67 +++----- .../cmu/tetrad/test/TestGeneralizedSem.java | 4 +- .../edu/cmu/tetrad/test/TestGraphUtils.java | 67 +++++++- .../cmu/tetrad/test/TestIndTestFisherZ.java | 91 ++++++++++- .../cmu/tetrad/test/TestLingamPattern.java | 2 +- .../test/java/edu/cmu/tetrad/test/TestPc.java | 8 +- .../java/edu/cmu/tetrad/test/TestPurify.java | 4 +- .../edu/cmu/tetrad/test/TestStatUtils.java | 32 ++++ 13 files changed, 402 insertions(+), 211 deletions(-) create mode 100644 tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalGaussianSimulation.java delete mode 100644 tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalLikelihood.java rename tetrad-lib/src/test/java/edu/cmu/tetrad/test/{TestFgs.java => TestFges.java} (93%) diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/ExploreAutisticsNeurotypicals.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/ExploreAutisticsNeurotypicals.java index 74d15c2b63..59c9d36e53 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/ExploreAutisticsNeurotypicals.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/ExploreAutisticsNeurotypicals.java @@ -26,7 +26,7 @@ import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphUtils; import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.Fgs; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.SemBicScore; import edu.cmu.tetrad.util.RandomUtil; import edu.cmu.tetrad.util.TetradMatrix; @@ -91,7 +91,7 @@ private List> runAlgorithm(String path, List> allDatas SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); score.setPenaltyDiscount(penaltyDiscount); - Fgs search = new Fgs(score); + Fges search = new Fges(score); search.setVerbose(false); Graph graph = search.search(); GraphUtils.saveGraph(graph, file, false); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalGaussianSimulation.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalGaussianSimulation.java new file mode 100644 index 0000000000..db367993ef --- /dev/null +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalGaussianSimulation.java @@ -0,0 +1,110 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015 by Peter Spirtes, Richard Scheines, Joseph // +// Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.test; + +import edu.cmu.tetrad.algcomparison.Comparison; +import edu.cmu.tetrad.algcomparison.algorithm.Algorithms; +import edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.*; +import edu.cmu.tetrad.algcomparison.graph.RandomForward; +import edu.cmu.tetrad.algcomparison.score.ConditionalGaussianBicScore; +import edu.cmu.tetrad.algcomparison.simulation.*; +import edu.cmu.tetrad.algcomparison.statistic.*; +import edu.cmu.tetrad.util.Parameters; + +/** + * An example script to simulate data and run a comparison analysis on it. + * + * @author jdramsey + */ +public class TestConditionalGaussianSimulation { + + public void test1() { + Parameters parameters = new Parameters(); + + parameters.set("numRuns", 1); + parameters.set("numMeasures", 100); + parameters.set("avgDegree", 4); + parameters.set("sampleSize", 1000); + parameters.set("penaltyDiscount", 2); + + parameters.set("maxDegree", 8); + + parameters.set("numCategories", 2, 3, 4, 5); + parameters.set("percentDiscrete", 50); + + parameters.set("assumeMixed", false); + + parameters.set("intervalBetweenRecordings", 20); + + parameters.set("varLow", 1.); + parameters.set("varHigh", 3.); + parameters.set("coefLow", .5); + parameters.set("coefHigh", 1.5); + parameters.set("coefSymmetric", true); + parameters.set("meanLow", -1); + parameters.set("meanHigh", 1); + + Statistics statistics = new Statistics(); + + statistics.add(new ParameterColumn("numCategories")); + statistics.add(new ParameterColumn("assumeMixed")); + statistics.add(new AdjacencyPrecision()); + statistics.add(new AdjacencyRecall()); + statistics.add(new ArrowheadPrecision()); + statistics.add(new ArrowheadRecall()); + statistics.add(new ElapsedTime()); + + statistics.setWeight("AP", 1.0); + statistics.setWeight("AR", 0.5); + + Algorithms algorithms = new Algorithms(); + + algorithms.add(new Fges(new ConditionalGaussianBicScore())); +// algorithms.add(new PcMax(new ConditionalGaussianLRT())); + + Simulations simulations = new Simulations(); + + simulations.add(new ConditionalGaussianSimulation(new RandomForward())); +// simulations.add(new LeeHastieSimulation(new RandomForward())); + + Comparison comparison = new Comparison(); + + comparison.setShowAlgorithmIndices(true); + comparison.setShowSimulationIndices(false); + comparison.setSortByUtility(false); + comparison.setShowUtilities(false); + comparison.setParallelized(false); + comparison.setSaveGraphs(true); + + comparison.setTabDelimitedTables(true); + + comparison.compareFromSimulations("comparison", simulations, algorithms, statistics, parameters); + } + + public static void main(String...args) { + new TestConditionalGaussianSimulation().test1(); + } +} + + + + diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalLikelihood.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalLikelihood.java deleted file mode 100644 index 51ef7f6c5d..0000000000 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestConditionalLikelihood.java +++ /dev/null @@ -1,73 +0,0 @@ -package edu.cmu.tetrad.test; - -import edu.cmu.tetrad.data.ContinuousVariable; -import edu.cmu.tetrad.data.DataSet; -import edu.cmu.tetrad.data.DiscreteVariable; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphNode; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.ConditionalGaussianLikelihood; -import edu.cmu.tetrad.sem.GeneralizedSemIm; -import edu.cmu.tetrad.sem.GeneralizedSemPm; -import edu.pitt.csb.mgm.MixedUtils; -import org.junit.Test; - -import java.util.HashMap; - -import static org.junit.Assert.assertEquals; - -/** - * @author jdramsey - */ -public class TestConditionalLikelihood { - - @Test - public void test1() { - - // Make a DAG 1->2 with 1 discrete and 2 continuous. - Graph dag = new EdgeListGraph(); - - Node n1 = new GraphNode("X0"); - Node n2 = new GraphNode("X1"); - - dag.addNode(n1); - dag.addNode(n2); - - dag.addDirectedEdge(n1, n2); - - // Simulate data from it using Lee & Hastie method. - HashMap nd = new HashMap<>(); - - nd.put(n1.getName(), 3); - nd.put(n2.getName(), 0); - - Graph graph = MixedUtils.makeMixedGraph(dag, nd); - GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(graph, "Split(-1.5,-.5,.5,1.5)"); - GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm); - DataSet data = MixedUtils.makeMixedData(im.simulateDataAvoidInfinity(100, false), nd); - - // Calculate lik and dof for 1 | 2, 2, 2 | 1, and 1. - ConditionalGaussianLikelihood lik = new ConditionalGaussianLikelihood(data); - - ConditionalGaussianLikelihood.Ret ret1 = lik.getLikelihoodRatio(0, new int[]{1}); - ConditionalGaussianLikelihood.Ret ret2 = lik.getLikelihoodRatio(1, new int[]{}); - ConditionalGaussianLikelihood.Ret ret3 = lik.getLikelihoodRatio(1, new int[]{0}); - ConditionalGaussianLikelihood.Ret ret4 = lik.getLikelihoodRatio(0, new int[]{}); - -// // Print out these likelihoods. -// System.out.println(ret1); -// System.out.println(ret2); -// System.out.println(ret3); -// System.out.println(ret4); -// -// System.out.println(); -// -// // Print sum of 1 | 2 and 2 and sum of 2 | 1 and 1 -// System.out.println("SUM 1, 2 lik = " + (ret1.getLik() + ret2.getLik()) + " dof = " + (ret1.getDof() + ret2.getDof())); -// System.out.println("SUM 3, 4 lik = " + (ret3.getLik() + ret4.getLik()) + " dof = " + (ret3.getDof() + ret4.getDof())); - - assertEquals(ret1.getLik() + ret2.getLik(), ret3.getLik() + ret4.getLik(), 0.001); - assertEquals(ret1.getDof() + ret2.getDof(), ret3.getDof() + ret4.getDof(), 0.001); - } -} diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java index 4a69911777..28bc23d1d2 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java @@ -651,7 +651,7 @@ public void test10() { DMSearch search = new DMSearch(); - search.setUseFgs(false); + search.setUseFges(false); search.setInputs(new int[]{0, 1}); search.setOutputs(new int[]{2, 3, 4}); @@ -1206,7 +1206,7 @@ public DMSearch readAndSearchData(String fileLocation, int[] inputs, int[] outpu if (useGES == false) { search.setAlphaPC(.05); - search.setUseFgs(useGES); + search.setUseFges(useGES); search.setData(data); search.setTrueInputs(trueInputs); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFgs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java similarity index 93% rename from tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFgs.java rename to tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java index 9827f8ec18..9777b287a9 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFgs.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java @@ -35,7 +35,7 @@ import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.*; -import edu.cmu.tetrad.search.Fgs; +import edu.cmu.tetrad.search.Fges; import edu.cmu.tetrad.search.Pc; import edu.cmu.tetrad.search.PcStable; import edu.cmu.tetrad.search.SemBicScore; @@ -59,7 +59,7 @@ /** * @author Joseph Ramsey */ -public class TestFgs { +public class TestFges { private PrintStream out = System.out; @@ -100,15 +100,15 @@ public void explore1() { SemBicScore score = new SemBicScore(cov); score.setPenaltyDiscount(penaltyDiscount); - Fgs fgs = new Fgs(score); - fgs.setVerbose(false); - fgs.setNumPatternsToStore(0); - fgs.setOut(out); - fgs.setFaithfulnessAssumed(true); -// fgs.setMaxIndegree(1); - fgs.setCycleBound(5); + Fges fges = new Fges(score); + fges.setVerbose(false); + fges.setNumPatternsToStore(0); + fges.setOut(out); + fges.setFaithfulnessAssumed(true); +// fges.setMaxIndegree(1); + fges.setCycleBound(5); - Graph estPattern = fgs.search(); + Graph estPattern = fges.search(); // printDegreeDistribution(estPattern, out); @@ -167,7 +167,7 @@ public void explore2() { score.setSamplePrior(samplePrior); score.setStructurePrior(structurePrior); - Fgs ges = new Fgs(score); + Fges ges = new Fges(score); ges.setVerbose(false); ges.setNumPatternsToStore(0); ges.setFaithfulnessAssumed(false); @@ -202,31 +202,31 @@ public void explore2() { @Test public void testExplore3() { Graph graph = GraphConverter.convert("A-->B,A-->C,B-->D,C-->D"); - Fgs fgs = new Fgs(new GraphScore(graph)); - Graph pattern = fgs.search(); + Fges fges = new Fges(new GraphScore(graph)); + Graph pattern = fges.search(); assertEquals(SearchGraphUtils.patternForDag(graph), pattern); } @Test public void testExplore4() { Graph graph = GraphConverter.convert("A-->B,A-->C,A-->D,B-->E,C-->E,D-->E"); - Fgs fgs = new Fgs(new GraphScore(graph)); - Graph pattern = fgs.search(); + Fges fges = new Fges(new GraphScore(graph)); + Graph pattern = fges.search(); assertEquals(SearchGraphUtils.patternForDag(graph), pattern); } @Test public void testExplore5() { Graph graph = GraphConverter.convert("A-->B,A-->C,A-->D,A->E,B-->F,C-->F,D-->F,E-->F"); - Fgs fgs = new Fgs(new GraphScore(graph)); - fgs.setFaithfulnessAssumed(false); - Graph pattern = fgs.search(); + Fges fges = new Fges(new GraphScore(graph)); + fges.setFaithfulnessAssumed(false); + Graph pattern = fges.search(); assertEquals(SearchGraphUtils.patternForDag(graph), pattern); } @Test - public void testFromGraphSimpleFgs() { + public void testFromGraphSimpleFges() { // This may fail if faithfulness is assumed but should pass if not. @@ -247,9 +247,9 @@ public void testFromGraphSimpleFgs() { g.addDirectedEdge(x4, x3); Graph pattern1 = new Pc(new IndTestDSep(g)).search(); - Fgs fgs = new Fgs(new GraphScore(g)); - fgs.setFaithfulnessAssumed(true); - Graph pattern2 = fgs.search(); + Fges fges = new Fges(new GraphScore(g)); + fges.setFaithfulnessAssumed(true); + Graph pattern2 = fges.search(); // System.out.println(pattern1); // System.out.println(pattern2); @@ -258,7 +258,7 @@ public void testFromGraphSimpleFgs() { } @Test - public void testFromGraphSimpleFgsMb() { + public void testFromGraphSimpleFgesMb() { // This may fail if faithfulness is assumed but should pass if not. @@ -279,9 +279,9 @@ public void testFromGraphSimpleFgsMb() { g.addDirectedEdge(x4, x3); Graph pattern1 = new Pc(new IndTestDSep(g)).search(); - FgsMb2 fgs = new FgsMb2(new GraphScore(g)); -// fgs.setHeuristicSpeedup(false); - Graph pattern2 = fgs.search(x1); + FgesMb2 fges = new FgesMb2(new GraphScore(g)); +// fges.setHeuristicSpeedup(false); + Graph pattern2 = fges.search(x1); // System.out.println(pattern1); // System.out.println(pattern2); @@ -290,19 +290,19 @@ public void testFromGraphSimpleFgsMb() { } @Test - public void testFgsMbFromGraph() { + public void testFgesMbFromGraph() { int numNodes = 20; int numIterations = 10; for (int i = 0; i < numIterations; i++) { // System.out.println("Iteration " + (i + 1)); Graph dag = GraphUtils.randomDag(numNodes, 0, numNodes, 10, 10, 10, false); - GraphScore fgsScore = new GraphScore(dag); + GraphScore fgesScore = new GraphScore(dag); - Fgs fgs = new Fgs(fgsScore); - Graph pattern1 = fgs.search(); + Fges fges = new Fges(fgesScore); + Graph pattern1 = fges.search(); - Node x1 = fgsScore.getVariable("X1"); + Node x1 = fgesScore.getVariable("X1"); Set mb = new HashSet<>(); mb.add(x1); @@ -315,8 +315,8 @@ public void testFgsMbFromGraph() { Graph mb1 = pattern1.subgraph(new ArrayList<>(mb)); - FgsMb2 fgsMb = new FgsMb2(fgsScore); - Graph mb2 = fgsMb.search(x1); + FgesMb2 fgesMb = new FgesMb2(fgesScore); + Graph mb2 = fgesMb.search(x1); assertEquals(mb1, mb2); } @@ -395,12 +395,12 @@ public void clarkTest() { ScoreWrapper score = new edu.cmu.tetrad.algcomparison.score.SemBicScore(); IndependenceWrapper test = new FisherZ(); - Algorithm fgs = new edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fgs(score); + Algorithm fges = new edu.cmu.tetrad.algcomparison.algorithm.oracle.pattern.Fges(score); - Graph fgsGraph = fgs.search(dataSet, parameters); + Graph fgesGraph = fges.search(dataSet, parameters); - clarkTestForAlpha(0.05, parameters, dataSet, trueGraph, fgsGraph, test); - clarkTestForAlpha(0.01, parameters, dataSet, trueGraph, fgsGraph, test); + clarkTestForAlpha(0.05, parameters, dataSet, trueGraph, fgesGraph, test); + clarkTestForAlpha(0.01, parameters, dataSet, trueGraph, fgesGraph, test); } @@ -563,10 +563,10 @@ public void testCites() { SemBicScore score = new SemBicScore(dataSet); score.setPenaltyDiscount(1); - Fgs fgs = new Fgs(score); - fgs.setKnowledge(knowledge); + Fges fges = new Fges(score); + fges.setKnowledge(knowledge); - Graph pattern = fgs.search(); + Graph pattern = fges.search(); // System.out.println(pattern); @@ -608,10 +608,10 @@ private void checkSearch(String inputGraph, String outputGraph) { Graph graph = GraphConverter.convert(inputGraph); // Set up search. - Fgs fgs = new Fgs(new GraphScore(graph)); + Fges fges = new Fges(new GraphScore(graph)); // Run search - Graph resultGraph = fgs.search(); + Graph resultGraph = fges.search(); // Build comparison graph. Graph trueGraph = GraphConverter.convert(outputGraph); @@ -640,13 +640,13 @@ private void checkWithKnowledge(String inputGraph, String answerGraph, Graph input = GraphConverter.convert(inputGraph); // Set up search. - Fgs fgs = new Fgs(new GraphScore(input)); + Fges fges = new Fges(new GraphScore(input)); // Set up search. - fgs.setKnowledge(knowledge); + fges.setKnowledge(knowledge); // Run search - Graph result = fgs.search(); + Graph result = fges.search(); // Build comparison graph. Graph answer = GraphConverter.convert(answerGraph); @@ -701,9 +701,9 @@ public void testFromGraph() { for (int i = 0; i < numIterations; i++) { // System.out.println("Iteration " + (i + 1)); Graph dag = GraphUtils.randomDag(numNodes, 0, numNodes, 10, 10, 10, false); - Fgs fgs = new Fgs(new GraphScore(dag)); - fgs.setFaithfulnessAssumed(true); - Graph pattern1 = fgs.search(); + Fges fges = new Fges(new GraphScore(dag)); + fges.setFaithfulnessAssumed(true); + Graph pattern1 = fges.search(); Graph pattern2 = new Pc(new IndTestDSep(dag)).search(); // System.out.println(pattern2); assertEquals(pattern2, pattern1); @@ -734,21 +734,22 @@ public void testFromData() { DataSet data = semSimulator.simulateDataFisher(sampleSize); - data = DataUtils.restrictToMeasured(data); - SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(data)); score.setPenaltyDiscount(4); - Fgs fgs = new Fgs(score); + Fges fges = new Fges(score); long start = System.currentTimeMillis(); - Graph graph = fgs.search(); + Graph graph = fges.search(); long stop = System.currentTimeMillis(); System.out.println("Elapsed " + (stop - start) + " ms"); Graph pattern = SearchGraphUtils.patternForDag(dag); + + pattern = GraphUtils.replaceNodes(pattern, graph.getNodes()); + System.out.println(MisclassificationUtils.edgeMisclassifications(graph, pattern)); } @@ -799,9 +800,9 @@ public void testAjData() { long start = System.currentTimeMillis(); -// Graph pattern = searchSemFgs(Dk); -// Graph pattern = searchBdeuFgs(Dk, k); - Graph pattern = searchMixedFgs(Dk, penalty); +// Graph pattern = searchSemFges(Dk); +// Graph pattern = searchBdeuFges(Dk, k); + Graph pattern = searchMixedFges(Dk, penalty); long stop = System.currentTimeMillis(); @@ -825,15 +826,15 @@ public void testAjData() { } } - private Graph searchSemFgs(DataSet Dk, double penalty) { + private Graph searchSemFges(DataSet Dk, double penalty) { Dk = DataUtils.convertNumericalDiscreteToContinuous(Dk); SemBicScore score = new SemBicScore(new CovarianceMatrixOnTheFly(Dk)); score.setPenaltyDiscount(penalty); - Fgs fgs = new Fgs(score); - return fgs.search(); + Fges fges = new Fges(score); + return fges.search(); } - private Graph searchBdeuFgs(DataSet Dk, int k) { + private Graph searchBdeuFges(DataSet Dk, int k) { Discretizer discretizer = new Discretizer(Dk); List nodes = Dk.getVariables(); @@ -848,25 +849,25 @@ private Graph searchBdeuFgs(DataSet Dk, int k) { BDeuScore score = new BDeuScore(Dk); score.setSamplePrior(1.0); score.setStructurePrior(1.0); - Fgs fgs = new Fgs(score); - return fgs.search(); + Fges fges = new Fges(score); + return fges.search(); } - private Graph searchMixedFgs(DataSet dk, double penalty) { + private Graph searchMixedFges(DataSet dk, double penalty) { MixedBicScore score = new MixedBicScore(dk); score.setPenaltyDiscount(penalty); - Fgs fgs = new Fgs(score); - return fgs.search(); + Fges fges = new Fges(score); + return fges.search(); } - public Graph searchMGMFgs(DataSet ds, double penalty) { + public Graph searchMGMFges(DataSet ds, double penalty) { MGM m = new MGM(ds, new double[]{0.1, 0.1, 0.1}); //m.setVerbose(this.verbose); Graph gm = m.search(); DataSet dataSet = MixedUtils.makeContinuousData(ds); SemBicScore2 score = new SemBicScore2(new CovarianceMatrixOnTheFly(dataSet)); score.setPenaltyDiscount(penalty); - Fgs fg = new Fgs(score); + Fges fg = new Fges(score); fg.setBoundGraph(gm); fg.setVerbose(false); return fg.search(); @@ -897,13 +898,13 @@ public DataSet getMixedDataAjStyle(Graph g, int k, int samps) { GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm); // System.out.println(im); - DataSet ds = im.simulateDataAvoidInfinity(samps, false); + DataSet ds = im.simulateDataFisher(samps); return MixedUtils.makeMixedData(ds, nd); } // @Test public void testBestAlgorithms() { - String[] algorithms = {"SemFGS", "BDeuFGS", "MixedFGS", "PC", "PCS", "CPC", "MGMFgs", "MGMPcs"}; + String[] algorithms = {"SemFGES", "BDeuFGES", "MixedFGES", "PC", "PCS", "CPC", "MGMFges", "MGMPcs"}; String[] statLabels = {"AP", "AR", "OP", "OR", "SUM", "McAdj", "McOr", "F1Adj", "F1Or", "E"}; int numMeasures = 30; @@ -1007,16 +1008,16 @@ private double[][] printStats(String[] algorithms, int t, int numRuns, switch (t) { case 0: - out = searchSemFgs(data, penalty); + out = searchSemFges(data, penalty); break; case 1: - out = searchBdeuFgs(data, numCategories); + out = searchBdeuFges(data, numCategories); break; case 2: - out = searchMixedFgs(data, penalty); + out = searchMixedFges(data, penalty); break; case 6: - out = searchMGMFgs(data, penalty); + out = searchMGMFges(data, penalty); break; default: throw new IllegalStateException(); @@ -1373,7 +1374,7 @@ public int compare(Pair o1, Pair o2) { } public static void main(String... args) { - new TestFgs().testBestAlgorithms(); + new TestFges().testBestAlgorithms(); } } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java index 79721106ae..a6f5b303b5 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java @@ -37,7 +37,6 @@ import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; -import java.util.Date; import java.util.List; import static org.junit.Assert.assertEquals; @@ -56,18 +55,10 @@ public void test1() { int numEdges = 10; int sampleSize = 1000; -// int numNodes = 3000; -// int numLatents = 150; -// int numEdges = 4500; -// int sampleSize = 1000; - double alpha = 0.01; double penaltyDiscount = 2; int depth = -1; int maxPathLength = -1; - boolean possibleDsepDone = true; - boolean completeRuleSetUsed = false; - boolean faithfulnessAssumed = true; List vars = new ArrayList<>(); @@ -97,11 +88,10 @@ public void test1() { GFci gFci = new GFci(independenceTest, score); gFci.setVerbose(false); - gFci.setMaxIndegree(depth); + gFci.setMaxDegree(depth); gFci.setMaxPathLength(maxPathLength); -// gFci.setPossibleDsepSearchDone(possibleDsepDone); - gFci.setCompleteRuleSetUsed(completeRuleSetUsed); - gFci.setFaithfulnessAssumed(faithfulnessAssumed); + gFci.setCompleteRuleSetUsed(false); + gFci.setFaithfulnessAssumed(true); Graph outGraph = gFci.search(); final DagToPag dagToPag = new DagToPag(dag); @@ -168,58 +158,41 @@ public void test2() { truePag.addBidirectedEdge(x2, x3); truePag.addPartiallyOrientedEdge(x4, x3); -// System.out.println(pag); - assertEquals(pag, truePag); } @Test public void testFromGraph() { - RandomUtil.getInstance().setSeed(new Date().getTime()); +// RandomUtil.getInstance().setSeed(new Date().getTime()); + RandomUtil.getInstance().setSeed(19444322L); - int numNodes = 20; + int numNodes = 15; int numLatents = 5; - int numIterations = 20; - - boolean completeRuleSetUsed = false; - boolean faithfulnessAssumed = true; + int numIterations = 10; for (int i = 0; i < numIterations; i++) { -// System.out.println("Iteration " + (i + 1)); Graph dag = GraphUtils.randomGraph(numNodes, numLatents, numNodes, 10, 10, 10, false); GFci gfci = new GFci(new IndTestDSep(dag), new GraphScore(dag)); - gfci.setCompleteRuleSetUsed(completeRuleSetUsed); -// GFci gfci = new GFci(new IndTestDSep(dag)); - gfci.setFaithfulnessAssumed(faithfulnessAssumed); + gfci.setCompleteRuleSetUsed(false); + gfci.setFaithfulnessAssumed(true); Graph pag1 = gfci.search(); - - DagToPag dagToPag = new DagToPag(dag); - dagToPag.setCompleteRuleSetUsed(completeRuleSetUsed); + dagToPag.setCompleteRuleSetUsed(false); Graph pag2 = dagToPag.convert(); -// System.out.println(pag1); -// System.out.println(pattern2); -// -// System.out.println(MisclassificationUtils.edgeMisclassifications(pag1, pag2)); assertEquals(pag2, pag1); } } @Test public void testFromData() { - int numNodes = 1000; - int numLatents = 50; - int numEdges = 1000; - int sampleSize = 1000; - -// System.out.println(RandomUtil.getInstance().getSeed()); -// -// RandomUtil.getInstance().setSeed(1461186701390L); - + int numNodes = 20; + int numLatents = 5; + int numEdges = 20; + int sampleSize = 50; List variables = new ArrayList<>(); @@ -243,7 +216,7 @@ public void testFromData() { long start = System.currentTimeMillis(); - Graph graph = gFci.search(); + gFci.search(); long stop = System.currentTimeMillis(); @@ -251,8 +224,6 @@ public void testFromData() { DagToPag dagToPag = new DagToPag(g); dagToPag.setVerbose(false); -// System.out.println(MisclassificationUtils.edgeMisclassifications(graph, dagToPag.convert())); - } @Test @@ -276,7 +247,7 @@ public void testRandomDiscreteData() { long start = System.currentTimeMillis(); - Graph graph = gFci.search(); + gFci.search(); long stop = System.currentTimeMillis(); @@ -291,7 +262,7 @@ public void testDiscreteData() throws IOException { double alpha = 0.05; char delimiter = '\t'; Path dataFile = Paths.get("../causal-cmd/test/data/diff_delim/sim_discrete_data_20vars_100cases.txt"); - // System.out.println(dataFile.toAbsolutePath().toString()); + VerticalTabularDiscreteDataReader dataReader = new VerticalTabularDiscreteDataReader(dataFile, delimiter); DataSet dataSet = dataReader.readInData(); @@ -303,14 +274,14 @@ public void testDiscreteData() throws IOException { GFci gFci = new GFci(indTest, score); gFci.setFaithfulnessAssumed(true); - gFci.setMaxIndegree(-1); + gFci.setMaxDegree(-1); gFci.setMaxPathLength(-1); gFci.setCompleteRuleSetUsed(false); gFci.setVerbose(true); long start = System.currentTimeMillis(); - Graph graph = gFci.search(); + gFci.search(); long stop = System.currentTimeMillis(); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGeneralizedSem.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGeneralizedSem.java index 525e468387..3b7117c1dd 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGeneralizedSem.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGeneralizedSem.java @@ -128,7 +128,7 @@ public void test1() { print(im); - DataSet dataSet = im.simulateDataAvoidInfinity(10, false); + DataSet dataSet = im.simulateDataFisher(10); print(dataSet); @@ -467,7 +467,7 @@ public void test6() { double aSquaredStar = estimator.getaSquaredStar(); - assertEquals(0.59, aSquaredStar, 0.01); + assertEquals(1.04, aSquaredStar, 0.01); } @Test diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java index 1a96c1d1f7..e3aba4e71e 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java @@ -27,9 +27,9 @@ import edu.cmu.tetrad.util.RandomUtil; import org.junit.Test; -import java.io.*; -import java.text.DecimalFormat; +import java.awt.*; import java.util.*; +import java.util.List; import static junit.framework.TestCase.assertEquals; import static org.junit.Assert.assertThat; @@ -279,6 +279,69 @@ public void test8() { } } + @Test + public void testPagColoring() { + Graph dag = GraphUtils.randomGraph(30, 5, 50, 10, 10, 10, false); + Graph pag = new DagToPag(dag).convert(); + + GraphUtils.addPagColoring(pag); + + for (Edge edge : pag.getEdges()) { + Node x1 = edge.getNode1(); + Node x2 = edge.getNode2(); + + if (edge.getLineColor() == Color.green) { + System.out.println("Green"); + + for (Node L : pag.getNodes()) { + if (L == x1 || L == x2) continue; + + if (L.getNodeType() == NodeType.LATENT) { + if (existsLatentPath(dag, L, x1) && existsLatentPath(dag, L, x2)) { + System.out.println("Edge " + edge + " falsely colored green."); + } + } + } + } + + if (edge.isDashed()) { + System.out.println("Dashed"); + + if (!existsLatentPath(dag, x1, x2)) { + System.out.println("Edge " + edge + " is falsely dashed."); + } + } + } + } + + public static boolean existsLatentPath(Graph graph, Node b, Node y) { + if (b == y) return false; + return existsLatentPath(graph, b, y, new LinkedList()); + } + + public static boolean existsLatentPath(Graph graph, Node b, Node y, LinkedList path) { + if (path.contains(b)) { + return false; + } + + path.addLast(b); + + for (Node c : graph.getChildren(b)) { + if (c == y) return true; + + if (c.getNodeType() != NodeType.LATENT) { + continue; + } + + if (!existsLatentPath(graph, c, y, path)) { + return false; + } + } + + path.removeLast(); + return true; + } + private List list(Node... z) { List list = new ArrayList<>(); Collections.addAll(list, z); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestIndTestFisherZ.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestIndTestFisherZ.java index 87f9d2023a..2c5b50d587 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestIndTestFisherZ.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestIndTestFisherZ.java @@ -21,15 +21,21 @@ package edu.cmu.tetrad.test; -import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.*; import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.IndTestFisherZ; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.sem.SemIm; import edu.cmu.tetrad.sem.SemPm; +import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.RandomUtil; +import edu.cmu.tetrad.util.StatUtils; +import edu.cmu.tetrad.util.TetradMatrix; import org.junit.Test; +import java.util.List; + +import static java.lang.Math.*; import static org.junit.Assert.assertEquals; @@ -92,6 +98,89 @@ public void testDirections() { assertEquals(0, p2, 0.01); assertEquals(0, p3, 0.01); } + + @Test + public void test2() { +// for (int p = 0; p < 50; p++) { +// final double low = .2; +// final double high = .5; +// +// double a = RandomUtil.getInstance().nextUniform(low, high); +// double b = RandomUtil.getInstance().nextUniform(low, high); +// double c = RandomUtil.getInstance().nextUniform(low, high); +// double d = RandomUtil.getInstance().nextUniform(low, high); +// +// final double q = d + a * b * c; +// +// double g1 = ((a * b + c * d) - c * q) / (sqrt(1. - c * c) * sqrt(1. - q * q)); +// double g2 = a * b + c * d; +// +// double c1 = c + a * b * d; +// double d1 = d + a * b * c; +// +// double g3 = (a * b - c1 * d1) / ((sqrt(1. - c1 * c1) * sqrt(1. - d1 * d1))); +// double g4 = a * b; +// +// double t = sqrt(1. - c * c) / sqrt(1. - q * q); +// +// System.out.println((g1 < g2) + "\t" + (g3 > g4) + "\t" + t); +// } + + for (int p = 0; p < 50; p++) { + Graph graph = new EdgeListGraph(); + Node x = new ContinuousVariable("X"); + Node y = new ContinuousVariable("Y"); + Node w1 = new ContinuousVariable("W1"); + Node w2 = new ContinuousVariable("W2"); + Node w3 = new ContinuousVariable("W3"); + Node r = new ContinuousVariable("R"); + + graph.addNode(x); + graph.addNode(y); + graph.addNode(w1); + graph.addNode(w2); + graph.addNode(w3); + graph.addNode(r); + + graph.addDirectedEdge(x, w1); + graph.addDirectedEdge(w1, w2); + graph.addDirectedEdge(w2, y); + graph.addDirectedEdge(w3, y); +// graph.addDirectedEdge(x, r); + +// graph.addDirectedEdge(r, y); + graph.addDirectedEdge(y, r); +// + SemPm pm = new SemPm(graph); + + Parameters parameters = new Parameters(); + + parameters.set("coefLow", .3); + parameters.set("coefHigh", .8); + parameters.set("coefSymmetric", false); + + SemIm im = new SemIm(pm, parameters); + + final int N = 1000; + DataSet data = im.simulateData(N, false); + ICovarianceMatrix _cov = new CovarianceMatrix(data); + TetradMatrix cov = _cov.getMatrix(); + + List nodes = _cov.getVariables(); + + final int xi = nodes.indexOf(x); + final int yi = nodes.indexOf(y); + final int ri = nodes.indexOf(r); + + double xy = StatUtils.partialCorrelation(cov, xi, yi); + double xyr = StatUtils.partialCorrelation(cov, xi, yi, ri); + + double f1 = 0.5 * sqrt(N - 3) * log(1. + xy) - log(1. - xy); + double f2 = 0.5 * sqrt(N - 3 - 1) * log(1. + xyr) - log(1. - xyr); + + System.out.println(abs(f1) > abs(f2)); + } + } } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestLingamPattern.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestLingamPattern.java index fe96a9d02d..211d6dce92 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestLingamPattern.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestLingamPattern.java @@ -75,7 +75,7 @@ public void test1() { DataSet dataSet = simulateDataNonNormal(semIm, sampleSize, variableDistributions); Score score = new SemBicScore(new CovarianceMatrixOnTheFly(dataSet)); - Graph estPattern = new Fgs(score).search(); + Graph estPattern = new Fges(score).search(); LingamPattern lingam = new LingamPattern(estPattern, dataSet); lingam.search(); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java index e94c800876..8999a98da1 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java @@ -238,7 +238,7 @@ public void testPcStable2() { // @Test public void testPcFci() { - String[] algorithms = {"PC", "CPC", "FGS", "FCI", "GFCI", "RFCI", "CFCI"}; + String[] algorithms = {"PC", "CPC", "FGES", "FCI", "GFCI", "RFCI", "CFCI"}; String[] statLabels = {"AP", "TP", "BP", "NA", "NT", "NB", "E"/*, "AP/E"*/}; int numMeasures = 200; @@ -352,7 +352,7 @@ private double[] printStats(String[] algorithms, int t, boolean directed, int nu search = new Cpc(test); break; case 2: - search = new Fgs(score); + search = new Fges(score); break; case 3: search = new Fci(test); @@ -615,7 +615,7 @@ public int compare(Pair o1, Pair o2) { // @Test public void testPcRegression() { - String[] algorithms = {"PC", "CPC", "FGS", "FCI", "GFCI", "RFCI", "CFCI", "Regression"}; + String[] algorithms = {"PC", "CPC", "FGES", "FCI", "GFCI", "RFCI", "CFCI", "Regression"}; String[] statLabels = {"AP", "AR"}; int numMeasures = 10; @@ -736,7 +736,7 @@ private double[] printStatsPcRegression(String[] algorithms, int t, boolean dire out = search.search(); break; case 2: - search = new Fgs(score); + search = new Fges(score); out = search.search(); break; case 3: diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPurify.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPurify.java index b77c1f1e6b..93cb8738c9 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPurify.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPurify.java @@ -309,9 +309,7 @@ public void test2() { Graph _structuralGraph = _graph.subgraph(_latents); - assertEquals(3, _structuralGraph.getNumEdges()); - - + assertEquals(2, _structuralGraph.getNumEdges()); } } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestStatUtils.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestStatUtils.java index 7a3555a4b3..0682180a7e 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestStatUtils.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestStatUtils.java @@ -647,6 +647,38 @@ public void test2() { System.out.println(sum / (60 * 60)); } +// @Test + public void test3() { + int count = 0; + int total = 100000; + + for (int i = 0; i < total; i++) { + double v1 = RandomUtil.getInstance().nextUniform(0, 100); + double v2 = RandomUtil.getInstance().nextUniform(0, 100); + double w1 = RandomUtil.getInstance().nextUniform(0, 1); + double w2 = 1.0 - w1; + double m1 = RandomUtil.getInstance().nextUniform(-5.0, 5.0); + double m2 = RandomUtil.getInstance().nextUniform(-5.0, 5.0); + + double left1 = 1.0 / v1; + double left2 = 1.0 / v2; + double m = (left1 + left2) / 2.0; + + double denRight = w1 * v1 + w2 * v2 + w1 * (m1 - m) * (m1 - m) + w1 * (m2 - m) * (m2 - m); + + double right = 1.0 / denRight; + + boolean holds = left1 + left2 > right; + + if (holds) count++; + +// System.out.println( + //x);// + " v1 = " + v1 + " v2 = " + v2 + " m1 = " + m1 + " m2 = " + m2); + } + + System.out.println(count); + } + private String f(double d1) { NumberFormat f = new DecimalFormat("0.000000"); return f.format(d1);