diff --git a/docs/manual/flowchart.html b/docs/manual/flowchart.html new file mode 100644 index 0000000000..2cee6d72bf --- /dev/null +++ b/docs/manual/flowchart.html @@ -0,0 +1,13 @@ + + + + Redirecting... + + + +

If you are not redirected automatically, follow this link + to the new page.

+ + diff --git a/tetrad-gui/dependency-reduced-pom.xml b/tetrad-gui/dependency-reduced-pom.xml new file mode 100644 index 0000000000..b7acb089ca --- /dev/null +++ b/tetrad-gui/dependency-reduced-pom.xml @@ -0,0 +1,74 @@ + + + + tetrad + io.github.cmu-phil + 7.6.4-SNAPSHOT + + 4.0.0 + tetrad-gui + + + + org.apache.maven.wagon + wagon-ssh + 2.10 + + + + + true + src/main/resources + + resources/version + + + + src/main/resources + + resources/version + + + + + + maven-compiler-plugin + 3.11.0 + + 17 + 17 + + + + maven-shade-plugin + 3.5.1 + + + package + + shade + + + + + + edu.cmu.tetradapp.Tetrad + all-permissions + ${project.name} + ${project.version} + + + + true + launch + + + + + + + + 1.8 + UTF-8 + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/LaunchFlowchartAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/LaunchFlowchartAction.java index 3d8531f265..99c518c7c8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/LaunchFlowchartAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/LaunchFlowchartAction.java @@ -50,8 +50,8 @@ public LaunchFlowchartAction() { public void actionPerformed(ActionEvent e) { Desktop d = Desktop.getDesktop(); try { - d.browse(new URI("https://htmlpreview.github.io/?https://raw.githubusercontent.com/cmu-phil/tetrad/" + - "development/docs/manual/flowchart.html")); + d.browse(new URI("https://htmlpreview.github.io/?https:///github.com/cmu-phil/" + + "tetrad/blob/development/tetrad-lib/src/main/resources/docs/manual/flowchart.html")); } catch (IOException | URISyntaxException e2) { e2.printStackTrace(); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java index 4cf27cbd6d..915beb8598 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/app/TetradMenuBar.java @@ -186,7 +186,6 @@ private void buildEditMenu(JMenu editMenu) { editMenu.add(cut); editMenu.add(copy); editMenu.add(paste); - editMenu.addSeparator(); } /** @@ -241,13 +240,13 @@ public SuggestionDialog(JComponent parent, String url) { // Create a clickable link JLabel label = new JLabel("" + - "

Please submit any issues you may have,

" + - "

whether bug reports, general encouragement,

" + - "

or feature requests, to our issues list. We'd

" + - "

love to hear from you as we continue to

" + - "

improve the Tetrad tools!

" + - "

" + url + "
" + - ""); + "

Please submit any issues you may have,

" + + "

whether bug reports, general encouragement,

" + + "

or feature requests, to our issues list. We'd

" + + "

love to hear from you as we continue to

" + + "

improve the Tetrad tools!

" + + "

" + url + "
" + + ""); label.setCursor(Cursor.getPredefinedCursor(Cursor.HAND_CURSOR)); label.setFont(label.getFont().deriveFont(Font.PLAIN, 14)); label.addMouseListener(new MouseAdapter() { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AlgcomparisonEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AlgcomparisonEditor.java index 1f03d7672b..f7d7a535a4 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AlgcomparisonEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AlgcomparisonEditor.java @@ -757,8 +757,19 @@ public static StringTextField getStringField(String parameter, Parameters parame * @throws IllegalAccessException If the graph or simulation constructor or class is inaccessible. */ @NotNull - private static edu.cmu.tetrad.algcomparison.simulation.Simulation getSimulation(Class graphClazz, Class simulationClazz) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { - RandomGraph randomGraph = graphClazz.getConstructor().newInstance(); + private edu.cmu.tetrad.algcomparison.simulation.Simulation getSimulation(Class graphClazz, Class simulationClazz) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { + RandomGraph randomGraph; + + if (graphClazz == SingleGraph.class) { + if (model.getSuppliedGraph() == null) { + throw new IllegalArgumentException("No graph supplied."); + } + + randomGraph = new SingleGraph(model.getSuppliedGraph()); + } else { + randomGraph = graphClazz.getConstructor().newInstance(); + } + return simulationClazz.getConstructor(RandomGraph.class).newInstance(randomGraph); } @@ -821,8 +832,12 @@ public static void scrollToWord(JTextArea textArea, JScrollPane scrollPane, Stri } @NotNull - private static Class getGraphClazz(String graphString) { - List graphTypeStrings = Arrays.asList(ParameterTab.GRAPH_TYPE_ITEMS); + private Class getGraphClazz(String graphString) { + List graphTypeStrings = new ArrayList<>(Arrays.asList(ParameterTab.GRAPH_TYPE_ITEMS)); + + if (model.getSuppliedGraph() != null) { + graphTypeStrings.add("User Supplied Graph"); + } return switch (graphTypeStrings.indexOf(graphString)) { case 0: @@ -831,12 +846,14 @@ private static Class getGraphClazz(String graphString) { yield ErdosRenyi.class; case 2: yield ScaleFree.class; - case 4: + case 3: yield Cyclic.class; - case 5: + case 4: yield RandomSingleFactorMim.class; - case 6: + case 5: yield RandomTwoFactorMim.class; + case 6: + yield SingleGraph.class; default: throw new IllegalArgumentException("Unexpected value: " + graphString); }; @@ -1441,6 +1458,11 @@ private void addAddSimulationListener() { JComboBox graphsDropdown = getGraphsDropdown(); Arrays.stream(ParameterTab.GRAPH_TYPE_ITEMS).forEach(graphsDropdown::addItem); + + if (model.getSuppliedGraph() != null) { + graphsDropdown.addItem("User Supplied Graph"); + } + graphsDropdown.setMaximumSize(graphsDropdown.getPreferredSize()); graphsDropdown.setSelectedItem(model.getLastGraphChoice()); @@ -1507,6 +1529,7 @@ private JComboBox getGraphsDropdown() { model.setLastGraphChoice(selectedItem); } }); + return graphsDropdown; } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AllPathsAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AllPathsAction.java index 7456efeb13..503034ae33 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AllPathsAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/AllPathsAction.java @@ -102,7 +102,7 @@ public void watch() { } private void addTreks(Node node1, Node node2, Graph graph, JTextArea textArea) { - List> treks = graph.paths().allPathsFromTo(node1, node2, 8); + List> treks = graph.paths().allPaths(node1, node2, 8); if (treks.isEmpty()) { return; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ApplyFinalFciRules.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ApplyFinalFciRules.java new file mode 100644 index 0000000000..e823ee4757 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ApplyFinalFciRules.java @@ -0,0 +1,95 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.utils.DagSepsets; +import edu.cmu.tetrad.search.utils.FciOrient; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * This class represents an action to run the final FCI (Fast Causal Inference) rules on a graph in a GraphWorkbench. + * It extends the AbstractAction class and implements the ClipboardOwner interface. + */ +public class ApplyFinalFciRules extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Runs the final FCI (Fast Causal Inference) rules on a graph in a GraphWorkbench. + * This action is triggered by clicking a button or selecting a menu option. + * + * @param workbench the GraphWorkbench instance containing the graph to run final FCI rules on. + * @throws NullPointerException if workbench is null. + */ + public ApplyFinalFciRules(GraphWorkbench workbench) { + super("Apply Final FCI Rules"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Performs an action when an event occurs. + * + * @param e the event that triggered the action. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + Graph graph = this.workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(this.workbench, "No graph to apply final FCI rules to."); + return; + } + + Graph __g = new EdgeListGraph(graph); + FciOrient finalFciRules = new FciOrient(new DagSepsets(__g)); + finalFciRules.zhangFinalOrientation(__g); + workbench.setGraph(__g); + } + + /** + * Called when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost ownership + * @param contents the contents that were lost + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ApplyMeekRules.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ApplyMeekRules.java new file mode 100644 index 0000000000..a5f533bdaf --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ApplyMeekRules.java @@ -0,0 +1,105 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.utils.MeekRules; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Selects all directed edges in the given display graph. + * + * @author josephramsey + * @version $Id: $Id + */ +public class ApplyMeekRules extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public ApplyMeekRules(GraphWorkbench workbench) { + super("Apply Meek Rules"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * {@inheritDoc} + *

+ * Selects all directed edges in the given display graph. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + Graph graph = this.workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(this.workbench, "No graph to apply Meek rules to."); + return; + } + + // check to make sure the edges in the graph are all directed or undirected + for (Edge edge : graph.getEdges()) { + if (!Edges.isDirectedEdge(edge) && !Edges.isUndirectedEdge(edge)) { + JOptionPane.showMessageDialog(this.workbench, + "To apply Meek rules, the graph must contain only directed or undirected edges."); + return; + } + } + + graph = new EdgeListGraph(graph); + MeekRules meekRules = new MeekRules(); + meekRules.setRevertToUnshieldedColliders(false); + meekRules.orientImplied(graph); + workbench.setGraph(graph); + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphFoDagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphFoDagAction.java new file mode 100644 index 0000000000..e52fa1c9f9 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphFoDagAction.java @@ -0,0 +1,80 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * CheckGraphForCpdagAction is an action class that checks if a given graph is a legal CPDAG + * (Completed Partially Directed Acyclic Graph) and displays a message to indicate the result. + */ +public class CheckGraphFoDagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Highlights all latent variables in the given display graph. + * + * @param workbench the given workbench. + */ + public CheckGraphFoDagAction(GraphWorkbench workbench) { + super("Check to see if Graph is a DAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is used to perform an action when an event is triggered, specifically when the user clicks on a button or menu item associated with it. It checks if a graph is + * a legal CPDAG (Completed Partially Directed Acyclic Graph). + * + * @param e The ActionEvent object that represents the event generated by the user action. + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to check for DAGness."); + return; + } + + if (graph.paths().isLegalDag()) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is a legal DAG."); + } else { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is not a legal DAG."); + } + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForCpdagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForCpdagAction.java new file mode 100644 index 0000000000..d0be2926a5 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForCpdagAction.java @@ -0,0 +1,80 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * CheckGraphForCpdagAction is an action class that checks if a given graph is a legal CPDAG + * (Completed Partially Directed Acyclic Graph) and displays a message to indicate the result. + */ +public class CheckGraphForCpdagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Highlights all latent variables in the given display graph. + * + * @param workbench the given workbench. + */ + public CheckGraphForCpdagAction(GraphWorkbench workbench) { + super("Check to see if Graph is a CPDAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is used to perform an action when an event is triggered, specifically when the user clicks on a button or menu item associated with it. It checks if a graph is + * a legal CPDAG (Completed Partially Directed Acyclic Graph). + * + * @param e The ActionEvent object that represents the event generated by the user action. + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to check for CPDAGness."); + return; + } + + if (graph.paths().isLegalCpdag()) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is a legal CPDAG."); + } else { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is not a legal CPDAG."); + } + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMagAction.java new file mode 100644 index 0000000000..7871e1ea20 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMagAction.java @@ -0,0 +1,109 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.utils.GraphSearchUtils; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.util.WatchedProcess; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * CheckGraphForMpdagAction is an action class that checks if a given graph is a legal MPDAG (Mixed Ancestral Graph) and + * displays a message to indicate the result. + */ +public class CheckGraphForMagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + private volatile GraphSearchUtils.LegalMagRet legalMag = null; + + /** + * Highlights all latent variables in the given display graph. + * + * @param workbench the given workbench. + */ + public CheckGraphForMagAction(GraphWorkbench workbench) { + super("Check to see if Graph is a MAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is used to perform an action when an event is triggered, specifically when the user clicks on a + * button or menu item associated with it. It checks if a graph is a legal MAG (Mixed Ancestral Graph). + * + * @param e The ActionEvent object that represents the event generated by the user action. + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to check for MAGness."); + return; + } + + class MyWatchedProcess extends WatchedProcess { + @Override + public void watch() { + Graph _graph = new EdgeListGraph(workbench.getGraph()); + legalMag = GraphSearchUtils.isLegalMag(_graph); + } + } + + new MyWatchedProcess(); + + while (legalMag == null) { + try { + Thread.sleep(100); // Sleep a bit to prevent tight loop + } catch (InterruptedException e2) { + Thread.currentThread().interrupt(); + } + } + + String reason = GraphUtils.breakDown(legalMag.getReason(), 60); + + if (!legalMag.isLegalMag()) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), + "This is not a legal MAG--one reason is as follows:" + + "\n\n" + reason + ".", + "Legal MAG check", + JOptionPane.WARNING_MESSAGE); + } else { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), reason); + } + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMpagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMpagAction.java new file mode 100644 index 0000000000..98fe9e62cb --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMpagAction.java @@ -0,0 +1,81 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * CheckGraphForMpdagAction is an action class that checks if a given graph is a legal MPDAG (Maximal Partially Directed + * Acyclic Graph) and displays a message to indicate the result. + */ +public class CheckGraphForMpagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Highlights all latent variables in the given display graph. + * + * @param workbench the given workbench. + */ + public CheckGraphForMpagAction(GraphWorkbench workbench) { + super("Check to see if Graph is a MPAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is used to perform an action when an event is triggered, specifically when the user clicks on a + * button or menu item associated with it. It checks if a graph is a legal MPDAG (Maximal Partially Directed + * Acyclic Graph). + * + * @param e The ActionEvent object that represents the event generated by the user action. + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to check for MPAGness."); + return; + } + + if (graph.paths().isLegalMpag()) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is a legal MPAG."); + } else { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is not a legal MPAG."); + } + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMpdagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMpdagAction.java new file mode 100644 index 0000000000..f845a5d4c4 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForMpdagAction.java @@ -0,0 +1,81 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * CheckGraphForMpdagAction is an action class that checks if a given graph is a legal MPDAG (Maximal Partially Directed + * Acyclic Graph) and displays a message to indicate the result. + */ +public class CheckGraphForMpdagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Highlights all latent variables in the given display graph. + * + * @param workbench the given workbench. + */ + public CheckGraphForMpdagAction(GraphWorkbench workbench) { + super("Check to see if Graph is a MPDAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is used to perform an action when an event is triggered, specifically when the user clicks on a + * button or menu item associated with it. It checks if a graph is a legal MPDAG (Maximal Partially Directed + * Acyclic Graph). + * + * @param e The ActionEvent object that represents the event generated by the user action. + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to check for MPDAGness."); + return; + } + + if (graph.paths().isLegalMpdag()) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is a legal MPDAG."); + } else { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "Graph is not a legal MPDAG."); + } + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForPagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForPagAction.java new file mode 100644 index 0000000000..734cffec31 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CheckGraphForPagAction.java @@ -0,0 +1,110 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.utils.GraphSearchUtils; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.util.WatchedProcess; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * CheckGraphForMpdagAction is an action class that checks if a given graph is a legal PAG (Mixed Ancesgral Graph) and + * displays a message to indicate the result. + */ +public class CheckGraphForPagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Highlights all latent variables in the given display graph. + * + * @param workbench the given workbench. + */ + public CheckGraphForPagAction(GraphWorkbench workbench) { + super("Check to see if Graph is a PAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + private volatile GraphSearchUtils.LegalPagRet legalPag = null; + + /** + * This method is used to perform an action when an event is triggered, specifically when the user clicks on a + * button or menu item associated with it. It checks if a graph is a legal DAG (Partial Ancestral Graph). + * + * @param e The ActionEvent object that represents the event generated by the user action. + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to check for PAGness."); + return; + } + + class MyWatchedProcess extends WatchedProcess { + @Override + public void watch() { + Graph _graph = new EdgeListGraph(workbench.getGraph()); + legalPag = GraphSearchUtils.isLegalPag(_graph); + } + } + + new MyWatchedProcess(); + + while (legalPag == null) { + try { + Thread.sleep(100); // Sleep a bit to prevent tight loop + } catch (InterruptedException e2) { + Thread.currentThread().interrupt(); + } + } + + String reason = GraphUtils.breakDown(legalPag.getReason(), 60); + + if (!legalPag.isLegalPag()) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), + "This is not a legal PAG--one reason is as follows:" + + "\n\n" + reason + ".", + "Legal PAG check", + JOptionPane.WARNING_MESSAGE); + } else { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), reason); + } + } + +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CopySubgraphAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CopySubgraphAction.java index 62720fc192..77332a6997 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CopySubgraphAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CopySubgraphAction.java @@ -49,7 +49,7 @@ public class CopySubgraphAction extends AbstractAction implements ClipboardOwner * @param graphEditor a {@link edu.cmu.tetradapp.editor.GraphEditable} object */ public CopySubgraphAction(GraphEditable graphEditor) { - super("Copy Selected Graph"); + super("Copy Selected Items"); if (graphEditor == null) { throw new NullPointerException("Desktop must not be null."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CutSubgraphAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CutSubgraphAction.java new file mode 100644 index 0000000000..2508d6e28c --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/CutSubgraphAction.java @@ -0,0 +1,85 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetradapp.util.InternalClipboard; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; +import java.util.List; + +/** + * Copies a selection of session nodes in the frontmost session editor, to the clipboard. + * + * @author josephramsey + * @version $Id: $Id + */ +public class CutSubgraphAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphEditable graphEditor; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param graphEditor a {@link GraphEditable} object + */ + public CutSubgraphAction(GraphEditable graphEditor) { + super("Cut Selected Items"); + + if (graphEditor == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.graphEditor = graphEditor; + } + + /** + * {@inheritDoc} + *

+ * Copies a parentally closed selection of session nodes in the frontmost session editor to the clipboard. + */ + public void actionPerformed(ActionEvent e) { + List modelComponents = this.graphEditor.getSelectedModelComponents(); + SubgraphSelection selection = new SubgraphSelection(modelComponents); + InternalClipboard.getInstance().setContents(selection, this); + graphEditor.getWorkbench().deleteSelectedObjects(); + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java index c0e14fb7ec..6ed64e831f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/DagEditor.java @@ -31,6 +31,7 @@ import edu.cmu.tetradapp.session.DelegatesEditing; import edu.cmu.tetradapp.ui.PaddingPanel; import edu.cmu.tetradapp.util.DesktopController; +import edu.cmu.tetradapp.util.GraphUtils; import edu.cmu.tetradapp.util.LayoutEditable; import edu.cmu.tetradapp.workbench.DisplayEdge; import edu.cmu.tetradapp.workbench.DisplayNode; @@ -451,16 +452,34 @@ private JMenu createEditMenu() { JMenu edit = new JMenu("Edit"); + JMenuItem cut = new JMenuItem(new CutSubgraphAction(this)); JMenuItem copy = new JMenuItem(new CopySubgraphAction(this)); JMenuItem paste = new JMenuItem(new PasteSubgraphAction(this)); + JMenuItem undoLast = new JMenuItem(new UndoLastAction(workbench)); + JMenuItem redoLast = new JMenuItem(new RedoLastAction(workbench)); + JMenuItem setToOriginal = new JMenuItem(new ResetGraph(workbench)); + cut.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_X, InputEvent.CTRL_DOWN_MASK)); copy.setAccelerator( KeyStroke.getKeyStroke(KeyEvent.VK_C, InputEvent.CTRL_DOWN_MASK)); paste.setAccelerator( KeyStroke.getKeyStroke(KeyEvent.VK_V, InputEvent.CTRL_DOWN_MASK)); - + undoLast.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Z, InputEvent.CTRL_DOWN_MASK)); + redoLast.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Y, InputEvent.CTRL_DOWN_MASK)); + setToOriginal.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_R, InputEvent.CTRL_DOWN_MASK)); + + edit.add(cut); edit.add(copy); edit.add(paste); + edit.addSeparator(); + + edit.add(undoLast); + edit.add(redoLast); + edit.add(setToOriginal); return edit; } @@ -475,14 +494,26 @@ private JMenu createGraphMenu() { graph.add(new GraphPropertiesAction(this.workbench)); graph.add(new PathsAction(this.workbench)); graph.add(new UnderliningsAction(this.workbench)); + graph.addSeparator(); - graph.add(new JMenuItem(new SelectDirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectBidirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectUndirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectTrianglesAction(this.workbench))); - graph.add(new JMenuItem(new SelectLatentsAction(this.workbench))); -// graph.add(new PagTypeSetter(getWorkbench())); - + graph.add(GraphUtils.getHighlightMenu(this.workbench)); + graph.add(GraphUtils.getCheckGraphMenu(this.workbench)); + +// JMenu revert = new JMenu("Revert Graph"); +// graph.add(revert); +// JMenuItem undoLast = new JMenuItem(new UndoLastAction(this.workbench)); +// JMenuItem redoLast = new JMenuItem(new RedoLastAction(this.workbench)); +// JMenuItem setToOriginal = new JMenuItem(new SetToOriginalAction(this.workbench)); +// revert.add(undoLast); +// revert.add(redoLast); +// revert.add(setToOriginal); + +// undoLast.setAccelerator( +// KeyStroke.getKeyStroke(KeyEvent.VK_Z, InputEvent.CTRL_DOWN_MASK)); +// redoLast.setAccelerator( +// KeyStroke.getKeyStroke(KeyEvent.VK_Y, InputEvent.CTRL_DOWN_MASK)); +// setToOriginal.setAccelerator( +// KeyStroke.getKeyStroke(KeyEvent.VK_O, InputEvent.CTRL_DOWN_MASK)); randomGraph.addActionListener(e -> { GraphParamsEditor editor = new GraphParamsEditor(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java index 1a5ae126fc..e49ad84cc3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphEditor.java @@ -24,13 +24,13 @@ import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.test.MsepTest; -import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.TetradSerializable; import edu.cmu.tetradapp.model.GraphWrapper; import edu.cmu.tetradapp.model.IndTestProducer; import edu.cmu.tetradapp.ui.PaddingPanel; import edu.cmu.tetradapp.util.DesktopController; +import edu.cmu.tetradapp.util.GraphUtils; import edu.cmu.tetradapp.util.LayoutEditable; import edu.cmu.tetradapp.workbench.DisplayEdge; import edu.cmu.tetradapp.workbench.DisplayNode; @@ -168,8 +168,7 @@ public void pasteSubsession(List sessionElements, Point upperLeft) { getWorkbench().deselectAll(); sessionElements.forEach(o -> { - if (o instanceof GraphNode) { - Node modelNode = (Node) o; + if (o instanceof GraphNode modelNode) { getWorkbench().selectNode(modelNode); } }); @@ -465,32 +464,6 @@ private JMenuBar createGraphMenuBar() { return menuBar; } - JMenuBar createGraphMenuBarNoEditing() { - JMenuBar menuBar = new JMenuBar(); - JMenu file = new JMenu("File"); - file.add(new SaveComponentImage(this.workbench, "Save Graph Image...")); - - menuBar.add(file); - - JMenu graph = new JMenu("Graph"); - - graph.add(new GraphPropertiesAction(this.workbench)); - graph.add(new PathsAction(this.workbench)); - graph.add(new UnderliningsAction(this.workbench)); - - graph.add(new JMenuItem(new SelectDirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectBidirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectUndirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectTrianglesAction(this.workbench))); - graph.add(new JMenuItem(new SelectLatentsAction(this.workbench))); -// graph.addSeparator(); - graph.add(new PagColorer(getWorkbench())); - - menuBar.add(graph); - - return menuBar; - } - /** * Creates the "file" menu, which allows the user to load, save, and post workbench models. @@ -500,16 +473,34 @@ JMenuBar createGraphMenuBarNoEditing() { private JMenu createEditMenu() { JMenu edit = new JMenu("Edit"); + JMenuItem cut = new JMenuItem(new CutSubgraphAction(this)); JMenuItem copy = new JMenuItem(new CopySubgraphAction(this)); JMenuItem paste = new JMenuItem(new PasteSubgraphAction(this)); + JMenuItem undoLast = new JMenuItem(new UndoLastAction(workbench)); + JMenuItem redoLast = new JMenuItem(new RedoLastAction(workbench)); + JMenuItem setToOriginal = new JMenuItem(new ResetGraph(workbench)); + cut.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_X, InputEvent.CTRL_DOWN_MASK)); copy.setAccelerator( KeyStroke.getKeyStroke(KeyEvent.VK_C, InputEvent.CTRL_DOWN_MASK)); paste.setAccelerator( KeyStroke.getKeyStroke(KeyEvent.VK_V, InputEvent.CTRL_DOWN_MASK)); - + undoLast.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Z, InputEvent.CTRL_DOWN_MASK)); + redoLast.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Y, InputEvent.CTRL_DOWN_MASK)); + setToOriginal.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_R, InputEvent.CTRL_DOWN_MASK)); + + edit.add(cut); edit.add(copy); edit.add(paste); + edit.addSeparator(); + + edit.add(undoLast); + edit.add(redoLast); + edit.add(setToOriginal); return edit; } @@ -525,28 +516,8 @@ private JMenu createGraphMenu() { graph.add(new GraphPropertiesAction(getWorkbench())); graph.add(new PathsAction(getWorkbench())); graph.add(new UnderliningsAction(getWorkbench())); - - graph.addSeparator(); - - JMenuItem correlateExogenous = new JMenuItem("Correlate Exogenous Variables"); - JMenuItem uncorrelateExogenous = new JMenuItem("Uncorrelate Exogenous Variables"); - graph.add(correlateExogenous); - graph.add(uncorrelateExogenous); graph.addSeparator(); - correlateExogenous.addActionListener(e -> { - correlateExogenousVariables(); - getWorkbench().invalidate(); - getWorkbench().repaint(); - }); - - uncorrelateExogenous.addActionListener(e -> { - uncorrelationExogenousVariables(); - getWorkbench().invalidate(); - getWorkbench().repaint(); - }); - - randomGraph.addActionListener(e -> { GraphParamsEditor editor = new GraphParamsEditor(); editor.setParams(this.parameters); @@ -583,12 +554,11 @@ public void internalFrameClosed(InternalFrameEvent e1) { }); }); - graph.add(new JMenuItem(new SelectDirectedAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectBidirectedAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectTrianglesAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectLatentsAction(getWorkbench()))); - graph.add(new PagColorer(getWorkbench())); + graph.add(GraphUtils.getHighlightMenu(this.workbench)); + graph.add(GraphUtils.getCheckGraphMenu(this.workbench)); + GraphUtils.addGraphManipItems(graph, this.workbench); + graph.addSeparator(); + graph.add(GraphUtils.addPagEdgeSpecializationsItems(this.workbench)); // Only show these menu options for graph that has interventional nodes - Zhou if (isHasInterventional()) { @@ -596,66 +566,11 @@ public void internalFrameClosed(InternalFrameEvent e1) { graph.add(new JMenuItem(new HideShowInterventionalAction(getWorkbench()))); } - graph.addSeparator(); - graph.add(new JMenuItem(new HideShowNoConnectionNodesAction(getWorkbench()))); +// graph.add(new JMenuItem(new HideShowNoConnectionNodesAction(getWorkbench()))); return graph; } - private void correlateExogenousVariables() { - Graph graph = getWorkbench().getGraph(); - - if (graph instanceof Dag) { - JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), - "Cannot add bidirected edges to DAG's."); - return; - } - - List nodes = graph.getNodes(); - - List exoNodes = new LinkedList<>(); - - for (Node node : nodes) { - if (graph.isExogenous(node)) { - exoNodes.add(node); - } - } - - for (int i = 0; i < exoNodes.size(); i++) { - - loop: - for (int j = i + 1; j < exoNodes.size(); j++) { - Node node1 = exoNodes.get(i); - Node node2 = exoNodes.get(j); - List edges = graph.getEdges(node1, node2); - - for (Edge edge : edges) { - if (Edges.isBidirectedEdge(edge)) { - continue loop; - } - } - - graph.addBidirectedEdge(node1, node2); - } - } - } - - private void uncorrelationExogenousVariables() { - Graph graph = getWorkbench().getGraph(); - - Set edges = graph.getEdges(); - - for (Edge edge : edges) { - if (Edges.isBidirectedEdge(edge)) { - try { - graph.removeEdge(edge); - } catch (Exception e) { - // Ignore. - } - } - } - } - /** * {@inheritDoc} */ diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphFileMenu.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphFileMenu.java index ad4893ba4d..a1df25daa8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphFileMenu.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphFileMenu.java @@ -49,8 +49,8 @@ public GraphFileMenu(GraphEditable editable, JComponent comp, boolean saveOnly) JMenu load = new JMenu("Load..."); add(load); - load.add(new LoadGraph(editable, "XML...")); load.add(new LoadGraphTxt(editable, "Text...")); + load.add(new LoadGraph(editable, "XML...")); load.add(new LoadGraphJson(editable, "Json...")); load.add(new LoadGraphAmatCpdag(editable, "amat.cpdag...")); load.add(new LoadGraphAmatPag(editable, "amat.pag...")); @@ -59,8 +59,8 @@ public GraphFileMenu(GraphEditable editable, JComponent comp, boolean saveOnly) JMenu save = new JMenu("Save..."); add(save); - save.add(new SaveGraph(editable, "XML...", SaveGraph.Type.xml)); save.add(new SaveGraph(editable, "Text...", SaveGraph.Type.text)); + save.add(new SaveGraph(editable, "XML...", SaveGraph.Type.xml)); save.add(new SaveGraph(editable, "Json...", SaveGraph.Type.json)); save.add(new SaveGraph(editable, "R...", SaveGraph.Type.r)); save.add(new SaveGraph(editable, "Dot...", SaveGraph.Type.dot)); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphParamsEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphParamsEditor.java index 7bd4b4f981..e1a1420da2 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphParamsEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphParamsEditor.java @@ -91,10 +91,10 @@ public void setup() { tabs.add("MIM", randomMimEditor); tabs.add("Scale Free", randomScaleFreeEditor); - String type = this.params.getString("randomGraphType", "Uniform"); + String type = this.params.getString("randomGraphType", "Dag"); switch (type) { - case "Uniform": + case "Dag": tabs.setSelectedIndex(0); break; case "Mim": @@ -111,7 +111,7 @@ public void setup() { JTabbedPane pane = (JTabbedPane) changeEvent.getSource(); if (pane.getSelectedIndex() == 0) { - GraphParamsEditor.this.params.set("randomGraphType", "Uniform"); + GraphParamsEditor.this.params.set("randomGraphType", "Dag"); } else if (pane.getSelectedIndex() == 1) { GraphParamsEditor.this.params.set("randomGraphType", "Mim"); } else if (pane.getSelectedIndex() == 2) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java index 8cc9b31c06..8dbd3db304 100755 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GraphSelectionEditor.java @@ -409,8 +409,8 @@ private void tabbedPaneGraphs(GraphSelectionWrapper wrapper) { private JMenu createSaveMenu(GraphEditable editable) { JMenu save = new JMenu("Save As"); - save.add(new SaveGraph(editable, "Graph XML...", SaveGraph.Type.xml)); save.add(new SaveGraph(editable, "Graph Text...", SaveGraph.Type.text)); + save.add(new SaveGraph(editable, "Graph XML...", SaveGraph.Type.xml)); save.add(new SaveGraph(editable, "Graph Json...", SaveGraph.Type.json)); save.add(new SaveGraph(editable, "R...", SaveGraph.Type.r)); save.add(new SaveGraph(editable, "Dot...", SaveGraph.Type.dot)); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/HideShowNoConnectionNodesAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/HideShowNoConnectionNodesAction.java index 389bfef3b0..fd6364868c 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/HideShowNoConnectionNodesAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/HideShowNoConnectionNodesAction.java @@ -35,7 +35,7 @@ public class HideShowNoConnectionNodesAction extends AbstractAction implements C * @param workbench a {@link edu.cmu.tetradapp.workbench.GraphWorkbench} object */ public HideShowNoConnectionNodesAction(GraphWorkbench workbench) { - super("Hide/Show No Connections Node"); + super("Hide/Show nodes with no connections"); if (workbench == null) { throw new NullPointerException("Desktop must not be null."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LoadGraphAmatPag.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LoadGraphAmatPag.java index 68a2bb51a8..6c4394ec3e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LoadGraphAmatPag.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/LoadGraphAmatPag.java @@ -39,15 +39,17 @@ class LoadGraphAmatPag extends AbstractAction { /** - * The component whose image is to be saved. + * The {@code GraphEditable} variable represents an interface for graph editors. It is used to load a graph into a + * {@code GraphEditable} object. The variable is of type {@code GraphEditable} and is final, meaning it cannot be + * reassigned once initialized. */ private final GraphEditable graphEditable; /** - *

Constructor for LoadGraphPcalg.

+ * Loads a graph in the "amat.pag" format used by PCALG. * - * @param graphEditable a {@link GraphEditable} object - * @param title a {@link String} object + * @param graphEditable The GraphEditable object to load the graph into. + * @param title The title of the action. */ public LoadGraphAmatPag(GraphEditable graphEditable, String title) { super(title); @@ -59,6 +61,11 @@ public LoadGraphAmatPag(GraphEditable graphEditable, String title) { this.graphEditable = graphEditable; } + /** + * Returns a JFileChooser object with specific configurations. + * + * @return a JFileChooser object + */ private static JFileChooser getJFileChooser() { JFileChooser chooser = new JFileChooser(); String sessionSaveLocation = @@ -70,9 +77,9 @@ private static JFileChooser getJFileChooser() { } /** - * {@inheritDoc} - *

- * Performs the action of loading a session from a file. + * Performs an action in response to an event. + * + * @param e the ActionEvent that triggered the action */ public void actionPerformed(ActionEvent e) { JFileChooser chooser = LoadGraphAmatPag.getJFileChooser(); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagColorer.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagEdgeSpecialization.java similarity index 58% rename from tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagColorer.java rename to tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagEdgeSpecialization.java index f2a4d44f2e..35cef63c41 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagColorer.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagEdgeSpecialization.java @@ -21,29 +21,25 @@ package edu.cmu.tetradapp.editor; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.search.utils.GraphSearchUtils; -import edu.cmu.tetradapp.util.WatchedProcess; import edu.cmu.tetradapp.workbench.GraphWorkbench; import javax.swing.*; /** - * Colors a graph using the PAG coloring. Optionally checks to make sure it's legal PAG. + * Markos up a graph using the PAG edge specialization algorithm. * * @author josephramsey * @version $Id: $Id */ -public class PagColorer extends JCheckBoxMenuItem { +public class PagEdgeSpecialization extends JCheckBoxMenuItem { /** * Creates a new copy subsession action for the given desktop and clipboard. * * @param workbench a {@link edu.cmu.tetradapp.workbench.GraphWorkbench} object */ - public PagColorer(GraphWorkbench workbench) { - super("Add/Remove PAG Coloring"); + public PagEdgeSpecialization(GraphWorkbench workbench) { + super("Add/Remove PAG Specialization Markups"); if (workbench == null) { throw new NullPointerException("Desktop must not be null."); @@ -51,40 +47,11 @@ public PagColorer(GraphWorkbench workbench) { final GraphWorkbench _workbench = workbench; - _workbench.setDoPagColoring(workbench.isDoPagColoring()); - setSelected(workbench.isDoPagColoring()); + _workbench.markPagEdgeSpecializations(workbench.isPagEdgeSpecializationMarked()); + setSelected(workbench.isPagEdgeSpecializationMarked()); addItemListener(e -> { - _workbench.setDoPagColoring(isSelected()); - - if (isSelected()) { - int ret = JOptionPane.showConfirmDialog(workbench, - breakDown("Would you like to verify that this is a legal PAG?", 60), - "Legal PAG check", JOptionPane.YES_NO_OPTION, JOptionPane.WARNING_MESSAGE); - if (ret == JOptionPane.YES_OPTION) { - class MyWatchedProcess extends WatchedProcess { - @Override - public void watch() { - Graph graph = new EdgeListGraph(workbench.getGraph()); - - GraphSearchUtils.LegalPagRet legalPag = GraphSearchUtils.isLegalPag(graph); - String reason = breakDown(legalPag.getReason(), 60); - - if (!legalPag.isLegalPag()) { - JOptionPane.showMessageDialog(workbench, - "This is not a legal PAG--one reason is as follows:" + - "\n\n" + reason + ".", - "Legal PAG check", - JOptionPane.WARNING_MESSAGE); - } else { - JOptionPane.showMessageDialog(workbench, reason); - } - } - } - - new MyWatchedProcess(); - } - } + _workbench.markPagEdgeSpecializations(isSelected()); }); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagEdgeTypeInstructions.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagEdgeTypeInstructions.java new file mode 100644 index 0000000000..632a1a4380 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PagEdgeTypeInstructions.java @@ -0,0 +1,86 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.help.CSH; +import javax.help.HelpBroker; +import javax.help.HelpSet; +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; +import java.net.URL; + +/** + * Represents an action to display PAG Edge Type Instructions in a GraphWorkbench. This class extends AbstractAction and + * implements ClipboardOwner. + */ +public class PagEdgeTypeInstructions extends AbstractAction implements ClipboardOwner { + + /** + * Represents an action to display PAG Edge Type Instructions in a GraphWorkbench. + */ + public PagEdgeTypeInstructions() { + super("PAG Edge Type Instructions"); + } + + /** + * Performs an action when an event occurs. + * + * @param e the event that triggered the action. + */ + public void actionPerformed(ActionEvent e) { + // Initialize helpSet + final String helpHS = "/docs/javahelp/TetradHelp.hs"; + + try { + URL url = this.getClass().getResource(helpHS); + HelpSet helpSet = new HelpSet(null, url); + helpSet.setHomeID("graph_edge_types"); + HelpBroker broker = helpSet.createHelpBroker(); + broker.setCurrentView("Index"); + ActionListener listener = new CSH.DisplayHelpFromSource(broker); + listener.actionPerformed(e); + } catch (Exception ee) { + System.out.println("HelpSet " + ee.getMessage()); + System.out.println("HelpSet " + helpHS + " not found"); + throw new IllegalArgumentException(); + } + + } + + /** + * Called when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost ownership + * @param contents the contents that were lost + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PasteSubgraphAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PasteSubgraphAction.java index be4afc5cb5..54f59f8856 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PasteSubgraphAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PasteSubgraphAction.java @@ -50,7 +50,7 @@ class PasteSubgraphAction extends AbstractAction implements ClipboardOwner { * @param graphEditor a {@link edu.cmu.tetradapp.editor.GraphEditable} object */ public PasteSubgraphAction(GraphEditable graphEditor) { - super("Paste Selected Graph"); + super("Paste Selected Items"); if (graphEditor == null) { throw new NullPointerException("Desktop must not be null."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PathsAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PathsAction.java index b9136d1683..6ea7d02a37 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PathsAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PathsAction.java @@ -234,7 +234,7 @@ private void allDirectedPaths(Graph graph, JTextArea textArea, List nodes1 for (Node node1 : nodes1) { for (Node node2 : nodes2) { - List> paths = graph.paths().directedPathsFromTo(node1, node2, + List> paths = graph.paths().directedPaths(node1, node2, Preferences.userRoot().getInt("pathMaxLength", 3)); if (paths.isEmpty()) { @@ -261,7 +261,7 @@ private void allSemidirectedPaths(Graph graph, JTextArea textArea, List no for (Node node1 : nodes1) { for (Node node2 : nodes2) { - List> paths = graph.paths().semidirectedPathsFromTo(node1, node2, + List> paths = graph.paths().semidirectedPaths(node1, node2, Preferences.userRoot().getInt("pathMaxLength", 3)); if (paths.isEmpty()) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickRandomDagInCpdagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickRandomDagInCpdagAction.java new file mode 100644 index 0000000000..466288461e --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickRandomDagInCpdagAction.java @@ -0,0 +1,82 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * This class represents an action to convert a CPDAG (Completed Partially Directed Acyclic Graph) to a random DAG + * (Directed Acyclic Graph). + */ +public class PickRandomDagInCpdagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * This class represents an action to convert a CPDAG (Completed Partially Directed Acyclic Graph) to a random DAG + * (Directed Acyclic Graph). + */ + public PickRandomDagInCpdagAction(GraphWorkbench workbench) { + super("Pick Random DAG in CPDAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is called when the user performs an action to convert a CPDAG (Completed Partially Directed Acyclic + * Graph) to a random DAG (Directed Acyclic Graph). + * + * @param e the action event generated by the user's action + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to convert."); + return; + } + +// if (!graph.paths().isLegalMpdag()) { +// JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "I can only convert CPDAGs, or CPDAG with additional oriented edges, with Meek rules applied."); +// return; +// } + + graph = GraphTransforms.dagFromCpdag(graph); + workbench.setGraph(graph); + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickRandomMagInPagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickRandomMagInPagAction.java new file mode 100644 index 0000000000..8b87fb0714 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickRandomMagInPagAction.java @@ -0,0 +1,79 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * The PickRandomMagInPagAction class represents an action to pick a random MAG (Maximal Ancestral Graph) in PAG + * (Partially Directed Acyclic Graph). + */ +public class PickRandomMagInPagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * This class represents an action to pick a random MAG (Maximal Ancestral Graph) in PAG (Partially Directed Acyclic + * Graph). + * + * @param workbench the GraphWorkbench containing the target session editor (must not be null) + */ + public PickRandomMagInPagAction(GraphWorkbench workbench) { + super("Pick Random MAG in PAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is called when the user performs an action to convert a CPDAG (Completed Partially Directed Acyclic + * Graph) to a random DAG (Directed Acyclic Graph). + * + * @param e the ActionEvent that triggered the action + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to convert."); + return; + } + + graph = GraphTransforms.magFromPag(graph); + workbench.setGraph(graph); + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickZhangMagInPagAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickZhangMagInPagAction.java new file mode 100644 index 0000000000..389a5bd9b4 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/PickZhangMagInPagAction.java @@ -0,0 +1,80 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; +import edu.cmu.tetradapp.util.GraphUtils; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.event.ActionEvent; + +/** + * This class represents an action to convert a CPDAG (Completed Partially Directed Acyclic Graph) + * to a random DAG (Directed Acyclic Graph). + */ +public class PickZhangMagInPagAction extends AbstractAction { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /***/ + public PickZhangMagInPagAction(GraphWorkbench workbench) { + super("Pick Zhang MAG in PAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * This method is called when the user performs an action to convert a CPDAG (Completed Partially Directed Acyclic Graph) + * to a random DAG (Directed Acyclic Graph). + * + * @param e the action event generated by the user's action + */ + public void actionPerformed(ActionEvent e) { + Graph graph = workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "No graph to convert."); + return; + } + + // Commenting this out because the PAG algorithms are not always returning legal PAGs +// if (!graph.paths().isLegalPag()) { +// JOptionPane.showMessageDialog(GraphUtils.getContainingScrollPane(workbench), "I can only convert PAGs."); +// return; +// } + + graph = GraphTransforms.zhangMagFromPag(graph); + workbench.setGraph(graph); + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomDagScaleFreeEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomDagScaleFreeEditor.java index b23edee4af..aec48a04d9 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomDagScaleFreeEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomDagScaleFreeEditor.java @@ -132,7 +132,7 @@ public RandomDagScaleFreeEditor() { b1.add(b10); Box b11 = Box.createHorizontalBox(); - b11.add(new JLabel("Max # latent confounders:")); + b11.add(new JLabel("Number of additional latent confounders:")); b11.add(Box.createHorizontalGlue()); b11.add(this.numLatentsField); b1.add(b11); @@ -217,7 +217,7 @@ public int getNumLatents() { private void setNumLatents(int numLatentNodes) { if (numLatentNodes < 0) { throw new IllegalArgumentException( - "Max # latent confounders must be" + " >= 0: " + + "Number of additional latent confounders must be" + " >= 0: " + numLatentNodes); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java index da15ea8743..950e7d6f76 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java @@ -46,7 +46,7 @@ class RandomGraphEditor extends JPanel { private final IntTextField maxIndegreeField; private final IntTextField maxOutdegreeField; private final IntTextField maxDegreeField; - private final JRadioButton chooseUniform; + // private final JRadioButton chooseUniform; private final JRadioButton chooseFixed; private final JComboBox connectedBox; private final IntTextField numTwoCyclesField; @@ -95,7 +95,7 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param int oldNumNodes = oldNumMeasured + oldNumLatents; if (oldNumNodes > 1 && oldNumMeasured == getNumMeasuredNodes() && - oldNumLatents == getNumLatents()) { + oldNumLatents == getNumLatents()) { setNumMeasuredNodes(oldNumMeasured); setNumLatents(oldNumLatents); setMaxEdges( @@ -108,21 +108,21 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param this.maxIndegreeField = new IntTextField(getMaxIndegree(), 4); this.maxOutdegreeField = new IntTextField(getMaxOutdegree(), 4); this.maxDegreeField = new IntTextField(getMaxDegree(), 4); - JRadioButton randomForward = new JRadioButton("Add random forward edges"); - this.chooseUniform = new JRadioButton("Draw uniformly from all such DAGs"); +// JRadioButton randomForward = new JRadioButton("Add random forward edges"); +// this.chooseUniform = new JRadioButton("Draw uniformly from all such DAGs"); this.chooseFixed = new JRadioButton("Guarantee maximum number of edges"); this.connectedBox = new JComboBox<>(new String[]{"No", "Yes"}); JComboBox addCyclesBox = new JComboBox<>(new String[]{"No", "Yes"}); this.numTwoCyclesField = new IntTextField(getMinNumCycles(), 4); this.minCycleLengthField = new IntTextField(getMinCycleLength(), 4); - ButtonGroup group = new ButtonGroup(); - group.add(randomForward); - group.add(this.chooseUniform); - group.add(this.chooseFixed); - randomForward.setSelected(isRandomForward()); - this.chooseUniform.setSelected(isUniformlySelected()); - this.chooseFixed.setSelected(isChooseFixed()); +// ButtonGroup group = new ButtonGroup(); +// group.add(randomForward); +//// group.add(this.chooseUniform); +// group.add(this.chooseFixed); +// randomForward.setSelected(true); +//// this.chooseUniform.setSelected(isUniformlySelected()); +// this.chooseFixed.setSelected(isChooseFixed()); // set up text and ties them to the parameters object being edited. this.numNodesField.setFilter((value, oldValue) -> { @@ -222,17 +222,17 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param connectedBox.setSelectedItem("No"); } - if (this.isUniformlySelected() || this.isChooseFixed()) { - maxIndegreeField.setEnabled(true); - maxOutdegreeField.setEnabled(true); - maxDegreeField.setEnabled(true); - connectedBox.setEnabled(true); - } else { - maxIndegreeField.setEnabled(false); - maxOutdegreeField.setEnabled(false); - maxDegreeField.setEnabled(false); - connectedBox.setEnabled(false); - } +// if (this.isUniformlySelected() || this.isChooseFixed()) { + maxIndegreeField.setEnabled(true); + maxOutdegreeField.setEnabled(true); + maxDegreeField.setEnabled(true); + connectedBox.setEnabled(true); +// } else { +// maxIndegreeField.setEnabled(false); +// maxOutdegreeField.setEnabled(false); +// maxDegreeField.setEnabled(false); +// connectedBox.setEnabled(false); +// } minCycleLengthField.setEnabled(this.isAddCycles()); @@ -253,44 +253,44 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param maxEdgesField.setValue(RandomGraphEditor.this.getMaxEdges()); }); - randomForward.addActionListener(e -> { - JRadioButton button = (JRadioButton) e.getSource(); - button.setSelected(true); - RandomGraphEditor.this.setRandomForward(true); - RandomGraphEditor.this.setUniformlySelected(false); - RandomGraphEditor.this.setChooseFixed(false); - - maxIndegreeField.setEnabled(true); - maxOutdegreeField.setEnabled(true); - maxDegreeField.setEnabled(true); - connectedBox.setEnabled(true); - }); - - chooseUniform.addActionListener(e -> { - JRadioButton button = (JRadioButton) e.getSource(); - button.setSelected(true); - RandomGraphEditor.this.setRandomForward(false); - RandomGraphEditor.this.setUniformlySelected(true); - RandomGraphEditor.this.setChooseFixed(false); - - maxIndegreeField.setEnabled(true); - maxOutdegreeField.setEnabled(true); - maxDegreeField.setEnabled(true); - connectedBox.setEnabled(true); - }); - - chooseFixed.addActionListener(e -> { - JRadioButton button = (JRadioButton) e.getSource(); - button.setSelected(true); - RandomGraphEditor.this.setRandomForward(false); - RandomGraphEditor.this.setUniformlySelected(false); - RandomGraphEditor.this.setChooseFixed(true); - - maxIndegreeField.setEnabled(false); - maxOutdegreeField.setEnabled(false); - maxDegreeField.setEnabled(false); - connectedBox.setEnabled(false); - }); +// randomForward.addActionListener(e -> { +// JRadioButton button = (JRadioButton) e.getSource(); +// button.setSelected(true); + RandomGraphEditor.this.setRandomForward(true); +// RandomGraphEditor.this.setUniformlySelected(false); +// RandomGraphEditor.this.setChooseFixed(false); + + maxIndegreeField.setEnabled(true); + maxOutdegreeField.setEnabled(true); + maxDegreeField.setEnabled(true); + connectedBox.setEnabled(true); +// }); + +// chooseUniform.addActionListener(e -> { +// JRadioButton button = (JRadioButton) e.getSource(); +// button.setSelected(true); +// RandomGraphEditor.this.setRandomForward(false); +// RandomGraphEditor.this.setUniformlySelected(true); +// RandomGraphEditor.this.setChooseFixed(false); +// +// maxIndegreeField.setEnabled(true); +// maxOutdegreeField.setEnabled(true); +// maxDegreeField.setEnabled(true); +// connectedBox.setEnabled(true); +// }); + +// chooseFixed.addActionListener(e -> { +// JRadioButton button = (JRadioButton) e.getSource(); +// button.setSelected(true); +// RandomGraphEditor.this.setRandomForward(false); +// RandomGraphEditor.this.setUniformlySelected(false); +// RandomGraphEditor.this.setChooseFixed(true); +// +// maxIndegreeField.setEnabled(false); +// maxOutdegreeField.setEnabled(false); +// maxDegreeField.setEnabled(false); +// connectedBox.setEnabled(false); +// }); if (this.isAddCycles()) { addCyclesBox.setSelectedItem("Yes"); @@ -361,7 +361,7 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param b1.add(b10); Box b11 = Box.createHorizontalBox(); - b11.add(new JLabel("Max # latent confounders:")); + b11.add(new JLabel("Number of additional latent confounders:")); b11.add(Box.createHorizontalStrut(25)); b11.add(Box.createHorizontalGlue()); b11.add(this.numLatentsField); @@ -369,7 +369,7 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param b1.add(Box.createVerticalStrut(5)); Box b12 = Box.createHorizontalBox(); - b12.add(new JLabel("Maximum number of edges:")); + b12.add(new JLabel("Number of edges:")); b12.add(Box.createHorizontalGlue()); b12.add(this.maxEdgesField); b1.add(b12); @@ -401,20 +401,20 @@ public RandomGraphEditor(Graph oldGraph, boolean cyclicAllowed, Parameters param b1.add(b16); b1.add(Box.createVerticalStrut(5)); - Box b17a = Box.createHorizontalBox(); - b17a.add(randomForward); - b17a.add(Box.createHorizontalGlue()); - b1.add(b17a); +// Box b17a = Box.createHorizontalBox(); +// b17a.add(randomForward); +// b17a.add(Box.createHorizontalGlue()); +// b1.add(b17a); - Box b17 = Box.createHorizontalBox(); - b17.add(this.chooseUniform); - b17.add(Box.createHorizontalGlue()); - b1.add(b17); +// Box b17 = Box.createHorizontalBox(); +// b17.add(this.chooseUniform); +// b17.add(Box.createHorizontalGlue()); +// b1.add(b17); - Box b18 = Box.createHorizontalBox(); - b18.add(this.chooseFixed); - b18.add(Box.createHorizontalGlue()); - b1.add(b18); +// Box b18 = Box.createHorizontalBox(); +// b18.add(this.chooseFixed); +// b18.add(Box.createHorizontalGlue()); +// b1.add(b18); Box d = Box.createVerticalBox(); b1.setBorder(new TitledBorder("")); @@ -460,7 +460,7 @@ public void setEnabled(boolean enabled) { this.maxOutdegreeField.setEnabled(false); this.maxDegreeField.setEnabled(false); this.connectedBox.setEnabled(false); - this.chooseUniform.setEnabled(enabled); +// this.chooseUniform.setEnabled(enabled); this.chooseFixed.setEnabled(enabled); } else { this.numNodesField.setEnabled(enabled); @@ -470,7 +470,7 @@ public void setEnabled(boolean enabled) { this.maxOutdegreeField.setEnabled(enabled); this.maxDegreeField.setEnabled(enabled); this.connectedBox.setEnabled(enabled); - this.chooseUniform.setEnabled(enabled); +// this.chooseUniform.setEnabled(enabled); this.chooseFixed.setEnabled(enabled); } } @@ -488,18 +488,18 @@ private void setRandomForward(boolean randomFoward) { this.parameters.set("graphRandomFoward", randomFoward); } - /** - *

isUniformlySelected.

- * - * @return a boolean - */ - public boolean isUniformlySelected() { - return this.parameters.getBoolean("graphUniformlySelected", true); - } +// /** +// *

isUniformlySelected.

+// * +// * @return a boolean +// */ +// public boolean isUniformlySelected() { +// return this.parameters.getBoolean("graphUniformlySelected", true); +// } - private void setUniformlySelected(boolean uniformlySelected) { - this.parameters.set("graphUniformlySelected", uniformlySelected); - } +// private void setUniformlySelected(boolean uniformlySelected) { +// this.parameters.set("graphUniformlySelected", uniformlySelected); +// } /** *

isChooseFixed.

@@ -551,8 +551,8 @@ public int getNumLatents() { private void setNumLatents(int numLatentNodes) { if (numLatentNodes < 0) { throw new IllegalArgumentException( - "Max # latent confounders must be" + " >= 0: " + - numLatentNodes); + "Number of additional latent confounders must be" + " >= 0: " + + numLatentNodes); } this.parameters.set("newGraphNumLatents", numLatentNodes); @@ -589,7 +589,7 @@ private void setMaxEdges(int numEdges) { * @return a int */ public int getMaxDegree() { - return this.parameters.getInt("randomGraphMaxDegree", 6); + return this.parameters.getInt("randomGraphMaxDegree", 100); } private void setMaxDegree(int maxDegree) { @@ -612,7 +612,7 @@ private void setMaxDegree(int maxDegree) { * @return a int */ public int getMaxIndegree() { - return this.parameters.getInt("randomGraphMaxIndegree", 3); + return this.parameters.getInt("randomGraphMaxIndegree", 100); } private void setMaxIndegree(int maxIndegree) { @@ -635,7 +635,7 @@ private void setMaxIndegree(int maxIndegree) { * @return a int */ public int getMaxOutdegree() { - return this.parameters.getInt("randomGraphMaxOutdegree", 3); + return this.parameters.getInt("randomGraphMaxOutdegree", 100); } private void setMaxOutdegree(int maxOutDegree) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RedoLastAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RedoLastAction.java new file mode 100644 index 0000000000..cb27c15360 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RedoLastAction.java @@ -0,0 +1,77 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Represents an action to redo the last graph change in a GraphWorkbench. Extends AbstractAction and implements + * ClipboardOwner. + */ +public class RedoLastAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Represents an action to undo the last graph change in a GraphWorkbench. Extends AbstractAction and implements + * ClipboardOwner. + */ + public RedoLastAction(GraphWorkbench workbench) { + super("Redo Last Graph Change"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Performs an action when an event occurs. + * + * @param e the event that triggered the action. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.redo(); + } + + /** + * Called when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost ownership + * @param contents the contents that were lost + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ResetGraph.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ResetGraph.java new file mode 100644 index 0000000000..2ca1ce64ba --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/ResetGraph.java @@ -0,0 +1,80 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * This class represents an action to reset a graph to its original state in a GraphWorkbench. It implements the + * ActionListener interface to respond to events triggered by clicking a button or selecting a menu option. It also + * implements the ClipboardOwner interface to handle clipboard ownership changes. + */ +public class ResetGraph extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * This class represents an action to reset a graph to its original state in a GraphWorkbench. It implements the + * ActionListener interface to respond to events triggered by clicking a button or selecting a menu option. It also + * implements the ClipboardOwner interface to handle clipboard ownership changes. + */ + public ResetGraph(GraphWorkbench workbench) { + super("Reset Graph"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Performs an action when an event occurs. + * + * @param e the event that triggered the action. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + this.workbench.setToOriginal(); + } + + /** + * Called when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost ownership + * @param contents the contents that were lost + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RevertToCpdag.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RevertToCpdag.java new file mode 100644 index 0000000000..2f50e1eabc --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RevertToCpdag.java @@ -0,0 +1,105 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.EdgeListGraph; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.search.utils.MeekRules; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Selects all directed edges in the given display graph. + * + * @author josephramsey + * @version $Id: $Id + */ +public class RevertToCpdag extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public RevertToCpdag(GraphWorkbench workbench) { + super("Revert to CPDAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * {@inheritDoc} + *

+ * Selects all directed edges in the given display graph. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + Graph graph = this.workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(this.workbench, "No graph to run Meek rules on."); + return; + } + + // check to make sure the edges in the graph are all directed or undirected + for (Edge edge : graph.getEdges()) { + if (!Edges.isDirectedEdge(edge) && !Edges.isUndirectedEdge(edge)) { + JOptionPane.showMessageDialog(this.workbench, + "To revert to CPDAG, the graph must contain only directed or undirected edges."); + return; + } + } + + graph = new EdgeListGraph(graph); + MeekRules meekRules = new MeekRules(); + meekRules.setRevertToUnshieldedColliders(true); + meekRules.orientImplied(graph); + workbench.setGraph(graph); + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RevertToPag.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RevertToPag.java new file mode 100644 index 0000000000..e10c47ea0a --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RevertToPag.java @@ -0,0 +1,99 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.search.utils.DagToPag; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Reverts the given graph to a PAG + * + * @author josephramsey + * @version $Id: $Id + */ +public class RevertToPag extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public RevertToPag(GraphWorkbench workbench) { + super("Revert to PAG"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Perform an action when an event occurs. + * + * @param e the action event + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + Graph graph = this.workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(this.workbench, "No graph to run Meek rules on."); + return; + } + + workbench.setGraph(new DagToPag(graph).convert()); + +// if (graph.paths().isLegalDag() || graph.paths().isLegalCpdag() || graph.paths().isLegalMpdag()) { +// workbench.setGraph(new DagToPag(graph).convert()); +// } else if (graph.paths().isLegalMpag()) { +// workbench.setGraph(new DagToPag(graph).convert()); +// } else if (graph.paths().isLegalPag()) { +// JOptionPane.showMessageDialog(this.workbench, "Graph is already a PAG."); +// } else { +// JOptionPane.showMessageDialog(this.workbench, "Graph is not a legal DAG, CPDAG, MPDAG, MAG or PAG."); +// } + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SaveGraph.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SaveGraph.java index 4bf2e0710e..ec8fc61eb1 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SaveGraph.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SaveGraph.java @@ -202,7 +202,7 @@ public void actionPerformed(ActionEvent e) { // } // } else if (this.type == Type.amatCpdag) { - File file = EditorUtils.getSaveFile("graph", "amagpag.txt", parent, false, this.title); + File file = EditorUtils.getSaveFile("graph", "amat.cpag.txt", parent, false, this.title); if (file == null) { System.out.println("File was null."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectCliquesAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectCliquesAction.java new file mode 100644 index 0000000000..28184baf72 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectCliquesAction.java @@ -0,0 +1,133 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static edu.cmu.tetrad.graph.GraphUtils.maximalCliques; + +/** + * An action to highlight edges in node cliques in the GraphWorkbench of a certain minimum size (input by the user). + */ +public class SelectCliquesAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Constructs a new SelectCliquesAction. + * + * @param workbench the GraphWorkbench to highlight the cliques in + * @throws NullPointerException if the workbench is null + */ + public SelectCliquesAction(GraphWorkbench workbench) { + super("Highlight Maximal Cliques"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Performs the action of highlighting all edges in cliques in the given display graph. Inputs the minimum size of the + * cliques to highlight by popping up a dialog box. + * + * @param e the {@link ActionEvent} object + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + + final Graph graph = this.workbench.getGraph(); + + String s = JOptionPane.showInputDialog("Enter the minimum size of the (maximal) clique: "); + + int minSize ; + + while (true) { + if (s == null) { + return; + } + + try { + minSize = Integer.parseInt(s); + + if (minSize < 2) { + JOptionPane.showMessageDialog(this.workbench, "Cliques must have at least 2 nodes"); + } else { + break; + } + } catch (NumberFormatException ex) { + JOptionPane.showMessageDialog(this.workbench, "Please enter a valid integer."); + s = JOptionPane.showInputDialog("Enter the minimum size of the (maximal) clique: "); + } + } + + Set> cliques = GraphUtils.maximalCliques(graph, graph.getNodes()); + + for (Set clique : cliques) { + if (clique.size() < minSize) { + continue; + } + + for (Node n1 : clique) { + for (Node n2 : clique) { + if (n1 == n2) { + continue; + } + + if (graph.isAdjacentTo(n1, n2)) { + this.workbench.selectEdge(graph.getEdge(n1, n2)); + } + } + } + } + } + + /** + * Invoked when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost the ownership + * @param contents the transferred contents that were previously on the clipboard + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectEdgesInAlmostCyclicPaths.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectEdgesInAlmostCyclicPaths.java new file mode 100644 index 0000000000..9285d3127d --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectEdgesInAlmostCyclicPaths.java @@ -0,0 +1,143 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetradapp.workbench.DisplayEdge; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; +import java.util.ArrayList; +import java.util.HashSet; + +/** + * Selects all directed edges in the given display graph. + * + * @author josephramsey + * @version $Id: $Id + */ +public class SelectEdgesInAlmostCyclicPaths extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public SelectEdgesInAlmostCyclicPaths(GraphWorkbench workbench) { + super("Highlight Edges on Almost Cyclic Paths"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * {@inheritDoc} + *

+ * Selects all directed edges in the given display graph. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + Graph graph = this.workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(this.workbench, "No graph to check for almost cyclic paths."); + return; + } + + // Make a list of the bidirected edges in the graph. + java.util.List bidirectedEdges = new ArrayList<>(); + + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + bidirectedEdges.add(edge); + } + } + + java.util.Set almostCyclicEdges = new HashSet<>(); + + for (Edge edge : bidirectedEdges) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + { + java.util.List> directedPaths = graph.paths().directedPaths(x, y, 1000); + + for (java.util.List path : directedPaths) { + for (int i = 0; i < path.size() - 1; i++) { + Node node1 = path.get(i); + Node node2 = path.get(i + 1); + + Edge _edge = graph.getEdge(node1, node2); + almostCyclicEdges.add(_edge); + almostCyclicEdges.add(edge); + } + } + } + + { + java.util.List> directedPaths = graph.paths().directedPaths(y, x, 1000); + + for (java.util.List path : directedPaths) { + for (int i = 0; i < path.size() - 1; i++) { + Node node1 = path.get(i); + Node node2 = path.get(i + 1); + + Edge _edge = graph.getEdge(node1, node2); + almostCyclicEdges.add(_edge); + almostCyclicEdges.add(edge); + } + } + } + } + + for (Edge edge : almostCyclicEdges) { + this.workbench.selectEdge(edge); + } + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectEdgesInCyclicPaths.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectEdgesInCyclicPaths.java new file mode 100644 index 0000000000..65c1123bad --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectEdgesInCyclicPaths.java @@ -0,0 +1,106 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetradapp.workbench.DisplayEdge; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Selects all directed edges in the given display graph. + * + * @author josephramsey + * @version $Id: $Id + */ +public class SelectEdgesInCyclicPaths extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public SelectEdgesInCyclicPaths(GraphWorkbench workbench) { + super("Highlight Edges on Cyclic Paths"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * {@inheritDoc} + *

+ * Selects all directed edges in the given display graph. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + Graph graph = this.workbench.getGraph(); + + if (graph == null) { + JOptionPane.showMessageDialog(this.workbench, "No graph to check for cycles."); + return; + } + + for (Component comp : this.workbench.getComponents()) { + if (comp instanceof DisplayEdge) { + Edge edge = ((DisplayEdge) comp).getModelEdge(); + + if (Edges.isDirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(y, x)) { + this.workbench.selectEdge(edge); + } + } + } + } + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectLatentsAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectLatentsAction.java index e789306646..927a563925 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectLatentsAction.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectLatentsAction.java @@ -36,10 +36,8 @@ import java.awt.event.ActionEvent; /** - * Highlights all latent variables in the given display graph. - * - * @author josephramsey - * @version $Id: $Id + * The SelectLatentsAction class is an implementation of the AbstractAction class and ClipboardOwner interface. It + * provides functionality to highlight all latent variables in a given display graph. */ public class SelectLatentsAction extends AbstractAction implements ClipboardOwner { @@ -49,9 +47,10 @@ public class SelectLatentsAction extends AbstractAction implements ClipboardOwne private final GraphWorkbench workbench; /** - * Highlights all latent variables in the given display graph. + * The SelectLatentsAction class is an implementation of the AbstractAction class and ClipboardOwner interface. It + * provides functionality to highlight all latent variables in a given display graph. * - * @param workbench the given workbench. + * @param workbench the GraphWorkbench containing the target session editor (must not be null) */ public SelectLatentsAction(GraphWorkbench workbench) { super("Highlight Latent Nodes"); @@ -64,9 +63,9 @@ public SelectLatentsAction(GraphWorkbench workbench) { } /** - * {@inheritDoc} - *

- * Highlights all latent variables in the given display graph. + * This method is called when an action event occurs. It highlights all latent nodes and edges in the workbench. + * + * @param e the action event that triggered the method */ public void actionPerformed(ActionEvent e) { this.workbench.deselectAll(); @@ -93,9 +92,10 @@ public void actionPerformed(ActionEvent e) { } /** - * {@inheritDoc} - *

- * Required by the AbstractAction interface; does nothing. + * This method is called when the application no longer owns the contents of the clipboard. + * + * @param clipboard The clipboard that lost ownership of the contents + * @param contents The contents that were lost */ public void lostOwnership(Clipboard clipboard, Transferable contents) { } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectMeasuredNodesAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectMeasuredNodesAction.java new file mode 100644 index 0000000000..3a71a73649 --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectMeasuredNodesAction.java @@ -0,0 +1,103 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.NodeType; +import edu.cmu.tetradapp.workbench.DisplayEdge; +import edu.cmu.tetradapp.workbench.DisplayNode; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * The SelectMeasuredNodesAction class highlights all measured nodes and edges in a GraphWorkbench instance. + */ +public class SelectMeasuredNodesAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Highlights all measured nodes and edges in the workbench. + * + * @param workbench the GraphWorkbench containing the target session editor (must not be null) + */ + public SelectMeasuredNodesAction(GraphWorkbench workbench) { + super("Highlight Measured Nodes"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Selects all measured nodes and edges in the workbench. This method is called when an action occurs. + * + * @param e the action event + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + + for (Component comp : this.workbench.getComponents()) { + if (comp instanceof DisplayNode) { + Node node = ((DisplayNode) comp).getModelNode(); + if (node.getNodeType() == NodeType.MEASURED) { + this.workbench.selectNode(node); + } + } + } + + for (Component comp : this.workbench.getComponents()) { + if (comp instanceof DisplayEdge) { + Edge edge = ((DisplayEdge) comp).getModelEdge(); + + if (edge.getNode1().getNodeType() == NodeType.MEASURED + && edge.getNode2().getNodeType() == NodeType.MEASURED) { + this.workbench.selectEdge(edge); + } + } + } + } + + /** + * This method is called when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost ownership (not null) + * @param contents the contents that were lost (not null) + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectNondirectedAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectNondirectedAction.java new file mode 100644 index 0000000000..60b2d9aa1f --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectNondirectedAction.java @@ -0,0 +1,92 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetradapp.workbench.DisplayEdge; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Selects all directed edges in the given display graph. + * + * @author josephramsey + * @version $Id: $Id + */ +public class SelectNondirectedAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public SelectNondirectedAction(GraphWorkbench workbench) { + super("Highlight Nondirected Edges"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * {@inheritDoc} + *

+ * Selects all directed edges in the given display graph. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + + for (Component comp : this.workbench.getComponents()) { + if (comp instanceof DisplayEdge) { + Edge edge = ((DisplayEdge) comp).getModelEdge(); + if (Edges.isNondirectedEdge(edge)) { + this.workbench.selectEdge(edge); + } + } + } + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectPartiallyOrientedAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectPartiallyOrientedAction.java new file mode 100644 index 0000000000..fc3b70fcaf --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SelectPartiallyOrientedAction.java @@ -0,0 +1,92 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetradapp.workbench.DisplayEdge; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Selects all directed edges in the given display graph. + * + * @author josephramsey + * @version $Id: $Id + */ +public class SelectPartiallyOrientedAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Creates a new copy subsession action for the given desktop and clipboard. + * + * @param workbench the given workbench. + */ + public SelectPartiallyOrientedAction(GraphWorkbench workbench) { + super("Highlight Partially Oriented Edges"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * {@inheritDoc} + *

+ * Selects all directed edges in the given display graph. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.deselectAll(); + + for (Component comp : this.workbench.getComponents()) { + if (comp instanceof DisplayEdge) { + Edge edge = ((DisplayEdge) comp).getModelEdge(); + if (Edges.isPartiallyOrientedEdge(edge)) { + this.workbench.selectEdge(edge); + } + } + } + } + + /** + * {@inheritDoc} + *

+ * Required by the AbstractAction interface; does nothing. + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java index 4d103f280f..ba93bd45bc 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/SemGraphEditor.java @@ -24,7 +24,6 @@ import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.IndependenceTest; import edu.cmu.tetrad.search.test.MsepTest; -import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.TetradSerializable; import edu.cmu.tetradapp.model.IndTestProducer; @@ -32,6 +31,7 @@ import edu.cmu.tetradapp.session.DelegatesEditing; import edu.cmu.tetradapp.ui.PaddingPanel; import edu.cmu.tetradapp.util.DesktopController; +import edu.cmu.tetradapp.util.GraphUtils; import edu.cmu.tetradapp.util.LayoutEditable; import edu.cmu.tetradapp.workbench.DisplayEdge; import edu.cmu.tetradapp.workbench.DisplayNode; @@ -50,6 +50,8 @@ import java.util.List; import java.util.*; +import static edu.cmu.tetradapp.util.GraphUtils.addGraphManipItems; + /** * Displays a workbench editing workbench area together with a toolbench for editing tetrad-style graphs. * @@ -276,6 +278,11 @@ private void initUI(SemGraphWrapper semGraphWrapper) { // Update the semGraphWrapper semGraphWrapper.setGraph(targetGraph); + + if (getWorkbench().getGraph() != targetGraph) { + getWorkbench().setGraph(targetGraph); + } + // Also need to update the UI // updateBootstrapTable(targetGraph); } @@ -313,37 +320,8 @@ private void initUI(SemGraphWrapper semGraphWrapper) { JLabel label = new JLabel("Double click variable/node to change name."); label.setFont(new Font("SansSerif", Font.PLAIN, 12)); - - // Info button added by Zhou to show edge types -// JButton infoBtn = new JButton(new ImageIcon(ImageUtils.getImage(this, "info.png"))); -// infoBtn.setBorder(new EmptyBorder(0, 0, 0, 0)); - - // Clock info button to show edge types instructions - Zhou -// infoBtn.addActionListener(new ActionListener() { -// @Override -// public void actionPerformed(ActionEvent e) { -// // Initialize helpSet -// final String helpHS = "/docs/javahelp/TetradHelp.hs"; -// -// try { -// URL url = this.getClass().getResource(helpHS); -// HelpSet helpSet = new HelpSet(null, url); -// -// helpSet.setHomeID("graph_edge_types"); -// HelpBroker broker = helpSet.createHelpBroker(); -// ActionListener listener = new CSH.DisplayHelpFromSource(broker); -// listener.actionPerformed(e); -// } catch (Exception ee) { -// System.out.println("HelpSet " + ee.getMessage()); -// System.out.println("HelpSet " + helpHS + " not found"); -// throw new IllegalArgumentException(); -// } -// } -// }); - instructionBox.add(label); instructionBox.add(Box.createHorizontalStrut(2)); -// instructionBox.add(infoBtn); // Add to topBox topBox.add(topGraphBox); @@ -351,10 +329,6 @@ private void initUI(SemGraphWrapper semGraphWrapper) { this.edgeTypeTable.setPreferredSize(new Dimension(820, 150)); -// //Use JSplitPane to allow resize the bottom box - Zhou -// JSplitPane splitPane = new JSplitPane(JSplitPane.VERTICAL_SPLIT, new PaddingPanel(topBox), new PaddingPanel(edgeTypeTable)); -// splitPane.setDividerLocation((int) (splitPane.getPreferredSize().getHeight() - 150)); - // Switching to tabbed pane because of resizing problems with the split pane... jdramsey 2021.08.25 JTabbedPane tabbedPane = new JTabbedPane(SwingConstants.RIGHT); tabbedPane.addTab("Graph", new PaddingPanel(topBox)); @@ -455,16 +429,34 @@ private JMenuBar createGraphMenuBar() { private JMenu createEditMenu() { JMenu edit = new JMenu("Edit"); + JMenuItem cut = new JMenuItem(new CutSubgraphAction(this)); JMenuItem copy = new JMenuItem(new CopySubgraphAction(this)); JMenuItem paste = new JMenuItem(new PasteSubgraphAction(this)); + JMenuItem undoLast = new JMenuItem(new UndoLastAction(workbench)); + JMenuItem redoLast = new JMenuItem(new RedoLastAction(workbench)); + JMenuItem setToOriginal = new JMenuItem(new ResetGraph(workbench)); + cut.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_X, InputEvent.CTRL_DOWN_MASK)); copy.setAccelerator( KeyStroke.getKeyStroke(KeyEvent.VK_C, InputEvent.CTRL_DOWN_MASK)); paste.setAccelerator( KeyStroke.getKeyStroke(KeyEvent.VK_V, InputEvent.CTRL_DOWN_MASK)); - + undoLast.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Z, InputEvent.CTRL_DOWN_MASK)); + redoLast.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Y, InputEvent.CTRL_DOWN_MASK)); + setToOriginal.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_R, InputEvent.CTRL_DOWN_MASK)); + + edit.add(cut); edit.add(copy); edit.add(paste); + edit.addSeparator(); + + edit.add(undoLast); + edit.add(redoLast); + edit.add(setToOriginal); return edit; } @@ -478,6 +470,7 @@ private JMenu createGraphMenu() { graph.add(new GraphPropertiesAction(getWorkbench())); graph.add(new PathsAction(getWorkbench())); + graph.add(new UnderliningsAction(this.workbench)); graph.addSeparator(); JMenuItem errorTerms = new JMenuItem(); @@ -502,26 +495,12 @@ private JMenu createGraphMenu() { graph.add(errorTerms); graph.addSeparator(); - JMenuItem correlateExogenous - = new JMenuItem("Correlate Exogenous Variables"); - JMenuItem uncorrelateExogenous - = new JMenuItem("Uncorrelate Exogenous Variables"); - graph.add(correlateExogenous); - graph.add(uncorrelateExogenous); + graph.add(GraphUtils.getHighlightMenu(this.workbench)); + graph.add(GraphUtils.getCheckGraphMenu(this.workbench)); + addGraphManipItems(graph, this.workbench); graph.addSeparator(); - correlateExogenous.addActionListener(e -> { - correlationExogenousVariables(); - getWorkbench().invalidate(); - getWorkbench().repaint(); - }); - - uncorrelateExogenous.addActionListener(e -> { - uncorrelateExogenousVariables(); - getWorkbench().invalidate(); - getWorkbench().repaint(); - }); - + graph.add(GraphUtils.addPagEdgeSpecializationsItems(this.workbench)); randomGraph.addActionListener(e -> { GraphParamsEditor editor = new GraphParamsEditor(); @@ -559,14 +538,6 @@ public void internalFrameClosed(InternalFrameEvent e1) { }); }); - graph.add(new JMenuItem(new SelectDirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectBidirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectTrianglesAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectUndirectedAction(getWorkbench()))); - graph.add(new JMenuItem(new SelectLatentsAction(this.workbench))); - graph.add(new PagColorer(getWorkbench())); - return graph; } @@ -574,59 +545,6 @@ private SemGraph getSemGraph() { return (SemGraph) this.semGraphWrapper.getGraph(); } - private void correlationExogenousVariables() { - Graph graph = getWorkbench().getGraph(); - - if (graph instanceof Dag) { - JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), - "Cannot add bidirected edges to DAG's."); - return; - } - - List nodes = graph.getNodes(); - - List exoNodes = new LinkedList<>(); - - for (Node node : nodes) { - if (graph.isExogenous(node)) { - exoNodes.add(node); - } - } - - for (int i = 0; i < exoNodes.size(); i++) { - - loop: - for (int j = i + 1; j < exoNodes.size(); j++) { - Node node1 = exoNodes.get(i); - Node node2 = exoNodes.get(j); - List edges = graph.getEdges(node1, node2); - - for (Edge edge : edges) { - if (Edges.isBidirectedEdge(edge)) { - continue loop; - } - } - - graph.addBidirectedEdge(node1, node2); - } - } - } - - private void uncorrelateExogenousVariables() { - Graph graph = getWorkbench().getGraph(); - - Set edges = graph.getEdges(); - - for (Edge edge : edges) { - if (Edges.isBidirectedEdge(edge)) { - try { - graph.removeEdge(edge); - } catch (Exception ignored) { - } - } - } - } - /** * {@inheritDoc} */ diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/UndoLastAction.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/UndoLastAction.java new file mode 100644 index 0000000000..a947eb1fea --- /dev/null +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/UndoLastAction.java @@ -0,0 +1,78 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetradapp.editor; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetradapp.workbench.GraphWorkbench; + +import javax.swing.*; +import java.awt.datatransfer.Clipboard; +import java.awt.datatransfer.ClipboardOwner; +import java.awt.datatransfer.Transferable; +import java.awt.event.ActionEvent; + +/** + * Represents an action to undo the last graph change in a GraphWorkbench. Extends AbstractAction and implements + * ClipboardOwner. + */ +public class UndoLastAction extends AbstractAction implements ClipboardOwner { + + /** + * The desktop containing the target session editor. + */ + private final GraphWorkbench workbench; + + /** + * Represents an action to undo the last graph change in a GraphWorkbench. + * Extends AbstractAction and implements ClipboardOwner. + */ + public UndoLastAction(GraphWorkbench workbench) { + super("Undo Last Graph Change"); + + if (workbench == null) { + throw new NullPointerException("Desktop must not be null."); + } + + this.workbench = workbench; + } + + /** + * Performs an action when an event occurs. + * + * @param e the event that triggered the action. + */ + public void actionPerformed(ActionEvent e) { + this.workbench.undo(); + } + + /** + * Called when ownership of the clipboard contents is lost. + * + * @param clipboard the clipboard that lost ownership + * @param contents the contents that were lost + */ + public void lostOwnership(Clipboard clipboard, Transferable contents) { + } +} + + + diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/search/GraphCard.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/search/GraphCard.java index ad27d4b3a7..5edf4ed30c 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/search/GraphCard.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/search/GraphCard.java @@ -24,6 +24,7 @@ import edu.cmu.tetradapp.editor.*; import edu.cmu.tetradapp.model.GeneralAlgorithmRunner; import edu.cmu.tetradapp.ui.PaddingPanel; +import edu.cmu.tetradapp.util.GraphUtils; import edu.cmu.tetradapp.util.ImageUtils; import edu.cmu.tetradapp.workbench.GraphWorkbench; @@ -125,13 +126,14 @@ JMenuBar menuBar() { graph.add(new GraphPropertiesAction(this.workbench)); graph.add(new PathsAction(this.workbench)); graph.add(new UnderliningsAction(this.workbench)); + graph.addSeparator(); - graph.add(new JMenuItem(new SelectDirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectBidirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectUndirectedAction(this.workbench))); - graph.add(new JMenuItem(new SelectTrianglesAction(this.workbench))); - graph.add(new JMenuItem(new SelectLatentsAction(this.workbench))); - graph.add(new PagColorer(this.workbench)); + graph.add(GraphUtils.getHighlightMenu(this.workbench)); + graph.add(GraphUtils.getCheckGraphMenu(this.workbench)); +// addGraphManipItems(graph, this.workbench); + graph.addSeparator(); + + graph.add(GraphUtils.addPagEdgeSpecializationsItems(this.workbench)); menuBar.add(graph); @@ -151,9 +153,9 @@ private JPanel createGraphPanel(Graph graph) { graphWorkbench.setKnowledge(knowledge); graphWorkbench.enableEditing(false); - // If the algorithm is a latent variable algorithm, then set the graph workbench to do PAG coloring. + // If the algorithm is a latent variable algorithm, then set the graph workbench to do PAG edge specialization markups. // This is to show the edge types in the graph. - jdramsey 2024/03/13 - graphWorkbench.setDoPagColoring(GraphSearchUtils.isLatentVariableAlgorithmByAnnotation(this.algorithmRunner.getAlgorithm())); + graphWorkbench.markPagEdgeSpecializations(GraphSearchUtils.isLatentVariableAlgorithmByAnnotation(this.algorithmRunner.getAlgorithm())); this.workbench = graphWorkbench; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java index d4fcd613ae..772c26fd17 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/knowledge_editor/KnowledgeGraph.java @@ -43,7 +43,14 @@ public class KnowledgeGraph implements Graph, TetradSerializableExcluded { private static final long serialVersionUID = 23L; /** - * @serial + * Represents a graph data structure. + *

+ * The graph can be of any type, allowing different implementations of the graph interface. In this case, the + * {@link EdgeListGraph} implementation is used. + *

+ * The graph variable is marked as private and final to restrict external modifications. + * + * @see EdgeListGraph */ private final Graph graph = new EdgeListGraph(); @@ -65,17 +72,17 @@ public class KnowledgeGraph implements Graph, TetradSerializableExcluded { /** * The underline triples. */ - private Set underLineTriples; + private final Set underLineTriples = new HashSet<>(); /** * The dotted underline triples. */ - private Set dottedUnderLineTriples; + private final Set dottedUnderLineTriples = new HashSet<>(); /** * The ambiguous triples. */ - private Set ambiguousTriples; + private final Set ambiguousTriples = new HashSet<>(); //============================CONSTRUCTORS=============================// @@ -106,7 +113,10 @@ public static KnowledgeGraph serializableInstance() { //=============================PUBLIC METHODS==========================// /** - * {@inheritDoc} + * Transfer nodes and edges from the given graph to the current graph. + * + * @param graph the graph from which to transfer nodes and edges + * @throws IllegalArgumentException if the provided graph is null */ public final void transferNodesAndEdges(Graph graph) throws IllegalArgumentException { @@ -117,7 +127,10 @@ public final void transferNodesAndEdges(Graph graph) } /** - * {@inheritDoc} + * Transfers the attributes from the given graph to this graph. + * + * @param graph The graph from which the attribute values should be transferred. + * @throws IllegalArgumentException If the given graph is null. */ public final void transferAttributes(Graph graph) throws IllegalArgumentException { @@ -125,7 +138,9 @@ public final void transferAttributes(Graph graph) } /** - * {@inheritDoc} + * Returns the Paths object associated with this instance. + * + * @return the Paths object. */ @Override public Paths paths() { @@ -133,32 +148,39 @@ public Paths paths() { } /** - * {@inheritDoc} + * Checks whether the given Node is parameterizable. + * + * @param node The Node to check. + * @return true if the Node is parameterizable, false otherwise. */ public boolean isParameterizable(Node node) { return false; } /** - *

isTimeLagModel.

+ * Checks if the model is a time lag model. * - * @return a boolean + * @return true if the model is a time lag model, false otherwise. */ public boolean isTimeLagModel() { return false; } /** - *

getTimeLagGraph.

+ * Retrieves the TimeLagGraph object. * - * @return a {@link edu.cmu.tetrad.graph.TimeLagGraph} object + * @return The TimeLagGraph object. */ public TimeLagGraph getTimeLagGraph() { return null; } /** - * {@inheritDoc} + * Returns the set of nodes that form the separator set for the given two nodes in the graph. + * + * @param n1 the first node + * @param n2 the second node + * @return the set of nodes that form the separator set */ @Override public Set getSepset(Node n1, Node n2) { @@ -166,60 +188,79 @@ public Set getSepset(Node n1, Node n2) { } /** - *

getNodeNames.

+ * Retrieves the names of all the nodes in the graph * - * @return a {@link java.util.List} object + * @return The list of node names */ public List getNodeNames() { return getGraph().getNodeNames(); } /** - * {@inheritDoc} + * Connects the specified endpoint to all other endpoints in the graph. + * + * @param endpoint the endpoint to be fully connected */ public void fullyConnect(Endpoint endpoint) { getGraph().fullyConnect(endpoint); } /** - * {@inheritDoc} + * Reorients all endpoints in the graph with the specified endpoint. + * + * @param endpoint the endpoint to reorient all endpoints in the graph with */ public void reorientAllWith(Endpoint endpoint) { getGraph().reorientAllWith(endpoint); } /** - * {@inheritDoc} + * Returns a list of adjacent nodes to the given node in the graph. + * + * @param node the node for which to find adjacent nodes + * @return a list of adjacent nodes */ public List getAdjacentNodes(Node node) { return getGraph().getAdjacentNodes(node); } /** - * {@inheritDoc} + * Get the list of nodes in the graph that have an edge pointing into the given node and connected to the given + * endpoint. + * + * @param node The node for which to get the incoming nodes. + * @param endpoint The endpoint that connects the nodes. + * @return The list of nodes in the graph that have an edge pointing into the given node and connected to the given + * endpoint. */ public List getNodesInTo(Node node, Endpoint endpoint) { return getGraph().getNodesInTo(node, endpoint); } /** - * {@inheritDoc} + * Retrieves the list of nodes that have outgoing edges to the specified destination node. + * + * @param node the source node from which the edges originate + * @param n the destination endpoint node + * @return the list of nodes that have outgoing edges to the specified destination node */ public List getNodesOutTo(Node node, Endpoint n) { return getGraph().getNodesOutTo(node, n); } /** - *

getNodes.

+ * Retrieves the list of nodes in the graph. * - * @return a {@link java.util.List} object + * @return the list of nodes in the graph */ public List getNodes() { return getGraph().getNodes(); } /** - * {@inheritDoc} + * Sets the list of nodes in the graph. + * + * @param nodes the list of nodes to be set */ @Override public void setNodes(List nodes) { @@ -227,42 +268,66 @@ public void setNodes(List nodes) { } /** - * {@inheritDoc} + * Removes the edge between two nodes. + * + * @param node1 the first node + * @param node2 the second node + * @return true if the edge is successfully removed, false if the edge does not exist */ public boolean removeEdge(Node node1, Node node2) { return removeEdge(getEdge(node1, node2)); } /** - * {@inheritDoc} + * Removes the edges between two nodes in the graph. + * + * @param node1 the first node + * @param node2 the second node + * @return true if the edges are successfully removed, false otherwise */ public boolean removeEdges(Node node1, Node node2) { return getGraph().removeEdges(node1, node2); } /** - * {@inheritDoc} + * Checks if two nodes are adjacent in the graph. + * + * @param nodeX the first node to check adjacency + * @param nodeY the second node to check adjacency + * @return true if nodeX is adjacent to nodeY, otherwise false */ public boolean isAdjacentTo(Node nodeX, Node nodeY) { return getGraph().isAdjacentTo(nodeX, nodeY); } /** - * {@inheritDoc} + * Sets the endpoint of a given graph's edge between the specified nodes. + * + * @param node1 The starting node of the edge. + * @param node2 The ending node of the edge. + * @param endpoint The desired endpoint for the edge. + * @return true if the endpoint was successfully set, false otherwise. */ public boolean setEndpoint(Node node1, Node node2, Endpoint endpoint) { return getGraph().setEndpoint(node1, node2, endpoint); } /** - * {@inheritDoc} + * Retrieves the endpoint of a given pair of nodes in the graph. + * + * @param node1 the first node + * @param node2 the second node + * @return the endpoint of the nodes in the graph */ public Endpoint getEndpoint(Node node1, Node node2) { return getGraph().getEndpoint(node1, node2); } /** - * {@inheritDoc} + * Compares this KnowledgeGraph with the specified Object for equality. + * + * @param o the Object to be compared for equality + * @return true if the specified Object is equal to this KnowledgeGraph, false otherwise */ public boolean equals(Object o) { if (!(o instanceof KnowledgeGraph)) return false; @@ -270,49 +335,76 @@ public boolean equals(Object o) { } /** - * {@inheritDoc} + * Returns a subgraph of the graph, containing only the nodes specified in the input list. + * + * @param nodes the list of nodes to include in the subgraph + * @return a subgraph containing only the specified nodes */ public Graph subgraph(List nodes) { return getGraph().subgraph(nodes); } /** - * {@inheritDoc} + * Adds a directed edge from the source node to the destination node. + * + * @param nodeA the source node + * @param nodeB the destination node + * @return true if the directed edge is successfully added, false otherwise */ public boolean addDirectedEdge(Node nodeA, Node nodeB) { throw new UnsupportedOperationException(); } /** - * {@inheritDoc} + * Adds an undirected edge between two nodes. + * + * @param nodeA the first node to connect + * @param nodeB the second node to connect + * @return {@code true} if the edge between the two nodes is successfully added, {@code false} otherwise + * @throws UnsupportedOperationException if the method is called on an unsupported operation */ public boolean addUndirectedEdge(Node nodeA, Node nodeB) { throw new UnsupportedOperationException(); } /** - * {@inheritDoc} + * Adds a nondirected edge between two nodes. + * + * @param nodeA the first node + * @param nodeB the second node + * @return true if the edge was successfully added, false otherwise */ public boolean addNondirectedEdge(Node nodeA, Node nodeB) { throw new UnsupportedOperationException(); } /** - * {@inheritDoc} + * Adds a partially oriented edge between {@code nodeA} and {@code nodeB}. + * + * @param nodeA the origin node of the partially oriented edge + * @param nodeB the destination node of the partially oriented edge + * @return {@code true} if the partially oriented edge was added successfully, otherwise {@code false} */ public boolean addPartiallyOrientedEdge(Node nodeA, Node nodeB) { throw new UnsupportedOperationException(); } /** - * {@inheritDoc} + * Adds a bidirectional edge between two nodes. + * + * @param nodeA the first node + * @param nodeB the second node + * @return true if the bidirectional edge is added successfully, false otherwise */ public boolean addBidirectedEdge(Node nodeA, Node nodeB) { throw new UnsupportedOperationException(); } /** - * {@inheritDoc} + * Adds the specified edge to the graph. + * + * @param edge the edge to be added to the graph + * @return true if the edge is successfully added, false otherwise */ public boolean addEdge(Edge edge) { if (!(edge instanceof KnowledgeModelEdge _edge)) { @@ -352,90 +444,119 @@ public boolean addEdge(Edge edge) { } /** - * {@inheritDoc} + * Adds a node to the graph. + * + * @param node the node to be added + * @return true if the node was added successfully, false otherwise */ public boolean addNode(Node node) { return getGraph().addNode(node); } /** - * {@inheritDoc} + * Adds a PropertyChangeListener to the Graph. The PropertyChangeListener will be notified of any changes to the + * properties of the Graph. + * + * @param l the PropertyChangeListener to be added */ public void addPropertyChangeListener(PropertyChangeListener l) { getGraph().addPropertyChangeListener(l); } /** - * {@inheritDoc} + * Checks if the graph contains the specified edge. + * + * @param edge the edge to check for + * @return {@code true} if the graph contains the edge, otherwise {@code false} */ public boolean containsEdge(Edge edge) { return getGraph().containsEdge(edge); } /** - * {@inheritDoc} + * Checks if a specific node is present in the graph. + * + * @param node The node to check for presence in the graph. + * @return {@code true} if the node is present in the graph, otherwise {@code false}. */ public boolean containsNode(Node node) { return getGraph().containsNode(node); } /** - *

getEdges.

+ * Returns the set of edges in the graph. * - * @return a {@link java.util.Set} object + * @return a Set of Edge objects representing the edges in the graph */ public Set getEdges() { return getGraph().getEdges(); } /** - * {@inheritDoc} + * Retrieves the list of edges connected to the given node in the graph. + * + * @param node the node for which to retrieve the edges + * @return the list of edges connected to the given node */ public List getEdges(Node node) { return getGraph().getEdges(node); } /** - * {@inheritDoc} + * Returns a list of edges between two nodes in the graph. + * + * @param node1 the first node + * @param node2 the second node + * @return a list of edges between node1 and node2 */ public List getEdges(Node node1, Node node2) { return getGraph().getEdges(node1, node2); } /** - * {@inheritDoc} + * Retrieves a node from the graph with the specified name. + * + * @param name the name of the node to retrieve + * @return the node with the specified name, or null if not found */ public Node getNode(String name) { return getGraph().getNode(name); } /** - *

getNumEdges.

+ * Returns the number of edges in the graph. * - * @return a int + * @return the number of edges in the graph. */ public int getNumEdges() { return getGraph().getNumEdges(); } /** - *

getNumNodes.

+ * Retrieves the number of nodes in the graph. * - * @return a int + * @return the number of nodes in the graph. */ public int getNumNodes() { return getGraph().getNumNodes(); } /** - * {@inheritDoc} + * Retrieves the number of edges for a given node in the graph. This method uses the getGraph() method to access the + * graph and uses the getNumEdges() method of the graph to retrieve the number of edges for the given node. + * + * @param node the node for which to retrieve the number of edges + * @return the number of edges for the given node in the graph */ public int getNumEdges(Node node) { return getGraph().getNumEdges(node); } /** - * {@inheritDoc} + * Removes an edge from the knowledge graph. + * + * @param edge the edge to be removed + * @return true if the edge was successfully removed, false otherwise */ public boolean removeEdge(Edge edge) { KnowledgeModelEdge _edge = (KnowledgeModelEdge) edge; @@ -464,7 +585,10 @@ public boolean removeEdge(Edge edge) { } /** - * {@inheritDoc} + * Removes a collection of edges from the graph. + * + * @param edges the collection of edges to be removed + * @return {@code true} if any edge is successfully removed, {@code false} otherwise */ public boolean removeEdges(Collection edges) { boolean removed = false; @@ -477,86 +601,122 @@ public boolean removeEdges(Collection edges) { } /** - * {@inheritDoc} + * Removes a given node from the graph. + * + * @param node the node to be removed + * @return true if the node was successfully removed, false otherwise */ public boolean removeNode(Node node) { return getGraph().removeNode(node); } /** - *

clear.

+ * Clears the graph by removing all its elements. */ public void clear() { getGraph().clear(); } /** - * {@inheritDoc} + * Removes the given nodes from the graph. + * + * @param nodes The list of nodes to be removed. + * @return True if the nodes were successfully removed, false otherwise. */ public boolean removeNodes(List nodes) { return getGraph().removeNodes(nodes); } /** - * {@inheritDoc} + * Checks if the given nodes form a default noncollider in the graph. + * + * @param node1 the first node in the potential noncollider + * @param node2 the second node in the potential noncollider + * @param node3 the third node in the potential noncollider + * @return true if the nodes form a default noncollider, false otherwise */ public boolean isDefNoncollider(Node node1, Node node2, Node node3) { return getGraph().isDefNoncollider(node1, node2, node3); } /** - * {@inheritDoc} + * Determines if there is a default collider between three nodes. + * + * @param node1 the first node + * @param node2 the second node + * @param node3 the third node + * @return true if there is a default collider, false otherwise */ public boolean isDefCollider(Node node1, Node node2, Node node3) { return getGraph().isDefCollider(node1, node2, node3); } /** - * {@inheritDoc} + * Returns a list of child nodes for the given node. + * + * @param node the node for which to retrieve the child nodes. + * @return a list of child nodes for the given node. */ public List getChildren(Node node) { return getGraph().getChildren(node); } /** - *

getDegree.

+ * Returns the degree of the graph. * - * @return a int + * @return the degree of the graph */ public int getDegree() { return getGraph().getDegree(); } /** - * {@inheritDoc} + * Retrieves the edge between two nodes in the graph. + * + * @param node1 the first node + * @param node2 the second node + * @return the edge between node1 and node2 */ public Edge getEdge(Node node1, Node node2) { return getGraph().getEdge(node1, node2); } /** - * {@inheritDoc} + * Returns the directed edge between two nodes. + * + * @param node1 the first node + * @param node2 the second node + * @return the directed edge between the two nodes */ public Edge getDirectedEdge(Node node1, Node node2) { return getGraph().getDirectedEdge(node1, node2); } /** - * {@inheritDoc} + * Returns the list of parent nodes for the given node. + * + * @param node The node for which parents need to be retrieved. + * @return The list of parent nodes for the given node. */ public List getParents(Node node) { return getGraph().getParents(node); } /** - * {@inheritDoc} + * Returns the indegree of the specified node in the graph. + * + * @param node the node to get the indegree for + * @return the indegree of the specified node */ public int getIndegree(Node node) { return getGraph().getIndegree(node); } /** - * {@inheritDoc} + * Retrieves the degree of the given node in the graph. + * + * @param node the node for which to retrieve the degree + * @return the degree of the specified node in the graph */ @Override public int getDegree(Node node) { @@ -564,57 +724,79 @@ public int getDegree(Node node) { } /** - * {@inheritDoc} + * Returns the outdegree of a given node in the graph. + * + * @param node The node for which to determine the outdegree. + * @return The outdegree of the given node. */ public int getOutdegree(Node node) { return getGraph().getOutdegree(node); } /** - * {@inheritDoc} + * Checks if a given Node is a child of another Node. + * + * @param node1 the Node to be checked + * @param node2 the potential parent Node + * @return true if node1 is a child of node2, false otherwise */ public boolean isChildOf(Node node1, Node node2) { return getGraph().isChildOf(node1, node2); } /** - * {@inheritDoc} + * Returns true if the first node is a parent of the second node in the graph. + * + * @param node1 The first node. + * @param node2 The second node. + * @return True if the first node is a parent of the second node, otherwise false. */ public boolean isParentOf(Node node1, Node node2) { return getGraph().isParentOf(node1, node2); } /** - * {@inheritDoc} + * Determines if a given node is exogenous. + * + * @param node the node to check + * @return true if the node is exogenous, false otherwise */ public boolean isExogenous(Node node) { return getGraph().isExogenous(node); } /** - *

toString.

+ * Returns a string representation of the object. The returned string is obtained by calling the toString method of + * the underlying graph object. * - * @return a {@link java.lang.String} object + * @return a string representation of the object. */ public String toString() { return getGraph().toString(); } /** - *

Getter for the field knowledge.

+ * Retrieves the knowledge object. * - * @return a {@link edu.cmu.tetrad.data.Knowledge} object + * @return The knowledge object. */ public Knowledge getKnowledge() { return this.knowledge; } + /** + * Retrieves the graph object. + * + * @return The graph object. + */ private Graph getGraph() { return this.graph; } /** - * {@inheritDoc} + * Retrieves all attributes stored in the object. + * + * @return A Map representing the attributes stored in the object. */ @Override public Map getAllAttributes() { @@ -622,7 +804,10 @@ public Map getAllAttributes() { } /** - * {@inheritDoc} + * Retrieves the value associated with the specified key from this object's attributes. + * + * @param key the key whose associated value is to be retrieved + * @return the value to which the specified key is mapped, or null if this object contains no mapping for the key */ @Override public Object getAttribute(String key) { @@ -630,7 +815,9 @@ public Object getAttribute(String key) { } /** - * {@inheritDoc} + * Removes the attribute with the specified key from the object. + * + * @param key the key associated with the attribute to be removed */ @Override public void removeAttribute(String key) { @@ -638,7 +825,10 @@ public void removeAttribute(String key) { } /** - * {@inheritDoc} + * Adds an attribute to the internal attribute map. + * + * @param key the key of the attribute + * @param value the value of the attribute */ @Override public void addAttribute(String key, Object value) { @@ -646,16 +836,18 @@ public void addAttribute(String key, Object value) { } /** - *

Getter for the field ambiguousTriples.

+ * Retrieves a set of ambiguous triples. * - * @return a {@link java.util.Set} object + * @return the set of ambiguous triples */ public Set getAmbiguousTriples() { return new HashSet<>(this.ambiguousTriples); } /** - * {@inheritDoc} + * Sets the ambiguous triples. + * + * @param triples - the set of triples to be set as ambiguous */ public void setAmbiguousTriples(Set triples) { this.ambiguousTriples.clear(); @@ -666,50 +858,64 @@ public void setAmbiguousTriples(Set triples) { } /** - *

getUnderLines.

+ * Retrieves the set of underlines. * - * @return a {@link java.util.Set} object + * @return the set of underlines as a new HashSet. */ public Set getUnderLines() { return new HashSet<>(this.underLineTriples); } /** - *

getDottedUnderlines.

+ * Returns a set of Triple objects representing the dotted underlines. * - * @return a {@link java.util.Set} object + * @return a set of Triple objects representing the dotted underlines */ public Set getDottedUnderlines() { return new HashSet<>(this.dottedUnderLineTriples); } /** - * {@inheritDoc} - *

- * States whether r-s-r is an underline triple or not. + * Determines if a triple of nodes is ambiguous. + * + * @param x the first node + * @param y the second node + * @param z the third node + * @return true if the triple is ambiguous, false otherwise */ public boolean isAmbiguousTriple(Node x, Node y, Node z) { return this.ambiguousTriples.contains(new Triple(x, y, z)); } /** - * {@inheritDoc} - *

- * States whether r-s-r is an underline triple or not. + * Checks if a given triple of nodes is an underline triple. + * + * @param x the first node in the triple + * @param y the second node in the triple + * @param z the third node in the triple + * @return true if the triple is an underline triple, false otherwise */ public boolean isUnderlineTriple(Node x, Node y, Node z) { return this.underLineTriples.contains(new Triple(x, y, z)); } /** - * {@inheritDoc} + * Adds an ambiguous triple to the collection. + * + * @param x - the first node of the triple + * @param y - the second node of the triple + * @param z - the third node of the triple */ public void addAmbiguousTriple(Node x, Node y, Node z) { this.ambiguousTriples.add(new Triple(x, y, z)); } /** - * {@inheritDoc} + * Adds the given triple to the collection of underline triples if it exists along a path in the current node. + * + * @param x The first node of the triple. + * @param y The second node of the triple. + * @param z The third node of the triple. */ public void addUnderlineTriple(Node x, Node y, Node z) { Triple triple = new Triple(x, y, z); @@ -722,7 +928,11 @@ public void addUnderlineTriple(Node x, Node y, Node z) { } /** - * {@inheritDoc} + * Adds a triple with dotted underline to the collection of dotted underline triples. + * + * @param x The first node of the triple. + * @param y The second node of the triple. + * @param z The third node of the triple. */ public void addDottedUnderlineTriple(Node x, Node y, Node z) { Triple triple = new Triple(x, y, z); @@ -735,28 +945,42 @@ public void addDottedUnderlineTriple(Node x, Node y, Node z) { } /** - * {@inheritDoc} + * Removes the specified triple from the list of ambiguous triples. + * + * @param x the first node of the triple to be removed + * @param y the second node of the triple to be removed + * @param z the third node of the triple to be removed */ public void removeAmbiguousTriple(Node x, Node y, Node z) { this.ambiguousTriples.remove(new Triple(x, y, z)); } /** - * {@inheritDoc} + * Removes an underline triple from the collection. + * + * @param x the first node of the triple to be removed + * @param y the second node of the triple to be removed + * @param z the third node of the triple to be removed */ public void removeUnderlineTriple(Node x, Node y, Node z) { this.underLineTriples.remove(new Triple(x, y, z)); } /** - * {@inheritDoc} + * Removes the specified triple (x, y, z) from the list of dotted underline triples. + * + * @param x The first node of the triple to be removed. + * @param y The second node of the triple to be removed. + * @param z The third node of the triple to be removed. */ public void removeDottedUnderlineTriple(Node x, Node y, Node z) { this.dottedUnderLineTriples.remove(new Triple(x, y, z)); } /** - * {@inheritDoc} + * Sets the underline triples. + * + * @param triples the set of triples to be set as underline triples */ public void setUnderLineTriples(Set triples) { this.underLineTriples.clear(); @@ -767,7 +991,9 @@ public void setUnderLineTriples(Set triples) { } /** - * {@inheritDoc} + * Clears the existing collection of dotted underlined triples and adds new triples to it. + * + * @param triples The collection of triples to add. */ public void setDottedUnderLineTriples(Set triples) { this.dottedUnderLineTriples.clear(); @@ -778,7 +1004,8 @@ public void setDottedUnderLineTriples(Set triples) { } /** - *

removeTriplesNotInGraph.

+ * Removes triples from the lists ("ambiguousTriples", "underLineTriples", and "dottedUnderLineTriples") that do not + * have all three nodes present in the graph or are not adjacent to each other. */ public void removeTriplesNotInGraph() { for (Triple triple : new HashSet<>(this.ambiguousTriples)) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AlgcomparisonModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AlgcomparisonModel.java index 91d4abba3b..13b75e8086 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AlgcomparisonModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/AlgcomparisonModel.java @@ -33,6 +33,7 @@ import edu.cmu.tetrad.algcomparison.statistic.Statistics; import edu.cmu.tetrad.algcomparison.utils.TakesIndependenceWrapper; import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper; +import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.util.ParamDescription; import edu.cmu.tetrad.util.ParamDescriptions; import edu.cmu.tetrad.util.Parameters; @@ -72,6 +73,11 @@ public class AlgcomparisonModel implements SessionModel { * The results path for the AlgcomparisonModel. */ private final String resultsRoot = System.getProperty("user.home"); + /** + * The suppliedGraph variable represents a graph that can be supplied by the user. + * This graph will be given as an option in the user interface. + */ + private Graph suppliedGraph = null; /** * The list of statistic names. */ @@ -130,6 +136,12 @@ public AlgcomparisonModel(Parameters parameters) { initializeIfNull(); } + public AlgcomparisonModel(GraphSource graphSource, Parameters parameters) { + this.parameters = new Parameters(); + this.suppliedGraph = graphSource.getGraph(); + initializeIfNull(); + } + /** * Finds and returns a list of algorithm classes that implement the Algorithm interface. * @@ -720,7 +732,7 @@ private boolean paramSetByUser(String columnName) { public List getLastStatisticsUsed() { String[] lastStatisticsUsed = Preferences.userRoot().get("lastAlgcomparisonStatisticsUsed", "").split(";"); List list = Arrays.asList(lastStatisticsUsed); - System.out.println("Getting last statistics used: " + list); +// System.out.println("Getting last statistics used: " + list); return list; } @@ -730,7 +742,7 @@ public void setLastStatisticsUsed(List lastStatisticsUsed) { sb.append(statistic.getAbbreviation()).append(";"); } - System.out.println("Setting last statistics used: " + sb); +// System.out.println("Setting last statistics used: " + sb); Preferences.userRoot().put("lastAlgcomparisonStatisticsUsed", sb.toString()); } @@ -800,6 +812,13 @@ public List getSelectedAlgorithmModels() { return new ArrayList<>(selectedAlgorithmModels); } + /** + * The user may supply a graph, which will be given as an option in the UI. + */ + public Graph getSuppliedGraph() { + return suppliedGraph; + } + public static class MyTableColumn { private final String columnName; private final String description; diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFitModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFitModel.java index 49173a5b32..24392e0d41 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFitModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFitModel.java @@ -139,7 +139,7 @@ public CPDAGFitModel(Simulation simulation, GeneralAlgorithmRunner algorithmRunn } catch (Exception e) { e.printStackTrace(); - Graph mag = GraphTransforms.pagToMag(graphs.get(0)); + Graph mag = GraphTransforms.zhangMagFromPag(graphs.get(0)); // Ricf.RicfResult result = estimatePag(dataSet, mag); SemGraph graph = new SemGraph(mag); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFromDagGraphWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFromDagGraphWrapper.java index bf1c680608..5b2ea14f14 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFromDagGraphWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/CPDAGFromDagGraphWrapper.java @@ -21,7 +21,6 @@ package edu.cmu.tetradapp.model; -import edu.cmu.tetrad.graph.Dag; import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphTransforms; @@ -29,6 +28,8 @@ import edu.cmu.tetrad.util.TetradLogger; import edu.cmu.tetradapp.session.DoNotAddOldModel; +import javax.swing.*; + /** *

CpdagFromDagGraphWrapper class.

* @@ -58,11 +59,9 @@ public CPDAGFromDagGraphWrapper(GraphSource source, Parameters parameters) { public CPDAGFromDagGraphWrapper(Graph graph) { super(new EdgeListGraph()); - // make sure the given graph is a dag. - try { - new Dag(graph); - } catch (Exception e) { - throw new IllegalArgumentException("The source graph is not a DAG."); + if (!graph.paths().isLegalDag()) { + JOptionPane.showMessageDialog(null, "The source graph is not a DAG.", + null, JOptionPane.WARNING_MESSAGE); } Graph cpdag = CPDAGFromDagGraphWrapper.getCpdag(new EdgeListGraph(graph)); @@ -85,7 +84,7 @@ public static CPDAGFromDagGraphWrapper serializableInstance() { private static Graph getCpdag(Graph graph) { - return GraphTransforms.cpdagForDag(graph); + return GraphTransforms.dagToCpdag(graph); } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/DagFromCPDAGWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/DagFromCPDAGWrapper.java index 7e8496a141..59f8aaf1ae 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/DagFromCPDAGWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/DagFromCPDAGWrapper.java @@ -57,7 +57,7 @@ public DagFromCPDAGWrapper(GraphSource source, Parameters parameters) { * @param graph a {@link edu.cmu.tetrad.graph.Graph} object */ public DagFromCPDAGWrapper(Graph graph) { - super(DagFromCPDAGWrapper.getGraph(graph), "Choose DAG in CPDAG."); + super(DagFromCPDAGWrapper.getGraph(graph), "Choose Random DAG in CPDAG."); String message = getGraph() + ""; TetradLogger.getInstance().forceLogMessage(message); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java index b1931b998e..a10137cbd9 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/EdgewiseComparisonModel.java @@ -126,7 +126,7 @@ public static Graph getComparisonGraph(Graph graph, Parameters params) { return new EdgeListGraph(graph); } else if ("CPDAG".equals(type)) { params.set("graphComparisonType", "CPDAG"); - return GraphTransforms.cpdagForDag(graph); + return GraphTransforms.dagToCpdag(graph); } else if ("PAG".equals(type)) { params.set("graphComparisonType", "PAG"); return GraphTransforms.dagToPag(graph); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java index dd95900bd5..1c307b3602 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphSelectionWrapper.java @@ -456,7 +456,7 @@ private Graph calculateSelectionGraph(int k) { for (int j = i + 1; j < selectedVariables.size(); j++) { Node x = selectedVariables.get(i); Node y = selectedVariables.get(j); - List> paths = getGraphAtIndex(k).paths().allPathsFromTo(x, y, getN()); + List> paths = getGraphAtIndex(k).paths().allPaths(x, y, getN()); if (this.params.getString("nType", "atLeast").equals(nType.atMost.toString()) && !paths.isEmpty()) { for (List path : paths) { @@ -494,21 +494,21 @@ private Graph calculateSelectionGraph(int k) { Node y = selectedVariables.get(j); if (this.params.getString("nType", "atLeast").equals(nType.atMost.toString())) { - List> paths = getGraphAtIndex(k).paths().allPathsFromTo(x, y, getN()); + List> paths = getGraphAtIndex(k).paths().allPaths(x, y, getN()); for (List path : paths) { if (path.size() <= getN() + 1) { edges.addAll(getEdgesFromPath(path, getGraphAtIndex(k))); } } } else if (this.params.getString("nType", "atLeast").equals(nType.atLeast.toString())) { - List> paths = getGraphAtIndex(k).paths().allPathsFromTo(x, y, -1); + List> paths = getGraphAtIndex(k).paths().allPaths(x, y, -1); for (List path : paths) { if (path.size() >= getN() + 1) { edges.addAll(getEdgesFromPath(path, getGraphAtIndex(k))); } } } else if (this.params.getString("nType", "atLeast").equals(nType.equals.toString())) { - List> paths = getGraphAtIndex(k).paths().allPathsFromTo(x, y, getN()); + List> paths = getGraphAtIndex(k).paths().allPaths(x, y, getN()); for (List path : paths) { if (path.size() == getN() + 1) { edges.addAll(getEdgesFromPath(path, getGraphAtIndex(k))); @@ -531,7 +531,7 @@ private Graph calculateSelectionGraph(int k) { Node y = selectedVariables.get(j); if (this.params.getString("nType", "atLeast").equals(nType.atMost.toString())) { - List> paths = getGraphAtIndex(k).paths().allDirectedPathsFromTo(x, y, getN()); + List> paths = getGraphAtIndex(k).paths().allDirectedPaths(x, y, getN()); for (List path : paths) { if (path.size() <= getN() + 1) { g.addDirectedEdge(x, y); @@ -539,7 +539,7 @@ private Graph calculateSelectionGraph(int k) { } } } else if (this.params.getString("nType", "atLeast").equals(nType.atLeast.toString())) { - List> paths = getGraphAtIndex(k).paths().allDirectedPathsFromTo(x, y, -1); + List> paths = getGraphAtIndex(k).paths().allDirectedPaths(x, y, -1); for (List path : paths) { if (path.size() >= getN() + 1) { g.addDirectedEdge(x, y); @@ -547,7 +547,7 @@ private Graph calculateSelectionGraph(int k) { } } } else if (this.params.getString("nType", "atLeast").equals(nType.equals.toString())) { - List> paths = getGraphAtIndex(k).paths().allDirectedPathsFromTo(x, y, getN()); + List> paths = getGraphAtIndex(k).paths().allDirectedPaths(x, y, getN()); for (List path : paths) { if (path.size() == getN() + 1) { g.addDirectedEdge(x, y); @@ -569,7 +569,7 @@ private Graph calculateSelectionGraph(int k) { Node x = selectedVariables.get(i); Node y = selectedVariables.get(j); - List> paths = getGraphAtIndex(k).paths().allDirectedPathsFromTo(x, y, getN()); + List> paths = getGraphAtIndex(k).paths().allDirectedPaths(x, y, getN()); if (this.params.getString("nType", "atLeast").equals(nType.atMost.toString()) && !paths.isEmpty()) { for (List path : paths) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java index 5004c0a2bf..796268172e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GraphWrapper.java @@ -307,7 +307,7 @@ public Graph getGraph() { */ public void setGraph(Graph graph) { this.graphs = new ArrayList<>(); - this.graphs.add(new EdgeListGraph(graph)); + this.graphs.add(graph); // log(); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MagInPagWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MagInPagWrapper.java index a5fcf7252d..52d508392b 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MagInPagWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/MagInPagWrapper.java @@ -54,13 +54,13 @@ public MagInPagWrapper(GraphSource source, Parameters parameters) { * @param graph a {@link edu.cmu.tetrad.graph.Graph} object */ public MagInPagWrapper(Graph graph) { - super(MagInPagWrapper.getGraph(graph), "Choose DAG in CPDAG."); + super(MagInPagWrapper.getGraph(graph), "Choose Zhang MAG in PAG."); String message = getGraph() + ""; TetradLogger.getInstance().forceLogMessage(message); } private static Graph getGraph(Graph graph) { - return GraphTransforms.pagToMag(graph); + return GraphTransforms.zhangMagFromPag(graph); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java index 7c17116aed..543801652f 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/Misclassifications.java @@ -128,7 +128,7 @@ public static Graph getComparisonGraph(Graph graph, Parameters params) { return new EdgeListGraph(graph); } else if ("CPDAG".equals(type)) { params.set("graphComparisonType", "CPDAG"); - return GraphTransforms.cpdagForDag(graph); + return GraphTransforms.dagToCpdag(graph); } else if ("PAG".equals(type)) { params.set("graphComparisonType", "PAG"); return GraphTransforms.dagToPag(graph); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java index 5e101e8960..d6914aab40 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/PValueImproverWrapper.java @@ -379,7 +379,7 @@ public void execute() { LayoutUtil.defaultLayout(this.graph); } - setResultGraph(GraphTransforms.cpdagForDag(this.graph)); + setResultGraph(GraphTransforms.dagToCpdag(this.graph)); } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemGraphWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemGraphWrapper.java index e63e7e3a9d..fa8b2c59a1 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemGraphWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/SemGraphWrapper.java @@ -409,7 +409,14 @@ public Graph getGraph() { */ public void setGraph(Graph graph) { this.graphs = new ArrayList<>(); - this.graphs.add(new SemGraph(graph)); + + if (graph instanceof SemGraph) { + this.graphs.add(graph); + } else { + this.graphs.add(new SemGraph(graph)); + } + +// this.graphs.add(new SemGraph(graph)); log(); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/ui/model/AlgorithmModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/ui/model/AlgorithmModel.java index caa4c60381..06eeda555c 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/ui/model/AlgorithmModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/ui/model/AlgorithmModel.java @@ -24,6 +24,7 @@ import edu.cmu.tetrad.annotation.AnnotatedClass; import edu.cmu.tetrad.util.AlgorithmDescriptions; +import java.io.Serial; import java.io.Serializable; /** @@ -34,6 +35,7 @@ */ public class AlgorithmModel implements Serializable, Comparable { + @Serial private static final long serialVersionUID = 8599854464475682558L; /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java index 04355670d2..b4ecf678e8 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/GraphUtils.java @@ -2,11 +2,18 @@ import edu.cmu.tetrad.data.DataGraphUtils; import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.Parameters; import edu.cmu.tetrad.util.PointXy; +import edu.cmu.tetradapp.editor.*; +import edu.cmu.tetradapp.workbench.GraphWorkbench; +import org.jetbrains.annotations.NotNull; -import java.util.ArrayList; -import java.util.HashMap; +import javax.swing.*; +import java.awt.*; +import java.awt.event.InputEvent; +import java.awt.event.KeyEvent; +import java.util.*; import java.util.List; /** @@ -48,32 +55,26 @@ public static Graph makeRandomGraph(Graph graph, Parameters parameters) { double deltaOut = parameters.getDouble("scaleFreeDeltaOut", 0.2); int numFactors = parameters.getInt("randomMimNumFactors", 1); - String type = parameters.getString("randomGraphType", "ScaleFree"); - - switch (type) { - case "Uniform": - return GraphUtils.makeRandomDag(graph, - newGraphNumMeasuredNodes, - newGraphNumLatents, - newGraphNumEdges, - randomGraphMaxDegree, - randomGraphMaxIndegree, - randomGraphMaxOutdegree, - graphRandomFoward, - graphUniformlySelected, - randomGraphConnected, - graphChooseFixed, - addCycles, parameters); - case "Mim": - return GraphUtils.makeRandomMim(numFactors, numStructuralNodes, maxStructuralEdges, measurementModelDegree, - numLatentMeasuredImpureParents, numMeasuredMeasuredImpureParents, - numMeasuredMeasuredImpureAssociations); - case "ScaleFree": - return GraphUtils.makeRandomScaleFree(newGraphNumMeasuredNodes, - newGraphNumLatents, alpha, beta, deltaIn, deltaOut); - } + String type = parameters.getString("randomGraphType", "Dag"); + + return switch (type) { + case "Dag" -> RandomGraph.randomGraph( + newGraphNumMeasuredNodes, + newGraphNumLatents, + newGraphNumEdges, + randomGraphMaxDegree, + randomGraphMaxIndegree, + randomGraphMaxOutdegree, + false); + case "Mim" -> + GraphUtils.makeRandomMim(numFactors, numStructuralNodes, maxStructuralEdges, measurementModelDegree, + numLatentMeasuredImpureParents, numMeasuredMeasuredImpureParents, + numMeasuredMeasuredImpureAssociations); + case "ScaleFree" -> GraphUtils.makeRandomScaleFree(newGraphNumMeasuredNodes, + newGraphNumLatents, alpha, beta, deltaIn, deltaOut); + default -> throw new IllegalStateException("Unrecognized graph type: " + type); + }; - throw new IllegalStateException("Unrecognized graph type: " + type); } private static Graph makeRandomDag(Graph _graph, int newGraphNumMeasuredNodes, @@ -81,10 +82,10 @@ private static Graph makeRandomDag(Graph _graph, int newGraphNumMeasuredNodes, int newGraphNumEdges, int randomGraphMaxDegree, int randomGraphMaxIndegree, int randomGraphMaxOutdegree, - boolean graphRandomFoward, - boolean graphUniformlySelected, +// boolean graphRandomFoward, +// boolean graphUniformlySelected, boolean randomGraphConnected, - boolean graphChooseFixed, +// boolean graphChooseFixed, boolean addCycles, Parameters parameters) { Graph graph = null; @@ -99,43 +100,43 @@ private static Graph makeRandomDag(Graph _graph, int newGraphNumMeasuredNodes, nodes.add(new GraphNode("X" + (i + 1))); } - if (graphRandomFoward) { - graph = RandomGraph.randomGraphRandomForwardEdges(nodes, newGraphNumLatents, +// if (true) { + graph = RandomGraph.randomGraph(nodes, newGraphNumLatents, newGraphNumEdges, randomGraphMaxDegree, randomGraphMaxIndegree, randomGraphMaxOutdegree, - randomGraphConnected, true); + randomGraphConnected); LayoutUtil.arrangeBySourceGraph(graph, _graph); HashMap layout = GraphSaveLoadUtils.grabLayout(nodes); LayoutUtil.arrangeByLayout(graph, layout); - } else { - if (graphUniformlySelected) { - - graph = RandomGraph.randomGraphUniform(nodes, - newGraphNumLatents, - newGraphNumEdges, - randomGraphMaxDegree, - randomGraphMaxIndegree, - randomGraphMaxOutdegree, - randomGraphConnected, 50000); - LayoutUtil.arrangeBySourceGraph(graph, _graph); - HashMap layout = GraphSaveLoadUtils.grabLayout(nodes); - LayoutUtil.arrangeByLayout(graph, layout); - } else { - if (graphChooseFixed) { - do { - graph = RandomGraph.randomGraph(nodes, - newGraphNumLatents, - newGraphNumEdges, - randomGraphMaxDegree, - randomGraphMaxIndegree, - randomGraphMaxOutdegree, - randomGraphConnected); - LayoutUtil.arrangeBySourceGraph(graph, _graph); - HashMap layout = GraphSaveLoadUtils.grabLayout(nodes); - LayoutUtil.arrangeByLayout(graph, layout); - } while (graph.getNumEdges() < newGraphNumEdges); - } - } - } +// } else { +// if (graphUniformlySelected) { +// +// graph = RandomGraph.randomGraphUniform(nodes, +// newGraphNumLatents, +// newGraphNumEdges, +// randomGraphMaxDegree, +// randomGraphMaxIndegree, +// randomGraphMaxOutdegree, +// randomGraphConnected, 50000); +// LayoutUtil.arrangeBySourceGraph(graph, _graph); +// HashMap layout = GraphSaveLoadUtils.grabLayout(nodes); +// LayoutUtil.arrangeByLayout(graph, layout); +// } else { +// if (graphChooseFixed) { +// do { +// graph = RandomGraph.randomGraph(nodes, +// newGraphNumLatents, +// newGraphNumEdges, +// randomGraphMaxDegree, +// randomGraphMaxIndegree, +// randomGraphMaxOutdegree, +// randomGraphConnected); +// LayoutUtil.arrangeBySourceGraph(graph, _graph); +// HashMap layout = GraphSaveLoadUtils.grabLayout(nodes); +// LayoutUtil.arrangeByLayout(graph, layout); +// } while (graph.getNumEdges() < newGraphNumEdges); +// } +// } +// } if (addCycles) { graph = RandomGraph.randomCyclicGraph2(numNodes, newGraphNumEdges, 8); @@ -182,4 +183,208 @@ private static Graph makeRandomScaleFree(int numNodes, int numLatents, double al alpha, beta, deltaIn, deltaOut); } + public static @NotNull JMenu getCheckGraphMenu(GraphWorkbench workbench) { + JMenu checkGraph = new JMenu("Check Graph Type"); + JMenuItem checkGraphForDag = new JMenuItem(new CheckGraphFoDagAction(workbench)); + JMenuItem checkGraphForCpdag = new JMenuItem(new CheckGraphForCpdagAction(workbench)); + JMenuItem checkGraphForMpdag = new JMenuItem(new CheckGraphForMpdagAction(workbench)); + JMenuItem checkGraphForMag = new JMenuItem(new CheckGraphForMagAction(workbench)); + JMenuItem checkGraphForPag = new JMenuItem(new CheckGraphForPagAction(workbench)); +// JMenuItem checkGraphForMpag = new JMenuItem(new CheckGraphForMpagAction(workbench)); + + checkGraph.add(checkGraphForDag); + checkGraph.add(checkGraphForCpdag); + checkGraph.add(checkGraphForMpdag); + checkGraph.add(checkGraphForMag); + checkGraph.add(checkGraphForPag); +// checkGraph.add(checkGraphForMpag); + return checkGraph; + } + + public static @NotNull JMenu getHighlightMenu(GraphWorkbench workbench) { + JMenu highlightMenu = new JMenu("Highlight"); + highlightMenu.add(new SelectDirectedAction(workbench)); + highlightMenu.add(new SelectBidirectedAction(workbench)); + highlightMenu.add(new SelectUndirectedAction(workbench)); + highlightMenu.add(new SelectPartiallyOrientedAction(workbench)); + highlightMenu.add(new SelectNondirectedAction(workbench)); + highlightMenu.addSeparator(); + + highlightMenu.add(new SelectTrianglesAction(workbench)); + highlightMenu.add(new SelectCliquesAction(workbench)); + highlightMenu.add(new SelectEdgesInCyclicPaths(workbench)); + highlightMenu.add(new SelectEdgesInAlmostCyclicPaths(workbench)); + highlightMenu.addSeparator();; + + highlightMenu.add(new SelectLatentsAction(workbench)); + highlightMenu.add(new SelectMeasuredNodesAction(workbench)); + return highlightMenu; + } + + /** + * Breaks down a given reason into multiple lines with a maximum number of columns. + * + * @param reason the reason to be broken down + * @param maxColumns the maximum number of columns in a line + * @return a string with the reason broken down into multiple lines + */ + public static String breakDown(String reason, int maxColumns) { + StringBuilder buf1 = new StringBuilder(); + StringBuilder buf2 = new StringBuilder(); + + String[] tokens = reason.split(" "); + + for (String token : tokens) { + if (buf1.length() + token.length() > maxColumns) { + buf2.append(buf1); + buf2.append("\n"); + buf1 = new StringBuilder(); + buf1.append(token); + } else { + buf1.append(" ").append(token); + } + } + + if (!buf1.isEmpty()) { + buf2.append(buf1); + } + + return buf2.toString().trim(); + } + + /** + * Adds graph manipulation items to the given graph menu. + * + * @param graph the graph menu to add the items to. + */ + public static void addGraphManipItems(JMenu graph, GraphWorkbench workbench) { + + JMenu transformGraph = new JMenu("Manipulate Graph"); + JMenuItem runMeekRules = new JMenuItem(new ApplyMeekRules(workbench)); + JMenuItem runFinalFciRules = new JMenuItem(new ApplyFinalFciRules(workbench)); + JMenuItem revertToCpdag = new JMenuItem(new RevertToCpdag(workbench)); + JMenuItem revertToPag = new JMenuItem(new RevertToPag(workbench)); + JMenuItem randomDagInCpdag = new JMenuItem(new PickRandomDagInCpdagAction(workbench)); +// JMenuItem randomMagInPag = new JMenuItem(new PickRandomMagInPagAction(workbench)); + JMenuItem zhangMagInPag = new JMenuItem(new PickZhangMagInPagAction(workbench)); + JMenuItem correlateExogenous = new JMenuItem("Correlate Exogenous Variables"); + JMenuItem uncorrelateExogenous = new JMenuItem("Uncorrelate Exogenous Variables"); + + correlateExogenous.addActionListener(e -> { + correlateExogenousVariables(workbench); + workbench.invalidate(); + workbench.repaint(); + }); + + uncorrelateExogenous.addActionListener(e -> { + uncorrelateExogenousVariables(workbench); + workbench.invalidate(); + workbench.repaint(); + }); + + transformGraph.add(runMeekRules); + transformGraph.add(revertToCpdag); + transformGraph.add(randomDagInCpdag); + transformGraph.addSeparator(); + + transformGraph.add(runFinalFciRules); + transformGraph.add(revertToPag); +// transformGraph.add(randomMagInPag); + transformGraph.add(zhangMagInPag); + transformGraph.addSeparator(); + + transformGraph.add(correlateExogenous); + transformGraph.add(uncorrelateExogenous); + + graph.add(transformGraph); + + runMeekRules.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_M, InputEvent.ALT_DOWN_MASK)); + revertToCpdag.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_C, InputEvent.ALT_DOWN_MASK)); + runFinalFciRules.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_F, InputEvent.ALT_DOWN_MASK)); + revertToPag.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_P, InputEvent.ALT_DOWN_MASK)); + randomDagInCpdag.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_D, InputEvent.ALT_DOWN_MASK)); +// randomMagInPag.setAccelerator( +// KeyStroke.getKeyStroke(KeyEvent.VK_R, InputEvent.ALT_DOWN_MASK)); + zhangMagInPag.setAccelerator( + KeyStroke.getKeyStroke(KeyEvent.VK_Z, InputEvent.ALT_DOWN_MASK)); + } + + private static void correlateExogenousVariables(GraphWorkbench workbench) { + Graph graph = workbench.getGraph(); + + if (graph instanceof Dag) { + JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), + "Cannot add bidirected edges to DAG's."); + return; + } + + List nodes = graph.getNodes(); + + List exoNodes = new LinkedList<>(); + + for (Node node : nodes) { + if (graph.isExogenous(node)) { + exoNodes.add(node); + } + } + + for (int i = 0; i < exoNodes.size(); i++) { + + loop: + for (int j = i + 1; j < exoNodes.size(); j++) { + Node node1 = exoNodes.get(i); + Node node2 = exoNodes.get(j); + List edges = graph.getEdges(node1, node2); + + for (Edge edge : edges) { + if (Edges.isBidirectedEdge(edge)) { + continue loop; + } + } + + graph.addBidirectedEdge(node1, node2); + } + } + } + + private static void uncorrelateExogenousVariables(GraphWorkbench workbench) { + Graph graph = workbench.getGraph(); + + Set edges = graph.getEdges(); + + for (Edge edge : edges) { + if (Edges.isBidirectedEdge(edge)) { + try { + graph.removeEdge(edge); + } catch (Exception e) { + // Ignore. + } + } + } + } + + public static @NotNull JMenu addPagEdgeSpecializationsItems(GraphWorkbench workbench) { + JMenu pagEdgeSpecializations = new JMenu("PAG Edge Specialization Markups"); + pagEdgeSpecializations.add(new PagEdgeSpecialization(workbench)); + pagEdgeSpecializations.add(new PagEdgeTypeInstructions()); + return pagEdgeSpecializations; + } + + /** + * Returns the JScrollPane containing the given component, or null if no such JScrollPane exists. + * + * @param component the component to search for a containing JScrollPane + * @return the JScrollPane containing the given component, or null if no such JScrollPane exists + */ + public static JScrollPane getContainingScrollPane(Component component) { + while (component != null && !(component instanceof JScrollPane)) { + component = component.getParent(); + } + return (JScrollPane) component; + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/WatchedProcess.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/WatchedProcess.java index ed7fb5197a..2978eeab9e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/WatchedProcess.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/util/WatchedProcess.java @@ -80,9 +80,11 @@ private void startLongRunningThread() { try { watch(); } catch (InterruptedException e) { - TetradLogger.getInstance().forceLogMessage("Thread was interrupted while watching. Stopping..."); + TetradLogger.getInstance().forceLogMessage("Thread was interrupted while watching. Stopping; see console for stack trace."); + e.printStackTrace(); } catch (Exception e) { - TetradLogger.getInstance().forceLogMessage("Exception while watching: " + e.getMessage()); + TetradLogger.getInstance().forceLogMessage("Exception while watching; see console for stack trace."); + e.printStackTrace(); } if (dialog != null) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java index 232c0f12b6..b1fbc6d105 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/AbstractWorkbench.java @@ -24,6 +24,7 @@ import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.utils.GraphSearchUtils; import edu.cmu.tetrad.util.JOptionUtils; +import edu.cmu.tetradapp.model.SessionWrapper; import edu.cmu.tetradapp.util.LayoutEditable; import edu.cmu.tetradapp.util.PasteLayoutAction; import org.apache.commons.math3.util.FastMath; @@ -89,6 +90,8 @@ public abstract class AbstractWorkbench extends JComponent implements WorkbenchM * Handler for PropertyChangeEvents. */ private final PropertyChangeHandler propChangeHandler = new PropertyChangeHandler(this); + private final LinkedList graphStack = new LinkedList<>(); + private final LinkedList redoStack = new LinkedList<>(); /** * The workbench which this workbench displays. */ @@ -151,48 +154,39 @@ public abstract class AbstractWorkbench extends JComponent implements WorkbenchM * Maximum x value (for dragging). */ private int maxX = 10000; - /** * Maximum y value (for dragging). */ private int maxY = 10000; - /** * True iff node/edge adding/removing errors should be reported to the user. */ private boolean nodeEdgeErrorsReported; - /** * True iff layout is permitted using a right click popup. */ private boolean rightClickPopupAllowed; - /** * A key dispatcher to allow pressing the control key to control whether edges will be drawn in the workbench. */ private KeyEventDispatcher controlDispatcher; - /** * The current displayed mouseover equation label. Null if none is displayed. Used for removing the label. */ private Point currentMouseLocation; - /** * Returns the current displayed mouseover equation label. Returns null if none is displayed. Used for removing the * label. */ private boolean enableEditing = true; - /** - * Whether to do pag coloring. + * Whether to do pag edge specialization markup. */ - private boolean doPagColoring = false; - + private boolean pagEdgeSpecializationMarked = false; /** * The graph to be used for sampling. */ private Graph samplingGraph; - /** * The knowledge. */ @@ -280,12 +274,14 @@ public final void deleteSelectedObjects() { for (DisplayNode graphNode : graphNodes) { removeNode(graphNode); + modelNodesToDisplay.remove(graphNode.getModelNode()); } for (IDisplayEdge displayEdge : graphEdges) { try { removeEdge(displayEdge); resetEdgeOffsets(displayEdge); + modelEdgesToDisplay.remove(displayEdge.getModelEdge()); } catch (Exception e) { if (isNodeEdgeErrorsReported()) { JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), e.getMessage()); @@ -376,6 +372,71 @@ public final void setGraph(Graph graph) { firePropertyChange("modelChanged", null, null); } + public void undo() { + if (graph == null) { + return; + } + + if (graph instanceof SemGraph) { + return; + } + + Graph oldGraph = new EdgeListGraph(graph); + + do { + if (graphStack.isEmpty()) { + break; + } + + Graph graph = graphStack.removeLast(); + + if (pagEdgeSpecializationMarked) { + GraphUtils.addEdgeSpecializationMarkup(new EdgeListGraph(graph)); + } + + setGraph(graph); + redoStack.add(graph); + } while (graph.equals(oldGraph)); + } + + public void redo() { + if (graph == null) { + return; + } + + if (graph instanceof SemGraph) { + return; + } + + Graph oldGraph = new EdgeListGraph(graph); + + do { + if (redoStack.isEmpty()) { + break; + } + + Graph graph = redoStack.removeLast(); + + if (pagEdgeSpecializationMarked) { + GraphUtils.addEdgeSpecializationMarkup(new EdgeListGraph(graph)); + } + + setGraph(graph); + } while (graph.equals(oldGraph)); + } + + public void setToOriginal() { + if (graphStack.size() == 1) { + return; + } + + Graph graph = graphStack.get(0); + for (int i = 1; i < new LinkedList<>(graphStack).size(); i++) { + graphStack.remove(graphStack.get(i)); + } + setGraph(graph); + } + /** * Returns the currently selected nodes as a list. * @@ -1024,7 +1085,7 @@ public Component getComponent(Node node) { * Returns a new tracking edge for the given display node and mouse location. * * @param displayNode The display node to create the tracking edge for. Must not be null. - * @param mouseLoc The location of the mouse pointer. Must not be null. + * @param mouseLoc The location of the mouse pointer. Must not be null. * @return The new tracking edge for the given display node and mouse location. */ public abstract IDisplayEdge getNewTrackingEdge(DisplayNode displayNode, Point mouseLoc); @@ -1038,8 +1099,16 @@ private void setGraphWithoutNotify(Graph graph) { throw new IllegalArgumentException("Graph model cannot be null."); } + if (!graph.equals(getGraph())) { + this.graphStack.addLast(new EdgeListGraph(graph)); + } + this.graph = graph; + if (pagEdgeSpecializationMarked) { + GraphUtils.addEdgeSpecializationMarkup(new EdgeListGraph(graph)); + } + this.modelEdgesToDisplay = new HashMap<>(); this.modelNodesToDisplay = new HashMap<>(); this.displayToModel = new HashMap<>(); @@ -1090,6 +1159,14 @@ private void setGraphWithoutNotify(Graph graph) { repaint(); } + private void addLast(Graph graph) { + if (graph instanceof SessionWrapper) { + return; + } + + this.graphStack.addLast(new EdgeListGraph(graph)); + } + /** * @return the maximum x value (for dragging). */ @@ -1322,7 +1399,7 @@ private void addEdge(Edge modelEdge) { displayEdge.setHighlighted(true); } - if (doPagColoring) { + if (pagEdgeSpecializationMarked) { // visible edges. boolean solid = modelEdge.getProperties().contains(Edge.Property.nl); @@ -1887,10 +1964,6 @@ private void edgeClicked(Object source, MouseEvent e) { graphEdge.setSelected(true); } } - -// if (SearchGraphUtils.isLegalPag(graph).isLegalPag() || doPagColoring) { -// GraphUtils.addPagColoring(new EdgeListGraph(graph)); -// } } private void nodeClicked(Object source, MouseEvent e) { @@ -1913,10 +1986,6 @@ private void nodeClicked(Object source, MouseEvent e) { selectConnectingEdges(); fireNodeSelection(); } - -// if (SearchGraphUtils.isLegalPag(graph).isLegalPag() || doPagColoring) { -// GraphUtils.addPagColoring(new EdgeListGraph(graph)); -// } } private void reorientEdge(Object source, MouseEvent e) { @@ -1945,10 +2014,6 @@ private void reorientEdge(Object source, MouseEvent e) { firePropertyChange("modelChanged", null, null); } } - -// if (SearchGraphUtils.isLegalPag(graph).isLegalPag() || doPagColoring) { -// GraphUtils.addPagColoring(new EdgeListGraph(graph)); -// } } private void fireModelChanged() { @@ -2005,10 +2070,6 @@ private void handleMousePressed(MouseEvent e) { break; } - -// if (SearchGraphUtils.isLegalPag(graph).isLegalPag() || doPagColoring) { -// GraphUtils.addPagColoring(new EdgeListGraph(graph)); -// } } private void launchPopup(MouseEvent e) { @@ -2053,10 +2114,6 @@ private void handleMouseReleased(MouseEvent e) { finishEdge(); break; } - -// if (SearchGraphUtils.isLegalPag(graph).isLegalPag() || doPagColoring) { -// GraphUtils.addPagColoring(new EdgeListGraph(graph)); -// } } private void handleMouseDragged(MouseEvent e) { @@ -2082,10 +2139,6 @@ private void handleMouseDragged(MouseEvent e) { dragNewEdge(source, newPoint); break; } - -// if (SearchGraphUtils.isLegalPag(graph).isLegalPag() || doPagColoring) { -// GraphUtils.addPagColoring(new EdgeListGraph(graph)); -// } } private void handleMouseEntered(MouseEvent e) { @@ -2379,8 +2432,8 @@ private void directEdge(IDisplayEdge graphEdge, int endpoint) { } catch (IllegalArgumentException e) { getGraph().addEdge(edge); - if (doPagColoring) { - GraphUtils.addPagColoring(new EdgeListGraph(graph)); + if (pagEdgeSpecializationMarked) { + GraphUtils.addEdgeSpecializationMarkup(new EdgeListGraph(graph)); } JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), @@ -2400,11 +2453,11 @@ private void toggleEndpoint(IDisplayEdge graphEdge, int endpointNumber) { Endpoint nextEndpoint; if (endpoint == Endpoint.TAIL) { - nextEndpoint = Endpoint.ARROW; - } else if (endpoint == Endpoint.ARROW) { nextEndpoint = Endpoint.CIRCLE; - } else { + } else if (endpoint == Endpoint.ARROW) { nextEndpoint = Endpoint.TAIL; + } else { + nextEndpoint = Endpoint.ARROW; } newEdge = new Edge(edge.getNode1(), edge.getNode2(), nextEndpoint, edge.getEndpoint2()); @@ -2413,11 +2466,11 @@ private void toggleEndpoint(IDisplayEdge graphEdge, int endpointNumber) { Endpoint nextEndpoint; if (endpoint == Endpoint.TAIL) { - nextEndpoint = Endpoint.ARROW; - } else if (endpoint == Endpoint.ARROW) { nextEndpoint = Endpoint.CIRCLE; - } else { + } else if (endpoint == Endpoint.ARROW) { nextEndpoint = Endpoint.TAIL; + } else { + nextEndpoint = Endpoint.ARROW; } newEdge = new Edge(edge.getNode1(), edge.getNode2(), edge.getEndpoint1(), nextEndpoint); @@ -2432,8 +2485,8 @@ private void toggleEndpoint(IDisplayEdge graphEdge, int endpointNumber) { if (!added) { getGraph().addEdge(edge); - if (doPagColoring) { - GraphUtils.addPagColoring(new EdgeListGraph(graph)); + if (pagEdgeSpecializationMarked) { + GraphUtils.addEdgeSpecializationMarkup(new EdgeListGraph(graph)); } return; @@ -2443,8 +2496,8 @@ private void toggleEndpoint(IDisplayEdge graphEdge, int endpointNumber) { return; } - if (doPagColoring) { - GraphUtils.addPagColoring(new EdgeListGraph(graph)); + if (pagEdgeSpecializationMarked) { + GraphUtils.addEdgeSpecializationMarkup(new EdgeListGraph(graph)); } revalidate(); @@ -2467,45 +2520,45 @@ private void setMouseDragging() { } /** - * True if the user is allowed to add measured variables. + * Checks whether adding measured variables is allowed. + * + * @return true if adding measured variables is allowed, false otherwise. */ private boolean isAddMeasuredVarsAllowed() { - /** - * True iff the user is allowed to add measured variables. - */ return true; } /** - * @return true if the user is allowed to edit existing meausred variables. + * Returns a boolean value indicating whether editing existing measured variables is allowed. + * + * @return true if editing existing measured variables is allowed, false otherwise. */ boolean isEditExistingMeasuredVarsAllowed() { return true; } /** - * @return true iff the user is allowed to delete variables. + * Checks if deleting variables is allowed. + * + * @return {@code true} if deleting variables is allowed, {@code false} otherwise */ private boolean isDeleteVariablesAllowed() { - /** - * True iff the user is allowed to delete variables. - */ return true; } /** - *

isEnableEditing.

+ * Checks if editing is enabled. * - * @return a boolean + * @return true if editing is enabled, false otherwise. */ public boolean isEnableEditing() { return this.enableEditing; } /** - *

enableEditing.

+ * Enables or disables editing for the software. * - * @param enableEditing a boolean + * @param enableEditing true to enable editing, false to disable editing */ public void enableEditing(boolean enableEditing) { this.enableEditing = enableEditing; @@ -2513,23 +2566,25 @@ public void enableEditing(boolean enableEditing) { } /** - *

isDoPagColoring.

+ * Checks if pagEdgeSpecializationMarked is true or false. * - * @return a boolean + * @return True if pagEdgeSpecializationsMarked is true, false otherwise. */ - public boolean isDoPagColoring() { - return this.doPagColoring; + public boolean isPagEdgeSpecializationMarked() { + return this.pagEdgeSpecializationMarked; } /** - *

Setter for the field doPagColoring.

+ * Marks the pag edge specializations based on the given flag. If the flag is set to true, the method applies + * special coloring to the page edges. If the flag is set to false, all special markings on page edges are removed. * - * @param doPagColoring a boolean + * @param pagEdgeSpecializationsMarked a boolean value indicating whether to mark the page edge specializations or + * not */ - public void setDoPagColoring(boolean doPagColoring) { - this.doPagColoring = doPagColoring; - if (doPagColoring) { - GraphUtils.addPagColoring(graph); + public void markPagEdgeSpecializations(boolean pagEdgeSpecializationsMarked) { + this.pagEdgeSpecializationMarked = pagEdgeSpecializationsMarked; + if (pagEdgeSpecializationsMarked) { + GraphUtils.addEdgeSpecializationMarkup(graph); } else { for (Edge edge : graph.getEdges()) { edge.getProperties().clear(); @@ -2887,15 +2942,25 @@ public void propertyChange(PropertyChangeEvent e) { if ("nodeAdded".equals(propName)) { this.workbench.addNode((Node) newValue); + addLast(workbench.getGraph()); + redoStack.clear(); } else if ("nodeRemoved".equals(propName)) { this.workbench.removeNode((Node) oldValue); + addLast(workbench.getGraph()); + redoStack.clear(); } else if ("edgeAdded".equals(propName)) { this.workbench.addEdge((Edge) newValue); + addLast(workbench.getGraph()); + redoStack.clear(); } else if ("edgeRemoved".equals(propName)) { this.workbench.removeEdge((Edge) oldValue); + addLast(workbench.getGraph()); + redoStack.clear(); } else if ("edgeLaunch".equals(propName)) { System.out.println("Attempt to launch edge."); } else if ("deleteNode".equals(propName)) { + addLast(workbench.getGraph()); + Object node = e.getSource(); if (node instanceof DisplayNode) { @@ -2907,6 +2972,9 @@ public void propertyChange(PropertyChangeEvent e) { this.workbench.selectNode((GraphNode) node); this.workbench.deleteSelectedObjects(); } + + addLast(workbench.getGraph()); + redoStack.clear(); } else if ("cloneMe".equals(propName)) { AbstractWorkbench.this.firePropertyChange("cloneMe", e.getOldValue(), e.getNewValue()); } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java index 94ac429f2b..aeac24112d 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/workbench/GraphWorkbench.java @@ -319,7 +319,6 @@ public String nextVariableName(String base) { String name = base + (++i); for (Node node1 : getGraph().getNodes()) { - if (node1.getName().equals(name)) { continue loop; } diff --git a/tetrad-gui/src/main/resources/config/prodConfig.xml b/tetrad-gui/src/main/resources/config/prodConfig.xml index 908d63a2f4..e71f150557 100644 --- a/tetrad-gui/src/main/resources/config/prodConfig.xml +++ b/tetrad-gui/src/main/resources/config/prodConfig.xml @@ -72,7 +72,7 @@ edu.cmu.tetradapp.editor.GraphSelectionEditor - diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/CompareTwoGraphs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/CompareTwoGraphs.java index 5efdd7e43d..4a340c9642 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/CompareTwoGraphs.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/CompareTwoGraphs.java @@ -263,7 +263,7 @@ private static List statistics() { statistics.add(new MathewsCorrArrow()); statistics.add(new NumberOfEdgesEst()); statistics.add(new NumberOfEdgesTrue()); - statistics.add(new NumCorrectVisibleAncestors()); + statistics.add(new NumCorrectVisibleEdges()); statistics.add(new PercentBidirectedEdges()); statistics.add(new TailPrecision()); statistics.add(new TailRecall()); @@ -278,8 +278,7 @@ private static List statistics() { statistics.add(new DensityTrue()); statistics.add(new StructuralHammingDistance()); - - // Joe table. + // Stats for PAGs. statistics.add(new NumDirectedEdges()); statistics.add(new NumUndirectedEdges()); statistics.add(new NumPartiallyOrientedEdges()); @@ -288,17 +287,8 @@ private static List statistics() { statistics.add(new TrueDagPrecisionTails()); statistics.add(new TrueDagPrecisionArrow()); statistics.add(new BidirectedLatentPrecision()); - - // Greg table -// statistics.add(new AncestorPrecision()); -// statistics.add(new AncestorRecall()); -// statistics.add(new AncestorF1()); -// statistics.add(new SemidirectedPrecision()); -// statistics.add(new SemidirectedRecall()); -// statistics.add(new SemidirectedPathF1()); -// statistics.add(new NoSemidirectedPrecision()); -// statistics.add(new NoSemidirectedRecall()); -// statistics.add(new NoSemidirectedF1()); + statistics.add(new LegalPag()); + statistics.add(new Maximal()); return statistics; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java index cbc9ede96a..be932d13c5 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/Comparison.java @@ -705,7 +705,7 @@ public void saveToFiles(String dataPath, Simulation simulation, Parameters param if (isSaveCPDAGs()) { File file3 = new File(dir3, "cpdag." + (j + 1) + ".txt"); - GraphSaveLoadUtils.saveGraph(GraphTransforms.cpdagForDag(graph), file3, false); + GraphSaveLoadUtils.saveGraph(GraphTransforms.dagToCpdag(graph), file3, false); } if (isSavePags()) { @@ -806,7 +806,7 @@ public void saveToFilesSingleSimulation(String dataPath, Simulation simulation, if (isSaveCPDAGs()) { File file3 = new File(dir3, "cpdag." + (j + 1) + ".txt"); - GraphSaveLoadUtils.saveGraph(GraphTransforms.cpdagForDag(graph), file3, false); + GraphSaveLoadUtils.saveGraph(GraphTransforms.dagToCpdag(graph), file3, false); } if (isSavePags()) { @@ -1398,7 +1398,7 @@ private void doRun(List algorithmSimulationWrappers, if (this.comparisonGraph == ComparisonGraph.true_DAG) { comparisonGraph = new EdgeListGraph(trueGraph); } else if (this.comparisonGraph == ComparisonGraph.CPDAG_of_the_true_DAG) { - comparisonGraph = GraphTransforms.cpdagForDag(trueGraph); + comparisonGraph = GraphTransforms.dagToCpdag(trueGraph); } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) { comparisonGraph = GraphTransforms.dagToPag(trueGraph); } else { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java index be34946638..dcfa186a5b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/TimeoutComparison.java @@ -608,7 +608,7 @@ public void saveToFiles(String dataPath, Simulation simulation, Parameters param if (isSaveCPDAGs()) { File file3 = new File(dir3, "pattern." + (j + 1) + ".txt"); - GraphSaveLoadUtils.saveGraph(GraphTransforms.cpdagForDag(graph), file3, false); + GraphSaveLoadUtils.saveGraph(GraphTransforms.dagToCpdag(graph), file3, false); } if (isSavePags()) { @@ -1277,7 +1277,7 @@ private void doRun(List algorithmSimulationWrappers, comparisonGraph = new EdgeListGraph(trueGraph); } else if (this.comparisonGraph == ComparisonGraph.CPDAG_of_the_true_DAG) { Graph dag = new EdgeListGraph(trueGraph); - comparisonGraph = GraphTransforms.cpdagForDag(dag); + comparisonGraph = GraphTransforms.dagToCpdag(dag); } else if (this.comparisonGraph == ComparisonGraph.PAG_of_the_true_DAG) { Graph trueGraph1 = new EdgeListGraph(trueGraph); comparisonGraph = GraphTransforms.dagToPag(trueGraph1); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java index 9d80408148..95f35b31dd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Bpc.java @@ -128,7 +128,7 @@ public Graph search(DataModel dataModel, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java index e31ef06c0c..928cb77bd3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Fofc.java @@ -150,7 +150,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { */ @Override public Graph getComparisonGraph(Graph graph) { - return GraphTransforms.cpdagForDag(graph); + return GraphTransforms.dagToCpdag(graph); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java index 03771149a0..339049d9df 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/cluster/Ftfc.java @@ -94,7 +94,7 @@ public Graph runSearch(DataModel dataSet, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java index c22cc86216..4564470943 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/multi/FgesConcatenated.java @@ -151,7 +151,7 @@ public Graph getComparisonGraph(Graph graph) { return new EdgeListGraph(graph); } else { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java index 0ba8a9c9c7..35bc0568fd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Cpc.java @@ -131,7 +131,7 @@ protected Graph runSearch(DataModel dataModel, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java index 1b310ecd09..4046960bcc 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Fas.java @@ -103,7 +103,7 @@ protected Graph runSearch(DataModel dataModel, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/FgesMeasurement.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/FgesMeasurement.java index 2806c4841e..f5da8a961a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/FgesMeasurement.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/FgesMeasurement.java @@ -92,7 +92,7 @@ protected Graph runSearch(DataModel dataModel, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java index 4e09157fab..dcf7573295 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pc.java @@ -123,7 +123,7 @@ protected Graph runSearch(DataModel dataModel, Parameters parameters) { @Override public Graph getComparisonGraph(Graph graph) { Graph dag = new EdgeListGraph(graph); - return GraphTransforms.cpdagForDag(dag); + return GraphTransforms.dagToCpdag(dag); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java index 0f70156f79..c8111ec037 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/cpdag/Pcd.java @@ -76,7 +76,7 @@ protected Graph runSearch(DataModel dataModel, Parameters parameters) { */ @Override public Graph getComparisonGraph(Graph graph) { - return GraphTransforms.cpdagForDag(graph); + return GraphTransforms.dagToCpdag(graph); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java index 230173c75f..6f0a1ffbec 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Bfci.java @@ -114,6 +114,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setBossUseBes(parameters.getBoolean(Params.USE_BES)); search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH)); search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); search.setDepth(parameters.getInt(Params.DEPTH)); search.setNumThreads(parameters.getInt(Params.NUM_THREADS)); @@ -178,6 +179,7 @@ public List getParameters() { params.add(Params.SEED); params.add(Params.NUM_THREADS); params.add(Params.VERBOSE); + params.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); // Parameters params.add(Params.NUM_STARTS); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java index d59aaa8ad8..fa5f58f144 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Cfci.java @@ -99,6 +99,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); search.setPossibleMsepSearchDone(parameters.getBoolean(Params.POSSIBLE_MSEP_DONE)); search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); return search.search(); @@ -123,7 +124,7 @@ public Graph getComparisonGraph(Graph graph) { * @return The description of the algorithm. */ public String getDescription() { - return "FCI (Fast Causal Inference) using " + this.test.getDescription(); + return "CFCI (Conservative Fast Causal Inference) using " + this.test.getDescription(); } /** @@ -147,6 +148,7 @@ public List getParameters() { parameters.add(Params.DEPTH); parameters.add(Params.POSSIBLE_MSEP_DONE); parameters.add(Params.DO_DISCRIMINATING_PATH_RULE); + parameters.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); parameters.add(Params.COMPLETE_RULE_SET_USED); parameters.add(Params.TIME_LAG); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java index 9a4472bf35..9e44853976 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Fci.java @@ -106,6 +106,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); search.setPossibleMsepSearchDone(parameters.getBoolean(Params.POSSIBLE_MSEP_DONE)); search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); search.setPcHeuristicType(pcHeuristicType); search.setStable(parameters.getBoolean(Params.STABLE_FAS)); @@ -159,6 +160,7 @@ public List getParameters() { parameters.add(Params.MAX_PATH_LENGTH); parameters.add(Params.POSSIBLE_MSEP_DONE); parameters.add(Params.DO_DISCRIMINATING_PATH_RULE); + parameters.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); parameters.add(Params.COMPLETE_RULE_SET_USED); parameters.add(Params.TIME_LAG); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java index d25ab8527d..d63df4e34f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/FciMax.java @@ -105,6 +105,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH)); search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setPossibleMsepSearchDone(parameters.getBoolean(Params.POSSIBLE_MSEP_DONE)); search.setPcHeuristicType(pcHeuristicType); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); @@ -156,6 +157,7 @@ public List getParameters() { parameters.add(Params.MAX_PATH_LENGTH); parameters.add(Params.COMPLETE_RULE_SET_USED); parameters.add(Params.DO_DISCRIMINATING_PATH_RULE); + parameters.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); parameters.add(Params.POSSIBLE_MSEP_DONE); // parameters.add(Params.PC_HEURISTIC); parameters.add(Params.TIME_LAG); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java index ed69c4a933..6d62fe87f4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Gfci.java @@ -103,6 +103,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH)); search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setPossibleMsepSearchDone(parameters.getBoolean((Params.POSSIBLE_MSEP_DONE))); search.setNumThreads(parameters.getInt(Params.NUM_THREADS)); @@ -163,6 +164,7 @@ public List getParameters() { parameters.add(Params.COMPLETE_RULE_SET_USED); parameters.add(Params.DO_DISCRIMINATING_PATH_RULE); parameters.add(Params.POSSIBLE_MSEP_DONE); + parameters.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); parameters.add(Params.TIME_LAG); parameters.add(Params.NUM_THREADS); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java index 4dbd40398b..8f5815fb7c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/GraspFci.java @@ -128,6 +128,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH)); search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); // General search.setVerbose(parameters.getBoolean(Params.VERBOSE)); @@ -193,6 +194,7 @@ public List getParameters() { params.add(Params.COMPLETE_RULE_SET_USED); params.add(Params.DO_DISCRIMINATING_PATH_RULE); params.add(Params.POSSIBLE_MSEP_DONE); + params.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); // General params.add(Params.TIME_LAG); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/LvLite.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/LvLite.java new file mode 100644 index 0000000000..87987cc8d0 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/LvLite.java @@ -0,0 +1,240 @@ +package edu.cmu.tetrad.algcomparison.algorithm.oracle.pag; + +import edu.cmu.tetrad.algcomparison.algorithm.AbstractBootstrapAlgorithm; +import edu.cmu.tetrad.algcomparison.algorithm.Algorithm; +import edu.cmu.tetrad.algcomparison.algorithm.ReturnsBootstrapGraphs; +import edu.cmu.tetrad.algcomparison.algorithm.TakesCovarianceMatrix; +import edu.cmu.tetrad.algcomparison.independence.IndependenceWrapper; +import edu.cmu.tetrad.algcomparison.score.ScoreWrapper; +import edu.cmu.tetrad.algcomparison.utils.HasKnowledge; +import edu.cmu.tetrad.algcomparison.utils.TakesIndependenceWrapper; +import edu.cmu.tetrad.algcomparison.utils.UsesScoreWrapper; +import edu.cmu.tetrad.annotation.AlgType; +import edu.cmu.tetrad.annotation.Bootstrapping; +import edu.cmu.tetrad.annotation.Experimental; +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.data.DataType; +import edu.cmu.tetrad.data.Knowledge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; +import edu.cmu.tetrad.search.IndependenceTest; +import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.TsUtils; +import edu.cmu.tetrad.util.Parameters; +import edu.cmu.tetrad.util.Params; + +import java.io.Serial; +import java.util.ArrayList; +import java.util.List; + + +/** + * This class represents the LV-Lite algorithm, which is an implementation of the LV algorithm for learning causal structures + * from observational data. It uses a combination of independence tests and scores to search for the best graph structure given + * a data set and parameters. + * + * @author josephramsey + */ +@edu.cmu.tetrad.annotation.Algorithm( + name = "LV-Lite", + command = "lv-lite", + algoType = AlgType.allow_latent_common_causes +) +@Bootstrapping +@Experimental +public class LvLite extends AbstractBootstrapAlgorithm implements Algorithm, UsesScoreWrapper, + HasKnowledge, ReturnsBootstrapGraphs, TakesCovarianceMatrix { + + @Serial + private static final long serialVersionUID = 23L; + + /** + * The score to use. + */ + private ScoreWrapper score; + + /** + * The knowledge. + */ + private Knowledge knowledge = new Knowledge(); + + /** + * This class represents a LvLite algorithm. + * + *

+ * The LvLite algorithm is a bootstrap algorithm that runs a search algorithm to find a graph structure based on a given data set and parameters. It is a subclass of the Abstract + * BootstrapAlgorithm class and implements the Algorithm interface. + *

+ * + * @see AbstractBootstrapAlgorithm + * @see Algorithm + */ + public LvLite() { + // Used for reflection; do not delete. + } + + /** + * LvLite is a class that represents a LvLite algorithm. + * + *

+ * The LvLite algorithm is a bootstrap algorithm that runs a search algorithm to find a graph structure based on a given data set and parameters. + * It is a subclass of the AbstractBootstrapAlgorithm class and implements the Algorithm interface. + *

+ * + * @see AbstractBootstrapAlgorithm + * @see Algorithm + */ + public LvLite(ScoreWrapper score) { + this.score = score; + } + + /** + * Runs the search algorithm to find a graph structure based on a given data model and parameters. + * + * @param dataModel The data model to use for the search algorithm. + * @param parameters The parameters to configure the search algorithm. + * @return The resulting graph structure. + * @throws IllegalArgumentException if the time lag is greater than 0 and the data model is not an instance of DataSet. + */ + @Override + public Graph runSearch(DataModel dataModel, Parameters parameters) { + if (parameters.getInt(Params.TIME_LAG) > 0) { + if (!(dataModel instanceof DataSet dataSet)) { + throw new IllegalArgumentException("Expecting a dataset for time lagging."); + } + + DataSet timeSeries = TsUtils.createLagData(dataSet, parameters.getInt(Params.TIME_LAG)); + if (dataSet.getName() != null) { + timeSeries.setName(dataSet.getName()); + } + dataModel = timeSeries; + knowledge = timeSeries.getKnowledge(); + } + + Score score = this.score.getScore(dataModel, parameters); + edu.cmu.tetrad.search.LvLite search = new edu.cmu.tetrad.search.LvLite(score); + + // BOSS + search.setUseDataOrder(parameters.getBoolean(Params.USE_DATA_ORDER)); + search.setNumStarts(parameters.getInt(Params.NUM_STARTS)); + search.setUseBes(parameters.getBoolean(Params.USE_BES)); + + // FCI-ORIENT + search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); + boolean aBoolean = parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE); + search.setDoDiscriminatingPathRule(aBoolean); + + // LV-Lite + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); + + // General + search.setVerbose(parameters.getBoolean(Params.VERBOSE)); + search.setKnowledge(this.knowledge); + + return search.search(); + } + + /** + * Retrieves a comparison graph by transforming a true directed graph into a partially directed graph (PAG). + * + * @param graph The true directed graph, if there is one. + * @return The comparison graph. + */ + @Override + public Graph getComparisonGraph(Graph graph) { + return GraphTransforms.dagToPag(graph); + } + + /** + * Returns a short, one-line description of this algorithm. The description is generated by concatenating the + * descriptions of the test and score objects associated with this algorithm. + * + * @return The description of this algorithm. + */ + @Override + public String getDescription() { + return "LV-Lite using " + this.score.getDescription(); + } + + /** + * Retrieves the data type required by the search algorithm. + * + * @return The data type required by the search algorithm. + */ + @Override + public DataType getDataType() { + return this.score.getDataType(); + } + + /** + * Retrieves the list of parameters used by the algorithm. + * + * @return The list of parameters used by the algorithm. + */ + @Override + public List getParameters() { + List params = new ArrayList<>(); + + // BOSS + params.add(Params.DEPTH); + params.add(Params.USE_BES); + params.add(Params.USE_DATA_ORDER); + params.add(Params.NUM_STARTS); + + // FCI-ORIENT + params.add(Params.DEPTH); + params.add(Params.COMPLETE_RULE_SET_USED); + params.add(Params.DO_DISCRIMINATING_PATH_RULE); + + // LV-Lite + params.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); + + + // General + params.add(Params.TIME_LAG); + params.add(Params.VERBOSE); + + return params; + } + + /** + * Retrieves the knowledge object associated with this method. + * + * @return The knowledge object. + */ + @Override + public Knowledge getKnowledge() { + return this.knowledge; + } + + /** + * Sets the knowledge object associated with this method. + * + * @param knowledge the knowledge object to be set + */ + @Override + public void setKnowledge(Knowledge knowledge) { + this.knowledge = new Knowledge(knowledge); + } + + /** + * Retrieves the ScoreWrapper object associated with this method. + * + * @return The ScoreWrapper object associated with this method. + */ + @Override + public ScoreWrapper getScoreWrapper() { + return this.score; + } + + /** + * Sets the score wrapper for the algorithm. + * + * @param score the score wrapper. + */ + @Override + public void setScoreWrapper(ScoreWrapper score) { + this.score = score; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java index bd34edc9c3..a741becc5f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/Rfci.java @@ -96,6 +96,7 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setKnowledge(this.knowledge); search.setDepth(parameters.getInt(Params.DEPTH)); search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); return search.search(); } @@ -138,6 +139,7 @@ public List getParameters() { parameters.add(Params.DEPTH); parameters.add(Params.MAX_PATH_LENGTH); + parameters.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); parameters.add(Params.TIME_LAG); parameters.add(Params.VERBOSE); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SpFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SpFci.java index b1364098bb..91aa364468 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SpFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/algorithm/oracle/pag/SpFci.java @@ -111,6 +111,8 @@ public Graph runSearch(DataModel dataModel, Parameters parameters) { search.setKnowledge(this.knowledge); search.setMaxPathLength(parameters.getInt(Params.MAX_PATH_LENGTH)); search.setCompleteRuleSetUsed(parameters.getBoolean(Params.COMPLETE_RULE_SET_USED)); + search.setDoDiscriminatingPathRule(parameters.getBoolean(Params.DO_DISCRIMINATING_PATH_RULE)); + search.setResolveAlmostCyclicPaths(parameters.getBoolean(Params.RESOLVE_ALMOST_CYCLIC_PATHS)); search.setVerbose(parameters.getBoolean(Params.VERBOSE)); Object obj = parameters.get(Params.PRINT_STREAM); @@ -167,6 +169,7 @@ public List getParameters() { params.add(Params.MAX_PATH_LENGTH); params.add(Params.COMPLETE_RULE_SET_USED); params.add(Params.DO_DISCRIMINATING_PATH_RULE); + params.add(Params.RESOLVE_ALMOST_CYCLIC_PATHS); params.add(Params.DEPTH); params.add(Params.TIME_LAG); params.add(Params.VERBOSE); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java index 1fb99bf977..5087af40d0 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiff.java @@ -1,6 +1,7 @@ package edu.cmu.tetrad.algcomparison.statistic; import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphTransforms; import edu.cmu.tetrad.search.score.SemBicScorer; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java index c97ad49523..0f7b10d337 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BicDiffPerRecord.java @@ -2,6 +2,7 @@ import edu.cmu.tetrad.data.DataModel; import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.graph.EdgeListGraph; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphTransforms; import edu.cmu.tetrad.search.score.SemBicScorer; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedEst.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedEst.java index 83dc697027..1779efb288 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedEst.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedEst.java @@ -30,7 +30,7 @@ public BidirectedEst() { */ @Override public String getAbbreviation() { - return "#X<->Y"; + return "#X<->Y (E)"; } /** @@ -38,7 +38,7 @@ public String getAbbreviation() { */ @Override public String getDescription() { - return "Number of True Bidirected Edges"; + return "Number of bidirected edges in estimated PAG"; } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedLatentPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedLatentPrecision.java index 9e9f267dbd..9e30e4cd31 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedLatentPrecision.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedLatentPrecision.java @@ -1,33 +1,31 @@ package edu.cmu.tetrad.algcomparison.statistic; import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Edges; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.*; import java.io.Serial; - -import static edu.cmu.tetrad.algcomparison.statistic.LatentCommonAncestorTruePositiveBidirected.existsLatentCommonAncestor; +import java.util.List; /** - * The bidirected true positives. - * - * @author josephramsey - * @version $Id: $Id + * The BidirectedLatentPrecision class implements the Statistic interface and represents a statistic that calculates + * the percentage of bidirected edges in an estimated graph for which a latent confounder exists in the true graph. */ public class BidirectedLatentPrecision implements Statistic { @Serial private static final long serialVersionUID = 23L; /** - * Constructs a new instance of the statistic. + * The BidirectedLatentPrecision class implements the Statistic interface and represents a statistic that calculates + * the percentage of bidirected edges in an estimated graph for which a latent confounder exists in the true graph. */ public BidirectedLatentPrecision() { } /** - * {@inheritDoc} + * Returns the abbreviation for the statistic. The abbreviation is a short string that represents the statistic. + * For this statistic, the abbreviation is "<->-Lat-Prec". + * + * @return The abbreviation for the statistic. */ @Override public String getAbbreviation() { @@ -35,7 +33,9 @@ public String getAbbreviation() { } /** - * {@inheritDoc} + * Returns a short description of the statistic, which is the percentage of bidirected edges for which a latent confounder exists. + * + * @return The description of the statistic. */ @Override public String getDescription() { @@ -43,31 +43,39 @@ public String getDescription() { } /** - * {@inheritDoc} + * Calculates the percentage of correctly identified bidirected edges in an estimated graph + * for which a latent confounder exists in the true graph. + * + * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG). + * @param estGraph The estimated graph (same type). + * @param dataModel The data model. + * @return The percentage of correctly identified bidirected edges. */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int tp = 0; - int fp = 0; + int pos = 0; estGraph = GraphUtils.replaceNodes(estGraph, trueGraph.getNodes()); for (Edge edge : estGraph.getEdges()) { if (Edges.isBidirectedEdge(edge)) { - if (existsLatentCommonAncestor(trueGraph, edge)) { + if (GraphUtils.isCorrectBidirectedEdge(edge, trueGraph)) { tp++; - } else { - fp++; } + + pos++; } } - return tp / (double) (tp + fp); + return tp / (double) pos; } - /** - * {@inheritDoc} + * Calculates the normalized value of a given statistic value. + * + * @param value The value of the statistic. + * @return The normalized value of the statistic. */ @Override public double getNormValue(double value) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedTrue.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedTrue.java index 5a7f0c0482..926be9ebe3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedTrue.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/BidirectedTrue.java @@ -29,7 +29,7 @@ public BidirectedTrue() { */ @Override public String getAbbreviation() { - return "BT"; + return "#X<->Y (T)"; } /** @@ -37,7 +37,7 @@ public String getAbbreviation() { */ @Override public String getDescription() { - return "Number of estimated bidirected edges"; + return "Number of bidirected edges in true PAG"; } /** @@ -53,8 +53,6 @@ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { if (Edges.isBidirectedEdge(edge)) t++; } - System.out.println("True # bidirected edges = " + t); - return t; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathPrecision.java index 5dfc0dde93..667cdf20b1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathPrecision.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathPrecision.java @@ -46,9 +46,9 @@ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int tp = 0, fp = 0; List nodes = trueGraph.getNodes(); - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); for (Node x : nodes) { for (Node y : nodes) { @@ -59,7 +59,7 @@ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { if (e != null && e.pointsTowards(y) && e.getProperties().contains(Edge.Property.dd)) { // if (estGraph.existsDirectedPathFromTo(x, y)) { - if (cpdag.paths().existsDirectedPathFromTo(x, y)) { + if (cpdag.paths().existsDirectedPath(x, y)) { tp++; } else { fp++; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathRecall.java index ea3bd95135..646ee208b3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathRecall.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/DefiniteDirectedPathRecall.java @@ -49,14 +49,14 @@ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int tp = 0, fn = 0; List nodes = trueGraph.getNodes(); - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); for (Node x : nodes) { for (Node y : nodes) { if (x == y) continue; - if (cpdag.paths().existsDirectedPathFromTo(x, y)) { - if (estGraph.paths().existsDirectedPathFromTo(x, y)) { + if (cpdag.paths().existsDirectedPath(x, y)) { + if (estGraph.paths().existsDirectedPath(x, y)) { tp++; } else { fn++; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LegalPag.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LegalPag.java index 5f0179e208..389f2516f1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LegalPag.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LegalPag.java @@ -35,7 +35,7 @@ public String getAbbreviation() { */ @Override public String getDescription() { - return "1 if the estimated graph passes the Legal PAG check, 0 if not"; + return "1 if the estimated graph is Legal PAG, 0 if not"; } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Maximal.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Maximal.java new file mode 100644 index 0000000000..1901ac112f --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Maximal.java @@ -0,0 +1,81 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.util.TetradLogger; + +import java.io.Serial; +import java.util.List; + +/** + * Checks whether a PAG is maximal. + */ +public class Maximal implements Statistic { + @Serial + private static final long serialVersionUID = 23L; + + /** + *

Constructor for LegalPag.

+ */ + public Maximal() { + } + + /** + * {@inheritDoc} + */ + @Override + public String getAbbreviation() { + return "Maximal"; + } + + /** + * Returns a short one-line description of this statistic. This will be printed at the beginning of the report. + * + * @return The description of the statistic. + */ + @Override + public String getDescription() { + return "1 if the estimated graph is maximal, 0 if not"; + } + + /** + * Checks whether a PAG is maximal. + */ + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + List nodes = estGraph.getNodes(); + boolean maximal = true; + + for (int i = 0; i < nodes.size(); i++) { + for (int j = i + 1; j < nodes.size(); j++) { + Node n1 = nodes.get(i); + Node n2 = nodes.get(j); + if (!estGraph.isAdjacentTo(n1, n2)) { + List inducingPath = estGraph.paths().getInducingPath(n1, n2); + + if (inducingPath != null) { + TetradLogger.getInstance().forceLogMessage("Maximality check: Found an inducing path for " + + n1 + "..." + n2 + ": " + + GraphUtils.pathString(estGraph, inducingPath)); + maximal = false; + } + } + } + } + + return maximal ? 1.0 : 0.0; + } + + /** + * Returns the normalized value of the given statistic value. + * + * @param value The value of the statistic. + * @return The normalized value of the statistic, between 0 and 1. + */ + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MaximalityCondition.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MaximalityCondition.java index 66ada017fa..a7fd028897 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MaximalityCondition.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/MaximalityCondition.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { Graph pag = estGraph; - Graph mag = GraphTransforms.pagToMag(estGraph); + Graph mag = GraphTransforms.zhangMagFromPag(estGraph); List nodes = pag.getNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsCondition.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsCondition.java index 18397a902f..19faefce64 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsCondition.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsCondition.java @@ -53,9 +53,9 @@ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { Node y = e.getNode2(); if (Edges.isBidirectedEdge(e)) { - if (pag.paths().existsDirectedPathFromTo(x, y)) { + if (pag.paths().existsDirectedPath(x, y)) { return 0; - } else if (pag.paths().existsDirectedPathFromTo(y, x)) { + } else if (pag.paths().existsDirectedPath(y, x)) { return 0; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsInMagCondition.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsInMagCondition.java deleted file mode 100644 index 4954692a97..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoAlmostCyclicPathsInMagCondition.java +++ /dev/null @@ -1,71 +0,0 @@ -package edu.cmu.tetrad.algcomparison.statistic; - -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.*; - -import java.io.Serial; - -/** - * No almost cyclic paths condition in MAG. - * - * @author josephramsey - * @version $Id: $Id - */ -public class NoAlmostCyclicPathsInMagCondition implements Statistic { - @Serial - private static final long serialVersionUID = 23L; - - /** - * Constructs a new instance of the statistic. - */ - public NoAlmostCyclicPathsInMagCondition() { - - } - - /** - * {@inheritDoc} - */ - @Override - public String getAbbreviation() { - return "NoAlmostCyclicInMag"; - } - - /** - * {@inheritDoc} - */ - @Override - public String getDescription() { - return "1 if the no almost cyclic paths condition passes in MAG, 0 if not"; - } - - /** - * {@inheritDoc} - */ - @Override - public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - Graph mag = GraphTransforms.pagToMag(estGraph); - - for (Edge e : mag.getEdges()) { - Node x = e.getNode1(); - Node y = e.getNode2(); - - if (Edges.isBidirectedEdge(e)) { - if (mag.paths().existsDirectedPathFromTo(x, y)) { - return 0; - } else if (mag.paths().existsDirectedPathFromTo(y, x)) { - return 0; - } - } - } - - return 1; - } - - /** - * {@inheritDoc} - */ - @Override - public double getNormValue(double value) { - return value; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsCondition.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsCondition.java index ea1cc889cc..0b0a8b0a1e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsCondition.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsCondition.java @@ -28,7 +28,7 @@ public NoCyclicPathsCondition() { */ @Override public String getAbbreviation() { - return "NoCyclic"; + return "NoCyclicPaths"; } /** @@ -44,10 +44,8 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - Graph pag = estGraph; - - for (Node n : pag.getNodes()) { - if (pag.paths().existsDirectedPathFromTo(n, n)) { + for (Node n : estGraph.getNodes()) { + if (estGraph.paths().existsDirectedPath(n, n)) { return 0; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsInMagCondition.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsInMagCondition.java deleted file mode 100644 index 41b487ad90..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoCyclicPathsInMagCondition.java +++ /dev/null @@ -1,66 +0,0 @@ -package edu.cmu.tetrad.algcomparison.statistic; - -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphTransforms; -import edu.cmu.tetrad.graph.Node; - -import java.io.Serial; - -/** - * No cyclic paths condition. - * - * @author josephramsey - * @version $Id: $Id - */ -public class NoCyclicPathsInMagCondition implements Statistic { - @Serial - private static final long serialVersionUID = 23L; - - /** - * Constructs a new instance of the statistic. - */ - public NoCyclicPathsInMagCondition() { - - } - - /** - * {@inheritDoc} - */ - @Override - public String getAbbreviation() { - return "NoCyclicInMag"; - } - - /** - * {@inheritDoc} - */ - @Override - public String getDescription() { - return "1 if the no cyclic paths condition passes in MAG, 0 if not"; - } - - /** - * {@inheritDoc} - */ - @Override - public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - Graph mag = GraphTransforms.pagToMag(estGraph); - - for (Node n : mag.getNodes()) { - if (mag.paths().existsDirectedPathFromTo(n, n)) { - return 0; - } - } - - return 1; - } - - /** - * {@inheritDoc} - */ - @Override - public double getNormValue(double value) { - return value; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedPrecision.java index 4a386e816b..842e15491a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedPrecision.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedPrecision.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int tp = 0, fp = 0; - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); List nodes = estGraph.getNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedRecall.java index 5bd95ceaf0..0a60abea55 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedRecall.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NoSemidirectedRecall.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int tp = 0, fn = 0; - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); List nodes = trueGraph.getNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesPrecision.java index 6f639e8aff..1aeb67641e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesPrecision.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesPrecision.java @@ -65,7 +65,7 @@ private Set getNodesInCycles(Graph graph) { Set inCycle = new HashSet<>(); for (Node x : graph.getNodes()) { - if (graph.paths().existsDirectedPathFromTo(x, x)) { + if (graph.paths().existsDirectedPath(x, x)) { inCycle.add(x); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesRecall.java index b8984ee02b..a26e2028f7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesRecall.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NodesInCyclesRecall.java @@ -65,7 +65,7 @@ private Set getNodesInCycles(Graph graph) { Set inCycle = new HashSet<>(); for (Node x : graph.getNodes()) { - if (graph.paths().existsDirectedPathFromTo(x, x)) { + if (graph.paths().existsDirectedPath(x, x)) { inCycle.add(x); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredDD.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredDD.java index 95573bae2d..d0a0520c13 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredDD.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredDD.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredNL.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredNL.java index ff142da0f3..29499c830a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredNL.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredNL.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPD.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPD.java index 4d6d11450e..c6eb78e9c9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPD.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPD.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPL.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPL.java index a6cba5f7e6..e558195d5d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPL.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumColoredPL.java @@ -48,7 +48,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDefiniteDirectedEdgeAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDefiniteDirectedEdgeAncestors.java index d386187910..1bcca589c1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDefiniteDirectedEdgeAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDefiniteDirectedEdgeAncestors.java @@ -45,7 +45,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); Graph pag = GraphTransforms.dagToPag(trueGraph); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeConfounded.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeConfounded.java index 96cdb3e6ff..e49b521e0d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeConfounded.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeConfounded.java @@ -46,7 +46,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); Graph pag = GraphTransforms.dagToPag(trueGraph); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeNonAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeNonAncestors.java index 28585fe462..21f67ae7b4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeNonAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleDirectedEdgeNonAncestors.java @@ -45,7 +45,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); // Graph pag = SearchGraphUtils.dagToPag(trueGraph); int tp = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleEdges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleEdges.java index 86c333c47a..a75c338d46 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleEdges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleEdges.java @@ -48,7 +48,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); Graph pag = GraphTransforms.dagToPag(trueGraph); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeAncestors.java index 9f4aca5d60..a57c6ce7f8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeAncestors.java @@ -45,7 +45,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); Graph pag = GraphTransforms.dagToPag(trueGraph); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeNonAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeNonAncestors.java index 0e3be56414..1e62b6e209 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeNonAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatiblePossiblyDirectedEdgeNonAncestors.java @@ -45,7 +45,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); Graph pag = GraphTransforms.dagToPag(trueGraph); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleVisibleAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleVisibleAncestors.java deleted file mode 100644 index 7ec8ebd739..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCompatibleVisibleAncestors.java +++ /dev/null @@ -1,81 +0,0 @@ -package edu.cmu.tetrad.algcomparison.statistic; - -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.*; - -import java.io.Serial; - -import static edu.cmu.tetrad.graph.GraphUtils.compatible; - -/** - * The bidirected true positives. - * - * @author josephramsey - * @version $Id: $Id - */ -public class NumCompatibleVisibleAncestors implements Statistic { - @Serial - private static final long serialVersionUID = 23L; - - /** - * Constructs a new instance of the statistic. - */ - public NumCompatibleVisibleAncestors() { - - } - - /** - * {@inheritDoc} - */ - @Override - public String getAbbreviation() { - return "#CVA"; - } - - /** - * {@inheritDoc} - */ - @Override - public String getDescription() { - return "Number compatible visible X-->Y in estimates for which X is an ancestor of Y in true"; - } - - /** - * {@inheritDoc} - */ - @Override - public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); - - Graph pag = GraphTransforms.dagToPag(trueGraph); - - int tp = 0; - int fp = 0; - - for (Edge edge : estGraph.getEdges()) { - Edge trueEdge = pag.getEdge(edge.getNode1(), edge.getNode2()); - if (!compatible(edge, trueEdge)) continue; - - if (edge.getProperties().contains(Edge.Property.nl)) { - Node x = Edges.getDirectedEdgeTail(edge); - Node y = Edges.getDirectedEdgeHead(edge); - - if (trueGraph.paths().isAncestorOf(x, y)) { - tp++; - } else { - fp++; - } - } - } - - return tp; - } - - /** - * {@inheritDoc} - */ - @Override - public double getNormValue(double value) { - return value; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectBidirected.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectBidirected.java new file mode 100644 index 0000000000..dbb4a641b4 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectBidirected.java @@ -0,0 +1,82 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Edges; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; + +import java.io.Serial; + +/** + * Counts the number of X<->Y edges for which a latent confounder of X and Y exists. + * + * @author josephramsey + * @version $Id: $Id + */ +public class NumCorrectBidirected implements Statistic { + @Serial + private static final long serialVersionUID = 23L; + + /** + * Counts the number of bidirectional edges for which a latent confounder of X and Y exists. + */ + public NumCorrectBidirected() { + } + + /** + * Retrieves the abbreviation for the statistic. + * + * @return The abbreviation for the statistic. + */ + @Override + public String getAbbreviation() { + return "<-> Correct"; + } + + /** + * Returns a short one-line description of this statistic. This will be printed at the beginning of the report. + * + * @return The description of the statistics as a String. + */ + @Override + public String getDescription() { + return "Number of bidirected edges for which a latent confounder exists"; + } + + /** + * Returns the number of bidirected edges for which a latent confounder exists. + * + * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG). + * @param estGraph The estimated graph (same type). + * @param dataModel The data model. + * @return The number of bidirected edges with a latent confounder. + */ + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + int tp = 0; + + estGraph = GraphUtils.replaceNodes(estGraph, trueGraph.getNodes()); + + for (Edge edge : estGraph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + if (GraphUtils.isCorrectBidirectedEdge(edge, trueGraph)) { + tp++; + } + } + } + + return tp; + } + + /** + * Returns the normalized value of the given statistic. + * + * @param value The value of the statistic. + * @return The normalized value. + */ + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectDDAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectDDAncestors.java index 7434443496..aec6f25a19 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectDDAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectDDAncestors.java @@ -43,7 +43,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); int tp = 0; int fp = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectPDAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectPDAncestors.java index 57df480b65..4cc2a9c52a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectPDAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectPDAncestors.java @@ -43,7 +43,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); int tp = 0; int fp = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectVisibleAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectVisibleAncestors.java deleted file mode 100644 index 535ef3681d..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectVisibleAncestors.java +++ /dev/null @@ -1,81 +0,0 @@ -package edu.cmu.tetrad.algcomparison.statistic; - -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.*; - -import java.io.Serial; - -/** - * The bidirected true positives. - * - * @author josephramsey - * @version $Id: $Id - */ -public class NumCorrectVisibleAncestors implements Statistic { - @Serial - private static final long serialVersionUID = 23L; - - /** - * Constructs a new instance of the statistic. - */ - public NumCorrectVisibleAncestors() { - - } - - /** - * {@inheritDoc} - */ - @Override - public String getAbbreviation() { - return "#CVA"; - } - - /** - * {@inheritDoc} - */ - @Override - public String getDescription() { - return "Number visible X-->Y where X~~>Y in true"; - } - - /** - * {@inheritDoc} - */ - @Override - public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); - - int tp = 0; - int fp = 0; - - for (Edge edge : estGraph.getEdges()) { - if (edge.getProperties().contains(Edge.Property.nl)) { - Node x = Edges.getDirectedEdgeTail(edge); - Node y = Edges.getDirectedEdgeHead(edge); - - if (/*!existsCommonAncestor(trueGraph, edge) &&*/ trueGraph.paths().isAncestorOf(x, y)) { - tp++; - -// System.out.println("Correct visible edge: " + edge); - } else { - fp++; - -// System.out.println("Incorrect visible edge: " + edge + " x = " + x + " y = " + y); -// System.out.println("\t ancestor = " + trueGraph.isAncestorOf(x, y)); -// System.out.println("\t no common ancestor = " + !existsCommonAncestor(trueGraph, edge)); - - } - } - } - - return tp; - } - - /** - * {@inheritDoc} - */ - @Override - public double getNormValue(double value) { - return value; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectVisibleEdges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectVisibleEdges.java new file mode 100644 index 0000000000..f82e06c9b1 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumCorrectVisibleEdges.java @@ -0,0 +1,81 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.*; + +import java.io.Serial; +import java.util.List; + +/** + * Represents a statistic that calculates the number of correct visible ancestors in the true graph + * that are also visible ancestors in the estimated graph. + */ +public class NumCorrectVisibleEdges implements Statistic { + @Serial + private static final long serialVersionUID = 23L; + + /** + * Constructs a new instance of the statistic. + */ + public NumCorrectVisibleEdges() { + + } + + /** + * {@inheritDoc} + */ + @Override + public String getAbbreviation() { + return "#CorrectVis"; + } + + /** + * {@inheritDoc} + */ + @Override + public String getDescription() { + return "Returns the number of visible edges X->Y in the estimated graph where X and Y have no latent confounder in the true graph."; + } + + /** + * {@inheritDoc} + */ + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + GraphUtils.addEdgeSpecializationMarkup(estGraph); + int tp = 0; + + for (Edge edge : estGraph.getEdges()) { + if (edge.getProperties().contains(Edge.Property.nl)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + boolean existsLatentConfounder = false; + + List> treks = trueGraph.paths().treks(x, y, -1); + + // If there is a trek, x<~~z~~>y, where z is latent, then the edge is not semantically visible. + for (List trek : treks) { + if (GraphUtils.isConfoundingTrek(trueGraph, trek, x, y)) { + existsLatentConfounder = true; + break; + } + } + + if (!existsLatentConfounder) { + tp++; + } + } + } + + return tp; + } + + /** + * {@inheritDoc} + */ + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyDirected.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyDirected.java index 6c6650e20a..4115d94f3d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyDirected.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyDirected.java @@ -45,7 +45,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyNotDirectedPaths.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyNotDirectedPaths.java index 276d1c8fb6..053c97ebfa 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyNotDirectedPaths.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDefinitelyNotDirectedPaths.java @@ -45,7 +45,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDirectedEdgeVisible.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDirectedEdgeVisible.java deleted file mode 100644 index 1ba5230a91..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumDirectedEdgeVisible.java +++ /dev/null @@ -1,68 +0,0 @@ -package edu.cmu.tetrad.algcomparison.statistic; - -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphTransforms; - -import java.io.Serial; - -/** - * Number of X-->Y for which X-->Y visible in true PAG. - * - * @author josephramsey - * @version $Id: $Id - */ -public class NumDirectedEdgeVisible implements Statistic { - @Serial - private static final long serialVersionUID = 23L; - - /** - * Constructs a new instance of the statistic. - */ - public NumDirectedEdgeVisible() { - - } - - /** - * {@inheritDoc} - */ - @Override - public String getAbbreviation() { - return "#X->Y-NL"; - } - - /** - * {@inheritDoc} - */ - @Override - public String getDescription() { - return "Number of X-->Y for which X-->Y visible in true PAG"; - } - - /** - * {@inheritDoc} - */ - @Override - public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - int tp = 0; - - Graph pag = GraphTransforms.dagToPag(trueGraph); - - for (Edge edge : pag.getEdges()) { - if (pag.paths().defVisible(edge)) { - tp++; - } - } - - return tp; - } - - /** - * {@inheritDoc} - */ - @Override - public double getNormValue(double value) { - return value; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectDDAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectDDAncestors.java index b681ecaa7b..e0679d1312 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectDDAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectDDAncestors.java @@ -42,7 +42,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); int tp = 0; int fp = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectPDAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectPDAncestors.java index 43b785f590..ab74cc9bc9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectPDAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectPDAncestors.java @@ -42,7 +42,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); int tp = 0; int fp = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectVisibleAncestors.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectVisibleAncestors.java index 7827c093c3..c04c0974ac 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectVisibleAncestors.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumIncorrectVisibleAncestors.java @@ -43,7 +43,7 @@ public String getDescription() { */ @Override public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - GraphUtils.addPagColoring(estGraph); + GraphUtils.addEdgeSpecializationMarkup(estGraph); int tp = 0; int fp = 0; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumPossiblyDirected.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumPossiblyDirected.java index 08229e9f51..832c9767dc 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumPossiblyDirected.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumPossiblyDirected.java @@ -45,7 +45,7 @@ public String getDescription() { public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { int count = 0; - Graph cpdag = GraphTransforms.cpdagForDag(trueGraph); + Graph cpdag = GraphTransforms.dagToCpdag(trueGraph); for (Edge edge : estGraph.getEdges()) { if (Edges.isDirectedEdge(edge)) { @@ -53,7 +53,7 @@ public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { Node y = Edges.getDirectedEdgeHead(edge); if (new Paths(cpdag).existsSemiDirectedPath(x, y)) { - if (!new Paths(cpdag).existsDirectedPathFromTo(x, y)) { + if (!new Paths(cpdag).existsDirectedPath(x, y)) { count++; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEdgeEst.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEdgeEst.java new file mode 100644 index 0000000000..21b6fdf8ed --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEdgeEst.java @@ -0,0 +1,80 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import org.apache.commons.math3.util.FastMath; + +import java.io.Serial; +import java.util.ArrayList; + +/** + * NumVisibleEdgeEst is a class that implements the Statistic interface. It calculates the number of X-->Y edges that + * are visible in the estimated PAG. + */ +public class NumVisibleEdgeEst implements Statistic { + @Serial + private static final long serialVersionUID = 23L; + + /** + * Constructs a new instance of the statistic. + */ + public NumVisibleEdgeEst() { + + } + + /** + * Returns the abbreviation for the statistic. This will be printed at the top of each column. + * + * @return The abbreviation for the statistic. + */ + @Override + public String getAbbreviation() { + return "#X->Y visible (E)"; + } + + /** + * Returns a short one-line description of this statistic. This will be printed at the beginning of the report. + * + * @return The description of the statistic. + */ + @Override + public String getDescription() { + return "Number of X-->Y for which X-->Y visible in estimated PAG"; + } + + /** + * Returns the number of X-->Y edges that are visible in the estimated PAG. + * + * @param trueGraph The true graph. + * @param estGraph The estimated graph. + * @param dataModel The data model. + * @return The number of X-->Y edges that are visible in the estimated PAG. + */ + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + int tp = 0; + + GraphUtils.addEdgeSpecializationMarkup(estGraph); + + for (Edge edge : new ArrayList<>(estGraph.getEdges())) { + if (edge.getProperties().contains(Edge.Property.nl)) { + tp++; + } + } + + return tp; + } + + /** + * Returns the normalized value of the given value. + * + * @param value The value to be normalized. + * @return The normalized value. + */ + @Override + public double getNormValue(double value) { + return FastMath.tan(value); + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEdgeTrue.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEdgeTrue.java new file mode 100644 index 0000000000..3c2370da01 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEdgeTrue.java @@ -0,0 +1,82 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.Edge; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphTransforms; +import edu.cmu.tetrad.graph.GraphUtils; +import org.apache.commons.math3.util.FastMath; + +import java.io.Serial; +import java.util.ArrayList; + +/** + * A class that implements the Statistic interface to calculate the number of visible edges in the true PAG. + */ +public class NumVisibleEdgeTrue implements Statistic { + @Serial + private static final long serialVersionUID = 23L; + + /** + * A class that calculates the number of visible edges in the true PAG. + */ + public NumVisibleEdgeTrue() { + + } + + /** + * Retrieves the abbreviation for the statistic. This will be printed at the top of each column. + * The abbreviation format is "#X->Y visible (T)". + * + * @return The abbreviation string. + */ + @Override + public String getAbbreviation() { + return "#X->Y visible (T)"; + } + + /** + * Retrieves the description of the statistic. This method returns the number of X-->Y edges for which X-->Y is visible in the true PAG. + * + * @return The description of the statistic. + */ + @Override + public String getDescription() { + return "Number of X-->Y for which X-->Y visible in true PAG"; + } + + /** + * Retrieves the number of X-->Y edges for which X-->Y is visible in the true PAG. + * + * @param trueGraph The true PAG graph. + * @param estGraph The estimated PAG graph. + * @param dataModel The data model. + * @return The number of X-->Y edges that are visible in the true PAG. + */ + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + int tp = 0; + + Graph pag = GraphTransforms.dagToPag(trueGraph); + GraphUtils.addEdgeSpecializationMarkup(pag); + + for (Edge edge : new ArrayList<>(pag.getEdges())) { + if (edge.getProperties().contains(Edge.Property.nl)) { + tp++; + } + } + + return tp; + } + + /** + * Returns the normalized value of a given statistic. + * + * @param value The original value of the statistic. + * @return The normalized value of the statistic. + */ + @Override + public double getNormValue(double value) { + return FastMath.tan(value); + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEst.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEst.java deleted file mode 100644 index d04e8b2b2d..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/NumVisibleEst.java +++ /dev/null @@ -1,68 +0,0 @@ -package edu.cmu.tetrad.algcomparison.statistic; - -import edu.cmu.tetrad.data.DataModel; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Edges; -import edu.cmu.tetrad.graph.Graph; - -import java.io.Serial; - -/** - * Number of X-->Y visible in est. - * - * @author josephramsey - * @version $Id: $Id - */ -public class NumVisibleEst implements Statistic { - @Serial - private static final long serialVersionUID = 23L; - - /** - * Constructs a new instance of the statistic. - */ - public NumVisibleEst() { - - } - - /** - * {@inheritDoc} - */ - @Override - public String getAbbreviation() { - return "#X->Y-NL-Est"; - } - - /** - * {@inheritDoc} - */ - @Override - public String getDescription() { - return "Number of X-->Y visible in est"; - } - - /** - * {@inheritDoc} - */ - @Override - public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { - int tp = 0; - - for (Edge edge : estGraph.getEdges()) { - if (Edges.isDirectedEdge(edge)) { - if (estGraph.paths().defVisible(edge)) { - tp++; - } - } - } - - return tp; - } - - /** - * {@inheritDoc} - */ - @Override - public double getNormValue(double value) { - return value; - } -} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifParser.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifParser.java index c3845f0abd..a28a75d4e7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifParser.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifParser.java @@ -42,6 +42,12 @@ public final class BayesBifParser { private BayesBifParser() { } + /** + * Parses a string in BayesBif format and converts it into a BayesIm object. + * + * @param text the string in BayesBif format + * @return the BayesIm object created from the parsed string + */ public static BayesIm makeBayesIm(String text) { text = text.replace("\n", ""); text = text.replace("\r", ""); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifRenderer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifRenderer.java index e4031f1d1d..836e60cacb 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifRenderer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesBifRenderer.java @@ -22,12 +22,7 @@ package edu.cmu.tetrad.bayes; import edu.cmu.tetrad.data.DiscreteVariable; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.graph.NodeType; import edu.cmu.tetrad.util.NumberFormatUtil; -import nu.xom.Attribute; -import nu.xom.Element; -import nu.xom.Text; import java.util.ArrayList; import java.util.List; @@ -46,11 +41,13 @@ public final class BayesBifRenderer { private BayesBifRenderer() { } + /** + * Renders the given BayesIm object as a Bayesian network in the BIF (Bayesian Interchange Format) format. + * + * @param bayesIm the BayesIm object representing the Bayesian network + * @return the Bayesian network in BIF format as a string + */ public static String render(BayesIm bayesIm) { - - - - StringBuilder builder = new StringBuilder(); // Write the name @@ -122,7 +119,7 @@ public static String render(BayesIm bayesIm) { builder.append(" ( "); for (int i = 0; i < parentValues.length; i++) { - builder.append(_parents.get(i).getCategory(parentValues[i])) ; + builder.append(_parents.get(i).getCategory(parentValues[i])); if (i < parentValues.length - 1) { builder.append(", "); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ModelGenerator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ModelGenerator.java index 438d7893e1..9fa5ba4187 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ModelGenerator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ModelGenerator.java @@ -84,7 +84,7 @@ public static List generate(Graph graph) { Edge newEdge = new Edge(n1, n2, e2, e1); toAdd.removeEdge(allEdge); - if (!toAdd.paths().existsDirectedPathFromTo(n1, n2)) { + if (!toAdd.paths().existsDirectedPath(n1, n2)) { toAdd.addEdge(newEdge); graphs.add(toAdd); } @@ -112,7 +112,7 @@ public static List generate(Graph graph) { Graph toAdd1 = new EdgeListGraph(graph); //Make sure adding this edge won't introduce a cycle. - if (!toAdd1.paths().existsDirectedPathFromTo(node1, node2)) { // + if (!toAdd1.paths().existsDirectedPath(node1, node2)) { // Edge newN2N1 = new Edge(node2, node1, Endpoint.TAIL, Endpoint.ARROW); toAdd1.addEdge(newN2N1); @@ -122,7 +122,7 @@ public static List generate(Graph graph) { //Now create the graph with the edge added in the other direction Graph toAdd2 = new EdgeListGraph(graph); //Make sure adding this edge won't introduce a cycle. - if (!toAdd2.paths().existsDirectedPathFromTo(node2, node1)) { + if (!toAdd2.paths().existsDirectedPath(node2, node1)) { Edge newN1N2 = new Edge(node1, node2, Endpoint.TAIL, Endpoint.ARROW); toAdd2.addEdge(newN1N2); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/UpdatedBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/UpdatedBayesIm.java index 2e4e03000d..b5cb2d0ffa 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/UpdatedBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/UpdatedBayesIm.java @@ -700,7 +700,7 @@ private boolean[] calcRelevantVars(int nodeIndex) { // Added the condition node == node2 since the updater was corrected to exclude this. // jdramsey 12.13.2014 - if (node == node2 || this.bayesIm.getDag().paths().isMConnectedTo(node, node2, conditionedNodes)) { + if (node == node2 || this.bayesIm.getDag().paths().isMConnectedTo(node, node2, conditionedNodes, false)) { relevantVars[i] = true; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java index da3e720431..f89cd87b5f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/EdgeListGraph.java @@ -496,32 +496,7 @@ public Set getSepset(Node x, Node y) { * @return True if the nodes in x are all d-separated from nodes in y given nodes in z, false if not. */ public boolean isMSeparatedFrom(Node x, Node y, Set z) { - return !new Paths(this).isMConnectedTo(x, y, z); - } - - /** - * Determines whether two nodes are d-separated given z. - * - * @param x a {@link java.util.Set} object - * @param y a {@link java.util.Set} object - * @param z a {@link java.util.Set} object - * @return True if the nodes in x are all d-separated from nodes in y given nodes in z, false if not. - */ - public boolean isMSeparatedFrom(Set x, Set y, Set z) { - return !new Paths(this).isMConnectedTo(x, y, z); - } - - /** - * Determines whether two nodes are d-separated given z. - * - * @param ancestors A map of ancestors for each node. - * @param x a {@link java.util.Set} object - * @param y a {@link java.util.Set} object - * @param z a {@link java.util.Set} object - * @return True if the nodes are d-separated given z, false if not. - */ - public boolean isMSeparatedFrom(Set x, Set y, Set z, Map> ancestors) { - return !new Paths(this).isMConnectedTo(x, y, z, ancestors); + return !new Paths(this).isMConnectedTo(x, y, z, false); } /** @@ -749,14 +724,10 @@ public boolean addEdge(Edge edge) { this.edgeLists = new HashMap<>(this.edgeLists); } - if (edgeLists.get(node1) == null) { - // System.out.println("Missing node1 is not in edgeLists: " + node1); - edgeLists.put(node1, new HashSet<>()); - } - if (edgeLists.get(node2) == null) { - // System.out.println("Missing node2 is not in edgeLists: " + node2); - edgeLists.put(node2, new HashSet<>()); - } + // System.out.println("Missing node1 is not in edgeLists: " + node1); + edgeLists.computeIfAbsent(node1, k -> new HashSet<>()); + // System.out.println("Missing node2 is not in edgeLists: " + node2); + edgeLists.computeIfAbsent(node2, k -> new HashSet<>()); this.edgeLists.get(node1).add(edge); this.edgeLists.get(node2).add(edge); this.edgesSet.add(edge); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphSaveLoadUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphSaveLoadUtils.java index 35b474f49b..79c3168075 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphSaveLoadUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphSaveLoadUtils.java @@ -1054,10 +1054,17 @@ public static String graphToPcalg(Graph g) { return table.toString(); } + /** + * Converts a given graph into an adjacency matrix in CPAG format. + * + * @param g the input graph to be converted + * @return the adjacency matrix representation of the graph in CPAG format + * @throws IllegalArgumentException if the graph is not a MPDAG (including CPDAG or DAG) + */ public static String graphToAmatCpag(Graph g) { - if (!(g.paths().isLegalMpdag())) { - throw new IllegalArgumentException("Graph is not a MPDAG (including CPDAG or DAG)."); - } +// if (!(g.paths().isLegalMpdag())) { +// throw new IllegalArgumentException("Graph is not a MPDAG (including CPDAG or DAG)."); +// } List vars = g.getNodes(); @@ -1112,11 +1119,14 @@ public static String graphToAmatCpag(Graph g) { * using write.matrix(mat, path). For the amat.pag format, for a matrix m, endpoints are explicitly represented, as * follows. 1 is a circle endpoint, 2 is an arrow endpoint, 3 is a tail endpoint, and 0 is a null endpoint (i.e., no * edge) + * + * @param g a {@link edu.cmu.tetrad.graph.Graph} object + * @return a {@link java.lang.String} object */ public static String graphToAmatPag(Graph g) { - if (!(g.paths().isLegalPag() || g.paths().isLegalMag())) { - throw new IllegalArgumentException("Graph is not a PAG or MAG."); - } +// if (!(g.paths().isLegalPag() || g.paths().isLegalMag())) { +// throw new IllegalArgumentException("Graph is not a PAG or MAG."); +// } List vars = g.getNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java index 52f1e7871d..478f90cfab 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphTransforms.java @@ -1,13 +1,13 @@ package edu.cmu.tetrad.graph; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.search.utils.DagInCpcagIterator; -import edu.cmu.tetrad.search.utils.DagToPag; -import edu.cmu.tetrad.search.utils.MeekRules; +import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.CombinationGenerator; +import edu.cmu.tetrad.util.RandomUtil; import org.jetbrains.annotations.NotNull; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** @@ -35,14 +35,37 @@ public static Graph dagFromCpdag(Graph graph) { } /** - * Returns a DAG from the given CPDAG. If the given CPDAG is not a PDAG, returns null. + * Returns a random DAG from the given CPDAG. If the given CPDAG is not a PDAG, returns null. * - * @param graph the CPDAG + * @param cpdag the CPDAG * @param knowledge the knowledge * @return a DAG from the given CPDAG. If the given CPDAG is not a PDAG, returns null. */ - public static Graph dagFromCpdag(Graph graph, Knowledge knowledge) { - Graph dag = new EdgeListGraph(graph); + public static Graph dagFromCpdag(Graph cpdag, Knowledge knowledge) { + Graph dag = new EdgeListGraph(cpdag); + transformCpdagIntoRandomDag(dag, knowledge); + return dag; + } + + /** + * Transforms a completed partially directed acyclic graph (CPDAG) into a random directed acyclic graph (DAG) by + * randomly orienting the undirected edges in the CPDAG in shuffled order. + * + * @param graph The original graph from which the CPDAG was derived. + * @param knowledge The knowledge available to check if a potential DAG violates any constraints. + */ + public static void transformCpdagIntoRandomDag(Graph graph, Knowledge knowledge) { + List undirectedEdges = new ArrayList<>(); + + for (Edge edge : graph.getEdges()) { + if (Edges.isUndirectedEdge(edge)) { + undirectedEdges.add(edge); + } + } + + Collections.shuffle(undirectedEdges); + + System.out.println(undirectedEdges); MeekRules rules = new MeekRules(); @@ -54,21 +77,79 @@ public static Graph dagFromCpdag(Graph graph, Knowledge knowledge) { NEXT: while (true) { - for (Edge edge : dag.getEdges()) { + for (Edge edge : undirectedEdges) { Node x = edge.getNode1(); Node y = edge.getNode2(); + if (!Edges.isUndirectedEdge(graph.getEdge(x, y))) { + continue; + } + if (Edges.isUndirectedEdge(edge) && !graph.paths().isAncestorOf(y, x)) { - direct(x, y, dag); - rules.orientImplied(dag); + double d = RandomUtil.getInstance().nextDouble(); + + if (d < 0.5) { + direct(x, y, graph); + } else { + direct(y, x, graph); + } + + rules.orientImplied(graph); continue NEXT; } } break; } + } - return dag; + /** + * Picks a random Maximal Ancestral Graph (MAG) from the given Partial Ancestral Graph (PAG) by randomly orienting + * the circle endpoints as either tail or arrow and then applying the final FCI orient algorithm after each change. + * The PAG graph type is not checked. + * + * @param pag The partially ancestral pag to transform. + * @return The maximally ancestral pag obtained from the PAG. + */ + public static Graph magFromPag(Graph pag) { + Graph mag = new EdgeListGraph(pag); + transormPagIntoRandomMag(mag); + return mag; + } + + /** + * Transforms a partially ancestral graph (PAG) into a maximally ancestral graph (MAG) by randomly orienting the + * circle endpoints as either tail or arrow and then applying the final FCI orient algorithm after each change. + * + * @param pag The partially ancestral graph to transform. + */ + public static void transormPagIntoRandomMag(Graph pag) { + for (Edge e : pag.getEdges()) pag.addEdge(new Edge(e)); + + List nodePairs = new ArrayList<>(); + + for (Edge edge : pag.getEdges()) { + if (!pag.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue; + nodePairs.add(new NodePair(edge.getNode1(), edge.getNode2())); + nodePairs.add(new NodePair(edge.getNode2(), edge.getNode1())); + } + + Collections.shuffle(nodePairs); + + for (NodePair edge : new ArrayList<>(nodePairs)) { + if (pag.getEndpoint(edge.getFirst(), edge.getSecond()).equals(Endpoint.CIRCLE)) { + double d = RandomUtil.getInstance().nextDouble(); + + if (d < 0.5) { + pag.setEndpoint(edge.getFirst(), edge.getSecond(), Endpoint.TAIL); + } else { + pag.setEndpoint(edge.getFirst(), edge.getSecond(), Endpoint.ARROW); + } + + FciOrient orient = new FciOrient(new DagSepsets(pag)); + orient.zhangFinalOrientation(pag); + } + } } /** @@ -78,7 +159,7 @@ public static Graph dagFromCpdag(Graph graph, Knowledge knowledge) { * @param pag The partially ancestral graph to transform. * @return The maximally ancestral graph obtained from the PAG. */ - public static Graph pagToMag(Graph pag) { + public static Graph zhangMagFromPag(Graph pag) { Graph mag = new EdgeListGraph(pag.getNodes()); for (Edge e : pag.getEdges()) mag.addEdge(new Edge(e)); @@ -245,7 +326,7 @@ public static List getAllGraphsByDirectingUndirectedEdges(Graph skeleton) * @param dag The input DAG. * @return The CPDAG resulting from applying Meek Rules to the input DAG. */ - public static Graph cpdagForDag(Graph dag) { + public static Graph dagToCpdag(Graph dag) { Graph cpdag = new EdgeListGraph(dag); MeekRules rules = new MeekRules(); rules.setRevertToUnshieldedColliders(true); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java index 4f54e58f7e..b0bfe75f47 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/GraphUtils.java @@ -26,9 +26,7 @@ import edu.cmu.tetrad.search.utils.FciOrient; import edu.cmu.tetrad.search.utils.GraphSearchUtils; import edu.cmu.tetrad.search.utils.SepsetProducer; -import edu.cmu.tetrad.util.ChoiceGenerator; -import edu.cmu.tetrad.util.Parameters; -import edu.cmu.tetrad.util.TextTable; +import edu.cmu.tetrad.util.*; import java.text.DecimalFormat; import java.text.NumberFormat; @@ -38,9 +36,7 @@ import java.util.concurrent.TimeUnit; /** - * Utility class for manipulating graphs. - * - * @author josephramsey + * Utility class for working with graphs. */ public final class GraphUtils { @@ -97,9 +93,9 @@ public static boolean isClique(Collection set, Graph graph) { } /** - * Calculates the subgraph over the Markov blanket of a target node in a given DAG, CPDAG, MAG, or PAG. - * Target Node is not included in the result graph's nodes list. - * Edges including the target node is included in the result graph's edges list. + * Calculates the subgraph over the Markov blanket of a target node in a given DAG, CPDAG, MAG, or PAG. Target Node + * is not included in the result graph's nodes list. Edges including the target node is included in the result + * graph's edges list. * * @param target a node in the given graph. * @param graph a DAG, CPDAG, MAG, or PAG. @@ -129,10 +125,9 @@ public static Graph markovBlanketSubgraph(Node target, Graph graph) { } /** - * Calculates the subgraph over the Markov blanket of a target node for a DAG, CPDAG, MAG, or PAG. - * This is not necessarily minimal (i.e. not necessarily a Markov Boundary). - * Target Node is included in the result graph's nodes list. - * Edges including the target node is included in the result graph's edges list. + * Calculates the subgraph over the Markov blanket of a target node for a DAG, CPDAG, MAG, or PAG. This is not + * necessarily minimal (i.e. not necessarily a Markov Boundary). Target Node is included in the result graph's nodes + * list. Edges including the target node is included in the result graph's edges list. * * @param target a node in the given graph. * @param graph a DAG, CPDAG, MAG, or PAG. @@ -145,26 +140,7 @@ public static Graph getMarkovBlanketSubgraphWithTargetNode(Graph graph, Node tar Graph res = g.subgraph(new ArrayList<>(mbNodes)); // System.out.println( target + " Node's MB Nodes list: " + res.getNodes()); // System.out.println("Graph result: " + res); - return res; - } - - /** - * Calculates the subgraph over the parents of a target node. - * Target Node is included in the result graph's nodes list. - * Edges including the target node is included in the result graph's edges list. - * - * @param target a node in the given graph. - * @param graph - * @return a {@link edu.cmu.tetrad.graph.Graph} object - */ - public static Graph getParentsSubgraphWithTargetNode(Graph graph, Node target) { - EdgeListGraph g = new EdgeListGraph(graph); - List parents = g.getParents(target); - parents.add(target); - Graph res = g.subgraph(new ArrayList<>(parents)); -// System.out.println( target + " Node's Parents list: " + res.getNodes()); -// System.out.println("Graph result: " + res); - return res; + return res; } /** @@ -671,15 +647,6 @@ public static List getAmbiguousTriplesFromGraph(Node node, Graph graph) return ambiguousTriples; } - /** - *

getUnderlinedTriplesFromGraph.

- * - * @param node a {@link edu.cmu.tetrad.graph.Node} object - * @param graph a {@link edu.cmu.tetrad.graph.Graph} object - * @return A list of triples of the form <X, Y, Z>, where <X, Y, Z> is a definite noncollider in the - * given graph. - */ - /** * Retrieves the underlined triples from the given graph that involve the specified node. These are triples that * represent definite noncolliders in the given graph. @@ -713,7 +680,7 @@ public static List getUnderlinedTriplesFromGraph(Node node, Graph graph) } /** - *

getDottedUnderlinedTriplesFromGraph.

+ *

getUnderlinedTriplesFromGraph.

* * @param node a {@link edu.cmu.tetrad.graph.Node} object * @param graph a {@link edu.cmu.tetrad.graph.Graph} object @@ -753,6 +720,15 @@ public static List getDottedUnderlinedTriplesFromGraph(Node node, Graph return dottedUnderlinedTriples; } + /** + *

getDottedUnderlinedTriplesFromGraph.

+ * + * @param node a {@link edu.cmu.tetrad.graph.Node} object + * @param graph a {@link edu.cmu.tetrad.graph.Graph} object + * @return A list of triples of the form <X, Y, Z>, where <X, Y, Z> is a definite noncollider in the + * given graph. + */ + /** * Checks if a given graph contains a bidirected edge. * @@ -771,7 +747,6 @@ public static boolean containsBidirectedEdge(Graph graph) { return containsBidirected; } - /** * Generates a list of triples where a node acts as a collider in a given graph. * @@ -1214,11 +1189,11 @@ public static String edgeMisclassifications(int[][] counts) { } /** - * Adds PAG coloring to the edges in the given graph. + * Adds markups for edge specilizations for the edges in the given graph. * - * @param graph The graph to which PAG coloring will be added. + * @param graph The graph to which PAG edge specialization markups will be added. */ - public static void addPagColoring(Graph graph) { + public static void addEdgeSpecializationMarkup(Graph graph) { for (Edge edge : graph.getEdges()) { edge.getProperties().clear(); @@ -1485,15 +1460,15 @@ private static void brokKerbosh1(Set R, Set P, Set X, Set nodes, - SepsetProducer sepsets) { + public static void gfciExtraEdgeRemovalStep(Graph graph, Graph referenceCpdag, List nodes, SepsetProducer sepsets, boolean verbose) { for (Node b : nodes) { if (Thread.currentThread().isInterrupted()) { break; @@ -1902,10 +1877,16 @@ public static void gfciExtraEdgeRemovalStep(Graph graph, Graph referenceCpdag, L Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); - if (graph.isAdjacentTo(a, c) && referenceCpdag.isAdjacentTo(a, c)) { + if (graph.isAdjacentTo(a, c)) {// && referenceCpdag.isAdjacentTo(a, c)) { Set sepset = sepsets.getSepset(a, c); if (sepset != null) { graph.removeEdge(a, c); + + if (verbose) { + double pValue = sepsets.getPValue(a, c, sepset); + TetradLogger.getInstance().forceLogMessage("Removed edge " + a + " -- " + c + + " in extra-edge removal step; sepset = " + sepset + ", p-value = " + pValue + "."); + } } } } @@ -2084,6 +2065,269 @@ public static Set district(Node x, Graph G) { return district; } + /** + * Calculates visual-edge adjustments given graph G between two nodes x and y that are subsets of MB(X). + * + * @param G the input graph + * @param x the source node + * @param y the target node + * @param numSmallestSizes the number of smallest adjustment sets to return + * @param graphType the type of the graph + * @return the adjustment sets as a set of sets of nodes + * @throws IllegalArgumentException if the input graph is not a legal MPDAG + */ + public static Set> visibleEdgeAdjustments1(Graph G, Node x, Node y, int numSmallestSizes, GraphType graphType) { + Graph G2 = getGraphWithoutXToY(G, x, y, graphType); + + if (G2 == null) { + return new HashSet<>(); + } + + if (G2.paths().isLegalMpdag() && G.isAdjacentTo(x, y) && !Edges.isDirectedEdge(G.getEdge(x, y))) { + System.out.println("The edge from x to y must be visible: " + G.getEdge(x, y)); + return new HashSet<>(); + } else if (G2.paths().isLegalPag() && G.isAdjacentTo(x, y) && !G.paths().defVisible(G.getEdge(x, y))) { + System.out.println("The edge from x to y must be visible:" + G.getEdge(x, y)); + return new HashSet<>(); + } + + // Get the Markov blanket for x in G2. + Set mbX = markovBlanket(x, G2); + mbX.remove(x); + mbX.remove(y); + mbX.removeAll(G.paths().getDescendants(x)); + return getNMinimalSubsets(getGraphWithoutXToY(G, x, y, graphType), mbX, x, y, numSmallestSizes); + } + + /** + * Calculates visual-edge adjustments of a given graph G between two nodes x and y that are subsets of MB(Y). + * + * @param G the input graph + * @param x the source node + * @param y the target node + * @param numSmallestSizes the number of smallest adjustment sets to return + * @param graphType the type of the graph + * @return the adjustment sets as a set of sets of nodes + * @throws IllegalArgumentException if the input graph is not a legal MPDAG + */ + public static Set> visualEdgeAdjustments2(Graph G, Node x, Node y, int numSmallestSizes, GraphType graphType) { + Graph G2 = getGraphWithoutXToY(G, x, y, graphType); + + if (G2 == null) { + return new HashSet<>(); + } + + if (G2.paths().isLegalMpdag() && G.isAdjacentTo(x, y) && !Edges.isDirectedEdge(G.getEdge(x, y))) { + System.out.println("The edge from x to y must be visible: " + G.getEdge(x, y)); + return new HashSet<>(); + } else if (G2.paths().isLegalPag() && G.isAdjacentTo(x, y) && !G.paths().defVisible(G.getEdge(x, y))) { + System.out.println("The edge from x to y must be visible:" + G.getEdge(x, y)); + return new HashSet<>(); + } + + // Get the Markov blanket for x in G2. + Set mbX = markovBlanket(y, G2); + mbX.remove(x); + mbX.remove(y); + mbX.removeAll(G.paths().getDescendants(x)); + return getNMinimalSubsets(getGraphWithoutXToY(G, x, y, graphType), mbX, x, y, numSmallestSizes); + } + + /** + * This method calculates visible-edge adjustments for a given graph, two nodes, a number of smallest sizes, and a + * graph type. + * + * @param G the input graph + * @param x the first node + * @param y the second node + * @param numSmallestSizes the number of smallest sizes to consider + * @param graphType the type of the graph + * @return a set of subsets of nodes representing visible-edge adjustments + */ + public static Set> visibleEdgeAdjustments3(Graph G, Node x, Node y, int numSmallestSizes, GraphType graphType) { + Graph G2; + + try { + G2 = getGraphWithoutXToY(G, x, y, graphType); + } catch (Exception e) { + return new HashSet<>(); + } + + if (G2 == null) { + return new HashSet<>(); + } + + if (!G.isAdjacentTo(x, y)) { + return new HashSet<>(); + } + + if (G2.paths().isLegalMpdag() && G.isAdjacentTo(x, y) && !Edges.isDirectedEdge(G.getEdge(x, y))) { + System.out.println("The edge from x to y must be visible: " + G.getEdge(x, y)); + return new HashSet<>(); + } else if (G2.paths().isLegalPag() && G.isAdjacentTo(x, y) && !G.paths().defVisible(G.getEdge(x, y))) { + System.out.println("The edge from x to y must be visible:" + G.getEdge(x, y)); + return new HashSet<>(); + } + + Set anteriority = G.paths().anteriority(x, y); + anteriority.remove(x); + anteriority.remove(y); + anteriority.removeAll(G.paths().getDescendants(x)); + return getNMinimalSubsets(getGraphWithoutXToY(G, x, y, graphType), anteriority, x, y, numSmallestSizes); + } + + /** + * Returns a graph that is obtained by removing the edge from node x to node y from the input graph. The type of the + * output graph is determined by the provided graph type. + * + * @param G the input graph + * @param x the starting node of the edge to be removed + * @param y the ending node of the edge to be removed + * @param graphType the type of the output graph (CPDAG, PAG, or MAG) + * @return the resulting graph after removing the edge from node x to node y + * @throws IllegalArgumentException if the input graph type is not legal (must be CPDAG, PAG, or MAG) + */ + public static Graph getGraphWithoutXToY(Graph G, Node x, Node y, GraphType graphType) { + if (graphType == GraphType.CPDAG) { + return getGraphWithoutXToYMpdag(G, x, y); + } else if (graphType == GraphType.PAG) { + return getGraphWithoutXToYPag(G, x, y); + } else { + throw new IllegalArgumentException("Graph must be a legal MPDAG, PAG, or MAG."); + } + } + + /** + * This method returns a graph G2 without the edge between Node x and Node y, creating a Maximum Partially Directed + * Acyclic Graph (MPDAG) representation. + * + * @param G the original graph + * @param x the starting node of the edge + * @param y the ending node of the edge + * @return a graph G2 without the edge between Node x and Node y, in MPDAG representation + * @throws IllegalArgumentException if the edge from x to y does not exist, is not directed, or does not point + * towards y + */ + private static Graph getGraphWithoutXToYMpdag(Graph G, Node x, Node y) { + Graph G2 = new EdgeListGraph(G); + + if (!G2.isAdjacentTo(x, y)) { + throw new IllegalArgumentException("Edge from x to y must exist."); + } else if (Edges.isUndirectedEdge(G2.getEdge(x, y))) { + throw new IllegalArgumentException("Edge from x to y must be directed."); + } else if (G2.getEdge(x, y).pointsTowards(x)) { + throw new IllegalArgumentException("Edge from x to y must point towards y."); + } + + G2.removeEdge(x, y); + return G2; + } + + /** + * Returns a graph without the edge from x to y in the given graph. If the edge is undirected, bidirected, or + * partially oriented, the method returns null. If the edge is directed, the method orients the edge from x to y and + * returns the resulting graph. + * + * @param G the graph in which to remove the edge + * @param x the first node in the edge + * @param y the second node in the edge + * @return a graph without the edge from x to y + * @throws IllegalArgumentException if the edge from x to y does not exist, is not directed, or does not point + * towards + */ + private static Graph getGraphWithoutXToYPag(Graph G, Node x, Node y) throws IllegalArgumentException { + if (!G.isAdjacentTo(x, y)) return null; + + Edge edge = G.getEdge(x, y); + + if (edge == null) { + throw new IllegalArgumentException("Edge from x to y must exist."); + } else if (!Edges.isDirectedEdge(edge)) { + throw new IllegalArgumentException("Edge from x to y must be directed."); + } else if (edge.pointsTowards(x)) { + throw new IllegalArgumentException("Edge from x to y must point towards y."); + } else if (!G.paths().defVisible(edge)) { + throw new IllegalArgumentException("Edge from x to y must be visible."); + } + + Graph G2 = new EdgeListGraph(G); + G2.removeEdge(x, y); + return G2; + } + + /** + * Returns the subsets T of S such that X _||_ Y | T in G and T is a subset of up to the numSmallestSizes smallest + * minimal sizes of subsets for S. + * + * @param G the graph in which to compute the subsets + * @param S the set of nodes for which to compute the subsets + * @param X the first node in the separation + * @param Y the second node in the separation + * @param numSmallestSizes the number of the smallest sizes for the subsets to return + * @return the subsets T of S such that X _||_ Y | T in G and T is a subset of up to the numSmallestSizes minimal + * sizes of subsets for S + */ + private static Set> getNMinimalSubsets(Graph G, Set S, Node X, Node Y, + int numSmallestSizes) { + if (numSmallestSizes < 0) { + throw new IllegalArgumentException("numSmallestSizes must be greater than or equal to 0."); + } + + List _S = new ArrayList<>(S); + Set> nMinimal = new HashSet<>(); + var sublists = new SublistGenerator(_S.size(), _S.size()); + int[] choice; + int _n = 0; + int size = -1; + + while ((choice = sublists.next()) != null) { + List subset = GraphUtils.asList(choice, _S); + HashSet s = new HashSet<>(subset); + if (G.paths().isMSeparatedFrom(X, Y, s, false)) { + + if (choice.length > size) { + size = choice.length; + _n++; + + if (_n > numSmallestSizes) { + break; + } + } + + nMinimal.add(s); + } + } + + return nMinimal; + } + + /** + * Computes the anteriority of the given nodes in a graph. An anterior node is a node that has a directed path to + * any of the given nodes. This method returns a set of anterior nodes. + * + * @param G the graph to compute anteriority on + * @param x the nodes to compute anteriority for + * @return a set of anterior nodes + */ + public static Set anteriority(Graph G, Node... x) { + Set anteriority = new HashSet<>(); + + Z: + for (Node z : G.getNodes()) { + for (Node _x : x) { + if (G.paths().existsDirectedPath(z, _x)) { + anteriority.add(z); + } + } + } + + for (Node _x : x) { + anteriority.remove(_x); + } + + return anteriority; + } + /** * Determines if the given graph is a directed acyclic graph (DAG). * @@ -2124,8 +2368,7 @@ public static Graph convert(String spec) { String var1 = st2.nextToken(); if (var1.startsWith("Latent(")) { - String latentName = - (String) var1.subSequence(7, var1.length() - 1); + String latentName = (String) var1.subSequence(7, var1.length() - 1); GraphNode node = new GraphNode(latentName); node.setNodeType(NodeType.LATENT); graph.addNode(node); @@ -2152,9 +2395,7 @@ public static Graph convert(String spec) { Edge edge = graph.getEdge(nodeA, nodeB); if (edge != null) { - throw new IllegalArgumentException( - "Multiple edges connecting " + - "nodes is not supported."); + throw new IllegalArgumentException("Multiple edges connecting " + "nodes is not supported."); } if (edgeSpec.lastIndexOf("-->") != -1) { @@ -2187,8 +2428,10 @@ public static Graph convert(String spec) { * @param referenceCpdag The reference CPDAG to guide the orientation of edges. * @param sepsets The sepsets used to determine the orientation of edges. * @param knowledge The knowledge used to determine the orientation of edges. + * @param verbose Whether to print verbose output. */ - public static void gfciR0(Graph graph, Graph referenceCpdag, SepsetProducer sepsets, Knowledge knowledge) { + public static void gfciR0(Graph graph, Graph referenceCpdag, SepsetProducer sepsets, Knowledge knowledge, + boolean verbose) { graph.reorientAllWith(Endpoint.CIRCLE); fciOrientbk(knowledge, graph, graph.getNodes()); @@ -2211,17 +2454,64 @@ public static void gfciR0(Graph graph, Graph referenceCpdag, SepsetProducer seps if (referenceCpdag.isDefCollider(a, b, c) && FciOrient.isArrowheadAllowed(a, b, graph, knowledge) - && FciOrient.isArrowheadAllowed(c, b, graph, knowledge)) { + && FciOrient.isArrowheadAllowed(c, b, graph, knowledge) + && !referenceCpdag.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { + + if (graph.getEndpoint(b, a) == Endpoint.ARROW && (graph.paths().existsDirectedPath(a, b) || graph.paths().existsDirectedPath(b, a))) { + continue; + } + + if (graph.getEndpoint(b, c) == Endpoint.ARROW && (graph.paths().existsDirectedPath(b, c) || graph.paths().existsDirectedPath(c, b))) { + continue; + } + graph.setEndpoint(a, b, Endpoint.ARROW); graph.setEndpoint(c, b, Endpoint.ARROW); - } else if (referenceCpdag.isAdjacentTo(a, c) && !graph.isAdjacentTo(a, c)) { + + if (verbose) { + TetradLogger.getInstance().forceLogMessage("Oriented collider " + a + " *-> " + b + " <-* " + c + " (from score search))."); + + if (Edges.isBidirectedEdge(graph.getEdge(a, b))) { + TetradLogger.getInstance().forceLogMessage("Created bidirected edge: " + graph.getEdge(a, b)); + } + + if (Edges.isBidirectedEdge(graph.getEdge(b, c))) { + TetradLogger.getInstance().forceLogMessage("Created bidirected edge: " + graph.getEdge(b, c)); + } + } + } else if (referenceCpdag.isAdjacentTo(a, c)) {// && !graph.isAdjacentTo(a, c)) { Set sepset = sepsets.getSepset(a, c); - if (sepset != null && !sepset.contains(b) - && FciOrient.isArrowheadAllowed(a, b, graph, knowledge) - && FciOrient.isArrowheadAllowed(c, b, graph, knowledge)) { + if (graph.isAdjacentTo(a, c)) { + graph.removeEdge(a, c); + } + + if (sepset != null && !sepset.contains(b) && FciOrient.isArrowheadAllowed(a, b, graph, knowledge) && FciOrient.isArrowheadAllowed(c, b, graph, knowledge)) { + if (graph.getEndpoint(b, a) == Endpoint.ARROW && (graph.paths().existsDirectedPath(a, b) || graph.paths().existsDirectedPath(b, a))) { + continue; + } + + if (graph.getEndpoint(b, c) == Endpoint.ARROW && (graph.paths().existsDirectedPath(b, c) || graph.paths().existsDirectedPath(c, b))) { + continue; + } + graph.setEndpoint(a, b, Endpoint.ARROW); graph.setEndpoint(c, b, Endpoint.ARROW); + + if (verbose) { + double p = sepsets.getPValue(a, c, sepset); + String _p = p < 0.0001 ? "< 0.0001" : String.format("%.4f", p); + + TetradLogger.getInstance().forceLogMessage("Oriented collider " + a + " *-> " + b + " <-* " + c + " (from test)), p = " + _p + "."); + + if (Edges.isBidirectedEdge(graph.getEdge(a, b))) { + TetradLogger.getInstance().forceLogMessage("Created bidirected edge: " + graph.getEdge(a, b)); + } + + if (Edges.isBidirectedEdge(graph.getEdge(b, c))) { + TetradLogger.getInstance().forceLogMessage("Created bidirected edge: " + graph.getEdge(b, c)); + } + } } } } @@ -2393,6 +2683,115 @@ private static Graph trimSemidirected(List targets, Graph graph) { return _graph; } + /** + * Checks if the given trek in a graph is a confounding trek. This is a trek from measured node x to measured node y + * that has only latent nodes in between. + * + * @param trueGraph the true graph representing the causal relationships between nodes + * @param trek the trek to be checked + * @param x the first node in the trek + * @param y the last node in the trek + * @return true if the trek is a confounding trek, false otherwise + */ + public static boolean isConfoundingTrek(Graph trueGraph, List trek, Node x, Node y) { + if (x.getNodeType() != NodeType.MEASURED || y.getNodeType() != NodeType.MEASURED) { + return false; + } + + Node source = getTrekSource(trueGraph, trek); + + if (source == x || source == y) { + return false; + } + + if (trek.size() < 3) { + return false; + } + + boolean allLatent = true; + + for (int i = 1; i < trek.size() - 1; i++) { + Node z = trek.get(i); + + if (z.getNodeType() != NodeType.LATENT) { + allLatent = false; + break; + } + } + + return allLatent; + } + + /** + * This method returns the source node of a given trek in a graph. + * + * @param graph The graph containing the nodes and edges. + * @param trek The list of nodes representing the trek. + * @return The source node of the trek. + */ + public static Node getTrekSource(Graph graph, List trek) { + Node y = trek.get(trek.size() - 1); + + Node source = y; + + // Find the first node where the direction is left to right. + for (int i = 0; i < trek.size() - 1; i++) { + Node n1 = trek.get(i); + Node n2 = trek.get(i + 1); + + if (graph.getEdge(n1, n2).pointsTowards(n2)) { + source = n1; + break; + } + } + + return source; + } + + /** + * Determines if the given bidirected edge has a latent confounder in the true graph. + * + * @param edge The edge to check. + * @param trueGraph The true graph (DAG, CPDAG, PAG_of_the_true_DAG). + * @return true if the given bidirected has a latent confounder in the true graph, false otherwise. + * @throws IllegalArgumentException if the edge is not bidirected. + */ + public static boolean isCorrectBidirectedEdge(Edge edge, Graph trueGraph) { + if (!Edges.isBidirectedEdge(edge)) { + throw new IllegalArgumentException("The edge is not bidirected: " + edge); + } + + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + List> treks = trueGraph.paths().treks(x, y, -1); + boolean existsLatentConfounder = false; + + for (List trek : treks) { + if (isConfoundingTrek(trueGraph, trek, x, y)) { + existsLatentConfounder = true; + } + } + + return existsLatentConfounder; + } + + /** + * The GraphType enum represents the types of graphs that can be used in the application. + */ + public enum GraphType { + + /** + * The CPDAG graph type. + */ + CPDAG, + + /** + * The PAG graph type. + */ + PAG + } + /** * The Counts class represents a matrix of counts for different edge types. */ @@ -2547,11 +2946,7 @@ public static class GraphComparison { * @param edgesRemoved a {@link java.util.List} object * @param counts a int[][] */ - public GraphComparison(int adjFn, int adjFp, int adjCorrect, int ahdFn, int ahdFp, - int ahdCorrect, double adjPrec, double adjRec, double ahdPrec, - double ahdRec, int shd, - List edgesAdded, List edgesRemoved, - int[][] counts) { + public GraphComparison(int adjFn, int adjFp, int adjCorrect, int ahdFn, int ahdFp, int ahdCorrect, double adjPrec, double adjRec, double ahdPrec, double ahdRec, int shd, List edgesAdded, List edgesRemoved, int[][] counts) { this.adjFn = adjFn; this.adjFp = adjFp; this.adjCorrect = adjCorrect; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java index aa585586f4..c56cffb280 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/Paths.java @@ -1,9 +1,7 @@ package edu.cmu.tetrad.graph; import edu.cmu.tetrad.search.IndependenceTest; -import edu.cmu.tetrad.search.test.MsepTest; -import edu.cmu.tetrad.search.utils.GraphSearchUtils; -import edu.cmu.tetrad.search.utils.SepsetMap; +import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.SublistGenerator; import edu.cmu.tetrad.util.TaskManager; import edu.cmu.tetrad.util.TetradLogger; @@ -69,6 +67,59 @@ private static Set getPrefix(List pi, int i) { return prefix; } + /** + * Generates a directed acyclic graph (DAG) based on the given list of nodes using Raskutti and Uhler's method. + * + * @param pi a list of nodes representing the set of vertices in the graph + * @param g the graph + * @param verbose whether to print verbose output + * @return a Graph object representing the generated DAG. + */ + public static Graph getDag(List pi, Graph g, boolean verbose) { + Graph graph = new EdgeListGraph(pi); + + for (int a = 0; a < pi.size(); a++) { + for (Node b : getParents(pi, a, g, verbose, false)) { + graph.addDirectedEdge(b, pi.get(a)); + } + } + + return graph; + } + + /** + * Returns the parents of the node at index p, calculated using Pearl's method. + * + * @param pi The list of nodes. + * @param p The index. + * @param g The graph. + * @param verbose Whether to print verbose output. + * @param allowSelectionBias whether to allow selection bias; if true, then undirected edges X--Y are uniformly + * treated as X->L<-Y. + * @return The parents, as a Pair object (parents + score). + */ + public static Set getParents(List pi, int p, Graph g, boolean verbose, boolean allowSelectionBias) { + Node x = pi.get(p); + Set parents = new HashSet<>(); + Set prefix = getPrefix(pi, p); + + for (Node y : prefix) { + Set minus = new HashSet<>(prefix); + minus.remove(y); + minus.remove(x); + Set z = new HashSet<>(minus); + + if (!g.paths().isMSeparatedFrom(x, y, z, allowSelectionBias)) { + if (verbose) { + System.out.println("Adding " + y + " as a parent of " + x + " with z = " + z); + } + parents.add(y); + } + } + + return parents; + } + /** * Returns a valid causal order for either a DAG or a CPDAG. (bryanandrews) * @@ -119,8 +170,8 @@ public void makeValidOrder(List order) { Node x; do { if (itr.hasNext()) x = itr.next(); - else throw new IllegalArgumentException("The remaining graph does not have valid sink; there " + - "could be a directed cycle or a non-chordal undirected cycle."); + else + throw new IllegalArgumentException("The remaining graph does not have valid sink; there " + "could be a directed cycle or a non-chordal undirected cycle."); } while (invalidSink(x, _graph)); order.add(x); _graph.removeNode(x); @@ -168,7 +219,7 @@ public boolean isLegalDag() { * * @return true if the graph is a legal CPDAG, false otherwise. */ - public boolean isLegalCpdag() { + public synchronized boolean isLegalCpdag() { Graph g = this.graph; for (Edge e : g.getEdges()) { @@ -181,9 +232,8 @@ public boolean isLegalCpdag() { try { g.paths().makeValidOrder(pi); - MsepTest msepTest = new MsepTest(g); - Graph dag = getDag(pi, msepTest); - Graph cpdag = GraphTransforms.cpdagForDag(dag); + Graph dag = getDag(pi, g/*GraphTransforms.dagFromCpdag(g)*/, false); + Graph cpdag = GraphTransforms.dagToCpdag(dag); return g.equals(cpdag); } catch (Exception e) { // There was no valid sink. @@ -193,9 +243,11 @@ public boolean isLegalCpdag() { } /** - * Checks if the given Multi-Parent Directed Acyclic Graph (MPDAG) is legal. A MPDAG is considered legal if it is - * equivalent to a CPDAG where additional edges have been oriented by Knowledge, with Meek rules applied for maximum - * orientation. + * Checks if the given graph is a legal Maximal Partial Directed Acyclic Graph (MPDAG). A MPDAG is considered legal + * if it is equal to a CPDAG where additional edges have been oriented by Knowledge, with Meek rules applied for + * maximum orientation. The test is performed by attemping to convert the graph to a CPDAG using the DAG to CPDAG + * transformation and testing whether that graph is a legal CPDAG. Finally, we test to see whether the obtained + * graph is equal to the original graph. * * @return true if the MPDAG is legal, false otherwise. */ @@ -212,14 +264,59 @@ public boolean isLegalMpdag() { try { g.paths().makeValidOrder(pi); - MsepTest msepTest = new MsepTest(g); - Graph dag = getDag(pi, msepTest); - Graph cpdag = GraphTransforms.cpdagForDag(dag); - + Graph dag = getDag(pi, g, false); + Graph cpdag = GraphTransforms.dagToCpdag(dag); Graph _g = new EdgeListGraph(g); - _g = GraphTransforms.cpdagForDag(_g); + _g = GraphTransforms.dagToCpdag(_g); + + boolean equals = _g.equals(cpdag); + + // Check maximality... + if (equals) { + Graph __g = new EdgeListGraph(g); + MeekRules meekRules = new MeekRules(); + meekRules.setRevertToUnshieldedColliders(false); + meekRules.orientImplied(__g); + return g.equals(__g); + } + + return false; + } catch (Exception e) { + // There was no valid sink. + System.out.println(e.getMessage()); + return false; + } + } + + /** + * Checks if the given Maximal Ancestral Graph (MPAG) is legal. A MPAG is considered legal if it is equal to a PAG + * where additional edges have been oriented by Knowledge, with final FCI rules applied for maximum orientation. The + * test is performed by attemping to convert the graph to a PAG using the DAG to CPDAG transformation and testing + * whether that graph is a legal PAG. Finally, we test to see whether the obtained graph is equal to the original + * graph. + *

+ * The user may choose to use the rules from Zhang (2008) or the rules from Spirtes et al. (2000). + * + * @return true if the MPDAG is legal, false otherwise. + */ + public boolean isLegalMpag() { + Graph g = this.graph; + + try { + Graph pag = GraphTransforms.dagToPag(g); + + if (pag.paths().isLegalPag()) { + Graph __g = new DagToPag(graph).convert(); + + if (__g.paths().isLegalPag()) { + Graph _g = new EdgeListGraph(g); + FciOrient fciOrient = new FciOrient(new DagSepsets(_g)); + fciOrient.zhangFinalOrientation(_g); + return g.equals(_g); + } + } - return _g.equals(cpdag); + return false; } catch (Exception e) { // There was no valid sink. System.out.println(e.getMessage()); @@ -245,49 +342,6 @@ public boolean isLegalPag() { return GraphSearchUtils.isLegalPag(graph).isLegalPag(); } - /** - * Generates a directed acyclic graph (DAG) based on the given list of nodes using Raskutti and Uhler's method. - * - * @param pi a list of nodes representing the set of vertices in the graph - * @param msep the MsepTest instance for determining d-separation relationships - * @return a Graph object representing the generated DAG. - */ - private Graph getDag(List pi, MsepTest msep) { - Graph graph = new EdgeListGraph(pi); - - for (int a = 0; a < pi.size(); a++) { - for (Node b : getParents(pi, a, msep)) { - graph.addDirectedEdge(b, pi.get(a)); - } - } - - return graph; - } - - /** - * Returns the parents of the node at index p, calculated using Pearl's method. - * - * @param p The index. - * @return The parents, as a Pair object (parents + score). - */ - private Set getParents(List pi, int p, MsepTest msep) { - Node x = pi.get(p); - Set parents = new HashSet<>(); - Set prefix = getPrefix(pi, p); - - for (Node y : prefix) { - Set minus = new HashSet<>(prefix); - minus.remove(y); - Set z = new HashSet<>(minus); - - if (msep.checkIndependence(x, y, z).isDependent()) { - parents.add(y); - } - } - - return parents; - } - /** * Returns a set of all maximum cliques in the graph. * @@ -365,13 +419,13 @@ public List> connectedComponents() { * @param maxLength the maximum length of the paths * @return a list of lists containing the directed paths from node1 to node2 */ - public List> directedPathsFromTo(Node node1, Node node2, int maxLength) { + public List> directedPaths(Node node1, Node node2, int maxLength) { List> paths = new LinkedList<>(); - directedPathsFromToVisit(node1, node2, new LinkedList<>(), paths, maxLength); + directedPaths(node1, node2, new LinkedList<>(), paths, maxLength); return paths; } - private void directedPathsFromToVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { + private void directedPaths(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { if (maxLength != -1 && path.size() > maxLength - 2) { return; } @@ -408,27 +462,27 @@ private void directedPathsFromToVisit(Node node1, Node node2, LinkedList p continue; } - directedPathsFromToVisit(child, node2, path, paths, maxLength); + directedPaths(child, node2, path, paths, maxLength); } path.removeLast(); } /** - *

semidirectedPathsFromTo.

+ * Finds all semi-directed paths between two nodes up to a maximum length. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @param maxLength a int - * @return a {@link java.util.List} object + * @param node1 the starting node + * @param node2 the ending node + * @param maxLength the maximum path length + * @return a list of all semi-directed paths between the two nodes */ - public List> semidirectedPathsFromTo(Node node1, Node node2, int maxLength) { + public List> semidirectedPaths(Node node1, Node node2, int maxLength) { List> paths = new LinkedList<>(); - semidirectedPathsFromToVisit(node1, node2, new LinkedList<>(), paths, maxLength); + semidirectedPathsVisit(node1, node2, new LinkedList<>(), paths, maxLength); return paths; } - private void semidirectedPathsFromToVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { + private void semidirectedPathsVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { if (maxLength != -1 && path.size() > maxLength - 2) { return; } @@ -465,27 +519,27 @@ private void semidirectedPathsFromToVisit(Node node1, Node node2, LinkedListallPathsFromTo.

+ * Finds all paths from node1 to node2 within a specified maximum length. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @param maxLength a int - * @return a {@link java.util.List} object + * @param node1 The starting node. + * @param node2 The target node. + * @param maxLength The maximum length of the paths. + * @return A list of paths, where each path is a list of nodes. */ - public List> allPathsFromTo(Node node1, Node node2, int maxLength) { + public List> allPaths(Node node1, Node node2, int maxLength) { List> paths = new LinkedList<>(); - allPathsFromToVisit(node1, node2, new LinkedList<>(), paths, maxLength); + allPathsVisit(node1, node2, new LinkedList<>(), paths, maxLength); return paths; } - private void allPathsFromToVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { + private void allPathsVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { path.addLast(node1); if (path.size() > (maxLength == -1 ? 1000 : maxLength)) { @@ -510,27 +564,27 @@ private void allPathsFromToVisit(Node node1, Node node2, LinkedList path, continue; } - allPathsFromToVisit(child, node2, path, paths, maxLength); + allPathsVisit(child, node2, path, paths, maxLength); } path.removeLast(); } /** - *

allDirectedPathsFromTo.

+ * Finds all directed paths from node1 to node2 with a maximum length. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @param maxLength a int - * @return a {@link java.util.List} object + * @param node1 The starting node. + * @param node2 The target node. + * @param maxLength The maximum length of the paths. + * @return A list of lists of nodes representing the directed paths from node1 to node2. */ - public List> allDirectedPathsFromTo(Node node1, Node node2, int maxLength) { + public List> allDirectedPaths(Node node1, Node node2, int maxLength) { List> paths = new LinkedList<>(); - allDirectedPathsFromToVisit(node1, node2, new LinkedList<>(), paths, maxLength); + allDirectedPathsVisit(node1, node2, new LinkedList<>(), paths, maxLength); return paths; } - private void allDirectedPathsFromToVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { + private void allDirectedPathsVisit(Node node1, Node node2, LinkedList path, List> paths, int maxLength) { path.addLast(node1); if (path.size() > (maxLength == -1 ? 1000 : maxLength)) { @@ -556,7 +610,7 @@ private void allDirectedPathsFromToVisit(Node node1, Node node2, LinkedList p } /** - *

existsDirectedPathFromTo.

+ * Checks if a directed path exists between two nodes within a certain depth. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @param depth a int - * @return a boolean + * @param node1 the first node in the path + * @param node2 the second node in the path + * @param depth the maximum depth to search for the path + * @return true if a directed path exists between the two nodes within the given depth, false otherwise */ - public boolean existsDirectedPathFromTo(Node node1, Node node2, int depth) { + public boolean existsDirectedPath(Node node1, Node node2, int depth) { return node1 == node2 || existsDirectedPathVisit(node1, node2, new LinkedList<>(), depth); } @@ -807,138 +861,12 @@ public boolean existsSemiDirectedPath(Node from, Node to) { } /** - *

isMConnectedTo.

+ * Retrieves the set of nodes that are connected to the given node {@code y} and are also present in the set of + * nodes {@code z}. * - * @param x a {@link java.util.Set} object - * @param y a {@link java.util.Set} object - * @param z a {@link java.util.Set} object - * @return a boolean - */ - public boolean isMConnectedTo(Set x, Set y, Set z) { - Set ancestors = ancestorsOf(z); - - Queue> Q = new ArrayDeque<>(); - Set> V = new HashSet<>(); - - for (Node _x : x) { - for (Node node : graph.getAdjacentNodes(_x)) { - if (y.contains(node)) { - return true; - } - OrderedPair edge = new OrderedPair<>(_x, node); - Q.offer(edge); - V.add(edge); - } - } - - while (!Q.isEmpty()) { - OrderedPair t = Q.poll(); - - Node b = t.getFirst(); - Node a = t.getSecond(); - - for (Node c : graph.getAdjacentNodes(b)) { - if (c == a) { - continue; - } - - boolean collider = graph.isDefCollider(a, b, c); - if (!((collider && ancestors.contains(b)) || (!collider && !z.contains(b)))) { - continue; - } - - if (y.contains(c)) { - return true; - } - - OrderedPair u = new OrderedPair<>(b, c); - if (V.contains(u)) { - continue; - } - - V.add(u); - Q.offer(u); - } - } - - return false; - } - - /** - * Checks to see if x and y are d-connected given z. - * - * @param ancestorMap A map of nodes to their ancestors. - * @param x a {@link java.util.Set} object - * @param y a {@link java.util.Set} object - * @param z a {@link java.util.Set} object - * @return True if x and y are d-connected given z. - */ - public boolean isMConnectedTo(Set x, Set y, Set z, Map> ancestorMap) { - if (ancestorMap == null) throw new NullPointerException("Ancestor map cannot be null."); - - Queue> Q = new ArrayDeque<>(); - Set> V = new HashSet<>(); - - for (Node _x : x) { - for (Node node : graph.getAdjacentNodes(_x)) { - if (y.contains(node)) { - return true; - } - OrderedPair edge = new OrderedPair<>(_x, node); - Q.offer(edge); - V.add(edge); - } - } - - while (!Q.isEmpty()) { - OrderedPair t = Q.poll(); - - Node b = t.getFirst(); - Node a = t.getSecond(); - - for (Node c : graph.getAdjacentNodes(b)) { - if (c == a) { - continue; - } - - boolean collider = graph.isDefCollider(a, b, c); - - boolean ancestor = false; - - for (Node _z : z) { - if (ancestorMap.get(_z).contains(b)) { - ancestor = true; - break; - } - } - - if (!((collider && ancestor) || (!collider && !z.contains(b)))) { - continue; - } - - if (y.contains(c)) { - return true; - } - - OrderedPair u = new OrderedPair<>(b, c); - if (V.contains(u)) { - continue; - } - - V.add(u); - Q.offer(u); - } - } - - return false; - } - - /** - *

getMConnectedVars.

- * - * @param y a {@link edu.cmu.tetrad.graph.Node} object - * @param z a {@link java.util.Set} object - * @return a {@link java.util.Set} object + * @param y The node for which to find the connected nodes. + * @param z The set of nodes to be considered for connecting nodes. + * @return The set of nodes that are connected to {@code y} and present in {@code z}. */ public Set getMConnectedVars(Node y, Set z) { Set Y = new HashSet<>(); @@ -1298,11 +1226,12 @@ public boolean existsInducingPathVisit(Node a, Node b, Node x, Node y, LinkedLis } /** - *

getInducingPath.

+ * This method calculates the inducing path between two measured nodes in a graph. * - * @param x a {@link edu.cmu.tetrad.graph.Node} object - * @param y a {@link edu.cmu.tetrad.graph.Node} object - * @return a {@link java.util.List} object + * @param x the first measured node in the graph + * @param y the second measured node in the graph + * @return the inducing path between node x and node y, or null if no inducing path exists + * @throws IllegalArgumentException if either x or y is not of NodeType.MEASURED */ public List getInducingPath(Node x, Node y) { if (x.getNodeType() != NodeType.MEASURED) { @@ -1505,8 +1434,8 @@ private boolean existOnePathWithPossibleParents(Map> previous, N /** - * Check to see if a set of variables Z satisfies the back-door criterion relative to node x and node y. - * (author Kevin V. Bui (March 2020). + * Check to see if a set of variables Z satisfies the back-door criterion relative to node x and node y. (author + * Kevin V. Bui (March 2020). * * @param graph a {@link edu.cmu.tetrad.graph.Graph} object * @param x a {@link edu.cmu.tetrad.graph.Node} object @@ -1524,7 +1453,7 @@ public boolean isSatisfyBackDoorCriterion(Graph graph, Node x, Node y, Set // make sure zNodes bock every path between node x and node y that contains an arrow into node x - List> directedPaths = allDirectedPathsFromTo(x, y, -1); + List> directedPaths = allDirectedPaths(x, y, -1); directedPaths.forEach(nodes -> { // remove all variables that are not on the back-door path nodes.forEach(node -> { @@ -1534,7 +1463,7 @@ public boolean isSatisfyBackDoorCriterion(Graph graph, Node x, Node y, Set }); }); - return dag.paths().isMSeparatedFrom(x, y, z); + return dag.paths().isMSeparatedFrom(x, y, z, false); } // Finds a sepset for x and y, if there is one; otherwise, returns null. @@ -1571,7 +1500,7 @@ private Set getSepsetVisit(Node x, Node y) { Set colliders = new HashSet<>(); for (Node b : graph.getAdjacentNodes(x)) { - if (sepsetPathFound(x, b, y, path, z, colliders, -1)) { + if (sepsetPathFound(x, b, y, path, z, colliders, 8)) { return null; } } @@ -1650,12 +1579,14 @@ private boolean sepsetPathFound(Node a, Node b, Node y, Set path, Set z) { + public boolean isMConnectedTo(Node x, Node y, Set z, boolean allowSelectionBias) { class EdgeNode { private final Edge edge; @@ -1712,6 +1643,25 @@ public boolean equals(Object o) { return true; } + // If in a CPDAG we have X->Y--Z<-W, reachability can't determine that the path should be + // blocked now matter which way Y--Z is oriented, so we need to make a choice. Choosing Y->Z + // works for cyclic directed graphs and for PAGs except where X->Y with no circle at X, + // in which case Y--Z should be interpreted as selection bias. This is a limitation of the + // reachability algorithm here. The problem is that Y--Z is interpreted differently for CPDAGs + // than for PAGs, and we are trying to make an m-connection procedure that works for both. + // Simply knowing whether selection bias is being allowed is sufficient to make the right choice. + // A similar problem can occur in a PAG; we deal with that as well. The idea is to make + // "virtual edges" that are directed in the direction of the arrow, so that the reachability + // algorithm can eventually find any colliders along the path that may be implied. + // jdramsey 2024-04-14 + if (!allowSelectionBias && edge1.getProximalEndpoint(b) == Endpoint.ARROW) { + if (Edges.isUndirectedEdge(edge2)) { + edge2 = Edges.directedEdge(b, edge2.getDistalNode(b)); + } else if (Edges.isNondirectedEdge(edge2)) { + edge2 = Edges.partiallyOrientedEdge(b, edge2.getDistalNode(b)); + } + } + EdgeNode u = new EdgeNode(edge2, b); if (!V.contains(u)) { @@ -1728,13 +1678,15 @@ public boolean equals(Object o) { /** * Detemrmines whether x and y are d-connected given z. * - * @param x a {@link edu.cmu.tetrad.graph.Node} object - * @param y a {@link edu.cmu.tetrad.graph.Node} object - * @param z a {@link java.util.Set} object - * @param ancestors a {@link java.util.Map} object + * @param x a {@link Node} object + * @param y a {@link Node} object + * @param z a {@link Set} object + * @param ancestors a {@link Map} object + * @param allowSelectionBias whether to allow selection bias; if true, then undirected edges X--Y are uniformly + * treated as X->L<-Y. * @return true if x and y are d-connected given z; false otherwise. */ - public boolean isMConnectedTo(Node x, Node y, Set z, Map> ancestors) { + public boolean isMConnectedTo(Node x, Node y, Set z, Map> ancestors, boolean allowSelectionBias) { class EdgeNode { private final Edge edge; @@ -1791,6 +1743,25 @@ public boolean equals(Object o) { return true; } + // If in a CPDAG we have X->Y--Z<-W, reachability can't determine that the path should be + // blocked now matter which way Y--Z is oriented, so we need to make a choice. Choosing Y->Z + // works for cyclic directed graphs and for PAGs except where X->Y with no circle at X, + // in which case Y--Z should be interpreted as selection bias. This is a limitation of the + // reachability algorithm here. The problem is that Y--Z is interpreted differently for CPDAGs + // than for PAGs, and we are trying to make an m-connection procedure that works for both. + // Simply knowing whether selection bias is being allowed is sufficient to make the right choice. + // A similar problem can occur in a PAG; we deal with that as well. The idea is to make + // "virtual edges" that are directed in the direction of the arrow, so that the reachability + // algorithm can eventually find any colliders along the path that may be implied. + // jdramsey 2024-04-14 + if (!allowSelectionBias && edge1.getProximalEndpoint(b) == Endpoint.ARROW) { + if (Edges.isUndirectedEdge(edge2)) { + edge2 = Edges.directedEdge(b, edge2.getDistalNode(b)); + } else if (Edges.isNondirectedEdge(edge2)) { + edge2 = Edges.partiallyOrientedEdge(b, edge2.getDistalNode(b)); + } + } + EdgeNode u = new EdgeNode(edge2, b); if (!V.contains(u)) { @@ -1850,8 +1821,7 @@ public boolean defVisible(Edge edge) { return visibleEdgeHelper(A, B); } else { - throw new IllegalArgumentException( - "Given edge is not in the graph."); + throw new IllegalArgumentException("Given edge is not in the graph."); } } @@ -1923,7 +1893,7 @@ private boolean visibleEdgeHelperVisit(Node c, Node a, Node b, LinkedList */ public boolean existsDirectedCycle() { for (Node node : graph.getNodes()) { - if (existsDirectedPathFromTo(node, node)) { + if (existsDirectedPath(node, node)) { TetradLogger.getInstance().forceLogMessage("Cycle found at node " + node.getName() + "."); return true; } @@ -1932,13 +1902,13 @@ public boolean existsDirectedCycle() { } /** - *

existsDirectedPathFromTo.

+ * Checks if a directed path exists between two nodes in a graph. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @return true iff there is a (nonempty) directed path from node1 to node2. a + * @param node1 the starting node of the path + * @param node2 the target node of the path + * @return true if a directed path exists from node1 to node2, false otherwise */ - public boolean existsDirectedPathFromTo(Node node1, Node node2) { + public boolean existsDirectedPath(Node node1, Node node2) { Queue Q = new LinkedList<>(); Set V = new HashSet<>(); @@ -1992,10 +1962,28 @@ public boolean existsTrek(Node node1, Node node2) { } /** - *

getDescendants.

+ * Returns a list of all descendants of the given node. * - * @param nodes a {@link java.util.List} object - * @return a {@link java.util.List} object + * @param node The node for which to find descendants. + * @return A list of all descendant nodes. + */ + public Set getDescendants(Node node) { + Set descendants = new HashSet<>(); + + for (Node n : graph.getNodes()) { + if (isDescendentOf(n, node)) { + descendants.add(n); + } + } + + return descendants; + } + + /** + * Retrieves the descendants of the given list of nodes. + * + * @param nodes The list of nodes to find descendants for. + * @return A list of nodes that are descendants of the given nodes. */ public List getDescendants(List nodes) { Set ancestors = new HashSet<>(); @@ -2019,14 +2007,32 @@ public List getDescendants(List nodes) { * @return a boolean */ public boolean isAncestorOf(Node node1, Node node2) { - return node1 == node2 || existsDirectedPathFromTo(node1, node2); + return node1 == node2 || existsDirectedPath(node1, node2); } /** - *

getAncestors.

+ * Retrieves the ancestors of a specified `Node` in the graph. * - * @param nodes a {@link java.util.List} object - * @return a {@link java.util.List} object + * @param node The node whose ancestors are to be retrieved. + * @return A list of ancestors for the specified `Node`. + */ + public List getAncestors(Node node) { + Set ancestors = new HashSet<>(); + + for (Node n : graph.getNodes()) { + if (isAncestorOf(n, node)) { + ancestors.add(n); + } + } + + return new ArrayList<>(ancestors); + } + + /** + * Returns a list of all ancestors of the given nodes. + * + * @param nodes the list of nodes for which to find ancestors + * @return a list containing all the ancestors of the given nodes */ public List getAncestors(List nodes) { Set ancestors = new HashSet<>(); @@ -2050,7 +2056,7 @@ public List getAncestors(List nodes) { * @return a boolean */ public boolean isDescendentOf(Node node1, Node node2) { - return node1 == node2 || existsDirectedPathFromTo(node2, node1); + return node1 == node2 || existsDirectedPath(node2, node1); } /** @@ -2070,32 +2076,42 @@ public boolean definiteNonDescendent(Node node1, Node node2) { * every collider on U is an ancestor of some element in Z and every non-collider on U is not in Z. Two elements are * d-separated just in case they are not d-connected. A collider is a node which two edges hold in common for which * the endpoints leading into the node are both arrow endpoints. + *

+ * Precondition: This graph is a DAG. Please don't violate this constraint; weird things can happen! * - * @param node1 the first node. - * @param node2 the second node. - * @param z the conditioning set. + * @param node1 the first node. + * @param node2 the second node. + * @param z the conditioning set. + * @param allowSelectionBias whether to allow selection bias; if true, then undirected edges X--Y are uniformly + * treated as X->L<-Y. * @return true if node1 is d-separated from node2 given set t, false if not. - * @see #isMConnectedTo */ - public boolean isMSeparatedFrom(Node node1, Node node2, Set z) { - return !isMConnectedTo(node1, node2, z); + public boolean isMSeparatedFrom(Node node1, Node node2, Set z, boolean allowSelectionBias) { + return !isMConnectedTo(node1, node2, z, allowSelectionBias); } /** - *

isMSeparatedFrom.

+ * Checks if two nodes are M-separated. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @param z a {@link java.util.Set} object - * @param ancestors a {@link java.util.Map} object - * @return a boolean + * @param node1 The first node. + * @param node2 The second node. + * @param z The set of nodes to be excluded from the path. + * @param ancestors A map containing the ancestors of each node. + * @param allowSelectionBias whether to allow selection bias; if true, then undirected edges X--Y are uniformly + * treated as X->L<-Y. + * @return {@code true} if the two nodes are M-separated, {@code false} otherwise. */ - public boolean isMSeparatedFrom(Node node1, Node node2, Set z, Map> ancestors) { - return !isMConnectedTo(node1, node2, z, ancestors); + public boolean isMSeparatedFrom(Node node1, Node node2, Set z, Map> ancestors, boolean allowSelectionBias) { + return !isMConnectedTo(node1, node2, z, ancestors, allowSelectionBias); } /** - * @return true iff there is a semi-directed path from node1 to node2 + * Checks if a semi-directed path exists between the given node and any of the nodes in the provided set. + * + * @param node1 The starting node for the path. + * @param nodes2 The set of nodes to check for a path. + * @param path The current path (used for cycle detection). + * @return {@code true} if a semi-directed path exists, {@code false} otherwise. */ private boolean existsSemiDirectedPathVisit(Node node1, Set nodes2, LinkedList path) { path.addLast(node1); @@ -2125,13 +2141,13 @@ private boolean existsSemiDirectedPathVisit(Node node1, Set nodes2, Linked } /** - *

isDirectedFromTo.

+ * Checks if there is a directed edge from node1 to node2 in the graph. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @return a boolean + * @param node1 the source node + * @param node2 the destination node + * @return true if there is a directed edge from node1 to node2, false otherwise */ - public boolean isDirectedFromTo(Node node1, Node node2) { + public boolean isDirected(Node node1, Node node2) { List edges = graph.getEdges(node1, node2); if (edges.size() != 1) { return false; @@ -2141,13 +2157,13 @@ public boolean isDirectedFromTo(Node node1, Node node2) { } /** - *

isUndirectedFromTo.

+ * Checks if the edge between two nodes in the graph is undirected. * - * @param node1 a {@link edu.cmu.tetrad.graph.Node} object - * @param node2 a {@link edu.cmu.tetrad.graph.Node} object - * @return a boolean + * @param node1 the first node + * @param node2 the second node + * @return true if the edge is undirected, false otherwise */ - public boolean isUndirectedFromTo(Node node1, Node node2) { + public boolean isUndirected(Node node1, Node node2) { Edge edge = graph.getEdge(node1, node2); return edge != null && edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL; } @@ -2163,6 +2179,16 @@ public boolean possibleAncestor(Node node1, Node node2) { return existsSemiDirectedPath(node1, Collections.singleton(node2)); } + /** + * Returns the set of nodes that are in the anteriority of the given nodes in the graph. + * + * @param X the nodes for which the anteriority needs to be determined + * @return the set of nodes in the anteriority of the given nodes + */ + public Set anteriority(Node... X) { + return GraphUtils.anteriority(graph, X); + } + /** * An algorithm to find all cliques in a graph. */ @@ -2181,13 +2207,7 @@ private AllCliquesAlgorithm() { * @param args the command-line arguments */ public static void main(String[] args) { - int[][] graph = { - {0, 1, 1, 0, 0}, - {1, 0, 1, 1, 0}, - {1, 1, 0, 1, 1}, - {0, 1, 1, 0, 1}, - {0, 0, 1, 1, 0} - }; + int[][] graph = {{0, 1, 1, 0, 0}, {1, 0, 1, 1, 0}, {1, 1, 0, 1, 1}, {0, 1, 1, 0, 1}, {0, 0, 1, 1, 0}}; int n = graph.length; List> cliques = findCliques(graph, n); @@ -2219,9 +2239,7 @@ public static List> findCliques(int[][] graph, int n) { return cliques; } - private static void bronKerbosch(int[][] graph, Set candidates, - Set excluded, Set included, - List> cliques) { + private static void bronKerbosch(int[][] graph, Set candidates, Set excluded, Set included, List> cliques) { if (candidates.isEmpty() && excluded.isEmpty()) { cliques.add(new ArrayList<>(included)); return; @@ -2236,10 +2254,7 @@ private static void bronKerbosch(int[][] graph, Set candidates, } } - bronKerbosch(graph, intersect(candidates, neighbors), - intersect(excluded, neighbors), - union(included, vertex), - cliques); + bronKerbosch(graph, intersect(candidates, neighbors), intersect(excluded, neighbors), union(included, vertex), cliques); candidates.remove(vertex); excluded.add(vertex); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java index 7e8916e1fa..3c3159265a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java @@ -124,7 +124,7 @@ public static Graph randomGraphUniform(List nodes, int numLatentConfounder } if (numLatentConfounders < 0 || numLatentConfounders > numNodes) { - throw new IllegalArgumentException("Max # latent confounders must be " + "at least 0 and at most the number of nodes: " + numLatentConfounders); + throw new IllegalArgumentException("Number of additional latent confounders must be " + "at least 0 and at most the number of nodes: " + numLatentConfounders); } for (Node node : nodes) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java index 436559affc..caa799571e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/BFci.java @@ -21,14 +21,13 @@ package edu.cmu.tetrad.search; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.test.MsepTest; +import edu.cmu.tetrad.search.utils.DagSepsets; import edu.cmu.tetrad.search.utils.FciOrient; import edu.cmu.tetrad.search.utils.SepsetProducer; -import edu.cmu.tetrad.search.utils.SepsetsGreedy; +import edu.cmu.tetrad.search.utils.SepsetsMinP; import edu.cmu.tetrad.util.RandomUtil; import edu.cmu.tetrad.util.TetradLogger; @@ -147,6 +146,14 @@ public final class BFci implements IGraphSearch { * used for processing. */ private int numThreads = 1; + /** + * Determines whether or not almost cyclic paths should be resolved during the graph search. + * + * Almost cyclic paths are paths that are almost cycles but have a single additional edge + * that prevents them from being cycles. Resolving these paths involves determining if the + * additional edge should be included or not. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructor. The test and score should be for the same data. @@ -176,6 +183,8 @@ public Graph search() { RandomUtil.getInstance().setSeed(seed); } + this.independenceTest.setVerbose(verbose); + List nodes = getIndependenceTest().getVariables(); if (verbose) { @@ -196,22 +205,45 @@ public Graph search() { Graph referenceDag = new EdgeListGraph(graph); // GFCI extra edge removal step... - SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); - gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets); - GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge); +// SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); + SepsetProducer sepsets; + + if (independenceTest instanceof MsepTest) { + sepsets = new DagSepsets(((MsepTest) independenceTest).getGraph()); + } else { + sepsets = new SepsetsMinP(graph, this.independenceTest, null, this.depth); + } + + gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets, verbose); + GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge, verbose); FciOrient fciOrient = new FciOrient(sepsets); - fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed); - fciOrient.setMaxPathLength(this.maxPathLength); - fciOrient.setDoDiscriminatingPathColliderRule(this.doDiscriminatingPathRule); - fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathRule); + fciOrient.setCompleteRuleSetUsed(completeRuleSetUsed); + fciOrient.setDoDiscriminatingPathColliderRule(doDiscriminatingPathRule); + fciOrient.setDoDiscriminatingPathTailRule(doDiscriminatingPathRule); fciOrient.setVerbose(verbose); fciOrient.setKnowledge(knowledge); - fciOrient.doFinalOrientation(graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + GraphUtils.replaceNodes(graph, this.independenceTest.getVariables()); +// graph = GraphTransforms.dagToPag(graph); + return graph; } @@ -321,4 +353,13 @@ public void setNumThreads(int numThreads) { } this.numThreads = numThreads; } + + /** + * Sets whether almost cyclic paths should be resolved during the search. + * + * @param resolveAlmostCyclicPaths True to resolve almost cyclic paths, false otherwise. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java index d82df9737c..bacc7107ab 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Cfci.java @@ -75,6 +75,8 @@ public final class Cfci implements IGraphSearch { private boolean verbose; // Whether to do the discriminating path rule. private boolean doDiscriminatingPathRule; + // Whether to resolve almost cyclic paths. + private boolean resolveAlmostCyclicPaths; /** * Constructs a new FCI search for the given independence test and background knowledge. @@ -166,7 +168,7 @@ public Graph search() { // Step CI D. (Zhang's step F4.) - FciOrient fciOrient = new FciOrient(new SepsetsConservative(this.graph, this.independenceTest, + FciOrient fciOrient = new FciOrient(new SepsetsMaxP(this.graph, this.independenceTest, new SepsetMap(), this.depth)); fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed); @@ -177,6 +179,21 @@ public Graph search() { fciOrient.ruleR0(this.graph); fciOrient.doFinalOrientation(this.graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + long endTime = MillisecondTimes.timeMillis(); this.elapsedTime = endTime - beginTime; @@ -541,6 +558,16 @@ private void fciOrientbk(Knowledge bk, Graph graph, List variables) { } } + /** + * Sets the flag indicating whether to resolve almost cyclic paths. + * + * @param resolveAlmostCyclicPaths If true, almost cyclic paths will be resolved. If false, they will not be + * resolved. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } + private enum TripleType { COLLIDER, NONCOLLIDER, AMBIGUOUS } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java index 008cc115df..9719b90c4a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Fci.java @@ -22,13 +22,8 @@ package edu.cmu.tetrad.search; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.Endpoint; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.search.utils.FciOrient; -import edu.cmu.tetrad.search.utils.PcCommon; -import edu.cmu.tetrad.search.utils.SepsetMap; -import edu.cmu.tetrad.search.utils.SepsetsSet; +import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.MillisecondTimes; import edu.cmu.tetrad.util.TetradLogger; @@ -125,6 +120,11 @@ public final class Fci implements IGraphSearch { * Whether the discriminating path rule should be used. */ private boolean doDiscriminatingPathRule = true; + /** + * Flag indicating whether almost cyclic paths should be resolved during the search. + * Default value is false. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructor. @@ -203,7 +203,8 @@ public Graph search() { // The original FCI, with or without JiJi Zhang's orientation rules // Optional step: Possible Msep. (Needed for correctness but very time-consuming.) - SepsetsSet sepsets1 = new SepsetsSet(this.sepsets, this.independenceTest); +// SepsetProducer sepsets1 = new SepsetsSet(this.sepsets, this.independenceTest); + SepsetProducer sepsets1 = new SepsetsGreedy(graph, this.independenceTest, null, depth, knowledge); if (this.possibleMsepSearchDone) { new FciOrient(sepsets1).ruleR0(graph); @@ -228,8 +229,25 @@ public Graph search() { fciOrient.doFinalOrientation(graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + long stop = MillisecondTimes.timeMillis(); +// graph = GraphTransforms.dagToPag(graph); + this.elapsedTime = stop - start; return graph; @@ -368,6 +386,15 @@ public void setStable(boolean stable) { public void setDoDiscriminatingPathRule(boolean doDiscriminatingPathRule) { this.doDiscriminatingPathRule = doDiscriminatingPathRule; } + + /** + * Sets whether to resolve almost cyclic paths during the search. + * + * @param resolveAlmostCyclicPaths True to resolve almost cyclic paths, false otherwise. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java index b72e766d28..e7b9eb08c8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/FciMax.java @@ -122,6 +122,10 @@ public final class FciMax implements IGraphSearch { * Whether verbose output should be printed. */ private boolean verbose = false; + /** + * Determines whether the algorithm should resolve almost cyclic paths during the search. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructor. @@ -182,6 +186,21 @@ public Graph search() { addColliders(graph); fciOrient.doFinalOrientation(graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + long stop = MillisecondTimes.timeMillis(); this.elapsedTime = stop - start; @@ -473,6 +492,15 @@ private void doNode(Graph graph, Map scores, Node b) { } } } + + /** + * Sets whether to resolve almost cyclic paths during the search. + * + * @param resolveAlmostCyclicPaths True, if almost cyclic paths should be resolved. False, otherwise. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java index 04f99d750d..e00c84f08c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GFci.java @@ -21,14 +21,10 @@ package edu.cmu.tetrad.search; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.score.Score; -import edu.cmu.tetrad.search.utils.FciOrient; -import edu.cmu.tetrad.search.utils.SepsetProducer; -import edu.cmu.tetrad.search.utils.SepsetsGreedy; +import edu.cmu.tetrad.search.test.MsepTest; +import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.TetradLogger; import java.io.PrintStream; @@ -122,6 +118,13 @@ public final class GFci implements IGraphSearch { * The number of threads to use in the search. Must be at least 1. */ private int numThreads = 1; + /** + * Determines whether almost cyclic paths should be resolved. + * If true, the algorithm will attempt to break almost cyclic paths by removing one edge. + * If false, almost cyclic paths will be treated as genuine causal relationships. + * The default value is false. + */ + private boolean resolveAlmostCyclicPaths; /** @@ -164,23 +167,45 @@ public Graph search() { fges.setNumThreads(numThreads); graph = fges.search(); - Knowledge knowledge2 = new Knowledge(knowledge); Graph referenceDag = new EdgeListGraph(graph); // GFCI extra edge removal step... - SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); - gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets); - GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge); + SepsetProducer sepsets; + + if (independenceTest instanceof MsepTest) { + sepsets = new DagSepsets(((MsepTest) independenceTest).getGraph()); + } else { + sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); + } + +// SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); +// SepsetProducer sepsets = new SepsetsConservative(graph, this.independenceTest, null, this.depth); + gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets, verbose); + GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge, verbose); FciOrient fciOrient = new FciOrient(sepsets); - fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed); - fciOrient.setMaxPathLength(this.maxPathLength); - fciOrient.setDoDiscriminatingPathColliderRule(this.doDiscriminatingPathRule); - fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathRule); + fciOrient.setCompleteRuleSetUsed(completeRuleSetUsed); + fciOrient.setDoDiscriminatingPathColliderRule(doDiscriminatingPathRule); + fciOrient.setDoDiscriminatingPathTailRule(doDiscriminatingPathRule); fciOrient.setVerbose(verbose); - fciOrient.setKnowledge(knowledge2); - + fciOrient.setKnowledge(knowledge); fciOrient.doFinalOrientation(graph); + + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + return graph; } @@ -317,4 +342,13 @@ public void setNumThreads(int numThreads) { } this.numThreads = numThreads; } + + /** + * Sets the flag to resolve almost cyclic paths. + * + * @param resolveAlmostCyclicPaths true if almost cyclic paths should be resolved, false otherwise + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GraspFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GraspFci.java index 048af9b22b..e26a0c2a6c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GraspFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/GraspFci.java @@ -21,14 +21,10 @@ package edu.cmu.tetrad.search; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.EdgeListGraph; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.GraphUtils; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.score.Score; -import edu.cmu.tetrad.search.utils.FciOrient; -import edu.cmu.tetrad.search.utils.SepsetProducer; -import edu.cmu.tetrad.search.utils.SepsetsGreedy; +import edu.cmu.tetrad.search.test.MsepTest; +import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.TetradLogger; import java.util.List; @@ -133,6 +129,10 @@ public final class GraspFci implements IGraphSearch { * @see GraspFci#setSeed(long) */ private long seed = -1; + /** + * Indicates whether almost cyclic paths should be resolved during the search. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructs a new GraspFci object. @@ -189,22 +189,46 @@ public Graph search() { Graph referenceDag = new EdgeListGraph(graph); - SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); - gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets); - GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge); + // GFCI extra edge removal step... +// SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); + SepsetProducer sepsets; + + if (independenceTest instanceof MsepTest) { + sepsets = new DagSepsets(((MsepTest) independenceTest).getGraph()); + } else { + sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); + } + + gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets, verbose); + GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge, verbose); FciOrient fciOrient = new FciOrient(sepsets); - fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed); - fciOrient.setMaxPathLength(this.maxPathLength); - fciOrient.setDoDiscriminatingPathColliderRule(this.doDiscriminatingPathRule); - fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathRule); + fciOrient.setCompleteRuleSetUsed(completeRuleSetUsed); + fciOrient.setDoDiscriminatingPathColliderRule(doDiscriminatingPathRule); + fciOrient.setDoDiscriminatingPathTailRule(doDiscriminatingPathRule); fciOrient.setVerbose(verbose); fciOrient.setKnowledge(knowledge); - fciOrient.doFinalOrientation(graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + GraphUtils.replaceNodes(graph, this.independenceTest.getVariables()); +// graph = GraphTransforms.dagToPag(graph); + return graph; } @@ -340,4 +364,13 @@ public void setOrdered(boolean ordered) { public void setSeed(long seed) { this.seed = seed; } + + /** + * Sets whether to resolve almost cyclic paths in the search. + * + * @param resolveAlmostCyclicPaths True, if almost cyclic paths should be resolved. False, otherwise. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IcaLingam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IcaLingam.java index a170d4b26f..c7ca3c4fde 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IcaLingam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/IcaLingam.java @@ -188,7 +188,7 @@ public boolean isAcyclic(Matrix scaledBHat) { private boolean existsDirectedCycle() { for (Node node : new HashSet<>(dummyCyclicNodes)) { - if (dummyGraph.paths().existsDirectedPathFromTo(node, node)) { + if (dummyGraph.paths().existsDirectedPath(node, node)) { return true; } else { dummyCyclicNodes.remove(node); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/LvLite.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/LvLite.java new file mode 100644 index 0000000000..dd6942f175 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/LvLite.java @@ -0,0 +1,567 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. //i +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// +package edu.cmu.tetrad.search; + +import edu.cmu.tetrad.data.Knowledge; +import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.utils.FciOrient; +import edu.cmu.tetrad.search.utils.TeyssierScorer; +import edu.cmu.tetrad.util.TetradLogger; + +import java.util.*; + +/** + * The LV-Lite algorithm implements the IGraphSearch interface and represents a search algorithm for learning the + * structure of a graphical model from observational data. + *

+ * This class provides methods for running the search algorithm and obtaining the learned pattern as a PAG (Partially + * Annotated Graph). + * + * @author josephramsey + */ +public final class LvLite implements IGraphSearch { + /** + * The score. + */ + private final Score score; + /** + * The background knowledge. + */ + private Knowledge knowledge = new Knowledge(); + /** + * Flag for the complete rule set, true if one should use the complete rule set, false otherwise. + */ + private boolean completeRuleSetUsed = true; + /** + * True iff verbose output should be printed. + */ + private boolean verbose; + /** + * The number of starts for GRaSP. + */ + private int numStarts = 1; + /** + * Whether to use data order. + */ + private boolean useDataOrder = true; + /** + * This flag represents whether the Bes algorithm should be used in the search. + *

+ * If set to true, the Bes algorithm will be used. If set to false, the Bes algorithm will not be used. + *

+ * By default, the value of this flag is false. + */ + private boolean useBes; + /** + * This variable represents whether the discriminating path rule is used in the LvLite class. + *

+ * The discriminating path rule is a rule used in the search algorithm. It determines whether the algorithm + * considers discriminating paths when searching for patterns in the data. + *

+ * By default, the value of this variable is set to false, indicating that the discriminating path rule is not used. + * To enable the use of the discriminating path rule, set the value of this variable to true using the + * {@link #setDoDiscriminatingPathRule(boolean)} method. + */ + private boolean doDiscriminatingPathRule = false; + /** + * Determines whether the search algorithm should resolve almost cyclic paths. + */ + private boolean resolveAlmostCyclicPaths = true; + + /** + * LvLite constructor. Initializes a new object of LvLite search algorithm with the given IndependenceTest and Score + * object. + * + * @param score The Score object to be used for scoring DAGs. + * @throws NullPointerException if score is null. + */ + public LvLite(Score score) { + if (score == null) { + throw new NullPointerException(); + } + + this.score = score; + } + + /** + * Run the search and return s a PAG. + * + * @return The PAG. + */ + public Graph search() { + List nodes = this.score.getVariables(); + + if (nodes == null) { + throw new NullPointerException("Nodes from test were null."); + } + + Boss suborderSearch = new Boss(score); + suborderSearch.setKnowledge(knowledge); + suborderSearch.setResetAfterBM(true); + suborderSearch.setResetAfterRS(true); + suborderSearch.setVerbose(verbose); + suborderSearch.setUseBes(useBes); + suborderSearch.setUseDataOrder(useDataOrder); + suborderSearch.setNumStarts(numStarts); + PermutationSearch permutationSearch = new PermutationSearch(suborderSearch); + permutationSearch.setKnowledge(knowledge); + permutationSearch.search(); + List best = permutationSearch.getOrder(); + + TetradLogger.getInstance().forceLogMessage("Best order: " + best); + + TeyssierScorer teyssierScorer = new TeyssierScorer(null, score); + teyssierScorer.score(best); + Graph dag = teyssierScorer.getGraph(false); + Graph cpdag = teyssierScorer.getGraph(true); + Graph pag = new EdgeListGraph(cpdag); + pag.reorientAllWith(Endpoint.CIRCLE); + + FciOrient fciOrient = new FciOrient(null); + fciOrient.setCompleteRuleSetUsed(completeRuleSetUsed); + fciOrient.setDoDiscriminatingPathColliderRule(false); + fciOrient.setDoDiscriminatingPathTailRule(false); + fciOrient.setVerbose(verbose); + fciOrient.setKnowledge(knowledge); + + fciOrient.fciOrientbk(knowledge, pag, best); + + // Copy unshielded colliders from DAG to PAG + for (int i = 0; i < best.size(); i++) { + for (int j = i + 1; j < best.size(); j++) { + for (int k = j + 1; k < best.size(); k++) { + Node a = best.get(i); + Node b = best.get(j); + Node c = best.get(k); + + if (dag.isAdjacentTo(a, c) && dag.isAdjacentTo(b, c) && !dag.isAdjacentTo(a, b) + && dag.getEdge(a, c).pointsTowards(c) && dag.getEdge(b, c).pointsTowards(c)) { + if (FciOrient.isArrowheadAllowed(a, c, pag, knowledge) && FciOrient.isArrowheadAllowed(b, c, pag, knowledge)) { + pag.setEndpoint(a, c, Endpoint.ARROW); + pag.setEndpoint(b, c, Endpoint.ARROW); + + if (verbose) { + TetradLogger.getInstance().forceLogMessage("Copying unshielded collider " + a + " -> " + c + " <- " + b + + " from CPDAG to PAG"); + } + } + } + } + } + } + + teyssierScorer.bookmark(); + + Set toRemove = new HashSet<>(); + + // Our extra collider orientation step to orient <-> edges: + for (int i = 0; i < best.size(); i++) { + for (int j = 0; j < best.size(); j++) { + for (int k = j + 1; k < best.size(); k++) { + Node a = best.get(i); + Node b = best.get(j); + Node c = best.get(k); + + Edge ab = cpdag.getEdge(a, b); + Edge bc = cpdag.getEdge(b, c); + Edge ac = cpdag.getEdge(a, c); + + Edge _ab = pag.getEdge(a, b); + Edge _bc = pag.getEdge(b, c); + Edge _ac = pag.getEdge(a, c); + + if ((bc != null && bc.pointsTowards(c)) && ab != null && ac != null + && (_bc != null && pag.getEndpoint(b, c) == Endpoint.ARROW) && _ab != null && _ac != null) { + teyssierScorer.goToBookmark(); + teyssierScorer.tuck(c, b); + + if (!teyssierScorer.adjacent(a, c)) { + toRemove.add(new Triple(a, b, c)); + } + } + } + } + } + + for (Triple triple : toRemove) { + Node a = triple.getX(); + Node b = triple.getY(); + Node c = triple.getZ(); + + if (pag.isAdjacentTo(a, c) && pag.isAdjacentTo(c, b)) { + if (FciOrient.isArrowheadAllowed(c, b, pag, knowledge)) { + pag.removeEdge(a, c); + pag.setEndpoint(c, b, Endpoint.ARROW); + + if (verbose) { + TetradLogger.getInstance().forceLogMessage("Orienting " + b + " <-* " + c + " in PAG and removing " + a + " *-* " + c + " from PAG."); + } + } + } + } + + for (Triple triple : toRemove) { + Node b = triple.getY(); + + List nodesInTo = pag.getNodesInTo(b, Endpoint.ARROW); + + if (nodesInTo.size() == 1) { + for (Node node : nodesInTo) { + pag.setEndpoint(node, b, Endpoint.CIRCLE); + + if (verbose) { + TetradLogger.getInstance().forceLogMessage("Orienting " + node + " --o " + b + " in PAG."); + } + } + } + } + + do { + if (completeRuleSetUsed) { + fciOrient.zhangFinalOrientation(pag); + } else { + fciOrient.spirtesFinalOrientation(pag); + } + + fciOrient.zhangFinalOrientation(pag); + } while (discriminatingPathRule(pag, teyssierScorer)); + + // Optional. + if (resolveAlmostCyclicPaths) { + for (Edge edge : pag.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (pag.paths().existsDirectedPath(x, y)) { + pag.setEndpoint(y, x, Endpoint.TAIL); + } else if (pag.paths().existsDirectedPath(y, x)) { + pag.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + + do { + if (completeRuleSetUsed) { + fciOrient.zhangFinalOrientation(pag); + } else { + fciOrient.spirtesFinalOrientation(pag); + } + } while (discriminatingPathRule(pag, teyssierScorer)); + } + + GraphUtils.replaceNodes(pag, this.score.getVariables()); + return pag; + } + + /** + * Sets the knowledge used in search. + * + * @param knowledge This knowledge. + */ + public void setKnowledge(Knowledge knowledge) { + this.knowledge = new Knowledge(knowledge); + } + + /** + * Sets whether the complete rule set should be used during the search algorithm. By default, the complete rule set + * is not used. + * + * @param completeRuleSetUsed true if the complete rule set should be used, false otherwise + */ + public void setCompleteRuleSetUsed(boolean completeRuleSetUsed) { + this.completeRuleSetUsed = completeRuleSetUsed; + } + + /** + * Sets the verbosity level of the search algorithm. + * + * @param verbose true to enable verbose mode, false to disable it + */ + public void setVerbose(boolean verbose) { + this.verbose = verbose; + } + + /** + * Sets the number of starts for BOSS. + * + * @param numStarts The number of starts. + */ + public void setNumStarts(int numStarts) { + this.numStarts = numStarts; + } + + /** + * Sets whether the search algorithm should use the order of the data set during the search. + * + * @param useDataOrder true if the algorithm should use the data order, false otherwise + */ + public void setUseDataOrder(boolean useDataOrder) { + this.useDataOrder = useDataOrder; + } + + /** + * Sets whether to use the BES (Backward Elimination Search) algorithm during the search. + * + * @param useBes true to use the BES algorithm, false otherwise + */ + public void setUseBes(boolean useBes) { + this.useBes = useBes; + } + + /** + * Sets whether the search algorithm should use the Discriminating Path Rule. + * + * @param doDiscriminatingPathRule true if the Discriminating Path Rule should be used, false otherwise + */ + public void setDoDiscriminatingPathRule(boolean doDiscriminatingPathRule) { + this.doDiscriminatingPathRule = doDiscriminatingPathRule; + } + + /** + * Sets whether the search algorithm should resolve almost cyclic paths. + * + * @param resolveAlmostCyclicPaths true to resolve almost cyclic paths, false otherwise + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } + + /** + * This is a score-based discriminating path rule. + *

+ * The triangles that must be oriented this way (won't be done by another rule) all look like the ones below, where + * the dots are a collider path from E to A with each node on the path (except L) a parent of C. + *

+     *          B
+     *         xo           x is either an arrowhead or a circle
+     *        /  \
+     *       v    v
+     * E....A --> C
+     * 
+ *

+ * This is Zhang's rule R4, discriminating paths. + * + * @param graph a {@link edu.cmu.tetrad.graph.Graph} object + */ + private boolean discriminatingPathRule(Graph graph, TeyssierScorer scorer) { + if (!doDiscriminatingPathRule) return false; + + List nodes = graph.getNodes(); + boolean oriented = false; + + for (Node b : nodes) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + // potential A and C candidate pairs are only those + // that look like this: A<-*Bo-*C + List possA = graph.getNodesOutTo(b, Endpoint.ARROW); + List possC = graph.getNodesInTo(b, Endpoint.CIRCLE); + + for (Node a : possA) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + for (Node c : possC) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + if (a == c) continue; + + if (!graph.isParentOf(a, c)) { + continue; + } + + if (graph.getEndpoint(b, c) != Endpoint.ARROW) { + continue; + } + + boolean _oriented = ddpOrient(a, b, c, graph, scorer); + + if (_oriented) oriented = true; + } + } + } + + return oriented; + } + + /** + * A method to search "back from a" to find a DDP. It is called with a reachability list (first consisting only of + * a). This is breadth-first, utilizing "reachability" concept from Geiger, Verma, and Pearl 1990. The body of a DDP + * consists of colliders that are parents of c. + * + * @param a a {@link edu.cmu.tetrad.graph.Node} object + * @param b a {@link edu.cmu.tetrad.graph.Node} object + * @param c a {@link edu.cmu.tetrad.graph.Node} object + * @param graph a {@link edu.cmu.tetrad.graph.Graph} object + */ + private boolean ddpOrient(Node a, Node b, Node c, Graph graph, TeyssierScorer scorer) { + Queue Q = new ArrayDeque<>(20); + Set V = new HashSet<>(); + + Node e = null; + + Map previous = new HashMap<>(); + Set colliderPath = new HashSet<>(); + colliderPath.add(a); + + List cParents = graph.getParents(c); + + Q.offer(a); + V.add(a); + V.add(b); + previous.put(a, b); + + while (!Q.isEmpty()) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + Node t = Q.poll(); + + if (e == null || e == t) { + e = t; + } + + List nodesInTo = graph.getNodesInTo(t, Endpoint.ARROW); + + for (Node d : nodesInTo) { + if (Thread.currentThread().isInterrupted()) { + break; + } + + if (V.contains(d)) { + continue; + } + + previous.put(d, t); + Node p = previous.get(t); + + if (!graph.isDefCollider(d, t, p)) { + continue; + } + + previous.put(d, t); + colliderPath.add(t); + + if (!graph.isAdjacentTo(d, c)) { + if (doDdpOrientation(d, a, b, c, graph, colliderPath, scorer)) { + return true; + } + } + + if (cParents.contains(d)) { + Q.offer(d); + V.add(d); + } + } + } + + return false; + } + + /** + * Determines the orientation for the nodes in a Directed Acyclic Graph (DAG) based on the Discriminating Path Rule + * Here, we insist that the sepset for D and B contain all the nodes along the collider path. + *

+ * Reminder: + *

+     *      The triangles that must be oriented this way (won't be done by another rule) all look like the ones below, where
+     *      the dots are a collider path from E to A with each node on the path (except L) a parent of C.
+     *      
+     *               B
+     *              xo           x is either an arrowhead or a circle
+     *             /  \
+     *            v    v
+     *      E....A --> C
+     *
+     *      This is Zhang's rule R4, discriminating paths. The "collider path" here is all of the collider nodes
+     *      along the E...A path (all parents of C), including A. The idea is that is we know that E is independent
+     *      of C given all of nodes on the collider path plus perhaps some other nodes, then there should be a collider
+     *      at B; otherwise, there should be a noncollider at B.
+     * 
+ * + * @param e the 'e' node + * @param a the 'a' node + * @param b the 'b' node + * @param c the 'c' node + * @param graph the graph representation + * @param colliderPath the list of nodes in the collider path + * @return true if the orientation is determined, false otherwise + * @throws IllegalArgumentException if 'e' is adjacent to 'c' + */ + private boolean doDdpOrientation(Node e, Node a, Node b, Node c, Graph graph, Set colliderPath, TeyssierScorer scorer) { + if (graph.isAdjacentTo(e, c)) { + throw new IllegalArgumentException(); + } + + scorer.goToBookmark(); + + scorer.tuck(e, b); + + for (Node node : colliderPath) { + scorer.tuck(node, e); + } + + boolean collider; + + if (scorer.index(b) < scorer.index(e)) { + collider = false; + } else { + collider = !scorer.adjacent(e, c); + } + + if (collider) { + if (!FciOrient.isArrowheadAllowed(a, b, graph, knowledge)) { + return false; + } + + if (!FciOrient.isArrowheadAllowed(c, b, graph, knowledge)) { + return false; + } + + graph.setEndpoint(a, b, Endpoint.ARROW); + graph.setEndpoint(c, b, Endpoint.ARROW); + + if (this.verbose) { + TetradLogger.getInstance().forceLogMessage( + "R4: Definite discriminating path collider rule e = " + e + " " + GraphUtils.pathString(graph, a, b, c)); + } + + return true; + } else { + graph.setEndpoint(c, b, Endpoint.TAIL); + + if (this.verbose) { + TetradLogger.getInstance().forceLogMessage( + "R4: Definite discriminating path tail rule e = " + e + " " + GraphUtils.pathString(graph, a, b, c)); + } + + return true; + } + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index 8a8b8fc684..f29e0d7cf6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -6,7 +6,10 @@ import edu.cmu.tetrad.algcomparison.statistic.ArrowheadRecall; import edu.cmu.tetrad.data.GeneralAndersonDarlingTest; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.*; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.IndependenceFact; +import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.search.test.*; import edu.cmu.tetrad.util.SublistGenerator; import edu.cmu.tetrad.util.TetradLogger; @@ -212,6 +215,12 @@ public AllSubsetsIndependenceFacts getAllSubsetsIndependenceFacts() { return new AllSubsetsIndependenceFacts(msep, mconn); } + /** + * Retrieves the list of local independence facts for a given node. + * + * @param x The node for which to retrieve the local independence facts. + * @return The list of local independence facts for the given node. + */ public List getLocalIndependenceFacts(Node x) { Set parents = new HashSet<>(graph.getParents(x)); @@ -230,6 +239,13 @@ public List getLocalIndependenceFacts(Node x) { return factList; } + /** + * Calculates the local p-values for a given independence test and a list of independence facts. + * + * @param independenceTest The independence test used for calculating the p-values. + * @param facts The list of independence facts. + * @return The list of local p-values. + */ public List getLocalPValues(IndependenceTest independenceTest, List facts) { // call pvalue function on each item, only include the non-null ones List pVals = new ArrayList<>(); @@ -237,21 +253,36 @@ public List getLocalPValues(IndependenceTest independenceTest, List pValues) { GeneralAndersonDarlingTest generalAndersonDarlingTest = new GeneralAndersonDarlingTest(pValues, new UniformRealDistribution(0, 1)); return generalAndersonDarlingTest.getP(); } + /** + * Calculates the Anderson-Darling test and classifies nodes as accepted or rejected based on the given threshold. + * + * @param independenceTest The independence test to be used for calculating p-values. + * @param graph The graph containing the nodes for testing. + * @param threshold The threshold value for classifying nodes. + * @return A list containing two lists: the first list contains the accepted nodes and the second list contains the + * rejected nodes. + */ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(IndependenceTest independenceTest, Graph graph, Double threshold) { // When calling, default reject null as <=0.05 List> accepts_rejects = new ArrayList<>(); @@ -262,7 +293,7 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind List localIndependenceFacts = getLocalIndependenceFacts(x); List localPValues = getLocalPValues(independenceTest, localIndependenceFacts); Double ADTest = checkAgainstAndersonDarlingTest(localPValues); - if (ADTest <= threshold) { + if (ADTest <= threshold) { rejects.add(x); } else { accepts.add(x); @@ -273,6 +304,14 @@ public List> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind return accepts_rejects; } + /** + * Calculates the precision and recall on the Markov Blanket graph for a given node. Prints the statistics to the + * console. + * + * @param x The target node. + * @param estimatedGraph The estimated graph. + * @param trueGraph The true graph. + */ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGraph, Graph trueGraph) { // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); @@ -288,29 +327,9 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGra double ahr = new ArrowheadRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null); NumberFormat nf = new DecimalFormat("0.00"); - System.out.println( "Node " + x + "'s statistics: " + " \n" + - " AdjPrecision = " + nf.format(ap) + " AdjRecall = " + nf.format(ar) + " \n" + - " ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr)); - } - - public void getPrecisionAndRecallOnParentsSubGraph(Node x, Graph estimatedGraph, Graph trueGraph) { - // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. - Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); - Graph xParentsLookupGraph = GraphUtils.getParentsSubgraphWithTargetNode(lookupGraph, x); - System.out.println("xParentsLookupGraph:" + xParentsLookupGraph); - Graph xParentsEstimatedGraph = GraphUtils.getParentsSubgraphWithTargetNode(estimatedGraph, x); - System.out.println("xParentsEstimatedGraph:" + xParentsEstimatedGraph); - - // TODO VBC: validate - double ap = new AdjacencyPrecision().getValue(xParentsLookupGraph, xParentsEstimatedGraph, null); - double ar = new AdjacencyRecall().getValue(xParentsLookupGraph, xParentsEstimatedGraph, null); - double ahp = new ArrowheadPrecision().getValue(xParentsLookupGraph, xParentsEstimatedGraph, null); - double ahr = new ArrowheadRecall().getValue(xParentsLookupGraph, xParentsEstimatedGraph, null); - - NumberFormat nf = new DecimalFormat("0.00"); - System.out.println( "Node " + x + "'s statistics: " + " \n" + - " AdjPrecision = " + nf.format(ap) + " AdjRecall = " + nf.format(ar) + " \n" + - " ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr)); + System.out.println("Node " + x + "'s statistics: " + " \n" + + " AdjPrecision = " + nf.format(ap) + " AdjRecall = " + nf.format(ar) + " \n" + + " ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr)); } /** @@ -866,6 +885,7 @@ class IndCheckTask implements Callable, Set, Set> call() { Set resultsIndep = new HashSet<>(); Set resultsDep = new HashSet<>(); + independenceTest.setVerbose(false); IndependenceFact fact = facts.get(index); @@ -1069,7 +1089,8 @@ private List getResultsLocal(boolean indep) { * @param x Node to check for independence along with y. * @param y Node to check for independence along with x. * @param z Set of nodes to check if all are contained within the conditioning nodes. - * @return true if x and y are in the independence nodes and all elements of z are in the conditioning nodes; false otherwise. + * @return true if x and y are in the independence nodes and all elements of z are in the conditioning nodes; false + * otherwise. */ private boolean checkNodeIndependenceAndConditioning(Node x, Node y, Set z) { List independenceNodes = getIndependenceNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PermutationSearch.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PermutationSearch.java index 2731d92388..195c0415d1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PermutationSearch.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/PermutationSearch.java @@ -12,19 +12,18 @@ import java.util.*; /** - * Implements common elements of a permutation search. The specific parts for each permutation search are implemented as - * a SuborderSearch. - *

- * This class specifically handles an optimization for tiered knowledge, whereby tiers in the knowledge can be searched - * one at a time in order from the lowest to highest, taking all variables from previous tiers as a fixed for a later - * tier. This allows these permutation searches to search over many more variables than otherwise, so long as tiered - * knowledge is available to organize the search. - *

- * This class is configured to respect the knowledge of forbidden and required edges, including knowledge of temporal - * tiers. + *

Implements common elements of a permutation search. The specific parts + * for each permutation search are implemented as a SuborderSearch.

+ * + *

This class specifically handles an optimization for tiered knowledge, whereby + * tiers in the knowledge can be searched one at a time in order from the lowest to highest, taking all variables from + * previous tiers as a fixed for a later tier. This allows these permutation searches to search over many more + * variables than otherwise, so long as tiered knowledge is available to organize the search.

+ * + *

This class is configured to respect the knowledge of forbidden and required + * edges, including knowledge of temporal tiers.

* * @author bryanandrews - * @version $Id: $Id * @see SuborderSearch * @see Boss * @see Sp @@ -70,11 +69,8 @@ public class PermutationSearch { */ private Knowledge knowledge = new Knowledge(); - /** - * The seed variable holds a long value that can be used to initialize the random number generator. It is used for - * generating pseudorandom numbers in various algorithms and simulations . The initial value of the seed is -1, - * indicating that no seed has been set yet. - */ + private boolean cpdag = true; + private long seed = -1; /** @@ -243,7 +239,25 @@ public void setKnowledge(Knowledge knowledge) { } /** - * Sets the seed for the random number generator. + * Retrieves the value of cpdag. + * + * @return The value of the cpdag flag. + */ + public boolean getCpdag() { + return cpdag; + } + + /** + * Sets the flag indicating whether a CPDAG (partially directed acyclic graph) is wanted or not. + * + * @param cpdag The value indicating whether a CPDAG is wanted or not. + */ + public void setCpdag(boolean cpdag) { + this.cpdag = cpdag; + } + + /** + * Sets the seed value used for generating random numbers. * * @param seed The seed value to set. */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java index 6f9e896869..1ae4ef3eb3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/Rfci.java @@ -23,9 +23,7 @@ import edu.cmu.tetrad.data.Knowledge; import edu.cmu.tetrad.graph.*; -import edu.cmu.tetrad.search.utils.FciOrient; -import edu.cmu.tetrad.search.utils.SepsetMap; -import edu.cmu.tetrad.search.utils.SepsetsSet; +import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.ChoiceGenerator; import edu.cmu.tetrad.util.MillisecondTimes; import edu.cmu.tetrad.util.TetradLogger; @@ -90,6 +88,12 @@ public final class Rfci implements IGraphSearch { * True iff verbose output should be printed. */ private boolean verbose; + /** + * Flag to indicate whether to resolve almost cyclic paths during the search. + * If true, the search algorithm will attempt to resolve paths that are almost cyclic, meaning that they have a single + * bidirected edge that is causing the cycle. If false, these paths will not be resolved. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructs a new RFCI search for the given independence test and background knowledge. @@ -188,7 +192,8 @@ public Graph search(IFas fas, List nodes) { long stop1 = MillisecondTimes.timeMillis(); long start2 = MillisecondTimes.timeMillis(); - FciOrient orient = new FciOrient(new SepsetsSet(this.sepsets, this.independenceTest)); +// FciOrient orient = new FciOrient(new SepsetsSet(this.sepsets, this.independenceTest)); + FciOrient orient = new FciOrient(new SepsetsMaxP(graph, this.independenceTest, null, this.maxPathLength)); // For RFCI always executes R5-10 orient.setCompleteRuleSetUsed(true); @@ -198,6 +203,21 @@ public Graph search(IFas fas, List nodes) { ruleR0_RFCI(getRTuples()); // RFCI Algorithm 4.4 orient.doFinalOrientation(this.graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + long endTime = MillisecondTimes.timeMillis(); this.elapsedTime = endTime - beginTime; @@ -534,6 +554,15 @@ private void setMinSepSet(Set _sepSet, Node x, Node y) { } } } + + /** + * Sets the flag to resolve almost cyclic paths in the RFCI search. + * + * @param resolveAlmostCyclicPaths the flag to resolve almost cyclic paths + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SpFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SpFci.java index 3eec99a191..700ebb5e25 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SpFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SpFci.java @@ -24,6 +24,7 @@ import edu.cmu.tetrad.data.KnowledgeEdge; import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.score.Score; +import edu.cmu.tetrad.search.test.MsepTest; import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.search.work_in_progress.MagSemBicScore; import edu.cmu.tetrad.util.ChoiceGenerator; @@ -119,6 +120,10 @@ public final class SpFci implements IGraphSearch { * Setting this variable to false disables the application of the discriminating path rule. */ private boolean doDiscriminatingPathRule = true; + /** + * Whether to resolve almost cyclic paths. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructor; requires by ta test and a score, over the same variables. @@ -162,29 +167,51 @@ public Graph search() { } Knowledge knowledge2 = new Knowledge(knowledge); - addForbiddenReverseEdgesForDirectedEdges(GraphTransforms.cpdagForDag(graph), knowledge2); + addForbiddenReverseEdgesForDirectedEdges(GraphTransforms.dagToCpdag(graph), knowledge2); // Keep a copy of this CPDAG. Graph referenceDag = new EdgeListGraph(this.graph); - SepsetProducer sepsets = new SepsetsGreedy(this.graph, this.independenceTest, null, this.depth, knowledge); - // GFCI extra edge removal step... - gfciExtraEdgeRemovalStep(this.graph, referenceDag, nodes, sepsets); - modifiedR0(referenceDag, sepsets); +// SepsetProducer sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); + SepsetProducer sepsets; + + if (independenceTest instanceof MsepTest) { + sepsets = new DagSepsets(((MsepTest) independenceTest).getGraph()); + } else { + sepsets = new SepsetsGreedy(graph, this.independenceTest, null, this.depth, knowledge); + } + + gfciExtraEdgeRemovalStep(graph, referenceDag, nodes, sepsets, verbose); + GraphUtils.gfciR0(graph, referenceDag, sepsets, knowledge, verbose); FciOrient fciOrient = new FciOrient(sepsets); - fciOrient.setCompleteRuleSetUsed(this.completeRuleSetUsed); - fciOrient.setMaxPathLength(this.maxPathLength); - fciOrient.setDoDiscriminatingPathColliderRule(this.doDiscriminatingPathRule); - fciOrient.setDoDiscriminatingPathTailRule(this.doDiscriminatingPathRule); + fciOrient.setCompleteRuleSetUsed(completeRuleSetUsed); + fciOrient.setDoDiscriminatingPathColliderRule(doDiscriminatingPathRule); + fciOrient.setDoDiscriminatingPathTailRule(doDiscriminatingPathRule); fciOrient.setVerbose(verbose); - fciOrient.setKnowledge(knowledge2); - + fciOrient.setKnowledge(knowledge); fciOrient.doFinalOrientation(graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + GraphUtils.replaceNodes(this.graph, this.independenceTest.getVariables()); +// graph = GraphTransforms.dagToPag(graph); + return this.graph; } @@ -315,49 +342,6 @@ public void setDoDiscriminatingPathRule(boolean doDiscriminatingPathRule) { this.doDiscriminatingPathRule = doDiscriminatingPathRule; } - - /** - * Modifies the graph using the Modified R0 algorithm. (Due to Spirtes.) - * - * @param fgesGraph The original graph obtained from FGES algorithm. - * @param sepsets The SepsetProducer for computing the separating sets. - */ - private void modifiedR0(Graph fgesGraph, SepsetProducer sepsets) { - this.graph = new EdgeListGraph(graph); - this.graph.reorientAllWith(Endpoint.CIRCLE); - fciOrientbk(this.knowledge, this.graph, this.graph.getNodes()); - - List nodes = this.graph.getNodes(); - - for (Node b : nodes) { - List adjacentNodes = new ArrayList<>(this.graph.getAdjacentNodes(b)); - - if (adjacentNodes.size() < 2) { - continue; - } - - ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); - int[] combination; - - while ((combination = cg.next()) != null) { - Node a = adjacentNodes.get(combination[0]); - Node c = adjacentNodes.get(combination[1]); - - if (fgesGraph.isDefCollider(a, b, c)) { - this.graph.setEndpoint(a, b, Endpoint.ARROW); - this.graph.setEndpoint(c, b, Endpoint.ARROW); - } else if (fgesGraph.isAdjacentTo(a, c) && !this.graph.isAdjacentTo(a, c)) { - Set sepset = sepsets.getSepset(a, c); - - if (sepset != null && !sepset.contains(b)) { - this.graph.setEndpoint(a, b, Endpoint.ARROW); - this.graph.setEndpoint(c, b, Endpoint.ARROW); - } - } - } - } - } - /** * Orients edges in the graph based on the knowledge. * @@ -416,4 +400,15 @@ private void fciOrientbk(Knowledge knowledge, Graph graph, List variables) TetradLogger.getInstance().forceLogMessage("Finishing BK Orientation."); } } + + /** + * Sets whether almost cyclic paths should be resolved during the search. + * If resolveAlmostCyclicPaths is set to true, the search algorithm will perform additional steps + * to resolve almost cyclic paths in the graph. + * + * @param resolveAlmostCyclicPaths True, if almost cyclic paths should be resolved. False, otherwise. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFci.java index b6824decec..67aadbc66c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFci.java @@ -22,10 +22,7 @@ package edu.cmu.tetrad.search; import edu.cmu.tetrad.data.Knowledge; -import edu.cmu.tetrad.graph.Edge; -import edu.cmu.tetrad.graph.Endpoint; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.*; import edu.cmu.tetrad.search.utils.*; import edu.cmu.tetrad.util.MillisecondTimes; import edu.cmu.tetrad.util.TetradLogger; @@ -95,6 +92,10 @@ public final class SvarFci implements IGraphSearch { * True iff verbose output should be printed. */ private boolean verbose; + /** + * Represents whether to resolve almost cyclic paths during the search. + */ + private boolean resolveAlmostCyclicPaths; /** * Constructs a new FCI search for the given independence test and background knowledge. @@ -209,6 +210,21 @@ public Graph search(IFas fas) { fciOrient.ruleR0(this.graph); fciOrient.doFinalOrientation(this.graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + return this.graph; } @@ -480,10 +496,16 @@ private void removeSimilarPairs(IndependenceTest test, Node x, Node y, Set * @return The name of the object without any lagging characters. */ public String getNameNoLag(Object obj) { - String tempS = obj.toString(); - if (tempS.indexOf(':') == -1) { - return tempS; - } else return tempS.substring(0, tempS.indexOf(':')); + return TsUtils.getNameNoLag(obj); + } + + /** + * Sets whether almost cyclic paths should be resolved during the search. + * + * @param resolveAlmostCyclicPaths true if almost cyclic paths should be resolved, false otherwise + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFges.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFges.java index 54e0b0caa0..f8ee0e9d0f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFges.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarFges.java @@ -745,10 +745,10 @@ protected Boolean compute() { } Node y = nodes.get(i); - Set cond = new HashSet<>(); - Set D = new HashSet<>(SvarFges.this.graph.paths().getMConnectedVars(y, cond)); +// Set cond = new HashSet<>(); + Set D = new HashSet<>(variables);// SvarFges.this.graph.paths().getMConnectedVars(y, cond)); D.remove(y); - SvarFges.this.effectEdgesGraph.getAdjacentNodes(y).forEach(D::remove); +// SvarFges.this.effectEdgesGraph.getAdjacentNodes(y).forEach(D::remove); for (Node x : D) { if (existsKnowledge()) { @@ -1064,9 +1064,10 @@ protected Boolean compute() { adj = new ArrayList<>(g); } else if (SvarFges.this.mode == Mode.allowUnfaithfulness) { - HashSet D = new HashSet<>(SvarFges.this.graph.paths().getMConnectedVars(x, new HashSet<>())); - D.remove(x); - adj = new ArrayList<>(D); +// HashSet D = new HashSet<>(SvarFges.this.graph.paths().getMConnectedVars(x, new HashSet<>())); +// D.remove(x); + adj = new ArrayList<>(variables); + adj.remove(x); } else { throw new IllegalStateException(); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarGfci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarGfci.java index 80dd33ad30..fb46e43402 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarGfci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/SvarGfci.java @@ -79,6 +79,10 @@ public final class SvarGfci implements IGraphSearch { * The sepsets. */ private SepsetProducer sepsets; + /** + * Indicates whether the search algorithm should resolve almost cyclic paths. + */ + private boolean resolveAlmostCyclicPaths; /** @@ -164,6 +168,21 @@ public Graph search() { fciOrient.setMaxPathLength(this.maxPathLength); fciOrient.doFinalOrientation(this.graph); + if (resolveAlmostCyclicPaths) { + for (Edge edge : graph.getEdges()) { + if (Edges.isBidirectedEdge(edge)) { + Node x = edge.getNode1(); + Node y = edge.getNode2(); + + if (graph.paths().existsDirectedPath(x, y)) { + graph.setEndpoint(y, x, Endpoint.TAIL); + } else if (graph.paths().existsDirectedPath(y, x)) { + graph.setEndpoint(x, y, Endpoint.TAIL); + } + } + } + } + GraphUtils.replaceNodes(this.graph, this.independenceTest.getVariables()); return this.graph; @@ -539,6 +558,15 @@ private List> returnSimilarPairs(Node x, Node y) { pairList.add(simListY); return (pairList); } + + /** + * Sets whether to resolve almost cyclic paths during the search. + * + * @param resolveAlmostCyclicPaths True if almost cyclic paths should be resolved, false otherwise. + */ + public void setResolveAlmostCyclicPaths(boolean resolveAlmostCyclicPaths) { + this.resolveAlmostCyclicPaths = resolveAlmostCyclicPaths; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/GraphScore.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/GraphScore.java index bb2760aae4..73d1c9806a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/GraphScore.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/score/GraphScore.java @@ -199,7 +199,7 @@ private double locallyConsistentScoringCriterion(int x, int y, int[] z) { boolean dSeparatedFrom; if (dag != null) { - dSeparatedFrom = dag.paths().isMSeparatedFrom(_x, _y, _z); + dSeparatedFrom = dag.paths().isMSeparatedFrom(_x, _y, _z, false); } else if (facts != null) { dSeparatedFrom = facts.isIndependent(_x, _y, _z); } else { @@ -211,7 +211,7 @@ private double locallyConsistentScoringCriterion(int x, int y, int[] z) { private boolean isMSeparatedFrom(Node x, Node y, Set z) { if (dag != null) { - return dag.paths().isMSeparatedFrom(x, y, z); + return dag.paths().isMSeparatedFrom(x, y, z, false); } else if (facts != null) { return facts.isIndependent(x, y, z); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestChiSquare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestChiSquare.java index e560f03dbd..22eb277e62 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestChiSquare.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestChiSquare.java @@ -247,9 +247,10 @@ public IndependenceResult checkIndependence(Node x, Node y, Set _z) { /** * Returns the pvalue if the fact of X _||_ Y | Z is within the cache of results for independence fact. - * @param x - * @param y - * @param z + * + * @param x the first node + * @param y the second node + * @param z the set of conditioning nodes * @return the pValue result or null if not within the cache */ public Double getPValue(Node x, Node y, Set z) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestFisherZ.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestFisherZ.java index 65ae297e8e..355caf5605 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestFisherZ.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/IndTestFisherZ.java @@ -375,6 +375,9 @@ private IndependenceResult checkIndependencePseudoinverse(Node xVar, Node yVar, /** * Returns the p-value for x _||_ y | z. * + * @param x The first node. + * @param y The second node. + * @param z The set of conditioning variables. * @return The p-value. * @throws SingularMatrixException If a singularity occurs when invering a matrix. */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/MsepTest.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/MsepTest.java index 7df72d375b..b8edf38925 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/MsepTest.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/test/MsepTest.java @@ -77,6 +77,10 @@ public class MsepTest implements IndependenceTest { * The "p-value" of the last test (this is 0 or 1). */ private double pvalue = 0; + /** + * Whether there are any latents. + */ + private boolean hasLatents = false; /** * Constructor. @@ -128,6 +132,13 @@ public MsepTest(Graph graph, boolean keepLatents) { this.ancestorMap = graph.paths().getAncestorMap(); this._observedVars = calcVars(graph.getNodes(), keepLatents); this.observedVars = new ArrayList<>(_observedVars); + this.hasLatents = false; + for (Node node : graph.getNodes()) { + if (node.getNodeType() == NodeType.LATENT) { + this.hasLatents = true; + break; + } + } } /** @@ -147,6 +158,13 @@ public MsepTest(IndependenceFacts facts, boolean keepLatents) { this._observedVars = calcVars(facts.getVariables(), keepLatents); this.observedVars = new ArrayList<>(_observedVars); + this.hasLatents = false; + for (Node node : facts.getVariables()) { + if (node.getNodeType() == NodeType.LATENT) { + this.hasLatents = true; + break; + } + } } /** @@ -238,7 +256,7 @@ public IndependenceResult checkIndependence(Node x, Node y, Set z) { boolean mSeparated; if (graph != null) { - mSeparated = !getGraph().paths().isMConnectedTo(x, y, z, ancestorMap); + mSeparated = !getGraph().paths().isMConnectedTo(x, y, z, ancestorMap, false); } else { mSeparated = independenceFacts.isIndependent(x, y, z); } @@ -289,7 +307,7 @@ public boolean isMSeparated(Node x, Node y, Set z) { } } - return getGraph().paths().isMSeparatedFrom(x, y, z, ancestorMap); + return getGraph().paths().isMSeparatedFrom(x, y, z, ancestorMap, hasLatents); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagSepsets.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagSepsets.java index fc597ded81..66b7a9aac6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagSepsets.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagSepsets.java @@ -58,6 +58,29 @@ public Set getSepset(Node a, Node b) { return this.dag.getSepset(a, b); } + /** + * Returns the sepset containing nodes 'a' and 'b' that also contains all the nodes in the given set 's'. Note + * that for the DAG case, it is expected that any sepset containing 'a' and 'b' will contain all the nodes in 's'; + * otherwise, an exception is thrown. + * + * @param a The first node. + * @param b The second node. + * @param s The set of nodes that must be contained in the sepset. + * @return The sepset containing 'a' and 'b' that also contains all the nodes in 's'. + * @throws IllegalArgumentException If the sepset of 'a' and 'b' does not contain all the nodes in 's'. + */ + @Override + public Set getSepsetContaining(Node a, Node b, Set s) { + Set sepset = this.dag.getSepset(a, b); + + if (sepset != null && !sepset.containsAll(s)) { + throw new IllegalArgumentException("Was expecting the sepset of " + a + " and " + b + " (" + sepset + + ") to contain all the nodes in " + s + "."); + } + + return sepset; + } + /** * {@inheritDoc} *

@@ -87,8 +110,16 @@ public double getScore() { * check. */ @Override - public boolean isIndependent(Node a, Node b, Set c) { - return this.dag.paths().isMSeparatedFrom(a, b, c); + public boolean isIndependent(Node a, Node b, Set sepset) { + return this.dag.paths().isMSeparatedFrom(a, b, sepset, false); + } + + /** + * @throws UnsupportedOperationException if this method is called. + */ + @Override + public double getPValue(Node a, Node b, Set sepset) { + throw new UnsupportedOperationException("This makes no sense for this subclass."); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagToPag.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagToPag.java index 040b91413b..ab6a2de3ab 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagToPag.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/DagToPag.java @@ -41,7 +41,7 @@ */ public final class DagToPag { - private static final WeakHashMap history = new WeakHashMap<>(); +// private static final WeakHashMap history = new WeakHashMap<>(); private final Graph dag; /** * The logger to use. @@ -107,9 +107,7 @@ public static boolean existsInducingPathInto(Node x, Node y, Graph graph) { * @return Returns the converted PAG. */ public Graph convert() { - TetradLogger.getInstance().forceLogMessage("Starting DAG to PAG_of_the_true_DAG."); - - if (history.get(dag) != null) return history.get(dag); +// if (history.get(dag) != null) return history.get(dag); if (this.verbose) { System.out.println("DAG to PAG_of_the_true_DAG: Starting adjacency search"); @@ -141,7 +139,7 @@ public Graph convert() { System.out.println("Finishing final orientation"); } - history.put(dag, graph); +// history.put(dag, graph); return graph; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/FciOrient.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/FciOrient.java index deb6a38670..177a529d76 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/FciOrient.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/FciOrient.java @@ -67,7 +67,6 @@ public final class FciOrient { private boolean completeRuleSetUsed = true; private int maxPathLength = -1; private boolean verbose; - private Graph truePag; private boolean doDiscriminatingPathColliderRule = true; private boolean doDiscriminatingPathTailRule = true; @@ -75,7 +74,8 @@ public final class FciOrient { /** * Constructs a new FCI search for the given independence test and background knowledge. * - * @param sepsets a {@link edu.cmu.tetrad.search.utils.SepsetProducer} object + * @param sepsets a {@link edu.cmu.tetrad.search.utils.SepsetProducer} object representing the independence test, + * which must be given only if the discriminating path rule is used. Otherwise, it can be null. */ public FciOrient(SepsetProducer sepsets) { this.sepsets = sepsets; @@ -351,8 +351,6 @@ public void ruleR0(Graph graph) { graph.setEndpoint(c, b, Endpoint.ARROW); if (this.verbose) { this.logger.forceLogMessage(LogUtilsSearch.colliderOrientedMsg(a, b, c)); - - printWrongColliderMessage(a, b, c, graph); } } } @@ -613,13 +611,13 @@ public void ruleR3(Graph graph) { /** * The triangles that must be oriented this way (won't be done by another rule) all look like the ones below, where - * the dots are a collider path from L to A with each node on the path (except L) a parent of C. + * the dots are a collider path from E to A with each node on the path (except L) a parent of C. *

      *          B
      *         xo           x is either an arrowhead or a circle
      *        /  \
      *       v    v
-     * L....A --> C
+     * E....A --> C
      * 
*

* This is Zhang's rule R4, discriminating paths. @@ -627,8 +625,12 @@ public void ruleR3(Graph graph) { * @param graph a {@link edu.cmu.tetrad.graph.Graph} object */ public void ruleR4B(Graph graph) { - if (doDiscriminatingPathColliderRule || doDiscriminatingPathTailRule) { + if (sepsets == null) { + throw new NullPointerException("SepsetProducer is null; if you want to use the discriminating path rule " + + "in FciOrient, you must provide a SepsetProducer."); + } + List nodes = graph.getNodes(); for (Node b : nodes) { @@ -669,7 +671,7 @@ public void ruleR4B(Graph graph) { } /** - * a method to search "back from a" to find a DDP. It is called with a reachability list (first consisting only of + * A method to search "back from a" to find a DDP. It is called with a reachability list (first consisting only of * a). This is breadth-first, utilizing "reachability" concept from Geiger, Verma, and Pearl 1990. The body of a DDP * consists of colliders that are parents of c. * @@ -678,7 +680,7 @@ public void ruleR4B(Graph graph) { * @param c a {@link edu.cmu.tetrad.graph.Node} object * @param graph a {@link edu.cmu.tetrad.graph.Graph} object */ - public void ddpOrient(Node a, Node b, Node c, Graph graph) { + private void ddpOrient(Node a, Node b, Node c, Graph graph) { Queue Q = new ArrayDeque<>(20); Set V = new HashSet<>(); @@ -686,6 +688,8 @@ public void ddpOrient(Node a, Node b, Node c, Graph graph) { int distance = 0; Map previous = new HashMap<>(); + Set colliderPath = new HashSet<>(); + colliderPath.add(a); List cParents = graph.getParents(c); @@ -728,9 +732,10 @@ public void ddpOrient(Node a, Node b, Node c, Graph graph) { } previous.put(d, t); + colliderPath.add(t); if (!graph.isAdjacentTo(d, c)) { - if (doDdpOrientation(d, a, b, c, graph)) { + if (doDdpOrientation(d, a, b, c, graph, colliderPath)) { return; } } @@ -915,22 +920,41 @@ public void rulesR8R9R10(Graph graph) { } /** - * Orients the edges inside the definte discriminating path triangle. Takes the left endpoint, and a,b,c as - * arguments. + * Determines the orientation for the nodes in a Directed Acyclic Graph (DAG) based on the Discriminating Path Rule + * Here, we insist that the sepset for D and B contain all the nodes along the collider path. + *

+ * Reminder: + *

+     *      The triangles that must be oriented this way (won't be done by another rule) all look like the ones below, where
+     *      the dots are a collider path from E to A with each node on the path (except L) a parent of C.
+     *      
+     *               B
+     *              xo           x is either an arrowhead or a circle
+     *             /  \
+     *            v    v
+     *      E....A --> C
      *
-     * @param d     a {@link edu.cmu.tetrad.graph.Node} object
-     * @param a     a {@link edu.cmu.tetrad.graph.Node} object
-     * @param b     a {@link edu.cmu.tetrad.graph.Node} object
-     * @param c     a {@link edu.cmu.tetrad.graph.Node} object
-     * @param graph a {@link edu.cmu.tetrad.graph.Graph} object
-     * @return a boolean
+     *      This is Zhang's rule R4, discriminating paths. The "collider path" here is all of the collider nodes
+     *      along the E...A path (all parents of C), including A. The idea is that is we know that E is independent
+     *      of C given all of nodes on the collider path plus perhaps some other nodes, then there should be a collider
+     *      at B; otherwise, there should be a noncollider at B.
+     * 
+ * + * @param d the 'd' node + * @param a the 'a' node + * @param b the 'b' node + * @param c the 'c' node + * @param graph the graph representation + * @param colliderPath the list of nodes in the collider path + * @return true if the orientation is determined, false otherwise + * @throws IllegalArgumentException if 'd' is adjacent to 'c' */ - public boolean doDdpOrientation(Node d, Node a, Node b, Node c, Graph graph) { + private boolean doDdpOrientation(Node d, Node a, Node b, Node c, Graph graph, Set colliderPath) { if (graph.isAdjacentTo(d, c)) { throw new IllegalArgumentException(); } - Set sepset = getSepsets().getSepset(d, c); + Set sepset = getSepsets().getSepsetContaining(d, c, colliderPath); if (this.verbose) { logger.forceLogMessage("Sepset for d = " + d + " and c = " + c + " = " + sepset); @@ -1135,7 +1159,10 @@ public void fciOrientbk(Knowledge bk, Graph graph, List variables) { // Orient to*->from graph.setEndpoint(to, from, Endpoint.ARROW); this.changeFlag = true; - this.logger.forceLogMessage(LogUtilsSearch.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); + + if (verbose) { + this.logger.forceLogMessage(LogUtilsSearch.edgeOrientedMsg("Knowledge", graph.getEdge(to, from))); + } } for (Iterator it @@ -1165,7 +1192,10 @@ public void fciOrientbk(Knowledge bk, Graph graph, List variables) { graph.setEndpoint(to, from, Endpoint.TAIL); graph.setEndpoint(from, to, Endpoint.ARROW); this.changeFlag = true; - this.logger.forceLogMessage(LogUtilsSearch.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); + + if (verbose) { + this.logger.forceLogMessage(LogUtilsSearch.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); + } } if (verbose) { @@ -1204,33 +1234,6 @@ public void setVerbose(boolean verbose) { this.verbose = verbose; } - /** - * The true PAG if available. Can be null. - * - * @return a {@link edu.cmu.tetrad.graph.Graph} object - */ - public Graph getTruePag() { - return this.truePag; - } - - /** - * Sets the true PAG for comparison. - * - * @param truePag This PAG. - */ - public void setTruePag(Graph truePag) { - this.truePag = truePag; - } - - /** - * Change flag for repeat rules - * - * @return True if a change has occurred. - */ - public boolean isChangeFlag() { - return this.changeFlag; - } - /** * Sets the change flag--marks externally that a change has been made. * @@ -1338,10 +1341,4 @@ public void ruleR10(Node a, Node c, Graph graph) { } } - - private void printWrongColliderMessage(Node a, Node b, Node c, Graph graph) { - if (this.truePag != null && graph.isDefCollider(a, b, c) && !this.truePag.isDefCollider(a, b, c)) { - logger.forceLogMessage("R0" + ": Orienting collider by mistake: " + a + "*->;" + b + "<-*" + c); - } - } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GraphSearchUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GraphSearchUtils.java index 9f387b236e..751f4018c4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GraphSearchUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/GraphSearchUtils.java @@ -409,7 +409,7 @@ public static LegalPagRet isLegalPag(Graph pag) { } } - Graph mag = GraphTransforms.pagToMag(pag); + Graph mag = GraphTransforms.zhangMagFromPag(pag); LegalMagRet legalMag = isLegalMag(mag); @@ -491,7 +491,7 @@ public static LegalMagRet isLegalMag(Graph mag) { } for (Node n : mag.getNodes()) { - if (mag.paths().existsDirectedPathFromTo(n, n)) + if (mag.paths().existsDirectedPath(n, n)) return new LegalMagRet(false, "Acyclicity violated: There is a directed cyclic path from from " + n + " to itself"); } @@ -501,14 +501,14 @@ public static LegalMagRet isLegalMag(Graph mag) { Node y = e.getNode2(); if (Edges.isBidirectedEdge(e)) { - if (mag.paths().existsDirectedPathFromTo(x, y)) { - List path = mag.paths().directedPathsFromTo(x, y, 100).get(0); + if (mag.paths().existsDirectedPath(x, y)) { + List path = mag.paths().directedPaths(x, y, 100).get(0); return new LegalMagRet(false, "Bidirected edge semantics is violated: there is a directed path for " + e + " from " + x + " to " + y + ". This is \"almost cyclic\"; for <-> edges there should not be a path from either endpoint to the other. " + "An example path is " + GraphUtils.pathString(mag, path)); - } else if (mag.paths().existsDirectedPathFromTo(y, x)) { - List path = mag.paths().directedPathsFromTo(y, x, 100).get(0); + } else if (mag.paths().existsDirectedPath(y, x)) { + List path = mag.paths().directedPaths(y, x, 100).get(0); return new LegalMagRet(false, "Bidirected edge semantics is violated: There is an a directed path for " + e + " from " + y + " to " + x + ". This is \"almost cyclic\"; for <-> edges there should not be a path from either endpoint to the other. " @@ -883,8 +883,8 @@ public static int structuralHammingDistance(Graph trueGraph, Graph estGraph) { try { estGraph = GraphUtils.replaceNodes(estGraph, trueGraph.getNodes()); - trueGraph = GraphTransforms.cpdagForDag(trueGraph); - estGraph = GraphTransforms.cpdagForDag(estGraph); + trueGraph = GraphTransforms.dagToCpdag(trueGraph); + estGraph = GraphTransforms.dagToCpdag(estGraph); // Will check mixedness later. if (trueGraph.paths().existsDirectedCycle()) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MbUtils.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MbUtils.java index 99284331d2..94acea2d06 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MbUtils.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MbUtils.java @@ -73,7 +73,7 @@ public static void trimToMbNodes(Graph graph, Node target, if (graph.isDefCollider(target, v, w)) { parentsOfChildren.add(w); } else if (graph.getNodesInTo(v, Endpoint.ARROW).contains(target) - && graph.paths().isUndirectedFromTo(v, w)) { + && graph.paths().isUndirected(v, w)) { parentsOfChildren.add(w); } } @@ -92,9 +92,9 @@ public static void trimToMbNodes(Graph graph, Node target, List pc = new LinkedList<>(); for (Node node : graph.getAdjacentNodes(target)) { - if (graph.paths().isDirectedFromTo(target, node) || - graph.paths().isDirectedFromTo(node, target) || - graph.paths().isUndirectedFromTo(node, target)) { + if (graph.paths().isDirected(target, node) || + graph.paths().isDirected(node, target) || + graph.paths().isUndirected(node, target)) { pc.add(node); } } @@ -106,7 +106,7 @@ public static void trimToMbNodes(Graph graph, Node target, continue; } - if (graph.paths().isDirectedFromTo(target, v)) { + if (graph.paths().isDirected(target, v)) { children.add(v); } } @@ -125,8 +125,8 @@ public static void trimToMbNodes(Graph graph, Node target, continue; } - if (graph.paths().isDirectedFromTo(target, v) && - graph.paths().isDirectedFromTo(w, v)) { + if (graph.paths().isDirected(target, v) && + graph.paths().isDirected(w, v)) { parentsOfChildren.add(w); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MeekRules.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MeekRules.java index ef73c81660..cec282c1aa 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MeekRules.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/MeekRules.java @@ -176,16 +176,8 @@ public void setRevertToUnshieldedColliders(boolean revertToUnshieldedColliders) * @param visited The set of nodes visited. */ private void revertToUnshieldedColliders(List nodes, Graph graph, Set visited) { - boolean reverted = true; - - while (reverted) { - reverted = false; - - for (Node node : nodes) { - if (revertToUnshieldedColliders(node, graph, visited)) { - reverted = true; - } - } + for (Node node : nodes) { + revertToUnshieldedColliders(node, graph, visited); } } @@ -213,29 +205,32 @@ private boolean meekR2(Node a, Node c, Graph graph, Set visited) { adjacentNodes.remove(a); Set common = getCommonAdjacents(a, c, graph); + boolean oriented = false; for (Node b : common) { - if (graph.paths().isDirectedFromTo(a, b) && graph.paths().isDirectedFromTo(b, c)) { + if (graph.paths().isDirected(a, b) && graph.paths().isDirected(b, c)) { if (r2Helper(a, b, c, graph, visited)) { - return true; + oriented = true; } } - if (graph.paths().isDirectedFromTo(c, b) && graph.paths().isDirectedFromTo(b, a)) { + if (graph.paths().isDirected(c, b) && graph.paths().isDirected(b, a)) { if (r2Helper(c, b, a, graph, visited)) { - return true; + oriented = true; } } } - return false; + return oriented; } private boolean r2Helper(Node a, Node b, Node c, Graph graph, Set visited) { - boolean directed = direct(a, c, graph, visited); - log(LogUtilsSearch.edgeOrientedMsg( - "Meek R2 triangle (" + a + "-->" + b + "-->" + c + ", " + a + "---" + c + ")", graph.getEdge(a, c))); - return directed; + if (direct(a, c, graph, visited)) { + log(LogUtilsSearch.edgeOrientedMsg( + "Meek R2 triangle (" + a + "-->" + b + "-->" + c + ", " + a + "--" + c + ")", graph.getEdge(a, c))); + return true; + } + return false; } /** @@ -248,6 +243,8 @@ private boolean meekR3(Node d, Node a, Graph graph, Set visited) { return false; } + boolean oriented = false; + for (int i = 0; i < adjacentNodes.size(); i++) { for (int j = i + 1; j < adjacentNodes.size(); j++) { Node b = adjacentNodes.get(i); @@ -255,31 +252,31 @@ private boolean meekR3(Node d, Node a, Graph graph, Set visited) { if (!graph.isAdjacentTo(b, c)) { if (r3Helper(a, d, b, c, graph, visited)) { - return true; + oriented = true; } } } } - return false; + return oriented; } private boolean r3Helper(Node a, Node d, Node b, Node c, Graph graph, Set visited) { - boolean oriented = false; - - boolean b4 = graph.paths().isUndirectedFromTo(d, a); - boolean b5 = graph.paths().isUndirectedFromTo(d, b); - boolean b6 = graph.paths().isUndirectedFromTo(d, c); - boolean b7 = graph.paths().isDirectedFromTo(b, a); - boolean b8 = graph.paths().isDirectedFromTo(c, a); + boolean b4 = graph.paths().isUndirected(d, a); + boolean b5 = graph.paths().isUndirected(d, b); + boolean b6 = graph.paths().isUndirected(d, c); + boolean b7 = graph.paths().isDirected(b, a); + boolean b8 = graph.paths().isDirected(c, a); if (b4 && b5 && b6 && b7 && b8) { - oriented = direct(d, a, graph, visited); - log(LogUtilsSearch.edgeOrientedMsg("Meek R3 " + d + "--" + a + ", " + b + ", " - + c, graph.getEdge(d, a))); + if (direct(d, a, graph, visited)) { + log(LogUtilsSearch.edgeOrientedMsg("Meek R3 " + d + "--" + a + ", " + b + ", " + + c, graph.getEdge(d, a))); + return true; + } } - return oriented; + return false; } private boolean meekR4(Node a, Node b, Graph graph, Set visited) { @@ -287,6 +284,8 @@ private boolean meekR4(Node a, Node b, Graph graph, Set visited) { return false; } + boolean oriented = false; + for (Node c : graph.getParents(b)) { Set adj = getCommonAdjacents(a, c, graph); adj.remove(b); @@ -298,12 +297,12 @@ private boolean meekR4(Node a, Node b, Graph graph, Set visited) { if (graph.getEdge(a, d).isDirected()) continue; if (direct(a, b, graph, visited)) { log(LogUtilsSearch.edgeOrientedMsg("Meek R4 using " + c + ", " + d, graph.getEdge(a, b))); - return true; + oriented = true; } } } - return false; + return oriented; } private boolean direct(Node a, Node c, Graph graph, Set visited) { @@ -313,7 +312,7 @@ private boolean direct(Node a, Node c, Graph graph, Set visited) { Edge before = graph.getEdge(a, c); graph.removeEdge(before); - if (meekPreventCycles && graph.paths().existsDirectedPathFromTo(c, a)) { + if (meekPreventCycles && graph.paths().existsDirectedPath(c, a)) { graph.addEdge(before); return false; } @@ -329,9 +328,7 @@ private boolean direct(Node a, Node c, Graph graph, Set visited) { return true; } - private boolean revertToUnshieldedColliders(Node y, Graph graph, Set visited) { - boolean did = false; - + private void revertToUnshieldedColliders(Node y, Graph graph, Set visited) { List parents = graph.getParents(y); P: @@ -350,11 +347,7 @@ private boolean revertToUnshieldedColliders(Node y, Graph graph, Set visit visited.add(p); visited.add(y); - - did = true; } - - return did; } private void log(String message) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetProducer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetProducer.java index 53cf68f2e2..40d4cbe3d7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetProducer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetProducer.java @@ -43,6 +43,20 @@ public interface SepsetProducer { */ Set getSepset(Node a, Node b); + /** + * Returns the subset for a and b, where this sepset is expected to contain all the nodes in s. The behavior is + * morphed depending on whether sepsets are calculated using an independence test or not. If sepsets are calculated + * using an independence test, and a sepset is not found containing all the nodes in s, then the method will return + * null. Otherwise, if the discovered sepset does not contain all the nodes in s, the method will throw an + * exception. + * + * @param a the first node + * @param b the second node + * @param s the set of nodes + * @return the set of nodes that sepsets for a and b are expected to contain. + */ + Set getSepsetContaining(Node a, Node b, Set s); + /** *

isUnshieldedCollider.

* @@ -77,11 +91,21 @@ public interface SepsetProducer { /** *

isIndependent.

* - * @param d a {@link edu.cmu.tetrad.graph.Node} object - * @param c a {@link edu.cmu.tetrad.graph.Node} object - * @param path a {@link java.util.Set} object + * @param d a {@link edu.cmu.tetrad.graph.Node} object + * @param c a {@link edu.cmu.tetrad.graph.Node} object + * @param sepset a {@link java.util.Set} object * @return a boolean */ - boolean isIndependent(Node d, Node c, Set path); + boolean isIndependent(Node d, Node c, Set sepset); + + /** + * Calculates the p-value for a statistical test a _||_ b | sepset. + * + * @param a the first node + * @param b the second node + * @param sepset the set of nodes + * @return the p-value for the statistical test + */ + double getPValue(Node a, Node b, Set sepset); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsGreedy.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsGreedy.java index 0a951cb3f2..c22b8638cb 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsGreedy.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsGreedy.java @@ -75,19 +75,35 @@ public SepsetsGreedy(Graph graph, IndependenceTest independenceTest, SepsetMap e } /** - * {@inheritDoc} - *

- * Pick out the sepset from among adj(i) or adj(k) with the highest score value. + * Retrieves the sepset (separating set) between two nodes, or null if no such sepset is found. + * + * @param i The first node + * @param k The second node + * @return The sepset between the two nodes */ public Set getSepset(Node i, Node k) { - return getSepsetGreedy(i, k); + return getSepsetGreedyContaining(i, k, null); + } + + /** + * Retrieves a sepset (separating set) between two nodes containing a set of nodes, or null if no such sepset is + * found. If there is no required set of nodes, pass null for the set. + * + * @param i The first node + * @param k The second node + * @param s The set of nodes that must be contained in the sepset, or null if no such set is required. + * @return The sepset between the two nodes + */ + @Override + public Set getSepsetContaining(Node i, Node k, Set s) { + return getSepsetGreedyContaining(i, k, s); } /** * {@inheritDoc} */ public boolean isUnshieldedCollider(Node i, Node j, Node k) { - Set set = getSepsetGreedy(i, k); + Set set = getSepsetGreedyContaining(i, k, null); return set != null && !set.contains(j); } @@ -95,12 +111,26 @@ public boolean isUnshieldedCollider(Node i, Node j, Node k) { * {@inheritDoc} */ @Override - public boolean isIndependent(Node a, Node b, Set c) { - IndependenceResult result = this.independenceTest.checkIndependence(a, b, c); + public boolean isIndependent(Node a, Node b, Set sepset) { + IndependenceResult result = this.independenceTest.checkIndependence(a, b, sepset); this.result = result; return result.isIndependent(); } + /** + * Returns the p-value for the independence test between two nodes, given a set of separator nodes. + * + * @param a the first node + * @param b the second node + * @param sepset the set of separator nodes + * @return the p-value for the independence test + */ + @Override + public double getPValue(Node a, Node b, Set sepset) { + IndependenceResult result = this.independenceTest.checkIndependence(a, b, sepset); + return result.getPValue(); + } + /** * {@inheritDoc} */ @@ -157,7 +187,7 @@ public void setDepth(int depth) { this.depth = depth; } - private Set getSepsetGreedy(Node i, Node k) { + private Set getSepsetGreedyContaining(Node i, Node k, Set s) { if (this.extraSepsets != null) { Set v = this.extraSepsets.get(i, k); @@ -179,6 +209,10 @@ private Set getSepsetGreedy(Node i, Node k) { while ((choice = gen.next()) != null) { Set v = GraphUtils.asSet(choice, adji); + if (s != null && !v.containsAll(s)) { + continue; + } + v = possibleParents(i, v, this.knowledge, k); if (this.independenceTest.checkIndependence(i, k, v).isIndependent()) { @@ -194,6 +228,10 @@ private Set getSepsetGreedy(Node i, Node k) { while ((choice = gen.next()) != null) { Set v = GraphUtils.asSet(choice, adjk); + if (s != null && !v.containsAll(s)) { + continue; + } + v = possibleParents(k, v, this.knowledge, i); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsConservative.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsMaxP.java similarity index 79% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsConservative.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsMaxP.java index dc6e89a290..354422d6a7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsConservative.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsMaxP.java @@ -45,7 +45,7 @@ * @see SepsetMap * @see Cpc */ -public class SepsetsConservative implements SepsetProducer { +public class SepsetsMaxP implements SepsetProducer { private final Graph graph; private final IndependenceTest independenceTest; private final SepsetMap extraSepsets; @@ -60,7 +60,7 @@ public class SepsetsConservative implements SepsetProducer { * @param extraSepsets a {@link edu.cmu.tetrad.search.utils.SepsetMap} object * @param depth a int */ - public SepsetsConservative(Graph graph, IndependenceTest independenceTest, SepsetMap extraSepsets, int depth) { + public SepsetsMaxP(Graph graph, IndependenceTest independenceTest, SepsetMap extraSepsets, int depth) { this.graph = graph; this.independenceTest = independenceTest; this.extraSepsets = extraSepsets; @@ -68,12 +68,29 @@ public SepsetsConservative(Graph graph, IndependenceTest independenceTest, Sepse } /** - * {@inheritDoc} - *

- * Pick out the sepset from among adj(i) or adj(k) with the highest p value. + * Returns the set of nodes in the sepset between two given nodes, or null if no sepset is found. + * + * @param i the first node + * @param k the second node + * @return a Set of Node objects representing the sepset between the two nodes, or null if no sepset is found. */ public Set getSepset(Node i, Node k) { - double _p = 0.0; + return getSepsetContaining(i, k, null); + } + + /** + * Returns the set of nodes in the sepset between two given nodes containing a given set of separator nodes, or null + * if no sepset is found. If there is no required set of nodes, pass null for the set. + * + * @param i the first node + * @param k the second node + * @param s A set of nodes that must be in the sepset, or null if no such set is required. + * @return a Set of Node objects representing the sepset between the two nodes containing the given set, or null if + * no sepset is found + */ + @Override + public Set getSepsetContaining(Node i, Node k, Set s) { + double _p = -1; Set _v = null; if (this.extraSepsets != null) { @@ -98,6 +115,10 @@ public Set getSepset(Node i, Node k) { while ((choice = gen.next()) != null) { Set v = GraphUtils.asSet(choice, adji); + if (s != null && !v.containsAll(s)) { + continue; + } + IndependenceResult result = getIndependenceTest().checkIndependence(i, k, v); if (result.isIndependent()) { @@ -116,6 +137,11 @@ public Set getSepset(Node i, Node k) { while ((choice = gen.next()) != null) { Set v = GraphUtils.asSet(choice, adjk); + + if (s != null && !v.containsAll(s)) { + continue; + } + IndependenceResult result = getIndependenceTest().checkIndependence(i, k, v); if (result.isIndependent()) { @@ -225,15 +251,34 @@ public List>> getSepsetsLists(Node x, Node y, Node z, /** - * {@inheritDoc} + * Determines if two nodes are independent given a set of separator nodes. + * + * @param a A {@link Node} object representing the first node. + * @param b A {@link Node} object representing the second node. + * @param sepset A {@link Set} object representing the set of separator nodes. + * @return True if the nodes are independent, false otherwise. */ @Override - public boolean isIndependent(Node a, Node b, Set c) { - IndependenceResult result = this.independenceTest.checkIndependence(a, b, c); + public boolean isIndependent(Node a, Node b, Set sepset) { + IndependenceResult result = this.independenceTest.checkIndependence(a, b, sepset); this.lastResult = result; return result.isIndependent(); } + /** + * Returns the p-value for the independence test between two nodes, given a set of separator nodes. + * + * @param a the first node + * @param b the second node + * @param sepset the set of separator nodes + * @return the p-value for the independence test + */ + @Override + public double getPValue(Node a, Node b, Set sepset) { + IndependenceResult result = this.independenceTest.checkIndependence(a, b, sepset); + return result.getPValue(); + } + /** * {@inheritDoc} */ diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsMinP.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsMinP.java new file mode 100644 index 0000000000..eb453302c9 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsMinP.java @@ -0,0 +1,314 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.search.utils; + +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.GraphUtils; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.search.Cpc; +import edu.cmu.tetrad.search.IndependenceTest; +import edu.cmu.tetrad.search.test.IndependenceResult; +import edu.cmu.tetrad.util.ChoiceGenerator; +import org.apache.commons.math3.util.FastMath; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** + *

Provides a SepsetProcuder that selects the first sepset it comes to from + * among the extra sepsets or the adjacents of i or k, or null if none is found. This version uses conservative + * reasoning (see the CPC algorithm).

+ * + * @author josephramsey + * @version $Id: $Id + * @see SepsetProducer + * @see SepsetMap + * @see Cpc + */ +public class SepsetsMinP implements SepsetProducer { + private final Graph graph; + private final IndependenceTest independenceTest; + private final SepsetMap extraSepsets; + private final int depth; + private IndependenceResult lastResult; + + /** + *

Constructor for SepsetsConservative.

+ * + * @param graph a {@link Graph} object + * @param independenceTest a {@link IndependenceTest} object + * @param extraSepsets a {@link SepsetMap} object + * @param depth a int + */ + public SepsetsMinP(Graph graph, IndependenceTest independenceTest, SepsetMap extraSepsets, int depth) { + this.graph = graph; + this.independenceTest = independenceTest; + this.extraSepsets = extraSepsets; + this.depth = depth; + } + + /** + * Returns the set of nodes that form the sepset (separating set) between two given nodes. + * + * @param i a {@link Node} object representing the first node. + * @param k a {@link Node} object representing the second node. + * @return a {@link Set} of nodes that form the sepset between the two given nodes. + */ + public Set getSepset(Node i, Node k) { + return getSepsetContaining(i, k, null); + } + + /** + * Returns the set of nodes that form the sepset (separating set) between two given nodes containing all the + * nodes in the given set. If there is no required set of nodes to include, pass null for s. + * + * @param i a {@link Node} object representing the first node. + * @param k a {@link Node} object representing the second node. + * @param s a {@link Set} of nodes to that must be included in the sepset, or null if there is no such requirement. + * @return a {@link Set} of nodes that form the sepset between the two given nodes. + */ + @Override + public Set getSepsetContaining(Node i, Node k, Set s) { + double _p = 2; + Set _v = null; + + if (this.extraSepsets != null) { + Set possibleMsep = this.extraSepsets.get(i, k); + if (possibleMsep != null) { + IndependenceResult result = this.independenceTest.checkIndependence(i, k, possibleMsep); + _p = result.getPValue(); + _v = possibleMsep; + } + } + + List adji = new ArrayList<>(this.graph.getAdjacentNodes(i)); + List adjk = new ArrayList<>(this.graph.getAdjacentNodes(k)); + adji.remove(k); + adjk.remove(i); + + for (int d = 0; d <= FastMath.min((this.depth == -1 ? 1000 : this.depth), FastMath.max(adji.size(), adjk.size())); d++) { + if (d <= adji.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adji.size(), d); + int[] choice; + + while ((choice = gen.next()) != null) { + Set v = GraphUtils.asSet(choice, adji); + + if (s != null && v.containsAll(s)) { + continue; + } + + IndependenceResult result = getIndependenceTest().checkIndependence(i, k, v); + + if (result.isIndependent()) { + double pValue = result.getPValue(); + if (pValue < _p) { + _p = pValue; + _v = v; + } + } + } + } + + if (d <= adjk.size()) { + ChoiceGenerator gen = new ChoiceGenerator(adjk.size(), d); + int[] choice; + + while ((choice = gen.next()) != null) { + Set v = GraphUtils.asSet(choice, adjk); + + if (s != null && v.containsAll(s)) { + continue; + } + + IndependenceResult result = getIndependenceTest().checkIndependence(i, k, v); + + if (result.isIndependent()) { + double pValue = result.getPValue(); + if (pValue < _p) { + _p = pValue; + _v = v; + } + } + } + } + } + + return _v; + + } + + /** + * {@inheritDoc} + */ + public boolean isUnshieldedCollider(Node i, Node j, Node k) { + List>> ret = getSepsetsLists(i, j, k, this.independenceTest, this.depth, true); + return ret.get(0).isEmpty(); + } + + // The published version. + + /** + *

getSepsetsLists.

+ * + * @param x a {@link Node} object + * @param y a {@link Node} object + * @param z a {@link Node} object + * @param test a {@link IndependenceTest} object + * @param depth a int + * @param verbose a boolean + * @return a {@link List} object + */ + public List>> getSepsetsLists(Node x, Node y, Node z, + IndependenceTest test, int depth, + boolean verbose) { + List> sepsetsContainingY = new ArrayList<>(); + List> sepsetsNotContainingY = new ArrayList<>(); + + List _nodes = new ArrayList<>(this.graph.getAdjacentNodes(x)); + _nodes.remove(z); + + int _depth = depth; + if (_depth == -1) { + _depth = 1000; + } + + _depth = FastMath.min(_depth, _nodes.size()); + + for (int d = 0; d <= _depth; d++) { + ChoiceGenerator cg = new ChoiceGenerator(_nodes.size(), d); + int[] choice; + + while ((choice = cg.next()) != null) { + Set cond = GraphUtils.asSet(choice, _nodes); + + if (test.checkIndependence(x, z, cond).isIndependent()) { + if (verbose) { + System.out.println("Indep: " + x + " _||_ " + z + " | " + cond); + } + + if (cond.contains(y)) { + sepsetsContainingY.add(cond); + } else { + sepsetsNotContainingY.add(cond); + } + } + } + } + + _nodes = new ArrayList<>(this.graph.getAdjacentNodes(z)); + _nodes.remove(x); + + _depth = depth; + if (_depth == -1) { + _depth = 1000; + } + _depth = FastMath.min(_depth, _nodes.size()); + + for (int d = 0; d <= _depth; d++) { + ChoiceGenerator cg = new ChoiceGenerator(_nodes.size(), d); + int[] choice; + + while ((choice = cg.next()) != null) { + Set cond = GraphUtils.asSet(choice, _nodes); + + if (test.checkIndependence(x, z, cond).isIndependent()) { + if (cond.contains(y)) { + sepsetsContainingY.add(cond); + } else { + sepsetsNotContainingY.add(cond); + } + } + } + } + + List>> ret = new ArrayList<>(); + ret.add(sepsetsContainingY); + ret.add(sepsetsNotContainingY); + + return ret; + } + + + /** + * Determines if two nodes are independent given a set of separator nodes. + * + * @param a A {@link Node} object representing the first node. + * @param b A {@link Node} object representing the second node. + * @param sepset A {@link Set} object representing the set of separator nodes. + * @return True if the nodes are independent, false otherwise. + */ + @Override + public boolean isIndependent(Node a, Node b, Set sepset) { + IndependenceResult result = this.independenceTest.checkIndependence(a, b, sepset); + this.lastResult = result; + return result.isIndependent(); + } + + /** + * Returns the p-value for the independence test between two nodes, given a set of separator nodes. + * + * @param a the first node + * @param b the second node + * @param sepset the set of separator nodes + * @return the p-value for the independence test + */ + @Override + public double getPValue(Node a, Node b, Set sepset) { + IndependenceResult result = this.independenceTest.checkIndependence(a, b, sepset); + return result.getPValue(); + } + + /** + * {@inheritDoc} + */ + @Override + public double getScore() { + return -(this.lastResult.getPValue() - this.independenceTest.getAlpha()); + } + + /** + * {@inheritDoc} + */ + @Override + public List getVariables() { + return this.independenceTest.getVariables(); + } + + /** + * {@inheritDoc} + */ + @Override + public void setVerbose(boolean verbose) { + } + + /** + *

Getter for the field independenceTest.

+ * + * @return a {@link IndependenceTest} object + */ + public IndependenceTest getIndependenceTest() { + return this.independenceTest; + } +} + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsPossibleMsep.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsPossibleMsep.java index 6c2241edc1..73465c767c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsPossibleMsep.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsPossibleMsep.java @@ -71,15 +71,38 @@ public SepsetsPossibleMsep(Graph graph, IndependenceTest test, Knowledge knowled } /** - * {@inheritDoc} - *

- * Pick out the sepset from among adj(i) or adj(k) with the highest p value. + * Retrieves the separation set (sepset) between two nodes. + * + * @param i The first node + * @param k The second node + * @return The set of nodes that form the sepset between node i and node k, or null if no sepset exists */ public Set getSepset(Node i, Node k) { - Set condSet = getCondSet(i, k, this.maxPathLength); + Set condSet = getCondSetContaining(i, k, null, this.maxPathLength); if (condSet == null) { - condSet = getCondSet(k, i, this.maxPathLength); + condSet = getCondSetContaining(k, i, null, this.maxPathLength); + } + + return condSet; + } + + /** + * Retrieves the separation set (sepset) between two nodes i and k that contains a given set of nodes s. If there + * is no required set of nodes, pass null for the set. + * + * @param i The first node + * @param k The second node + * @param s The set of nodes to be contained in the sepset + * @return The set of nodes that form the sepset between node i and node k and contains all nodes from set s, + * or null if no sepset exists + */ + @Override + public Set getSepsetContaining(Node i, Node k, Set s) { + Set condSet = getCondSetContaining(i, k, s, this.maxPathLength); + + if (condSet == null) { + condSet = getCondSetContaining(k, i, s, this.maxPathLength); } return condSet; @@ -129,12 +152,26 @@ public void setVerbose(boolean verbose) { * {@inheritDoc} */ @Override - public boolean isIndependent(Node d, Node c, Set path) { - IndependenceResult result = this.test.checkIndependence(d, c, path); + public boolean isIndependent(Node d, Node c, Set sepset) { + IndependenceResult result = this.test.checkIndependence(d, c, sepset); return result.isIndependent(); } - private Set getCondSet(Node node1, Node node2, int maxPathLength) { + /** + * Returns the p-value for the independence test between two nodes, given a set of separator nodes. + * + * @param a the first node + * @param b the second node + * @param sepset the set of separator nodes + * @return the p-value for the independence test + */ + @Override + public double getPValue(Node a, Node b, Set sepset) { + IndependenceResult result = this.test.checkIndependence(a, b, sepset); + return result.getPValue(); + } + + private Set getCondSetContaining(Node node1, Node node2, Set s, int maxPathLength) { List possibleMsepSet = getPossibleMsep(node1, node2, maxPathLength); List possibleMsep = new ArrayList<>(possibleMsepSet); boolean noEdgeRequired = this.knowledge.noEdgeRequired(node1.getName(), node2.getName()); @@ -154,6 +191,10 @@ private Set getCondSet(Node node1, Node node2, int maxPathLength) { Set condSet = GraphUtils.asSet(choice, possibleMsep); + if (s != null && !condSet.containsAll(s)) { + continue; + } + // check against bk knowledge added by DMalinsky 07/24/17 **/ // if (knowledge.isForbidden(node1.getName(), node2.getName())) continue; boolean flagForbid = false; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsSet.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsSet.java index c85f51c1d9..daf34f3351 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsSet.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/SepsetsSet.java @@ -55,14 +55,46 @@ public SepsetsSet(SepsetMap sepsets, IndependenceTest test) { } /** - * {@inheritDoc} + * Retrieves the sepset between two nodes. + * + * @param a the first node + * @param b the second node + * @return the set of nodes in the sepset between a and b */ @Override public Set getSepset(Node a, Node b) { - //isIndependent(a, b, sepsets.get(a, b)); return this.sepsets.get(a, b); } + /** + * Retrieves the sepset for a and b, where we are expecting this sepset to contain all the nodes in s. + * + * @param a the first node + * @param b the second node + * @param s the set of nodes to check in the sepset of a and b + * @return the set of nodes that the sepset of a and b is expected to contain. + * @throws IllegalArgumentException if the sepset of a and b does not contain all the nodes in s + */ + @Override + public Set getSepsetContaining(Node a, Node b, Set s) { + Set sepset = this.sepsets.get(a, b); + + if (sepset != null && !sepset.containsAll(s)) { + throw new IllegalArgumentException("Was expecting the sepset of " + a + " and " + b + " (" + sepset + + ") to contain all the sepset in " + s + "."); + } + + return sepset; + } + + /** + * @throws UnsupportedOperationException if this method is called + */ + @Override + public double getPValue(Node a, Node b, Set sepset) { + throw new UnsupportedOperationException("This makes no sense for this subclass."); + } + /** * {@inheritDoc} */ @@ -77,8 +109,8 @@ public boolean isUnshieldedCollider(Node i, Node j, Node k) { * {@inheritDoc} */ @Override - public boolean isIndependent(Node a, Node b, Set c) { - IndependenceResult result = this.test.checkIndependence(a, b, c); + public boolean isIndependent(Node a, Node b, Set sepset) { + IndependenceResult result = this.test.checkIndependence(a, b, sepset); this.result = result; return result.isIndependent(); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java index dcff59f8e6..35d92a5430 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/utils/TeyssierScorer.java @@ -156,20 +156,26 @@ public void swaptuck(Node x, Node y) { } /** - *

tuck.

+ * Moves j to before k and moves all the ancestors of j betwween k and j to before k. * - * @param k a {@link edu.cmu.tetrad.graph.Node} object - * @param j a int - * @return a boolean + * @param j The node to tuck. + * @param k The node to tuck j before. + * @return true if the tuck made a change. */ - public boolean tuck(Node k, int j) { - if (adjacent(k, get(j))) return false; - if (j >= index(k)) return false; + public boolean tuck(Node j, Node k) { + int jIndex = index(j); + int kIndex = index(k); + + if (jIndex < kIndex) { + return false; + } + + Set ancestors = getAncestors(j); + int _kIndex = kIndex; - Set ancestors = getAncestors(k); - for (int i = j + 1; i <= index(k); i++) { + for (int i = jIndex; i > kIndex; i--) { if (ancestors.contains(get(i))) { - moveTo(get(i), j++); + moveTo(get(i), _kIndex++); } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Dci.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Dci.java index fbf672b460..1d78c15329 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Dci.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Dci.java @@ -710,7 +710,7 @@ private void awayFromCycle(Graph graph, Node a, Node b, Node c) { if ((graph.isAdjacentTo(a, c)) && (graph.getEndpoint(a, c) == Endpoint.ARROW) && (graph.getEndpoint(c, a) == Endpoint.CIRCLE)) { - if (graph.paths().isDirectedFromTo(a, b) && graph.paths().isDirectedFromTo(b, c)) { + if (graph.paths().isDirected(a, b) && graph.paths().isDirected(b, c)) { graph.setEndpoint(c, a, Endpoint.TAIL); this.changeFlag = true; } @@ -1744,7 +1744,7 @@ private boolean predictsFalseDependence(Graph graph) { continue; } for (Set condSet : sepset.getSet(x, y)) { - if (!graph.paths().isMSeparatedFrom(x, y, condSet)) { + if (!graph.paths().isMSeparatedFrom(x, y, condSet, false)) { return true; } } @@ -1864,7 +1864,7 @@ private void resolveResultingIndependenciesB() { System.out.println("Resolving inconsistencies... " + c + " of " + cs + " (" + p + " of " + pairs.size() + " pairs)"); c++; Set z = new HashSet<>(set); - if (allInd.paths().isMConnectedTo(pair.getFirst(), pair.getSecond(), z)) { + if (allInd.paths().isMConnectedTo(pair.getFirst(), pair.getSecond(), z, false)) { continue; } combinedSepset.set(pair.getFirst(), pair.getSecond(), new HashSet<>(set)); @@ -1937,7 +1937,7 @@ private void resolveResultingIndependenciesC() { for (Set inpset : pset) { Set cond = new HashSet<>(inpset); cond.add(node); - if (fciResult.paths().isMSeparatedFrom(x, y, cond)) { + if (fciResult.paths().isMSeparatedFrom(x, y, cond, false)) { newSepset.set(x, y, cond); } } @@ -1969,7 +1969,7 @@ private void doSepsetClosure(SepsetMapDci sepset, Graph graph) { int ps = (int) FastMath.pow(2, possibleNodes.size()); for (Set condSet : new PowerSet<>(possibleNodes)) { System.out.println("Getting closure set... " + c + " of " + ps + "(" + p + " of " + pairs.size() + " remaining)"); - if (graph.paths().isMSeparatedFrom(x, y, new HashSet<>(condSet))) { + if (graph.paths().isMSeparatedFrom(x, y, new HashSet<>(condSet), false)) { sepset.set(x, y, new HashSet<>(condSet)); } c++; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsBeam.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsBeam.java index d08554d8a2..1cd1482cde 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsBeam.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsBeam.java @@ -161,7 +161,7 @@ public Graph search() { if (this.trueModel != null) { this.trueModel = GraphUtils.replaceNodes(this.trueModel, bestGraph.getNodes()); - this.trueModel = GraphTransforms.cpdagForDag(this.trueModel); + this.trueModel = GraphTransforms.dagToCpdag(this.trueModel); } System.out.println("Initial Score = " + this.nf.format(bestScore)); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsGes.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsGes.java index 784225f733..d2e2cc66c4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsGes.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/HbsmsGes.java @@ -93,7 +93,7 @@ public HbsmsGes(Graph graph, DataSet data) { DagInCpcagIterator iterator = new DagInCpcagIterator(graph, getKnowledge(), allowArbitraryOrientations, allowNewColliders); graph = iterator.next(); - graph = GraphTransforms.cpdagForDag(graph); + graph = GraphTransforms.dagToCpdag(graph); if (GraphUtils.containsBidirectedEdge(graph)) { throw new IllegalArgumentException("Contains bidirected edge."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Ion.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Ion.java index a6f73eb78a..3c56b4c69e 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Ion.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/work_in_progress/Ion.java @@ -871,7 +871,7 @@ private List> findSepAndAssoc(Graph graph) { for (Node node : subset) { pagSubset.add(pag.getNode(node.getName())); } - if (pag.paths().isMSeparatedFrom(pagX, pagY, new HashSet<>(pagSubset))) { + if (pag.paths().isMSeparatedFrom(pagX, pagY, new HashSet<>(pagSubset), false)) { if (!pag.isAdjacentTo(pagX, pagY)) { addIndep = true; indep.addMoreZ(new HashSet<>(subset)); @@ -918,7 +918,7 @@ private boolean predictsFalseIndependence(Set associations for (IonIndependenceFacts assocFact : associations) for (Set conditioningSet : assocFact.getZ()) if (pag.paths().isMSeparatedFrom( - assocFact.getX(), assocFact.getY(), conditioningSet)) + assocFact.getX(), assocFact.getY(), conditioningSet, false)) return true; return false; } @@ -1355,7 +1355,7 @@ private void awayFromCycle(Graph graph, Node a, Node b, Node c) { if ((graph.isAdjacentTo(a, c)) && (graph.getEndpoint(a, c) == Endpoint.ARROW) && (graph.getEndpoint(c, a) == Endpoint.CIRCLE)) { - if (graph.paths().isDirectedFromTo(a, b) && graph.paths().isDirectedFromTo(b, c)) { + if (graph.paths().isDirected(a, b) && graph.paths().isDirected(b, c)) { graph.setEndpoint(c, a, Endpoint.TAIL); this.changeFlag = true; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java index d86d33f45b..7ec08268a4 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/GdistanceRandom.java @@ -67,8 +67,8 @@ private List randomPairSimulation() { //convert those dags to CPDAGs if (this.verbose) System.out.println("converting dags to CPDAGs"); - Graph graph1 = GraphTransforms.cpdagForDag(dag1); - Graph graph2 = GraphTransforms.cpdagForDag(dag2); + Graph graph1 = GraphTransforms.dagToCpdag(dag1); + Graph graph2 = GraphTransforms.dagToCpdag(dag2); //run Gdistance on these two graphs if (this.verbose) System.out.println("running Gdistance on the CPDAGs"); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/RBExperiments.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/RBExperiments.java index d317c4b64a..efda98a00d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/RBExperiments.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/RBExperiments.java @@ -511,7 +511,7 @@ private double getLnProbUsingDepFiltering(Graph pag, Map H) { for (IndependenceFact fact : H.keySet()) { BCInference.OP op; - if (pag.paths().isMSeparatedFrom(fact.getX(), fact.getY(), fact.getZ())) { + if (pag.paths().isMSeparatedFrom(fact.getX(), fact.getY(), fact.getZ(), false)) { op = BCInference.OP.independent; } else { op = BCInference.OP.dependent; diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java index 28fb2e581d..aad02d7153 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java @@ -198,20 +198,20 @@ public static ComparisonResult compare(ComparisonParameters params) { Pc search = new Pc(test); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) { if (test == null) throw new IllegalArgumentException("Test not set."); Cpc search = new Cpc(test); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) { if (score == null) throw new IllegalArgumentException("Score not set."); Fges search = new Fges(score); search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) { if (test == null) throw new IllegalArgumentException("Test not set."); Fci search = new Fci(test); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java index aa054ecdef..bce1afa787 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java @@ -152,18 +152,18 @@ public static ComparisonResult compare(ComparisonParameters params) { Pc search = new Pc(test); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) { Cpc search = new Cpc(test); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) { Fges search = new Fges(score); //search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) { Fci search = new Fci(test); result.setResultGraph(search.search()); @@ -390,7 +390,7 @@ public static ComparisonResult compare(ComparisonParameters params) { Pc search = new Pc(test); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) { if (test == null) { throw new IllegalArgumentException("Test not set."); @@ -398,7 +398,7 @@ public static ComparisonResult compare(ComparisonParameters params) { Cpc search = new Cpc(test); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) { if (score == null) { throw new IllegalArgumentException("Score not set."); @@ -407,7 +407,7 @@ public static ComparisonResult compare(ComparisonParameters params) { //search.setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed()); result.setResultGraph(search.search()); Graph dag = new EdgeListGraph(trueDag); - result.setCorrectResult(GraphTransforms.cpdagForDag(dag)); + result.setCorrectResult(GraphTransforms.dagToCpdag(dag)); } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) { if (test == null) { throw new IllegalArgumentException("Test not set."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java index 66d4e42b74..f84aaabe31 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java @@ -317,7 +317,7 @@ public void testPc(int numVars, double edgeFactor, int numCases, double alpha) { this.out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms"); - GraphSearchUtils.graphComparison(GraphTransforms.cpdagForDag(graph), outGraph, this.out); + GraphSearchUtils.graphComparison(GraphTransforms.dagToCpdag(graph), outGraph, this.out); this.out.close(); } @@ -439,7 +439,7 @@ public void testPcStable(int numVars, double edgeFactor, int numCases, double al this.out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms"); - Graph trueCPDAG = GraphTransforms.cpdagForDag(dag); + Graph trueCPDAG = GraphTransforms.dagToCpdag(dag); System.out.println("# edges in true CPDAG = " + trueCPDAG.getNumEdges()); System.out.println("# edges in est CPDAG = " + estCPDAG.getNumEdges()); @@ -510,7 +510,7 @@ public void testFges(int numVars, double edgeFactor, int numCases, double penalt this.out.println("Total elapsed (cov + FGES) " + (time4 - time2) + " ms"); - Graph trueCPDAG = GraphTransforms.cpdagForDag(dag); + Graph trueCPDAG = GraphTransforms.dagToCpdag(dag); System.out.println("# edges in true CPDAG = " + trueCPDAG.getNumEdges()); System.out.println("# edges in est CPDAG = " + estCPDAG.getNumEdges()); @@ -605,7 +605,7 @@ public void testCpc(int numVars, double edgeFactor, int numCases) { this.out.println("Total elapsed (cov + PC-Stable) " + (time4 - time2) + " ms"); - GraphSearchUtils.graphComparison(GraphTransforms.cpdagForDag(graph), outGraph, this.out); + GraphSearchUtils.graphComparison(GraphTransforms.dagToCpdag(graph), outGraph, this.out); this.out.close(); } @@ -684,7 +684,7 @@ public void testCpcStable(int numVars, double edgeFactor, int numCases, double a this.out.println("Total elapsed (cov + CPC-Stable) " + (time4 - time2) + " ms"); - Graph trueCPDAG = GraphTransforms.cpdagForDag(graph); + Graph trueCPDAG = GraphTransforms.dagToCpdag(graph); GraphSearchUtils.graphComparison(trueCPDAG, outGraph, this.out); @@ -951,7 +951,7 @@ private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, System.out.println("Calculating CPDAG for DAG"); - Graph CPDAG = GraphTransforms.cpdagForDag(dag); + Graph CPDAG = GraphTransforms.dagToCpdag(dag); List vars = dag.getNodes(); @@ -1171,7 +1171,7 @@ private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRun System.out.println("Calculating CPDAG for DAG"); - Graph CPDAG = GraphTransforms.cpdagForDag(dag); + Graph CPDAG = GraphTransforms.dagToCpdag(dag); int[] tiers = new int[dag.getNumNodes()]; @@ -1598,7 +1598,7 @@ public void testCompareDagToCPDAG(int numLatents) { System.out.println("PC graph = " + left); - Graph top = GraphTransforms.cpdagForDag(dag); + Graph top = GraphTransforms.dagToCpdag(dag); System.out.println("DAG to CPDAG graph = " + top); @@ -1656,7 +1656,7 @@ public void testComparePcVersions(int numVars, double edgeFactor, int numLatents System.out.println("Graph done"); - Graph left = GraphTransforms.cpdagForDag(dag);// pc1.search(); + Graph left = GraphTransforms.dagToCpdag(dag);// pc1.search(); System.out.println("First FAS graph = " + left); @@ -1813,8 +1813,8 @@ private void bidirectedComparison(Graph dag, Graph truePag, Graph estGraph, Set< boolean existsCommonCause = false; for (Node latent : missingNodes) { - if (dag.paths().existsDirectedPathFromTo(latent, edge.getNode1()) - && dag.paths().existsDirectedPathFromTo(latent, edge.getNode2())) { + if (dag.paths().existsDirectedPath(latent, edge.getNode1()) + && dag.paths().existsDirectedPath(latent, edge.getNode2())) { existsCommonCause = true; break; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ChoiceGenerator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ChoiceGenerator.java index e99a5dbad9..fdb6f953df 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ChoiceGenerator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/ChoiceGenerator.java @@ -66,7 +66,6 @@ public final class ChoiceGenerator { */ public ChoiceGenerator(int a, int b) { if (a < 0 || b < 0) throw new IllegalArgumentException("ERROR: a and b must be non-negative"); - if (b > a) b = a; this.a = a; this.b = b; @@ -148,6 +147,10 @@ public static double logCombinations(int a, int b) { * @return the next combination in the series, or null if the series is finished. */ public synchronized int[] next() { + if (a < b) { + return null; + } + int i = getB(); // Scan from the right for the first index whose value is less than diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/GraphSampling.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/GraphSampling.java index 42a1cdb087..6aa964874f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/GraphSampling.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/GraphSampling.java @@ -93,10 +93,10 @@ public static Graph createGraphWithHighProbabilityEdges(List graphs, Resa * @return graph containing edges with edge type of the highest probability */ public static Graph createGraphWithHighProbabilityEdges(List graphs) { - // filter out null graphs and add PAG coloring + // filter out null graphs and add PAG edge specializstion markup graphs = graphs.stream() .filter(Objects::nonNull) - .map(GraphSampling::addPagColorings) + .map(GraphSampling::addEdgeSpecializationMarkups) .collect(Collectors.toList()); if (graphs.isEmpty()) { @@ -332,8 +332,8 @@ private static Graph createNewGraph(List graphNodes) { return new EdgeListGraph(Arrays.asList(nodes)); } - private static Graph addPagColorings(Graph graph) { - GraphUtils.addPagColoring(graph); + private static Graph addEdgeSpecializationMarkups(Graph graph) { + GraphUtils.addEdgeSpecializationMarkup(graph); return graph; } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java index 02bb34abdf..5cbbe56cbd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/Params.java @@ -886,6 +886,12 @@ public final class Params { * Constant USE_PSEUDOINVERSE_FOR_LATENT="usePseudoinverseForLatent" */ public static final String COMPARE_GRAPH_ALGCOMP = "compareGraphAlgcomp"; + + /** + * Constant THRESHOLD_LV_LITE = "thresholdLvLite" + */ + public static final String THRESHOLD_LV_LITE = "thresholdLvLite"; + // All parameters that are found in HTML manual documentation private static final Set ALL_PARAMS_IN_HTML_MANUAL = new HashSet<>(Arrays.asList( Params.ADD_ORIGINAL_DATASET, Params.ALPHA, Params.APPLY_R1, Params.AVG_DEGREE, Params.BASIS_TYPE, @@ -941,6 +947,10 @@ public final class Params { * Constant PC_HEURISTIC="pcHeuristic" */ public static String PC_HEURISTIC = "pcHeuristic"; + /** + * Constant RESOLVE_ALMOST_CYCLIC_PATHS="resolveAlmostCyclicPaths" + */ + public static String RESOLVE_ALMOST_CYCLIC_PATHS = "resolveAlmostCyclicPaths"; private Params() { } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/SublistGenerator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/SublistGenerator.java index 6f9bdafedc..e67fb5fba5 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/util/SublistGenerator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/util/SublistGenerator.java @@ -211,7 +211,7 @@ public synchronized int[] next() { * @return a {@link java.lang.String} object */ public String toString() { - return "Depth choice generator: a = " + this.a + " depth = " + this.depth; + return "Sublist generator: a = " + this.a + " depth = " + this.depth; } /** diff --git a/tetrad-lib/src/main/java/edu/pitt/csb/mgm/ExploreIndepTests.java b/tetrad-lib/src/main/java/edu/pitt/csb/mgm/ExploreIndepTests.java index 1d6ce5d18f..1abfbd318e 100644 --- a/tetrad-lib/src/main/java/edu/pitt/csb/mgm/ExploreIndepTests.java +++ b/tetrad-lib/src/main/java/edu/pitt/csb/mgm/ExploreIndepTests.java @@ -21,7 +21,6 @@ package edu.pitt.csb.mgm; -import edu.cmu.tetrad.algcomparison.algorithm.ExternalAlgorithm; import edu.cmu.tetrad.data.DataSet; import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.GraphSaveLoadUtils; @@ -56,7 +55,7 @@ public static void main(String[] args) { try { String path = ExampleMixedSearch.class.getResource("test_data").getPath(); Graph dag3 = GraphSaveLoadUtils.loadGraphTxt(new File(path, "DAG_0_graph.txt")); - Graph trueGraph = GraphTransforms.cpdagForDag(dag3); + Graph trueGraph = GraphTransforms.dagToCpdag(dag3); DataSet ds = MixedUtils.loadDataSet(path, "DAG_0_data.txt"); IndTestMultinomialLogisticRegression indMix = new IndTestMultinomialLogisticRegression(ds, .05); @@ -73,17 +72,17 @@ public static void main(String[] args) { long time = MillisecondTimes.timeMillis(); Graph dag2 = s1.search(); - Graph g1 = GraphTransforms.cpdagForDag(dag2); + Graph g1 = GraphTransforms.dagToCpdag(dag2); System.out.println("Mix Time " + ((MillisecondTimes.timeMillis() - time) / 1000.0)); time = MillisecondTimes.timeMillis(); Graph dag1 = s2.search(); - Graph g2 = GraphTransforms.cpdagForDag(dag1); + Graph g2 = GraphTransforms.dagToCpdag(dag1); System.out.println("Wald lin Time " + ((MillisecondTimes.timeMillis() - time) / 1000.0)); time = MillisecondTimes.timeMillis(); Graph dag = s3.search(); - Graph g3 = GraphTransforms.cpdagForDag(dag); + Graph g3 = GraphTransforms.dagToCpdag(dag); System.out.println("Wald log Time " + ((MillisecondTimes.timeMillis() - time) / 1000.0)); System.out.println(MixedUtils.EdgeStatHeader); diff --git a/tetrad-lib/src/main/java/edu/pitt/dbmi/algo/bayesian/constraint/search/RfciBsc.java b/tetrad-lib/src/main/java/edu/pitt/dbmi/algo/bayesian/constraint/search/RfciBsc.java index 9cc8fac898..d3f1ff5b83 100644 --- a/tetrad-lib/src/main/java/edu/pitt/dbmi/algo/bayesian/constraint/search/RfciBsc.java +++ b/tetrad-lib/src/main/java/edu/pitt/dbmi/algo/bayesian/constraint/search/RfciBsc.java @@ -114,7 +114,7 @@ private static double getLnProbUsingDepFiltering(Graph pag, Map H) { for (IndependenceFact fact : H.keySet()) { BCInference.OP op; - if (pag.paths().isMSeparatedFrom(fact.getX(), fact.getY(), fact.getZ())) { + if (pag.paths().isMSeparatedFrom(fact.getX(), fact.getY(), fact.getZ(), false)) { op = BCInference.OP.independent; } else { op = BCInference.OP.dependent; diff --git a/tetrad-lib/src/main/resources/docs/javahelp/TetradHelp.hs b/tetrad-lib/src/main/resources/docs/javahelp/TetradHelp.hs index 312e004b97..30fd246631 100644 --- a/tetrad-lib/src/main/resources/docs/javahelp/TetradHelp.hs +++ b/tetrad-lib/src/main/resources/docs/javahelp/TetradHelp.hs @@ -64,14 +64,14 @@ --> main window - + Project Tetrad Help - - javax.help.BackAction - javax.help.ForwardAction - javax.help.HomeAction - +-- +-- javax.help.BackAction +-- javax.help.ForwardAction +-- javax.help.HomeAction +-- + + + + + + + + + diff --git a/tetrad-lib/src/main/resources/docs/manual/index.html b/tetrad-lib/src/main/resources/docs/manual/index.html index bbb70db4b4..00efa54ff8 100755 --- a/tetrad-lib/src/main/resources/docs/manual/index.html +++ b/tetrad-lib/src/main/resources/docs/manual/index.html @@ -299,17 +299,18 @@

Display Subgraphs

-

Choose DAG in CPDAG

+

Choose Random DAG in CPDAG

If given a CPDAG as input, this chooses a random DAG from the Markov equivalence class of the CPDAG to display. The resulting DAG functions as a normal graph box.

-

Choose MAG in PAG

+

Choose Zhang MAG in PAG

If given a partial ancestral graph (PAG) as input, this chooses a - random mixed ancestral graph (MAG) from the equivalence class of the PAG - to display. The resulting MAG functions as a normal graph box.

+ mixed ancestral graph (MAG) from the equivalence class of the PAG + to display using Zhang's method. The resulting MAG functions as a + normal graph box.

Show DAGs in CPDAG

@@ -4863,6 +4864,25 @@

Zhang-Shen Bound Score

Double +

thresholdLvLite

+
    +
  • Short Description: Score threshold for judging model score equality
  • +
  • Long Description: Score threshold for judging model score equality +
  • +
  • Default Value: 0.01
  • +
  • Lower + Bound: 0
  • +
  • Upper Bound: Infinity
  • +
  • Value Type: + Double
  • +
+

addOriginalDataset

    coefLow Boolean
+

resolveAlmostCyclicPaths

+
    +
  • Short Description: + True just in case almost cyclic paths should be resolved in the + direction of the cycle. +
  • +
  • Long Description: + If true we resolved <-> edges as --> if there is a directed path x~~>y. + +
  • +
  • Default Value: true
  • +
  • Lower + Bound:
  • +
  • Upper Bound:
  • +
  • Value Type: + Boolean
  • +
+

doDiscriminatingPathColliderRule

    acceptsPrecision = new ArrayList<>(); + List acceptsRecall = new ArrayList<>(); for(Node a: accepts) { System.out.println("====================="); markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java index 0bc3e86f8c..9f5a35cb80 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDM.java @@ -743,8 +743,8 @@ public void rtest11() { graph.addDirectedEdge(X3, X0); - System.out.print(graph.paths().existsDirectedPathFromTo(X0, X3)); - System.out.print(graph.paths().existsDirectedPathFromTo(X3, X0)); + System.out.print(graph.paths().existsDirectedPath(X0, X3)); + System.out.print(graph.paths().existsDirectedPath(X3, X0)); for (Node node : graph.getNodes()) { System.out.println("Nodes adjacent to " + node + ": " + graph.getAdjacentNodes(node) + "\n"); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDag.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDag.java index 4f8384645d..73b6480278 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDag.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDag.java @@ -74,10 +74,10 @@ private void checkAddRemoveNodes(Dag graph) { assertTrue(parents.contains(x3)); assertTrue(parents.contains(x5)); - assertTrue(graph.paths().isMConnectedTo(x1, x3, Collections.EMPTY_SET)); + assertTrue(graph.paths().isMConnectedTo(x1, x3, Collections.EMPTY_SET, false)); - assertTrue(graph.paths().existsDirectedPathFromTo(x1, x4)); - assertFalse(graph.paths().existsDirectedPathFromTo(x1, x5)); + assertTrue(graph.paths().existsDirectedPath(x1, x4)); + assertFalse(graph.paths().existsDirectedPath(x1, x5)); assertTrue(graph.paths().isAncestorOf(x2, x4)); assertFalse(graph.paths().isAncestorOf(x4, x2)); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDagInPatternIterator.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDagInPatternIterator.java index 8f62d049f0..94abc8d6cd 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDagInPatternIterator.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDagInPatternIterator.java @@ -54,7 +54,7 @@ public void test1() { Dag dag = new Dag(graph); - Graph CPDAG = GraphTransforms.cpdagForDag(graph); + Graph CPDAG = GraphTransforms.dagToCpdag(graph); System.out.println(CPDAG); @@ -175,7 +175,7 @@ public void test5() { Dag dag1 = new Dag(RandomGraph.randomGraph(nodes1, 0, 3, 30, 15, 15, false)); - Graph CPDAG = GraphTransforms.cpdagForDag(dag1); + Graph CPDAG = GraphTransforms.dagToCpdag(dag1); List nodes = CPDAG.getNodes(); // Make random knowedge. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestEdgeListGraph.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestEdgeListGraph.java index c9f7e1758a..70ce6d035f 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestEdgeListGraph.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestEdgeListGraph.java @@ -74,7 +74,7 @@ public void testSequence1() { assertEquals(children, Collections.singletonList(this.x2)); assertEquals(parents, Collections.singletonList(this.x3)); - assertTrue(this.graph.paths().isMConnectedTo(this.x1, this.x3, Collections.EMPTY_SET)); + assertTrue(this.graph.paths().isMConnectedTo(this.x1, this.x3, Collections.EMPTY_SET, false)); this.graph.removeNode(this.x2); // No cycles. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java index 5730004833..3f7631d912 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java @@ -197,7 +197,7 @@ public void explore1() { alg.setFaithfulnessAssumed(true); Graph estCPDAG = alg.search(); - Graph trueCPDAG = GraphTransforms.cpdagForDag(dag); + Graph trueCPDAG = GraphTransforms.dagToCpdag(dag); estCPDAG = GraphUtils.replaceNodes(estCPDAG, vars); @@ -242,7 +242,7 @@ public void testExplore3() { Graph graph = GraphUtils.convert("A-->B,A-->C,B-->D,C-->D"); edu.cmu.tetrad.search.Fges fges = new edu.cmu.tetrad.search.Fges(new GraphScore(graph)); Graph CPDAG = fges.search(); - assertEquals(GraphTransforms.cpdagForDag(graph), CPDAG); + assertEquals(GraphTransforms.dagToCpdag(graph), CPDAG); } @Test @@ -250,7 +250,7 @@ public void testExplore4() { Graph graph = GraphUtils.convert("A-->B,A-->C,A-->D,B-->E,C-->E,D-->E"); edu.cmu.tetrad.search.Fges fges = new edu.cmu.tetrad.search.Fges(new GraphScore(graph)); Graph CPDAG = fges.search(); - assertEquals(GraphTransforms.cpdagForDag(graph), CPDAG); + assertEquals(GraphTransforms.dagToCpdag(graph), CPDAG); } @Test @@ -259,7 +259,7 @@ public void testExplore5() { edu.cmu.tetrad.search.Fges fges = new edu.cmu.tetrad.search.Fges(new GraphScore(graph)); fges.setFaithfulnessAssumed(true); Graph CPDAG = fges.search(); - assertEquals(GraphTransforms.cpdagForDag(graph), CPDAG); + assertEquals(GraphTransforms.dagToCpdag(graph), CPDAG); } @Test @@ -599,7 +599,7 @@ public void testFromGraph() { fges.setVerbose(true); fges.setNumThreads(1); Graph CPDAG1 = fges.search(); - Graph CPDAG2 = GraphTransforms.cpdagForDag(dag); + Graph CPDAG2 = GraphTransforms.dagToCpdag(dag); assertEquals(CPDAG2, CPDAG1); } } diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java index 4653051cec..c51dfb7a7d 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java @@ -152,6 +152,8 @@ public void test2() { g1.addDirectedEdge(L, x2); g1.addDirectedEdge(L, x3); + System.out.println(g1); + GFci gfci = new GFci(new MsepTest(g1), new GraphScore(g1)); Graph pag = gfci.search(); @@ -167,7 +169,7 @@ public void test2() { truePag.addBidirectedEdge(x2, x3); truePag.addPartiallyOrientedEdge(x4, x3); - assertEquals(pag, truePag); + assertEquals(truePag, pag); } // @Test diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraph.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraph.java index 8b483c6587..6757d17c65 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraph.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraph.java @@ -213,7 +213,7 @@ public void testLegalCpdag() { assertFalse(g1.paths().isLegalCpdag()); Graph g2 = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); - g2 = GraphTransforms.cpdagForDag(g2); + g2 = GraphTransforms.dagToCpdag(g2); assertTrue(g2.paths().isLegalCpdag()); @@ -282,7 +282,7 @@ private void checkAddRemoveNodes(Graph graph) { List children = graph.getChildren(x1); List parents = graph.getParents(x4); - assertTrue(graph.paths().isMConnectedTo(x1, x3, new HashSet<>())); + assertTrue(graph.paths().isMConnectedTo(x1, x3, new HashSet<>(), false)); graph.removeNode(x2); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java index 6c5f854e34..086dc50bc8 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGraphUtils.java @@ -1,283 +1,450 @@ -/////////////////////////////////////////////////////////////////////////////// -// For information as to what this class does, see the Javadoc, below. // -// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // -// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // -// Scheines, Joseph Ramsey, and Clark Glymour. // -// // -// This program is free software; you can redistribute it and/or modify // -// it under the terms of the GNU General Public License as published by // -// the Free Software Foundation; either version 2 of the License, or // -// (at your option) any later version. // -// // -// This program is distributed in the hope that it will be useful, // -// but WITHOUT ANY WARRANTY; without even the implied warranty of // -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // -// GNU General Public License for more details. // -// // -// You should have received a copy of the GNU General Public License // -// along with this program; if not, write to the Free Software // -// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // -/////////////////////////////////////////////////////////////////////////////// - -package edu.cmu.tetrad.test; - -import edu.cmu.tetrad.data.ContinuousVariable; -import edu.cmu.tetrad.graph.*; -import edu.cmu.tetrad.util.RandomUtil; -import org.junit.Test; - -import java.util.*; - -import static junit.framework.TestCase.assertEquals; -import static org.junit.Assert.*; - -/** - * @author josephramsey - */ -public final class TestGraphUtils { - - @Test - public void testCreateRandomDag() { - List nodes = new ArrayList<>(); - - for (int i = 0; i < 50; i++) { - nodes.add(new ContinuousVariable("X" + (i + 1))); - } + /////////////////////////////////////////////////////////////////////////////// + // For information as to what this class does, see the Javadoc, below. // + // Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // + // 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // + // Scheines, Joseph Ramsey, and Clark Glymour. // + // // + // This program is free software; you can redistribute it and/or modify // + // it under the terms of the GNU General Public License as published by // + // the Free Software Foundation; either version 2 of the License, or // + // (at your option) any later version. // + // // + // This program is distributed in the hope that it will be useful, // + // but WITHOUT ANY WARRANTY; without even the implied warranty of // + // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // + // GNU General Public License for more details. // + // // + // You should have received a copy of the GNU General Public License // + // along with this program; if not, write to the Free Software // + // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // + /////////////////////////////////////////////////////////////////////////////// + + package edu.cmu.tetrad.test; + + import edu.cmu.tetrad.data.ContinuousVariable; + import edu.cmu.tetrad.data.Knowledge; + import edu.cmu.tetrad.graph.*; + import edu.cmu.tetrad.search.utils.DagSepsets; + import edu.cmu.tetrad.search.utils.FciOrient; + import edu.cmu.tetrad.util.RandomUtil; + import org.jetbrains.annotations.Nullable; + import org.junit.Test; + + import java.util.*; + + import static junit.framework.TestCase.assertEquals; + import static org.junit.Assert.*; + + /** + * @author josephramsey + */ + public final class TestGraphUtils { + + @Test + public void testCreateRandomDag() { + List nodes = new ArrayList<>(); + + for (int i = 0; i < 50; i++) { + nodes.add(new ContinuousVariable("X" + (i + 1))); + } - Dag dag = new Dag(RandomGraph.randomGraph(nodes, 0, 50, - 4, 3, 3, false)); + Dag dag = new Dag(RandomGraph.randomGraph(nodes, 0, 50, + 4, 3, 3, false)); - assertEquals(50, dag.getNumNodes()); - assertEquals(50, dag.getNumEdges()); - } + assertEquals(50, dag.getNumNodes()); + assertEquals(50, dag.getNumEdges()); + } - @Test - public void testDirectedPaths() { - List nodes = new ArrayList<>(); + @Test + public void testDirectedPaths() { + List nodes = new ArrayList<>(); - for (int i1 = 0; i1 < 6; i1++) { - nodes.add(new ContinuousVariable("X" + (i1 + 1))); - } + for (int i1 = 0; i1 < 6; i1++) { + nodes.add(new ContinuousVariable("X" + (i1 + 1))); + } - Graph graph = new Dag(RandomGraph.randomGraph(nodes, 0, 6, - 3, 3, 3, false)); + Graph graph = new Dag(RandomGraph.randomGraph(nodes, 0, 6, + 3, 3, 3, false)); - for (int i = 0; i < graph.getNodes().size(); i++) { - for (int j = 0; j < graph.getNodes().size(); j++) { - Node node1 = graph.getNodes().get(i); - Node node2 = graph.getNodes().get(j); + for (int i = 0; i < graph.getNodes().size(); i++) { + for (int j = 0; j < graph.getNodes().size(); j++) { + Node node1 = graph.getNodes().get(i); + Node node2 = graph.getNodes().get(j); - List> directedPaths = graph.paths().directedPathsFromTo(node1, node2, -1); + List> directedPaths = graph.paths().directedPaths(node1, node2, -1); - for (List path : directedPaths) { - assertTrue(graph.paths().isAncestorOf(path.get(0), path.get(path.size() - 1))); + for (List path : directedPaths) { + assertTrue(graph.paths().isAncestorOf(path.get(0), path.get(path.size() - 1))); + } } } } - } - @Test - public void testTreks() { - List nodes = new ArrayList<>(); + @Test + public void testTreks() { + List nodes = new ArrayList<>(); - for (int i1 = 0; i1 < 10; i1++) { - nodes.add(new ContinuousVariable("X" + (i1 + 1))); - } + for (int i1 = 0; i1 < 10; i1++) { + nodes.add(new ContinuousVariable("X" + (i1 + 1))); + } - Graph graph = new Dag(RandomGraph.randomGraph(nodes, 0, 15, - 3, 3, 3, false)); + Graph graph = new Dag(RandomGraph.randomGraph(nodes, 0, 15, + 3, 3, 3, false)); - for (int i = 0; i < graph.getNodes().size(); i++) { - for (int j = 0; j < graph.getNodes().size(); j++) { - Node node1 = graph.getNodes().get(i); - Node node2 = graph.getNodes().get(j); + for (int i = 0; i < graph.getNodes().size(); i++) { + for (int j = 0; j < graph.getNodes().size(); j++) { + Node node1 = graph.getNodes().get(i); + Node node2 = graph.getNodes().get(j); - List> treks = graph.paths().treks(node1, node2, -1); + List> treks = graph.paths().treks(node1, node2, -1); - TREKS: - for (List trek : treks) { - Node m0 = trek.get(0); - Node m1 = trek.get(trek.size() - 1); + TREKS: + for (List trek : treks) { + Node m0 = trek.get(0); + Node m1 = trek.get(trek.size() - 1); - for (Node n : trek) { + for (Node n : trek) { - // Not quite it but good enough for a test. - if (graph.paths().isAncestorOf(n, m0) && graph.paths().isAncestorOf(n, m1)) { - continue TREKS; + // Not quite it but good enough for a test. + if (graph.paths().isAncestorOf(n, m0) && graph.paths().isAncestorOf(n, m1)) { + continue TREKS; + } } - } - fail("Some trek failed."); + fail("Some trek failed."); + } } } } - } - @Test - public void testGraphToDot() { - final long seed = 28583848283L; - RandomUtil.getInstance().setSeed(seed); + @Test + public void testGraphToDot() { + final long seed = 28583848283L; + RandomUtil.getInstance().setSeed(seed); - List nodes = new ArrayList<>(); + List nodes = new ArrayList<>(); + + for (int i = 0; i < 5; i++) { + nodes.add(new ContinuousVariable("X" + (i + 1))); + } + + Graph g = new Dag(RandomGraph.randomGraph(nodes, 0, 5, + 30, 15, 15, false)); + + String x = GraphSaveLoadUtils.graphToDot(g); + String[] tokens = x.split("\n"); + int length = tokens.length; + assertEquals(7, length); - for (int i = 0; i < 5; i++) { - nodes.add(new ContinuousVariable("X" + (i + 1))); } - Graph g = new Dag(RandomGraph.randomGraph(nodes, 0, 5, - 30, 15, 15, false)); + @Test + public void testTwoCycleErrors() { + Node x1 = new GraphNode("X1"); + Node x2 = new GraphNode("X2"); + Node x3 = new GraphNode("X3"); + Node x4 = new GraphNode("X4"); + + Graph trueGraph = new EdgeListGraph(); + trueGraph.addNode(x1); + trueGraph.addNode(x2); + trueGraph.addNode(x3); + trueGraph.addNode(x4); + + Graph estGraph = new EdgeListGraph(); + estGraph.addNode(x1); + estGraph.addNode(x2); + estGraph.addNode(x3); + estGraph.addNode(x4); + + trueGraph.addDirectedEdge(x1, x2); + trueGraph.addDirectedEdge(x2, x1); + trueGraph.addDirectedEdge(x2, x3); + trueGraph.addDirectedEdge(x3, x2); + + estGraph.addDirectedEdge(x1, x2); + estGraph.addDirectedEdge(x2, x1); + estGraph.addDirectedEdge(x3, x4); + estGraph.addDirectedEdge(x4, x3); + estGraph.addDirectedEdge(x4, x1); + estGraph.addDirectedEdge(x1, x4); + + GraphUtils.TwoCycleErrors errors = GraphUtils.getTwoCycleErrors(trueGraph, estGraph); + + assertEquals(1, errors.twoCycCor); + assertEquals(2, errors.twoCycFp); + assertEquals(1, errors.twoCycFn); + } - String x = GraphSaveLoadUtils.graphToDot(g); - String[] tokens = x.split("\n"); - int length = tokens.length; - assertEquals(7, length); + @Test + public void testMsep() { + Node a = new ContinuousVariable("A"); + Node b = new ContinuousVariable("B"); + Node x = new ContinuousVariable("X"); + Node y = new ContinuousVariable("Y"); - } + Graph graph = new EdgeListGraph(); - @Test - public void testTwoCycleErrors() { - Node x1 = new GraphNode("X1"); - Node x2 = new GraphNode("X2"); - Node x3 = new GraphNode("X3"); - Node x4 = new GraphNode("X4"); - - Graph trueGraph = new EdgeListGraph(); - trueGraph.addNode(x1); - trueGraph.addNode(x2); - trueGraph.addNode(x3); - trueGraph.addNode(x4); - - Graph estGraph = new EdgeListGraph(); - estGraph.addNode(x1); - estGraph.addNode(x2); - estGraph.addNode(x3); - estGraph.addNode(x4); - - trueGraph.addDirectedEdge(x1, x2); - trueGraph.addDirectedEdge(x2, x1); - trueGraph.addDirectedEdge(x2, x3); - trueGraph.addDirectedEdge(x3, x2); - - estGraph.addDirectedEdge(x1, x2); - estGraph.addDirectedEdge(x2, x1); - estGraph.addDirectedEdge(x3, x4); - estGraph.addDirectedEdge(x4, x3); - estGraph.addDirectedEdge(x4, x1); - estGraph.addDirectedEdge(x1, x4); - - GraphUtils.TwoCycleErrors errors = GraphUtils.getTwoCycleErrors(trueGraph, estGraph); - - assertEquals(1, errors.twoCycCor); - assertEquals(2, errors.twoCycFp); - assertEquals(1, errors.twoCycFn); - } + graph.addNode(a); + graph.addNode(b); + graph.addNode(x); + graph.addNode(y); - @Test - public void testMsep() { - Node a = new ContinuousVariable("A"); - Node b = new ContinuousVariable("B"); - Node x = new ContinuousVariable("X"); - Node y = new ContinuousVariable("Y"); + graph.addDirectedEdge(a, x); + graph.addDirectedEdge(b, y); + graph.addDirectedEdge(x, y); + graph.addDirectedEdge(y, x); - Graph graph = new EdgeListGraph(); + // System.out.println(graph); - graph.addNode(a); - graph.addNode(b); - graph.addNode(x); - graph.addNode(y); + assertTrue(graph.paths().isAncestorOf(a, a)); + assertTrue(graph.paths().isAncestorOf(b, b)); + assertTrue(graph.paths().isAncestorOf(x, x)); + assertTrue(graph.paths().isAncestorOf(y, y)); - graph.addDirectedEdge(a, x); - graph.addDirectedEdge(b, y); - graph.addDirectedEdge(x, y); - graph.addDirectedEdge(y, x); + assertTrue(graph.paths().isAncestorOf(a, x)); + assertFalse(graph.paths().isAncestorOf(x, a)); + assertTrue(graph.paths().isAncestorOf(a, y)); + assertFalse(graph.paths().isAncestorOf(y, a)); - assertTrue(graph.paths().isAncestorOf(a, a)); - assertTrue(graph.paths().isAncestorOf(b, b)); - assertTrue(graph.paths().isAncestorOf(x, x)); - assertTrue(graph.paths().isAncestorOf(y, y)); + assertTrue(graph.paths().isAncestorOf(a, y)); + assertTrue(graph.paths().isAncestorOf(b, x)); - assertTrue(graph.paths().isAncestorOf(a, x)); - assertFalse(graph.paths().isAncestorOf(x, a)); - assertTrue(graph.paths().isAncestorOf(a, y)); - assertFalse(graph.paths().isAncestorOf(y, a)); + assertFalse(graph.paths().isAncestorOf(a, b)); + assertFalse(graph.paths().isAncestorOf(y, a)); + assertFalse(graph.paths().isAncestorOf(x, b)); - assertTrue(graph.paths().isAncestorOf(a, y)); - assertTrue(graph.paths().isAncestorOf(b, x)); + assertTrue(graph.paths().isMConnectedTo(a, y, new HashSet<>(), false)); + assertTrue(graph.paths().isMConnectedTo(b, x, new HashSet<>(), false)); - assertFalse(graph.paths().isAncestorOf(a, b)); - assertFalse(graph.paths().isAncestorOf(y, a)); - assertFalse(graph.paths().isAncestorOf(x, b)); + // MSEP problem now with 2-cycles. TODO + assertTrue(graph.paths().isMConnectedTo(a, y, Collections.singleton(x), false)); + assertTrue(graph.paths().isMConnectedTo(b, x, Collections.singleton(y), false)); - assertTrue(graph.paths().isMConnectedTo(a, y, new HashSet<>())); - assertTrue(graph.paths().isMConnectedTo(b, x, new HashSet<>())); + assertTrue(graph.paths().isMConnectedTo(a, y, Collections.singleton(b), false)); + assertTrue(graph.paths().isMConnectedTo(b, x, Collections.singleton(a), false)); - assertTrue(graph.paths().isMConnectedTo(a, y, Collections.singleton(x))); - assertTrue(graph.paths().isMConnectedTo(b, x, Collections.singleton(y))); + assertTrue(graph.paths().isMConnectedTo(y, a, Collections.singleton(b), false)); + assertTrue(graph.paths().isMConnectedTo(x, b, Collections.singleton(a), false)); + } - assertTrue(graph.paths().isMConnectedTo(a, y, Collections.singleton(b))); - assertTrue(graph.paths().isMConnectedTo(b, x, Collections.singleton(a))); + @Test + public void testMsep2() { + Node a = new ContinuousVariable("A"); + Node b = new ContinuousVariable("B"); + Node c = new ContinuousVariable("C"); - assertTrue(graph.paths().isMConnectedTo(y, a, Collections.singleton(b))); - assertTrue(graph.paths().isMConnectedTo(x, b, Collections.singleton(a))); - } + Graph graph = new EdgeListGraph(); - @Test - public void testMsep2() { - Node a = new ContinuousVariable("A"); - Node b = new ContinuousVariable("B"); - Node c = new ContinuousVariable("C"); + graph.addNode(a); + graph.addNode(b); + graph.addNode(c); - Graph graph = new EdgeListGraph(); + graph.addDirectedEdge(a, b); + graph.addDirectedEdge(b, c); + graph.addDirectedEdge(c, b); - graph.addNode(a); - graph.addNode(b); - graph.addNode(c); + // System.out.println(graph); - graph.addDirectedEdge(a, b); - graph.addDirectedEdge(b, c); - graph.addDirectedEdge(c, b); + assertTrue(graph.paths().isAncestorOf(a, b)); + assertTrue(graph.paths().isAncestorOf(a, c)); - assertTrue(graph.paths().isAncestorOf(a, b)); - assertTrue(graph.paths().isAncestorOf(a, c)); + // MSEP problem now with 2-cycles. TODO + assertTrue(graph.paths().isMConnectedTo(a, b, Collections.EMPTY_SET, false)); + assertTrue(graph.paths().isMConnectedTo(a, c, Collections.EMPTY_SET, false)); + // + assertTrue(graph.paths().isMConnectedTo(a, c, Collections.singleton(b), false)); + assertTrue(graph.paths().isMConnectedTo(c, a, Collections.singleton(b), false)); + } - assertTrue(graph.paths().isMConnectedTo(a, b, Collections.EMPTY_SET)); - assertTrue(graph.paths().isMConnectedTo(a, c, Collections.EMPTY_SET)); - assertTrue(graph.paths().isMConnectedTo(a, c, Collections.singleton(b))); - assertTrue(graph.paths().isMConnectedTo(c, a, Collections.singleton(b))); - } + public void test8() { + final int numNodes = 5; + for (int i = 0; i < 100; i++) { + Graph graph = RandomGraph.randomGraphRandomForwardEdges(numNodes, 0, numNodes, 10, 10, 10, true); + + List nodes = graph.getNodes(); + Node x = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); + Node y = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); + Node z1 = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); + Node z2 = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); + + if (graph.paths().isMSeparatedFrom(x, y, set(z1), false) && graph.paths().isMSeparatedFrom(x, y, set(z2), false) && + !graph.paths().isMSeparatedFrom(x, y, set(z1, z2), false)) { + System.out.println("x = " + x); + System.out.println("y = " + y); + System.out.println("z1 = " + z1); + System.out.println("z2 = " + z2); + System.out.println(graph); + return; + } + } + } - public void test8() { - final int numNodes = 5; - for (int i = 0; i < 100000; i++) { - Graph graph = RandomGraph.randomGraphRandomForwardEdges(numNodes, 0, numNodes, 10, 10, 10, true); + @Test + public void test9() { + + Graph graph = RandomGraph.randomGraphRandomForwardEdges(20, 0, 50, + 10, 10, 10, false); + graph = GraphTransforms.dagToCpdag(graph); + + int numSmnallestSizes = 2; + + System.out.println(graph); + + System.out.println("Number of smallest sizes printed = " + numSmnallestSizes); List nodes = graph.getNodes(); - Node x = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); - Node y = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); - Node z1 = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); - Node z2 = nodes.get(RandomUtil.getInstance().nextInt(numNodes)); - - if (graph.paths().isMSeparatedFrom(x, y, set(z1)) && graph.paths().isMSeparatedFrom(x, y, set(z2)) && - !graph.paths().isMSeparatedFrom(x, y, set(z1, z2))) { - System.out.println("x = " + x); - System.out.println("y = " + y); - System.out.println("z1 = " + z1); - System.out.println("z2 = " + z2); - System.out.println(graph); - return; + + for (Node x : nodes) { + for (Node y : nodes) { + if (x == y) continue; + Set> sets = GraphUtils.visibleEdgeAdjustments3(graph, x, y, numSmnallestSizes, GraphUtils.GraphType.CPDAG); + + if (sets.isEmpty()) { + continue; + } + + System.out.println(); + + for (Set set : sets) { + System.out.println("For " + x + "-->" + y + ", set = " + set); + } + } } } - } - private Set set(Node... z) { - Set list = new HashSet<>(); - Collections.addAll(list, z); - return list; + private static @Nullable Graph getGraphWithoutXToYPag(Node x, Node y, Graph graph) { + if (!graph.isAdjacentTo(x, y)) return null; + + if (Edges.isBidirectedEdge(graph.getEdge(x, y))) { + return null; + } else if (Edges.isPartiallyOrientedEdge(graph.getEdge(x, y)) && graph.getEdge(x, y).pointsTowards(x)) { + return null; + } else if (Edges.isUndirectedEdge(graph.getEdge(x, y))) { + return null; + } + + Graph _graph = new EdgeListGraph(graph); + + _graph.removeEdge(x, y); + _graph.addDirectedEdge(x, y); + + Knowledge knowledge = new Knowledge(); + knowledge.setRequired(x.getName(), y.getName()); + + FciOrient fciOrientation = new FciOrient(new DagSepsets(graph)); + fciOrientation.setKnowledge(knowledge); + fciOrientation.orient(_graph); + + _graph.removeEdge(x, y); + return _graph; + } + + @Test + public void test10() { + + Graph graph = RandomGraph.randomGraphRandomForwardEdges(10, 2, 10, + 10, 10, 10, false); + graph = GraphTransforms.dagToPag(graph); + + int numSmnallestSizes = 2; + + System.out.println(graph); + + System.out.println("Number of smallest sizes printed = " + numSmnallestSizes); + + List nodes = graph.getNodes(); + + for (Node x : nodes) { + for (Node y : nodes) { + if (x == y) continue; + Set> sets = null; + try { + sets = GraphUtils.visibleEdgeAdjustments1(graph, x, y, numSmnallestSizes, GraphUtils.GraphType.PAG); + } catch (Exception e) { + continue; + } + + if (sets.isEmpty()) { + continue; + } + + System.out.println(); + + for (Set set : sets) { + System.out.println("For " + x + "-->" + y + ", set = " + set); + } + } + } + } + + @Test + public void test11() { +// RandomUtil.getInstance().setSeed(1040404L); + + // 10 times over, make a random DAG + for (int i = 0; i < 1000; i++) { + Graph graph = RandomGraph.randomGraphRandomForwardEdges(5, 0, 5, + 100, 100, 100, false); + + // Construct its CPDAG + Graph cpdag = GraphTransforms.dagToCpdag(graph); + assertTrue(cpdag.paths().isLegalCpdag()); + assertTrue(cpdag.paths().isLegalMpdag()); + +// if (!cpdag.paths().isLegalCpdag()) { +// +// System.out.println("Not legal CPDAG:"); +// +// System.out.println(cpdag); +// +// List pi = new ArrayList<>(cpdag.getNodes()); +// cpdag.paths().makeValidOrder(pi); +// +// System.out.println("Valid order: " + pi); +// +// Graph dag = Paths.getDag(pi, cpdag, true); +// +// System.out.println("DAG: " + dag); +// +// Graph cpdag2 = GraphTransforms.dagToCpdag(dag); +// +// System.out.println("CPDAG for DAG: " + cpdag2); +// +// break; +// } + } + + } + + private Set set(Node... z) { + Set list = new HashSet<>(); + Collections.addAll(list, z); + return list; + } + + /** + * A test of m-connection. We generate 10 random graphs with latents and check that dagToPag + * produces a legal PAG. We then call dagToPag again on the PAG and check that the result is + * also a legal PAG. + */ + @Test + public void test12() { + RandomUtil.getInstance().setSeed(1040404L); + + for (int i = 0; i < 10; i++) { + Graph graph = RandomGraph.randomGraph(10, 3, 10, + 10, 10, 10, false); + Graph pag = GraphTransforms.dagToPag(graph); + assertTrue(pag.paths().isLegalPag()); + Graph pag2 = GraphTransforms.dagToPag(pag); + assertTrue(pag2.paths().isLegalPag()); + } + } } -} diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGrasp.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGrasp.java index e3e68939ff..76453df0ab 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGrasp.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGrasp.java @@ -2364,7 +2364,7 @@ private boolean setPathsCanceling(Node x1, Node x4, StandardizedSemIm imsd, List SemGraph graph = imsd.getSemPm().getGraph(); graph.setShowErrorTerms(false); - List> paths = graph.paths().allDirectedPathsFromTo(x1, x4, -1); + List> paths = graph.paths().allDirectedPaths(x1, x4, -1); if (paths.size() < 2) return false; @@ -2519,8 +2519,8 @@ public void testFciAlgs() { statistics.add(new LegalPag()); // statistics.add(new NoAlmostCyclicPathsCondition()); // statistics.add(new NoCyclicPathsCondition()); - statistics.add(new NoAlmostCyclicPathsInMagCondition()); - statistics.add(new NoCyclicPathsInMagCondition()); + statistics.add(new NoAlmostCyclicPathsCondition()); + statistics.add(new NoCyclicPathsCondition()); statistics.add(new MaximalityCondition()); statistics.add(new ParameterColumn(Params.ALPHA)); @@ -2773,7 +2773,7 @@ public void testDsep() { for (Node y : graph.getNodes()) { if (!graph.paths().isDescendentOf(y, x) && !parents.contains(y)) { - if (!graph.paths().isMSeparatedFrom(x, y, parents)) { + if (!graph.paths().isMSeparatedFrom(x, y, parents, false)) { System.out.println("Failure! " + LogUtilsSearch.dependenceFactMsg(x, y, parents, 1.0)); } } @@ -3149,7 +3149,7 @@ public void testWayne2() { if (g1.equals(g2)) gsCount++; gsShd += GraphSearchUtils.structuralHammingDistance( - GraphTransforms.cpdagForDag(g1), GraphTransforms.cpdagForDag(g2)); + GraphTransforms.dagToCpdag(g1), GraphTransforms.dagToCpdag(g2)); for (int i = 0; i < alpha.length; i++) { // test.setAlpha(alpha[i]); @@ -3164,7 +3164,7 @@ public void testWayne2() { if (g1.equals(g3)) pearlCounts[i]++; pearlShd[i] += GraphSearchUtils.structuralHammingDistance( - GraphTransforms.cpdagForDag(g1), GraphTransforms.cpdagForDag(g3)); + GraphTransforms.dagToCpdag(g1), GraphTransforms.dagToCpdag(g3)); } } @@ -3298,7 +3298,7 @@ public void testAddUnfaithfulIndependencies() { count++; } else { - List> paths = graph.paths().allPathsFromTo(x, y, 4); + List> paths = graph.paths().allPaths(x, y, 4); if (paths.size() >= 1) { List> nonTrekPaths = new ArrayList<>(); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java index a0b7b31cc7..79ffc41a14 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestPc.java @@ -217,7 +217,7 @@ public void checknumCPDAGsToStore() { MsepTest test = new MsepTest(graph); Pc pc = new Pc(test); Graph CPDAG = pc.search(); - Graph CPDAG2 = GraphTransforms.cpdagForDag(graph); + Graph CPDAG2 = GraphTransforms.dagToCpdag(graph); assertEquals(CPDAG, CPDAG2); } } @@ -399,7 +399,7 @@ private double[] printStats(String[] algorithms, int t, } if (edge.getEndpoint1() == Endpoint.TAIL) { - if (dag.paths().existsDirectedPathFromTo(edge.getNode1(), edge.getNode2())) { + if (dag.paths().existsDirectedPath(edge.getNode1(), edge.getNode2())) { tailsTp++; } else { tailsFp++; @@ -409,7 +409,7 @@ private double[] printStats(String[] algorithms, int t, } if (edge.getEndpoint2() == Endpoint.TAIL) { - if (dag.paths().existsDirectedPathFromTo(edge.getNode2(), edge.getNode1())) { + if (dag.paths().existsDirectedPath(edge.getNode2(), edge.getNode1())) { tailsTp++; } else { tailsFp++; diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRubenData.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRubenData.java index e1e816251a..5a49a23d35 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRubenData.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRubenData.java @@ -53,7 +53,7 @@ public void test1() { Graph graph = GraphSaveLoadUtils.loadGraphTxt(new File(path2)); - graph = GraphTransforms.cpdagForDag(graph); + graph = GraphTransforms.dagToCpdag(graph); SemBicScore score = new SemBicScore(data, precomputeCovariances); score.setPenaltyDiscount(2); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestSearchGraph.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestSearchGraph.java index 67133fe906..c875ae467a 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestSearchGraph.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestSearchGraph.java @@ -241,6 +241,8 @@ public void testMSeparation() { } if (graph.isMSeparatedFrom(x, y, z) != graph.isMSeparatedFrom(y, x, z)) { + System.out.println(graph); + fail(LogUtilsSearch.independenceFact(x, y, z) + " should have same m-sep result as " + LogUtilsSearch.independenceFact(y, x, z)); } @@ -285,8 +287,8 @@ public void testMSeparation2() { z.add(theRest.get(value)); } - boolean mConnectedTo = graph.paths().isMConnectedTo(x, y, z); - boolean mConnectedTo1 = graph.paths().isMConnectedTo(y, x, z); + boolean mConnectedTo = graph.paths().isMConnectedTo(x, y, z, false); + boolean mConnectedTo1 = graph.paths().isMConnectedTo(y, x, z, false); if (mConnectedTo != mConnectedTo1) { System.out.println(x + " d connected to " + y + " given " + z); @@ -304,10 +306,10 @@ public void testMSeparation2() { // Trying to trip up the breadth first algorithm. public void testMSeparation3() { Graph graph = GraphUtils.convert("x-->s1,x-->s2,s1-->s3,s3-->s2,s3<--y"); - assertTrue(graph.paths().isMSeparatedFrom(graph.getNode("x"), graph.getNode("y"), new HashSet<>())); + assertTrue(graph.paths().isMSeparatedFrom(graph.getNode("x"), graph.getNode("y"), new HashSet<>(), false)); graph = GraphUtils.convert("1-->2,2<--4,2-->7,2-->3"); - assertTrue(graph.paths().isMSeparatedFrom(graph.getNode("4"), graph.getNode("1"), new HashSet<>())); + assertTrue(graph.paths().isMSeparatedFrom(graph.getNode("4"), graph.getNode("1"), new HashSet<>(), false)); graph = GraphUtils.convert("X1-->X5,X1-->X6,X2-->X3,X4-->X6,X5-->X3,X6-->X5,X7-->X3"); assertTrue(mConnected(graph, "X2", "X4", "X3", "X6")); @@ -378,7 +380,7 @@ private boolean mConnected(Graph graph, String x, String y, String... z) { _z.add(graph.getNode(name)); } - return graph.paths().isMConnectedTo(_x, _y, _z); + return graph.paths().isMConnectedTo(_x, _y, _z, false); } public void testAlternativeGraphs() {