From 7c88361b78b06089bd58229d2ac408b9dbf26ac7 Mon Sep 17 00:00:00 2001 From: Joseph Ramsey Date: Wed, 6 Jan 2016 16:43:46 -0500 Subject: [PATCH 1/2] Fiddling. --- .../main/java/edu/cmu/tetrad/search/Fgs.java | 4 +- .../tetrad/search/FindTwoFactorClusters.java | 367 +++++++++--------- 2 files changed, 187 insertions(+), 184 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java index bb2996665b..29a37dcc6a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fgs.java @@ -461,9 +461,9 @@ public void setCycleBound(int cycleBound) { } /** - * Creates a new processors pool with the specified number of threads + * Creates a new processors pool with the specified number of threads. */ - public void setNumProcessors(int numProcessors) { + public void setParallelism(int numProcessors) { this.pool = new ForkJoinPool(numProcessors); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FindTwoFactorClusters.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FindTwoFactorClusters.java index 04c1a5d417..6c7dcfa706 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FindTwoFactorClusters.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FindTwoFactorClusters.java @@ -34,7 +34,7 @@ /** * Implements FindOneFactorCluster by Erich Kummerfeld (adaptation of a two factor - * quartet algorithm to a one factor IntSextad algorithm). + * sextet algorithm to a one factor IntSextad algorithm). * * @author Joseph Ramsey */ @@ -214,12 +214,12 @@ public Graph search() { private Set> estimateClustersGAP() { List _variables = allVariables(); - Set> pentads = findPurepentads(_variables); - Set> combined = combinePurePentads(pentads, _variables); + Set> pentads = findPurepentads(_variables); + Set> combined = combinePurePentads(pentads, _variables); Set> _combined = new HashSet<>(); - for (Set c : combined) { + for (List c : combined) { List a = new ArrayList<>(c); Collections.sort(a); _combined.add(a); @@ -246,7 +246,7 @@ private Set> estimateClustersSAG() { } - private Set> findPurepentads(List variables) { + private Set> findPurepentads(List variables) { if (variables.size() < 6) { return new HashSet<>(); } @@ -255,7 +255,7 @@ private Set> findPurepentads(List variables) { ChoiceGenerator gen = new ChoiceGenerator(variables.size(), 5); int[] choice; - Set> purePentads = new HashSet<>(); + Set> purePentads = new HashSet<>(); CHOICE: while ((choice = gen.next()) != null) { int n1 = variables.get(choice[0]); @@ -275,6 +275,8 @@ private Set> findPurepentads(List variables) { List sextet = sextet(n1, n2, n3, n4, n5, o); + Collections.sort(sextet); + boolean vanishes = vanishes(sextet); if (!vanishes) { @@ -282,7 +284,7 @@ private Set> findPurepentads(List variables) { } } - HashSet _cluster = new HashSet<>(pentad); + List _cluster = new ArrayList<>(pentad); if (verbose) { System.out.println(variablesForIndices(pentad)); @@ -295,13 +297,13 @@ private Set> findPurepentads(List variables) { return purePentads; } - private Set> combinePurePentads(Set> purePentads, List _variables) { + private Set> combinePurePentads(Set> purePentads, List _variables) { log("Growing pure pentads.", true); - Set> grown = new HashSet<>(); + Set> grown = new HashSet<>(); // Lax grow phase with speedup. - if (true) { - Set t = new HashSet<>(); + if (false) { + List t = new ArrayList<>(); int count = 0; int total = purePentads.size(); @@ -310,57 +312,49 @@ private Set> combinePurePentads(Set> purePentads, List break; } - Set cluster = purePentads.iterator().next(); - Set _cluster = new HashSet<>(cluster); - boolean changed = true; + List cluster = purePentads.iterator().next(); + List _cluster = new ArrayList<>(cluster); - while (changed) { - changed = false; - - for (int o : _variables) { - if (_cluster.contains(o)) continue; + for (int o : _variables) { + if (_cluster.contains(o)) continue; - List _cluster2 = new ArrayList<>(_cluster); - int rejected = 0; - int accepted = 0; + List _cluster2 = new ArrayList<>(_cluster); + int rejected = 0; + int accepted = 0; - ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 4); - int[] choice; + ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 4); + int[] choice; - while ((choice = gen.next()) != null) { - t.clear(); - t.add(_cluster2.get(choice[0])); - t.add(_cluster2.get(choice[1])); - t.add(_cluster2.get(choice[2])); - t.add(_cluster2.get(choice[3])); - t.add(o); + while ((choice = gen.next()) != null) { + t.clear(); + t.add(_cluster2.get(choice[0])); + t.add(_cluster2.get(choice[1])); + t.add(_cluster2.get(choice[2])); + t.add(_cluster2.get(choice[3])); + t.add(o); - if (!purePentads.contains(t)) { - rejected++; - } else { - accepted++; - } + if (!purePentads.contains(t)) { + rejected++; + } else { + accepted++; } + } - System.out.println("accepted = " + accepted + " rejected = " + rejected); - - if (rejected > accepted) { - continue; - } + if (rejected > accepted) { + continue; + } -// if (rejected > 0) { -// continue; -// } + _cluster.add(o); - _cluster.add(o); - changed = true; - } +// if (!(avgSumLnP(new ArrayList(_cluster)) > -10)) { +// _cluster.remove(o); +// } } // This takes out all pure clusters that are subsets of _cluster. - ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 5); + ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3); int[] choice2; - List _cluster3 = new ArrayList<>(_cluster); + List _cluster3 = new ArrayList(_cluster); while ((choice2 = gen2.next()) != null) { int n1 = _cluster3.get(choice2[0]); @@ -380,7 +374,7 @@ private Set> combinePurePentads(Set> purePentads, List } if (verbose) { - System.out.println("Grown " + (++count) + " of " + total + ": " + variablesForIndices(new ArrayList<>(_cluster))); + System.out.println("Grown " + (++count) + " of " + total + ": " + variablesForIndices(new ArrayList(_cluster))); } grown.add(_cluster); } while (!purePentads.isEmpty()); @@ -392,8 +386,8 @@ private Set> combinePurePentads(Set> purePentads, List int total = purePentads.size(); // Optimized lax version of grow phase. - for (Set cluster : new HashSet<>(purePentads)) { - Set _cluster = new HashSet<>(cluster); + for (List cluster : new HashSet<>(purePentads)) { + List _cluster = new ArrayList<>(cluster); for (int o : _variables) { if (_cluster.contains(o)) continue; @@ -413,7 +407,9 @@ private Set> combinePurePentads(Set> purePentads, List List pentad = pentad(n1, n2, n3, n4, o); - Set t = new HashSet<>(pentad); + List t = new ArrayList<>(pentad); + + Collections.sort(t); if (!purePentads.contains(t)) { rejected++; @@ -429,7 +425,7 @@ private Set> combinePurePentads(Set> purePentads, List _cluster.add(o); } - for (Set c : new HashSet<>(purePentads)) { + for (List c : new HashSet<>(purePentads)) { if (_cluster.containsAll(c)) { purePentads.remove(c); } @@ -444,8 +440,8 @@ private Set> combinePurePentads(Set> purePentads, List } // Strict grow phase. - if (false) { - Set t = new HashSet<>(); + if (true) { + List t = new ArrayList<>(); int count = 0; int total = purePentads.size(); @@ -454,8 +450,8 @@ private Set> combinePurePentads(Set> purePentads, List break; } - Set cluster = purePentads.iterator().next(); - Set _cluster = new HashSet<>(cluster); + List cluster = purePentads.iterator().next(); + List _cluster = new ArrayList<>(cluster); VARIABLES: for (int o : _variables) { @@ -479,6 +475,8 @@ private Set> combinePurePentads(Set> purePentads, List t.add(n4); t.add(o); + Collections.sort(t); + if (!purePentads.contains(t)) { continue VARIABLES; } @@ -487,6 +485,18 @@ private Set> combinePurePentads(Set> purePentads, List _cluster.add(o); } +// for (Set c : new HashSet<>(purePentads)) { +//// for (Integer d : c) { +//// if (_cluster.contains(d)) { +//// purePentads.remove(c); +//// } +//// } +// +// if (_cluster.containsAll(c)) { +// purePentads.remove(c); +// } +// } + // This takes out all pure clusters that are subsets of _cluster. ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 5); int[] choice2; @@ -506,85 +516,23 @@ private Set> combinePurePentads(Set> purePentads, List t.add(n4); t.add(n5); + Collections.sort(t); + purePentads.remove(t); } if (verbose) { System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster); } + grown.add(_cluster); } while (!purePentads.isEmpty()); } - if (false) { - System.out.println("# pure pentads = " + purePentads.size()); - - List> clusters = new LinkedList<>(purePentads); - Set t = new HashSet<>(); - - I: - for (int i = 0; i < clusters.size(); i++) { - System.out.println("I = " + i); - - // remove "i" clusters that intersect with previous clusters. - for (int k = 0; k < i - 1; k++) { - Set ck = clusters.get(k); - Set ci = clusters.get(i); - - if (ck == null) continue; - if (ci == null) continue; - - Set cm = new HashSet(ck); - cm.retainAll(ci); - - if (!cm.isEmpty()) { - clusters.remove(i); - i--; - continue I; - } - } - - J: - for (int j = i + 1; j < clusters.size(); j++) { - Set ci = clusters.get(i); - Set cj = clusters.get(j); - - if (ci == null) continue; - if (cj == null) continue; - - Set ck = new HashSet<>(ci); - ck.addAll(cj); - - List cm = new ArrayList<>(ck); - - ChoiceGenerator gen = new ChoiceGenerator(cm.size(), 5); - int[] choice; - - while ((choice = gen.next()) != null) { - t.clear(); - t.add(cm.get(choice[0])); - t.add(cm.get(choice[1])); - t.add(cm.get(choice[2])); - - if (!purePentads.contains(t)) { - continue J; - } - } - - clusters.set(i, ck); - clusters.remove(j); - j--; - System.out.println("Removing " + ci + ", " + cj + ", adding " + ck); - } - } - - grown = new HashSet<>(clusters); - } - // Optimized pick phase. log("Choosing among grown clusters.", true); - for (Set l : grown) { + for (List l : grown) { ArrayList _l = new ArrayList(l); Collections.sort(_l); if (verbose) { @@ -592,50 +540,32 @@ private Set> combinePurePentads(Set> purePentads, List } } - Set> out = new HashSet<>(); + Set> out = new HashSet<>(); - List> list = new ArrayList<>(grown); + List> list = new ArrayList<>(grown); - while (!list.isEmpty()) { - Collections.sort(list, new Comparator>() { - @Override - public int compare(Set o1, Set o2) { - return o2.size() - o1.size(); - } - }); + Collections.sort(list, new Comparator>() { + @Override + public int compare(List o1, List o2) { + return o2.size() - o1.size(); + } + }); - Set first = list.get(0); - out.add(first); - list.remove(first); + List all = new ArrayList<>(); - for (Set s : new ArrayList<>(list)) { - s.removeAll(first); - if (s.size() < 5) list.remove(s); + CLUSTER: + for (List cluster : list) { + for (Integer i : cluster) { + if (all.contains(i)) continue CLUSTER; } - } -// Collections.sort(list, new Comparator>() { -// @Override -// public int compare(Set o1, Set o2) { -// return o2.size() - o1.size(); -// } -// }); -// -// Set all = new HashSet<>(); -// -// CLUSTER: -// for (Set cluster : list) { -// for (Integer i : cluster) { -// if (all.contains(i)) continue CLUSTER; -// } -// -// out.add(cluster); -// all.addAll(cluster); -// } + out.add(cluster); + all.addAll(cluster); + } boolean significanceCalculated = false; if (significanceCalculated) { - for (Set _out : out) { + for (List _out : out) { try { double p = significance(new ArrayList<>(_out)); log("OUT: " + variablesForIndices(new ArrayList<>(_out)) + " p = " + p, true); @@ -644,20 +574,45 @@ public int compare(Set o1, Set o2) { } } } else { - for (Set _out : out) { + for (List _out : out) { log("OUT: " + variablesForIndices(new ArrayList<>(_out)), true); } } +// C: +// for (List cluster : new HashSet<>(out)) { +// if (cluster.size() >= 6) { +// ChoiceGenerator gen = new ChoiceGenerator(cluster.size(), 6); +// int[] choice; +// +// while ((choice = gen.next()) != null) { +// int n1 = cluster.get(choice[0]); +// int n2 = cluster.get(choice[1]); +// int n3 = cluster.get(choice[2]); +// int n4 = cluster.get(choice[3]); +// int n5 = cluster.get(choice[4]); +// int n6 = cluster.get(choice[5]); +// +// List _cluster = sextet(n1, n2, n3, n4, n5, n6); +// +// // Note that purity needs to be assessed with respect to all of the variables in order to +// // remove all latent-measure impurities between pairs of latents. +// if (!pure(_cluster)) { +// out.remove(cluster); +// continue C; +// } +// } +// } +// } + return out; } // Finds clusters of size 6 or higher for the IntSextad first algorithm. private Set> findPureClusters(List _variables) { Set> clusters = new HashSet<>(); - List allVariables = allVariables(); - for (int k = 7; k >= 6; k--) { + for (int k = 6; k >= 6; k--) { VARIABLES: while (!_variables.isEmpty()) { if (verbose) { @@ -693,6 +648,7 @@ private Set> findPureClusters(List _variables) { log("Cluster found: " + variablesForIndices(cluster), true); System.out.println("Indices for cluster = " + cluster); } + clusters.add(cluster); _variables.removeAll(cluster); @@ -702,31 +658,58 @@ private Set> findPureClusters(List _variables) { break; } + +// C: +// for (List cluster : new HashSet<>(clusters)) { +// if (cluster.size() >= 6) { +// ChoiceGenerator gen = new ChoiceGenerator(cluster.size(), 6); +// int[] choice; +// +// while ((choice = gen.next()) != null) { +// int n1 = cluster.get(choice[0]); +// int n2 = cluster.get(choice[1]); +// int n3 = cluster.get(choice[2]); +// int n4 = cluster.get(choice[3]); +// int n5 = cluster.get(choice[4]); +// int n6 = cluster.get(choice[5]); +// +// List _cluster = sextet(n1, n2, n3, n4, n5, n6); +// +// // Note that purity needs to be assessed with respect to all of the variables in order to +// // remove all latent-measure impurities between pairs of latents. +// if (!pure(_cluster)) { +// clusters.remove(cluster); +// continue C; +// } +// } +// } +// } } return clusters; } private void addOtherVariables(List _variables, List cluster) { + O: for (int o : _variables) { if (cluster.contains(o)) continue; - List _cluster = new ArrayList(cluster); + List _cluster = new ArrayList<>(cluster); - ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 5); - int[] choice2; + ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 6); + int[] choice; - while ((choice2 = gen2.next()) != null) { - int t1 = _cluster.get(choice2[0]); - int t2 = _cluster.get(choice2[1]); - int t3 = _cluster.get(choice2[2]); - int t4 = _cluster.get(choice2[3]); - int t5 = _cluster.get(choice2[4]); + while ((choice = gen2.next()) != null) { + int t1 = _cluster.get(choice[0]); + int t2 = _cluster.get(choice[1]); + int t3 = _cluster.get(choice[2]); + int t4 = _cluster.get(choice[3]); + int t5 = _cluster.get(choice[4]); - List quartet = pentad(t1, t2, t3, t4, t5); - quartet.add(o); + List sextad = pentad(t1, t2, t3, t4, t5); + sextad.add(o); - if (!pure(quartet)) { + if (!pure(sextad)) { continue O; } } @@ -736,7 +719,7 @@ private void addOtherVariables(List _variables, List cluster) } } - // Finds clusters of size 5 for the quartet first algorithm. + // Finds clusters of size 5 for the sextet first algorithm. private Set> findMixedClusters(Set> clusters, List remaining, Set unionPure) { Set> pentads = new HashSet<>(); Set> _clusters = new HashSet<>(clusters); @@ -865,11 +848,11 @@ private boolean pure(List sextet) { if (sextet.contains(o)) continue; for (int i = 0; i < sextet.size(); i++) { - List _quartet = new ArrayList<>(sextet); - _quartet.remove(sextet.get(i)); - _quartet.add(i, o); + List _sextet = new ArrayList<>(sextet); + _sextet.remove(sextet.get(i)); + _sextet.add(i, o); - if (!(vanishes(_quartet))) { + if (!(vanishes(_sextet))) { return false; } } @@ -888,7 +871,7 @@ private double getClusterChiSquare(List cluster) { return im.getChiSquare(); } - private SemIm estimateClusterModel(List quartet) { + private SemIm estimateClusterModel(List sextet) { Graph g = new EdgeListGraph(); Node l1 = new GraphNode("L1"); l1.setNodeType(NodeType.LATENT); @@ -897,7 +880,7 @@ private SemIm estimateClusterModel(List quartet) { g.addNode(l1); g.addNode(l2); - for (Integer aQuartet : quartet) { + for (Integer aQuartet : sextet) { Node n = this.variables.get(aQuartet); g.addNode(n); g.addDirectedEdge(l1, n); @@ -1016,6 +999,23 @@ private List pentad(int n1, int n2, int n3, int n4, int n5) { } private boolean vanishes(List sextet) { + + PermutationGenerator gen = new PermutationGenerator(6); + int[] perm; + +// while ((perm = gen.next()) != null) { +// int n1 = sextet.get(perm[0]); +// int n2 = sextet.get(perm[1]); +// int n3 = sextet.get(perm[2]); +// int n4 = sextet.get(perm[3]); +// int n5 = sextet.get(perm[4]); +// int n6 = sextet.get(perm[5); +// +// if (!vanishes(n1, n2, n3, n4, n5, n6)) return false; +// } +// +// return true; + int n1 = sextet.get(0); int n2 = sextet.get(1); int n3 = sextet.get(2); @@ -1023,7 +1023,10 @@ private boolean vanishes(List sextet) { int n5 = sextet.get(4); int n6 = sextet.get(5); - return vanishes(n1, n2, n3, n4, n5, n6); + return vanishes(n1, n2, n3, n4, n5, n6) + && vanishes(n3, n2, n1, n6, n5, n4) + && vanishes(n4, n5, n6, n1, n2, n3) + && vanishes(n6, n5, n4, n3, n2, n1); } private boolean zeroCorr(List cluster, int n) { From 05ed507c99e3a341044de6329c212cae2a783b1a Mon Sep 17 00:00:00 2001 From: Joseph Ramsey Date: Wed, 6 Jan 2016 16:53:40 -0500 Subject: [PATCH 2/2] Changed name of setNumProcessors method to setParallelism in FGS. Fixed the zero degrees of freedom problem in the delta sextad test. --- .../src/main/java/edu/cmu/tetrad/search/DeltaSextadTest.java | 4 ++-- .../test/java/edu/cmu/tetrad/test/TestDeltaSextadTest.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DeltaSextadTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DeltaSextadTest.java index c22ade5900..f8762ad67d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DeltaSextadTest.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/DeltaSextadTest.java @@ -364,13 +364,13 @@ private double r(double array1[], double array2[], int N) { private int dofDrton(int n) { int dof = ((n - 2) * (n - 3)) / 2 - 2; - if (dof < 0) dof = 0; + if (dof < 1) dof = 1; return dof; } private int dofHarman(int n) { int dof = n * (n - 5) / 2 + 1; - if (dof < 0) dof = 0; + if (dof < 1) dof = 1; return dof; } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDeltaSextadTest.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDeltaSextadTest.java index 85402f9960..207f8d0c24 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDeltaSextadTest.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDeltaSextadTest.java @@ -144,7 +144,7 @@ public void testBollenExampleb() { IntSextad[] _sextads = {t2, t5, t10, t3, t6}; double p = test.getPValue(_sextads); - assertEquals(0.90, p, 0.01); + assertEquals(0.21, p, 0.01); _sextads = new IntSextad[] {t10}; p = test.getPValue(_sextads);