diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java index 2a6f684960..bfe8d1dbd8 100755 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java @@ -23,6 +23,7 @@ import edu.cmu.tetrad.bayes.BayesIm; import edu.cmu.tetrad.bayes.BayesPm; +import edu.cmu.tetrad.bayes.MlBayesIm; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.NumberFormatUtil; @@ -536,8 +537,9 @@ public Object getValueAt(int tableRow, int tableCol) { int colIndex = tableCol - parentVals.length; if (colIndex < getBayesIm().getNumColumns(getNodeIndex())) { - return getBayesIm().getProbability(getNodeIndex(), tableRow, + double probability = getBayesIm().getProbability(getNodeIndex(), tableRow, colIndex); + return probability; } return "null"; @@ -555,10 +557,14 @@ public boolean isCellEditable(int row, int col) { * Sets the value of the cell at (row, col) to 'aValue'. */ public void setValueAt(Object aValue, int row, int col) { + if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.COUNT_MAP) { + return; + } + int numParents = getBayesIm().getNumParents(getNodeIndex()); int colIndex = col - numParents; - if ("".equals(aValue) || aValue == null) { + if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.PROB_MAP && ("".equals(aValue) || aValue == null)) { getBayesIm().setProbability(getNodeIndex(), row, colIndex, Double.NaN); fireTableRowsUpdated(row, row); @@ -752,6 +758,8 @@ public void resetFailedRow() { public void resetFailedCol() { this.failedCol = -1; } + + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java index c4e5e55146..da15ea8743 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java @@ -481,7 +481,7 @@ public void setEnabled(boolean enabled) { * @return a boolean */ public boolean isRandomForward() { - return this.parameters.getBoolean("graphRandomFoward", true); + return this.parameters.getBoolean("graphRandomForward", true); } private void setRandomForward(boolean randomFoward) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java index 08aa40b5ee..f8f0d8b793 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java @@ -36,6 +36,7 @@ import java.io.Serial; import java.util.ArrayList; import java.util.List; +import java.util.function.IntBinaryOperator; /** * Wraps a Bayes Pm for use in the Tetrad application. @@ -90,11 +91,11 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, BayesImWrapper oldBayesImwr BayesIm oldBayesIm = oldBayesImwrapper.getBayesIm(); if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) { - setBayesIm(bayesPm, oldBayesIm, MlBayesIm.MANUAL); + setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.MANUAL); } else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) { - setBayesIm(bayesPm, oldBayesIm, MlBayesIm.RANDOM); + setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.RANDOM); } else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) { - setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM)); + setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM)); } } @@ -193,9 +194,9 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, Parameters params) { if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) { setBayesIm(new MlBayesIm(bayesPm)); } else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) { - setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM)); + setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM)); } else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) { - setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM)); + setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM)); } } @@ -342,9 +343,9 @@ public String getModelSourceName() { //============================== private methods ============================// - private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int manual) { + private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, MlBayesIm.InitializationMethod initializationMethod) { this.bayesIms = new ArrayList<>(); - this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, manual)); + this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, initializationMethod)); } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java index 2c7ee8d4b3..6511a3bcc3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java @@ -206,8 +206,8 @@ public BayesPmWrapper(GraphWrapper graphWrapper, Parameters params) { lowerBound = upperBound = 3; setBayesPm(graph, lowerBound, upperBound); } else if (params.getString("bayesPmInitializationMode", "range").equals("range")) { - lowerBound = params.getInt("minCategories", 2); - upperBound = params.getInt("maxCategories", 4); + lowerBound = params.getInt("lowerBoundNumVals", 2); + upperBound = params.getInt("upperBoundNumVals", 4); setBayesPm(graph, lowerBound, upperBound); } else { throw new IllegalStateException("Unrecognized type."); diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java index 42681673ae..4ccd0fc806 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java @@ -1173,6 +1173,13 @@ public int compareTo(Node node) { * @return a boolean */ public boolean existsParameterizedConstructor(Class modelClass) { + if (modelClass == null) { + + // If the model class is null, then there is no constructor, so display a dialog to the users by + // return false here. + return false; + } + Object param = getParam(modelClass); List parentModels = listParentModels(); parentModels.add(param); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java index f973cf76de..2a0718b19f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java @@ -238,9 +238,9 @@ private DataSet simulate(Graph graph, Parameters parameters) { int minCategories = parameters.getInt(Params.MIN_CATEGORIES); int maxCategories = parameters.getInt(Params.MAX_CATEGORIES); pm = new BayesPm(graph, minCategories, maxCategories); - im = new MlBayesIm(pm, MlBayesIm.RANDOM); + im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); } else { - im = new MlBayesIm(pm, MlBayesIm.RANDOM); + im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); this.im = im; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java index 0af14dfd20..676915f1c8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java @@ -320,7 +320,7 @@ private DataSet simulate(Graph G, Parameters parameters) { } BayesPm bayesPm = new BayesPm(AG); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); SemPm semPm = new SemPm(XG); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java index 147b4c8d73..248d9ad374 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java @@ -43,7 +43,7 @@ public interface Statistic extends Serializable { /** * Returns a mapping of the statistic to the interval [0, 1], with higher being better. This is used for a - * calculation of a utility for an algorithm.If the statistic is already between 0 and 1, you can just return the + * calculation of a utility for an algorithm. If the statistic is already between 0 and 1, you can just return the * statistic. * * @param value The value of the statistic. diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java index 5b69d7c030..375cec1846 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java @@ -344,7 +344,7 @@ private void doUpdate() { } private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java index 7b1f1f7b62..a768c73798 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java @@ -393,4 +393,8 @@ public interface BayesIm extends VariableSource, Im, Simulator { * @return a string representation for this Bayes net. */ String toString(); + + default MlBayesIm.CptMapType getCptMapType() { + return MlBayesIm.CptMapType.PROB_MAP; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java index c3f9183d1a..312a6ba9d6 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java @@ -266,7 +266,7 @@ public String toString() { //==============================PRIVATE METHODS=======================// private BayesIm createdManipulatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createManipulatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java new file mode 100644 index 0000000000..d8d271c5e2 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java @@ -0,0 +1,41 @@ +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.util.TetradSerializable; + +/** + * An interface representing a map of probabilities or counts for nodes in a Bayesian network. Implementations of this + * interface should provide methods to get the probability or count for a node at a given row and column, as well as + * methods to retrieve the number of rows and columns in the map. + *

+ * This interface extends the TetradSerializable interface, indicating that implementations should be serializable and + * follow certain guidelines for compatibility across different versions of Tetrad. + * + * @author josephramsey + * @see CptMapProbs + * @see CptMapCounts + */ +public interface CptMap extends TetradSerializable { + + /** + * Retrieves the value at the specified row and column in the CptMap. + * + * @param row the row index of the value to retrieve. + * @param column the column index of the value to retrieve. + * @return the value at the specified row and column in the CptMap. + */ + double get(int row, int column); + + /** + * Retrieves the number of rows in the CptMap. + * + * @return the number of rows in the CptMap. + */ + int getNumRows(); + + /** + * Retrieves the number of columns in the CptMap. + * + * @return the number of columns in the CptMap. + */ + int getNumColumns(); +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java new file mode 100644 index 0000000000..61f936b2ec --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java @@ -0,0 +1,181 @@ +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.data.DataSet; + +import java.io.Serial; +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique + * integer index for a particular node to the cell count for that node, where 0's are not stored. Row counts are also + * stored, so that the probability of a cell can be calculated. A prior cell count of 0 is assumed for all cells, + * but this may be set by the user to any non-negative count. (A prior count of 0 is equivalent to a maximum + * likelihood estimate.) + * + * @author josephramsey + */ +public class CptMapCounts implements CptMap { + @Serial + private static final long serialVersionUID = 23L; + /** + * Constructs a new count map, a map from a unique integer index for a particular node to the count for that value, + * where 0's are not stored. + */ + private final Map cellCounts = new HashMap<>(); + /** + * Constructs a new row count map, a map from a unique integer index for a particular node to the count for that + * row, where 0's are not stored. + */ + private final Map rowCounts = new HashMap<>(); + /** + * The number of rows in the table. + */ + private final int numRows; + /** + * The number of columns in the table. + */ + private final int numColumns; + /** + * The prior count for all cells. + */ + private int priorCount = 1; + + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain + * number of rows and a certain number of columns in the table. + * + * @param numRows the number of rows in the table + * @param numColumns the number of columns in the table + */ + public CptMapCounts(int numRows, int numColumns) { + if (numRows < 1 || numColumns < 1) { + throw new IllegalArgumentException("Number of rows and columns must be at least 1."); + } + + this.numRows = numRows; + this.numColumns = numColumns; + } + + /** + * Constructs a new CptMap based on counts from a given dataset. + * + * @param data the DataSet object representing the probability matrix + * @throws IllegalArgumentException if the data set is null or not discrete + */ + public CptMapCounts(DataSet data) { + if (data == null) { + throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); + } + + if (!data.isDiscrete()) { + throw new IllegalArgumentException("Data set must be discrete."); + + } + + numRows = data.getNumRows(); + numColumns = data.getNumColumns(); + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + int key = i * numColumns + j; + + if (data.getInt(i, j) == -1) { + continue; + } + + if (data.getInt(i, j) == 0) { + continue; + } + + if (!cellCounts.containsKey(key)) { + cellCounts.put(key, 1); + } else { + cellCounts.put(key, cellCounts.get(key) + 1); + } + + if (!rowCounts.containsKey(i)) { + rowCounts.put(i, 1); + } else { + rowCounts.put(i, rowCounts.get(i) + 1); + } + } + } + } + + /** + * Returns the probability of the node taking on the value specified by the given row and column. + * + * @param row the row of the node + * @param column the column of the node + * @return the probability of the node taking on the value specified by the given row and column + */ + @Override + public double get(int row, int column) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + int rowCount = rowCounts.getOrDefault(row, 0); + int cellCount = cellCounts.getOrDefault(key, 0); + rowCount += priorCount * numColumns; + cellCount += priorCount; + return cellCount / (double) rowCount; + } + + public void addCounts(int row, int column, int count) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + + if (!cellCounts.containsKey(key)) { + cellCounts.put(key, count); + } else { + cellCounts.put(key, cellCounts.get(key) + count); + } + + if (!rowCounts.containsKey(row)) { + rowCounts.put(row, count); + } else { + rowCounts.put(row, rowCounts.get(row) + count); + } + } + + /** + * Returns the number of rows in the probability map. + * + * @return the number of rows in the probability map. + */ + @Override + public int getNumRows() { + return numRows; + } + + /** + * Returns the number of columns in the probability map. + * + * @return the number of columns in the probability map. + */ + @Override + public int getNumColumns() { + return numColumns; + } + + /** + * The prior count for all cells. + */ + public int getPriorCount() { + return priorCount; + } + + /** + * The prior count for all cells. + */ + public void setPriorCount(int priorCount) { + this.priorCount = priorCount; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java new file mode 100644 index 0000000000..f2ccf95c53 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java @@ -0,0 +1,160 @@ +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.util.TetradSerializable; +import edu.cmu.tetrad.util.Vector; + +import java.io.Serial; +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique + * integer index for a particular node to the probability of that node taking on that value, where NaN's are not + * stored. + * + * @author josephramsey + */ +public class CptMapProbs implements CptMap { + @Serial + private static final long serialVersionUID = 23L; + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. + */ + private final Map map = new HashMap<>(); + /** + * The number of rows in the table. + */ + private final int numRows; + /** + * The number of columns in the table. + */ + private final int numColumns; + + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain + * number of rows and a certain number of columns in the table. + * + * @param numRows the number of rows in the table + * @param numColumns the number of columns in the table + */ + public CptMapProbs(int numRows, int numColumns) { + if (numRows < 1 || numColumns < 1) { + throw new IllegalArgumentException("Number of rows and columns must be at least 1."); + } + + this.numRows = numRows; + this.numColumns = numColumns; + } + + /** + * Constructs a new probability map based on the given 2-dimensional array. + * + * @param probMatrix the 2-dimensional array representing the probability matrix + * @throws IllegalArgumentException if the number of columns in any row is different + */ + public CptMapProbs(double[][] probMatrix) { + if (probMatrix == null || probMatrix.length == 0 || probMatrix[0].length == 0) { + throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); + } + + numRows = probMatrix.length; + numColumns = probMatrix[0].length; + + for (int i = 0; i < numRows; i++) { + if (probMatrix[i].length != numColumns) { + throw new IllegalArgumentException("All rows must have the same number of columns."); + } + } + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + map.put(i * numColumns + j, probMatrix[i][j]); + } + } + } + + /** + * Returns the probability of the node taking on the value specified by the given row and column. + * + * @param row the row of the node + * @param column the column of the node + * @return the probability of the node taking on the value specified by the given row and column + */ + @Override + public double get(int row, int column) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + + if (!map.containsKey(key)) { + return Double.NaN; + } + + return map.get(key); + } + + /** + * Sets the probability of the node taking on the value specified by the given row and column to the given value. + * + * @param row the row of the node + * @param column the column of the node + * @param value the probability of the node taking on the value specified by the given row and column (NaN to + * remove the value) + */ + public void set(int row, int column, double value) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + + if (Double.isNaN(value)) { + map.remove(key); + return; + } + + map.put(key, value); + } + + /** + * Returns the number of rows in the probability map. + * + * @return the number of rows in the probability map. + */ + @Override + public int getNumRows() { + return numRows; + } + + /** + * Returns the number of columns in the probability map. + * + * @return the number of columns in the probability map. + */ + @Override + public int getNumColumns() { + return numColumns; + } + + /** + * Assigns the values in the provided vector to a specific row in the probability map. + * + * @param rowIndex the index of the row to be assigned + * @param vector the vector containing the values to be assigned to the row + * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the + * probability map + */ + public void assignRow(int rowIndex, Vector vector) { + if (vector.size() != numColumns) { + throw new IllegalArgumentException("Vector must have the same number of columns as the probability map."); + } + + for (int i = 0; i < numColumns; i++) { + set(rowIndex, i, vector.get(i)); + } + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java index ba63797b5b..75636a8fbd 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java @@ -490,7 +490,7 @@ private void estimateIM(BayesPm bayesPm, DataSet dataSet) { BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // Create a new Bayes IM to store the estimated values. - this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int numNodes = this.estimatedIm.getNumNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java index dfc75e9d5f..0bf7efecd7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java @@ -441,7 +441,7 @@ public double getJointMarginal(int[] sVariables, int[] sValues) { Dag gD = new Dag(this.bayesIm.getDag().subgraph(dNodes)); BayesPm bayesPmD = new BayesPm(gD, this.bayesIm.getBayesPm()); - BayesIm bayesImD = new MlBayesIm(bayesPmD, this.bayesIm, MlBayesIm.RANDOM); + BayesIm bayesImD = new MlBayesIm(bayesPmD, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); if (this.debug) { System.out.println("------ bayeIm.getDag() -------------"); @@ -919,7 +919,7 @@ else if (nodesA.containsAll(nodesT)) { // construct an IM with the dag graphA BayesPm bayesPmA = new BayesPm(graphA, this.bayesIm.getBayesPm()); - BayesIm bayesImA = new MlBayesIm(bayesPmA, this.bayesIm, MlBayesIm.RANDOM); + BayesIm bayesImA = new MlBayesIm(bayesPmA, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); // get c-components of graphA int[] cComponentsA = getCComponents(bayesImA); @@ -967,7 +967,7 @@ else if (nodesA.containsAll(nodesT)) { ///////////////////////////////////////////////////////////////// private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java index 225861cf40..508e7b3f3a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java @@ -300,7 +300,7 @@ private void updateAll() { } private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java index a24cd8ed4b..e19d76b85c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java @@ -22,8 +22,11 @@ package edu.cmu.tetrad.bayes; import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -58,75 +61,63 @@ public BayesIm estimate(BayesPm bayesPm, DataSet dataSet) { throw new NullPointerException(); } -// if (DataUtils.containsMissingValue(dataSet)) { -// throw new IllegalArgumentException("Please remove or impute missing values."); -// } - - // Make sure all of the variables in the PM are in the data set; - // otherwise, estimation is impossible. - BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); - - // Create a new Bayes IM to store the estimated values. - BayesIm estimatedIm = new MlBayesIm(bayesPm); - - // Create a subset of the data set with the variables of the IM, in - // the order of the IM. - List variables = estimatedIm.getVariables(); - DataSet columnDataSet2 = dataSet.subsetColumns(variables); - DiscreteProbs discreteProbs = new DataSetProbs(columnDataSet2); - - // We will use the same estimation methods as the updaters, to ensure - // compatibility. - Proposition assertion = Proposition.tautology(estimatedIm); - Proposition condition = Proposition.tautology(estimatedIm); - Evidence evidence2 = Evidence.tautology(estimatedIm); - - int numNodes = estimatedIm.getNumNodes(); - - for (int node = 0; node < numNodes; node++) { - int numRows = estimatedIm.getNumRows(node); - int numCols = estimatedIm.getNumColumns(node); - int[] parents = estimatedIm.getParents(node); - - for (int row = 0; row < numRows; row++) { - int[] parentValues = estimatedIm.getParentValues(node, row); - - for (int col = 0; col < numCols; col++) { - - // Remove values from the proposition in various ways; if - // a combination exists in the end, calculate a conditional - // probability. - assertion.setToTautology(); - condition.setToTautology(); - - for (int i = 0; i < numNodes; i++) { - for (int j = 0; j < evidence2.getNumCategories(i); j++) { - if (!evidence2.getProposition().isAllowed(i, j)) { - condition.removeCategory(i, j); - } - } - } - - assertion.disallowComplement(node, col); - - for (int k = 0; k < parents.length; k++) { - condition.disallowComplement(parents[k], parentValues[k]); - } - - if (condition.existsCombination()) { - double p = discreteProbs.getConditionalProb(assertion, condition); -// if (Double.isNaN(p)) p = 1.0 / numCols; - estimatedIm.setProbability(node, row, col, p); - } else { - estimatedIm.setProbability(node, row, col, Double.NaN); - } + MlBayesIm im = new MlBayesIm(bayesPm, true); + + // Get the nodes from the BayesPm. This fixes the order of the nodes + // in the BayesIm, independently of any change to the BayesPm. + // (This order must be maintained.) + Graph graph = bayesPm.getDag(); + + for (int nodeIndex = 0; nodeIndex < im.getNumNodes(); nodeIndex++) { + Node node = im.getNode(nodeIndex); + + // Set up parents array. Should store the parents of + // each node as ints in a particular order. + List parentList = new ArrayList<>(graph.getParents(node)); + int[] parentArray = new int[parentList.size()]; + + for (int i = 0; i < parentList.size(); i++) { + parentArray[i] = im.getNodeIndex(parentList.get(i)); + } + + // Sort the parent array. + Arrays.sort(parentArray); + + // Setup dimensions array for parents. + int[] dims = new int[parentArray.length]; + + for (int i = 0; i < dims.length; i++) { + Node parNode = im.getNode(parentArray[i]); + dims[i] = bayesPm.getNumCategories(parNode); + } + + // Calculate dimensions of table. + int numRows = 1; + + for (int dim : dims) { + numRows *= dim; + } + + int numCols = bayesPm.getNumCategories(node); + + CptMapCounts counts = new CptMapCounts(numRows, numCols); + + for (int row = 0; row < dataSet.getNumRows(); row++) { + int[] parentValues = new int[parentArray.length]; + + for (int i = 0; i < parentValues.length; i++) { + parentValues[i] = dataSet.getInt(row, parentArray[i]); } + + int value = dataSet.getInt(row, nodeIndex); + + counts.addCounts(im.getRowIndex(nodeIndex, parentValues), value, 1); } - } -// System.out.println(estimatedIm); + im.setCountMap(nodeIndex, counts); + } - return estimatedIm; + return im; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java new file mode 100644 index 0000000000..0717b628a2 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java @@ -0,0 +1,136 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.graph.Node; + +import java.util.List; + +/** + * Estimates parameters of the given Bayes net from the given data using maximum likelihood method. + * + * @author Shane Harwood, Joseph Ramsey + * @version $Id: $Id + */ +public final class MlBayesEstimatorOld { + + /** + *

Constructor for MlBayesEstimator.

+ */ + public MlBayesEstimatorOld() { + + } + + /** + * 33 Estimates a Bayes IM using the variables, graph, and parameters in the given Bayes PM and the data columns in + * the given data set. Each variable in the given Bayes PM must be equal to a variable in the given data set. + * + * @param bayesPm a {@link BayesPm} object + * @param dataSet a {@link DataSet} object + * @return a {@link BayesIm} object + */ + public BayesIm estimate(BayesPm bayesPm, DataSet dataSet) { + if (bayesPm == null) { + throw new NullPointerException(); + } + + if (dataSet == null) { + throw new NullPointerException(); + } + +// if (DataUtils.containsMissingValue(dataSet)) { +// throw new IllegalArgumentException("Please remove or impute missing values."); +// } + + // Make sure all of the variables in the PM are in the data set; + // otherwise, estimation is impossible. + BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); + + // Create a new Bayes IM to store the estimated values. + BayesIm estimatedIm = new MlBayesIm(bayesPm); + + // Create a subset of the data set with the variables of the IM, in + // the order of the IM. + List variables = estimatedIm.getVariables(); + DataSet columnDataSet2 = dataSet.subsetColumns(variables); + DiscreteProbs discreteProbs = new DataSetProbs(columnDataSet2); + + // We will use the same estimation methods as the updaters, to ensure + // compatibility. + Proposition assertion = Proposition.tautology(estimatedIm); + Proposition condition = Proposition.tautology(estimatedIm); + Evidence evidence2 = Evidence.tautology(estimatedIm); + + int numNodes = estimatedIm.getNumNodes(); + + for (int node = 0; node < numNodes; node++) { + int numRows = estimatedIm.getNumRows(node); + int numCols = estimatedIm.getNumColumns(node); + int[] parents = estimatedIm.getParents(node); + + for (int row = 0; row < numRows; row++) { + int[] parentValues = estimatedIm.getParentValues(node, row); + + for (int col = 0; col < numCols; col++) { + + // Remove values from the proposition in various ways; if + // a combination exists in the end, calculate a conditional + // probability. + assertion.setToTautology(); + condition.setToTautology(); + + for (int i = 0; i < numNodes; i++) { + for (int j = 0; j < evidence2.getNumCategories(i); j++) { + if (!evidence2.getProposition().isAllowed(i, j)) { + condition.removeCategory(i, j); + } + } + } + + assertion.disallowComplement(node, col); + + for (int k = 0; k < parents.length; k++) { + condition.disallowComplement(parents[k], parentValues[k]); + } + + if (condition.existsCombination()) { + double p = discreteProbs.getConditionalProb(assertion, condition); +// if (Double.isNaN(p)) p = 1.0 / numCols; + estimatedIm.setProbability(node, row, col, p); + } else { + estimatedIm.setProbability(node, row, col, Double.NaN); + } + } + } + } + +// System.out.println(estimatedIm); + + return estimatedIm; + } +} + + + + + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 90514d7f2a..f828e231af 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -27,7 +27,7 @@ import edu.cmu.tetrad.graph.TimeLagGraph; import edu.cmu.tetrad.util.NumberFormatUtil; import edu.cmu.tetrad.util.RandomUtil; -import org.apache.commons.math3.distribution.ChiSquaredDistribution; +import edu.cmu.tetrad.util.Vector; import java.io.IOException; import java.io.ObjectInputStream; @@ -36,12 +36,11 @@ import java.util.*; import static org.apache.commons.math3.util.FastMath.abs; -import static org.apache.commons.math3.util.FastMath.pow; /** * Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate * this table. The division of labor is as follows. The Dag is responsible for manipulating the basic graphical - * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there are no method + * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there is no method * in either BayesPm or BayesIm to do this. BayesPm stores and manipulates the *categories* of each node in a DAG, * considered as a variable in a Bayes net. The number of categories for a variable can be changed there as well as the * names for those categories. This class, BayesIm, stores the actual probability tables which are implied by the @@ -50,47 +49,35 @@ * probabilities is organized in this class as a three-dimensional table of double values. The first dimension * corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of * combinations of parent categories for that node. The third dimension corresponds to the list of categories for that - * node itself. Two methods allow these values to be set and retrieved: - * To determine the index of the node in question, use the method To determine - * the index of the row in question, use the method - * To determine the order of the - * parent values for a given node so that you can build the parentVals[] array, - * use the method To determine the - * index of a category, use the method in BayesPm. The rest of the methods in this class are easily understood - * as variants of the methods above. + * node itself. Two methods allow these values to be set and retrieved: getWordRatio(int nodeIndex, int rowIndex, int + * colIndex); and setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of + * the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, use the + * method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that you can + * build the parentVals[] array, use the method getParents(int nodeIndex). To determine the index of a category, use the + * method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood as + * variants of the methods above. *

- * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for - * advice and earlier versions. + * This version uses a sparse method for storing the probabilities, where NaNs are not stored. This allows BayesPms with + * many categories per variable to be estimated from small samples without overflowing memory. The old method of storing + * probabilities is kept here for backward compatibility, with an internal code flag to indicate which should be used. + *

+ * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions. * * @author josephramsey * @version $Id: $Id */ public final class MlBayesIm implements BayesIm { - - /** - * Inidicates that new rows in this BayesIm should be initialized as unknowns, forcing them to be specified - * manually. This is the default. - */ - public static final int MANUAL = 0; - /** - * Indicates that new rows in this BayesIm should be initialized randomly. - */ public static final int RANDOM = 1; @Serial private static final long serialVersionUID = 23L; - /** * Tolerance. */ private static final double ALLOWABLE_DIFFERENCE = 1.0e-3; - /** * Random number generator. */ static private final Random random = new Random(); - /** * The associated Bayes PM model. */ @@ -99,6 +86,7 @@ public final class MlBayesIm implements BayesIm { * The array of nodes from the graph. Order is important. */ private final Node[] nodes; + private CptMapType cptMapType = null; /** * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', * and order in subarrays is important. @@ -108,8 +96,6 @@ public final class MlBayesIm implements BayesIm { * The array of dimensionality (number of categories for each node) for each of the subarrays of 'parents'. */ private int[][] parentDims; - - //===============================CONSTRUCTORS=========================// /** * The main data structure; stores the values of all of the conditional probabilities for the Bayes net of the form * P(N=v0 | P1=v1, P2=v2,...). The first dimension is the node N, in the order of 'nodes'. The second dimension is @@ -118,10 +104,16 @@ public final class MlBayesIm implements BayesIm { * for each of the parent values; the order of the values in this array is the same as the order of node in * 'parents'; the value indices are obtained from the Bayes PM for each node. The column is the index of the value * of N, where this index is obtained from the Bayes PM. - * - * @serial + *

+ * This is kept here for backward compatibility. The new way of storing the probabilities is in the probMatrices + * array. */ private double[][][] probs; + /** + * The array of CPT maps for each node. The index of the node corresponds to the index of the probability map in + * this array. Replaces the probs array. + */ + private CptMap[] probMatrices; /** * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). @@ -131,7 +123,7 @@ public final class MlBayesIm implements BayesIm { * contained in the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException { - this(bayesPm, null, MlBayesIm.MANUAL); + this(bayesPm, null, InitializationMethod.MANUAL); } /** @@ -144,16 +136,90 @@ public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException { * @throws java.lang.IllegalArgumentException if the array of nodes provided is not a permutation of the nodes * contained in the bayes parametric model provided. */ - public MlBayesIm(BayesPm bayesPm, int initializationMethod) + public MlBayesIm(BayesPm bayesPm, InitializationMethod initializationMethod) throws IllegalArgumentException { this(bayesPm, null, initializationMethod); } + /** + * Constructs an instance of MlBayesIm. + * + * @param bayesPm the BayesPm object that represents the Bayesian network. + * @param countsOnly should be set to true for this constructor. + * @throws IllegalArgumentException if countsOnly is false. + * @throws NullPointerException if bayesPm is null. + */ + public MlBayesIm(BayesPm bayesPm, boolean countsOnly) { + if (!countsOnly) { + throw new IllegalArgumentException("countsOnly must be true for this constructor."); + } + + if (bayesPm == null) { + throw new NullPointerException("BayesPm must not be null."); + } + + + + this.bayesPm = new BayesPm(bayesPm); + + // Get the nodes from the BayesPm. This fixes the order of the nodes + // in the BayesIm, independently of any change to the BayesPm. + // (This order must be maintained.) + Graph graph = bayesPm.getDag(); + this.nodes = graph.getNodes().toArray(new Node[0]); + + this.cptMapType = CptMapType.COUNT_MAP; + + this.parents = new int[this.nodes.length][]; + this.parentDims = new int[this.nodes.length][]; + this.probs = new double[this.nodes.length][][]; + this.probMatrices = new CptMapCounts[this.nodes.length]; + + for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { + Node node = this.nodes[nodeIndex]; + + // Set up parents array. Should store the parents of + // each node as ints in a particular order. + List parentList = new ArrayList<>(graph.getParents(node)); + int[] parentArray = new int[parentList.size()]; + + for (int i = 0; i < parentList.size(); i++) { + parentArray[i] = getNodeIndex(parentList.get(i)); + } + + // Sort parent array. + Arrays.sort(parentArray); + + this.parents[nodeIndex] = parentArray; + + // Setup dimensions array for parents. + int[] dims = new int[parentArray.length]; + + for (int i = 0; i < dims.length; i++) { + Node parNode = this.nodes[parentArray[i]]; + dims[i] = getBayesPm().getNumCategories(parNode); + } + + // Calculate dimensions of table. + int numRows = 1; + + for (int dim : dims) { + numRows *= dim; + } + + int numCols = getBayesPm().getNumCategories(node); + + this.parentDims[nodeIndex] = dims; + this.probs[nodeIndex] = new double[numRows][numCols]; + this.probMatrices[nodeIndex] = new CptMapCounts(numRows, numCols); + } + } + /** * Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values * from the old BayesIm provided where posssible. If initialized manually, all values that cannot be retrieved from * oldBayesIm will be set to Double.NaN ("?") in each such row; if initialized randomly, all values that cannot be - * retrieved from oldBayesIm will distributed randomly in each such row. + * retrieved from oldBayesIm will be distributed randomly in each such row. * * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. * @param oldBayesIm an already-constructed BayesIm whose values may be used where possible to initialize @@ -163,7 +229,7 @@ public MlBayesIm(BayesPm bayesPm, int initializationMethod) * contained in the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, - int initializationMethod) throws IllegalArgumentException { + InitializationMethod initializationMethod) throws IllegalArgumentException { if (bayesPm == null) { throw new NullPointerException("BayesPm must not be null."); } @@ -203,7 +269,7 @@ public MlBayesIm(BayesIm bayesIm) throws IllegalArgumentException { } // Copy all the old values over. - initialize(bayesIm, MlBayesIm.MANUAL); + initialize(bayesIm, InitializationMethod.MANUAL); } /** @@ -215,8 +281,6 @@ public static MlBayesIm serializableInstance() { return new MlBayesIm(BayesPm.serializableInstance()); } - //===============================PUBLIC METHODS========================// - /** *

getParameterNames.

* @@ -251,6 +315,8 @@ private static double[] getRandomWeights(int size) { return row; } + //===============================PUBLIC METHODS========================// + /** *

Getter for the field bayesPm.

* @@ -345,11 +411,11 @@ public List getMeasuredNodes() { * @return a {@link java.util.List} object */ public List getVariableNames() { + List nodes = getVariables(); List variableNames = new LinkedList<>(); - for (int i = 0; i < getNumNodes(); i++) { - Node node = getNode(i); - variableNames.add(this.bayesPm.getVariable(node).getName()); + for (Node node : nodes) { + variableNames.add(node.getName()); } return variableNames; @@ -362,7 +428,11 @@ public List getVariableNames() { * @return the number of columns. */ public int getNumColumns(int nodeIndex) { - return this.probs[nodeIndex][0].length; + if (cptMapType != null) { + return probMatrices[nodeIndex].getNumColumns(); + } else { + return this.probs[nodeIndex][0].length; + } } /** @@ -372,7 +442,11 @@ public int getNumColumns(int nodeIndex) { * @return the number of rows in the node. */ public int getNumRows(int nodeIndex) { - return this.probs[nodeIndex].length; + if (cptMapType != null) { + return probMatrices[nodeIndex].getNumRows(); + } else { + return this.probs[nodeIndex].length; + } } /** @@ -474,7 +548,11 @@ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { * @return the probability value for the given node. */ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { - return this.probs[nodeIndex][rowIndex][colIndex]; + if (cptMapType != null) { + return probMatrices[nodeIndex].get(rowIndex, colIndex); + } else { + return this.probs[nodeIndex][rowIndex][colIndex]; + } } /** @@ -556,8 +634,22 @@ public void normalizeRow(int nodeIndex, int rowIndex) { */ @Override public void setProbability(int nodeIndex, double[][] probMatrix) { - for (int i = 0; i < probMatrix.length; i++) { - System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); + if (cptMapType == CptMapType.PROB_MAP) { + probMatrices[nodeIndex] = new CptMapProbs(probMatrix); + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot set probability matrix for an estimated Bayes IM."); + } else { + for (int i = 0; i < probMatrix.length; i++) { + System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); + } + } + } + + public void setCountMap(int nodeIndex, CptMapCounts countMap) { + if (cptMapType == CptMapType.COUNT_MAP) { + probMatrices[nodeIndex] = countMap; + } else { + throw new IllegalArgumentException("Cannot set count map for a non-estimated Bayes IM."); } } @@ -584,7 +676,13 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex, + "between 0.0 and 1.0 or Double.NaN."); } - this.probs[nodeIndex][rowIndex][colIndex] = value; + if (cptMapType == CptMapType.PROB_MAP) { + ((CptMapProbs) probMatrices[nodeIndex]).set(rowIndex, colIndex, value); + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot set probability value for an estimated Bayes IM."); + } else { + this.probs[nodeIndex][rowIndex][colIndex] = value; + } } /** @@ -620,7 +718,19 @@ public void clearRow(int nodeIndex, int rowIndex) { */ public void randomizeRow(int nodeIndex, int rowIndex) { int size = getNumColumns(nodeIndex); - this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size); + double[] row = getRandomWeights(size); + + if (cptMapType == CptMapType.PROB_MAP) { + for (int colIndex = 0; colIndex < size; colIndex++) { + ((CptMapProbs) probMatrices[nodeIndex]).set(rowIndex, colIndex, row[colIndex]); + } + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot randomize row for an estimated Bayes IM."); + } else { + this.probs[nodeIndex][rowIndex] = row; + } + +// this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size); } /** @@ -647,95 +757,6 @@ public void randomizeTable(int nodeIndex) { } } - private int score(int nodeIndex) { - double[][] p = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; - copy(this.probs[nodeIndex], p); - int num = 0; - - int numRows = getNumRows(nodeIndex); - - for (int r = 0; r < p.length; r++) { - for (int c = 0; c < p[0].length; c++) { - p[r][c] /= numRows; - } - } - - int[] parents = getParents(nodeIndex); - - for (int t = 0; t < parents.length; t++) { - int numParentValues = getParentDim(nodeIndex, t); - int numColumns = getNumColumns(nodeIndex); - - double[][] table = new double[numParentValues][numColumns]; - - for (int childCol = 0; childCol < numColumns; childCol++) { - for (int parentValue = 0; parentValue < numParentValues; parentValue++) { - for (int row = 0; row < numRows; row++) { - if (getParentValues(nodeIndex, row)[t] == parentValue) { - table[parentValue][childCol] += p[row][childCol]; - } - } - } - } - - final double N = 1000.0; - - for (int r = 0; r < table.length; r++) { - for (int c = 0; c < table[0].length; c++) { - table[r][c] *= N; - } - } - - double chisq = 0.0; - - for (int r = 0; r < table.length; r++) { - for (int c = 0; c < table[0].length; c++) { - double _sumRow = sumRow(table, r); - double _sumCol = sumCol(table, c); - double exp = (_sumRow / N) * (_sumCol / N) * N; - double obs = table[r][c]; - chisq += pow(obs - exp, 2) / exp; - } - } - - int dof = (table.length - 1) * (table[0].length - 1); - - ChiSquaredDistribution distribution = new ChiSquaredDistribution(dof); - double prob = 1 - distribution.cumulativeProbability(chisq); - - num += prob < 0.0001 ? 1 : 0; - } - -// return num == parents.length ? -score : 0; - return num; - } - - private double sumCol(double[][] marginals, int j) { - double sum = 0.0; - - for (double[] marginal : marginals) { - sum += marginal[j]; - } - - return sum; - } - - private double sumRow(double[][] marginals, int i) { - double sum = 0.0; - - for (int h = 0; h < marginals[i].length; h++) { - sum += marginals[i][h]; - } - - return sum; - } - - private void copy(double[][] a, double[][] b) { - for (int r = 0; r < a.length; r++) { - System.arraycopy(a[r], 0, b[r], 0, a[r].length); - } - } - /** * Clears the table by clearing all rows for the given node. * @@ -1016,75 +1037,6 @@ private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved, int private void constructSample(int sampleSize, DataSet dataSet, int[] map, int[] tiers) { -// //Do the simulation. -// class SimulationTask extends RecursiveTask { -// private int chunk; -// private int from; -// private int to; -// private int[] tiers; -// private DataSet dataSet; -// private int[] map; -// -// public SimulationTask(int chunk, int from, int to, int[] tiers, DataSet dataSet, int[] map) { -// this.chunk = chunk; -// this.from = from; -// this.to = to; -// this.tiers = tiers; -// this.dataSet = dataSet; -// this.map = map; -// } -// -// @Override -// protected Boolean compute() { -// if (to - from <= chunk) { -// RandomGenerator randomGenerator = new Well1024a(++seed[0]); -// -// for (int row = from; row < to; row++) { -// for (int t : tiers) { -// int[] parentValues = new int[parents[t].length]; -// -// for (int k = 0; k < parentValues.length; k++) { -// parentValues[k] = dataSet.getInt(row, parents[t][k]); -// } -// -// int rowIndex = getRowIndex(t, parentValues); -// double sum = 0.0; -// double r; -// -// r = randomGenerator.nextDouble(); -// -// for (int k = 0; k < getNumColumns(t); k++) { -// double probability = getProbability(t, rowIndex, k); -// sum += probability; -// -// if (sum >= r) { -// dataSet.setInt(row, map[t], k); -// break; -// } -// } -// } -// } -// -// return true; -// } else { -// int mid = (to + from) / 2; -// SimulationTask left = new SimulationTask(chunk, from, mid, tiers, dataSet, map); -// SimulationTask right = new SimulationTask(chunk, mid, to, tiers, dataSet, map); -// -// left.fork(); -// right.compute(); -// left.join(); -// -// return true; -// } -// } -// } -// -// int chunk = 25; -// -// ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); -// SimulationTask task = new SimulationTask(chunk, 0, sampleSize, tiers, dataSet, map); -// pool.invoke(task); // Construct the sample. for (int i = 0; i < sampleSize; i++) { for (int t : tiers) { @@ -1167,8 +1119,6 @@ public boolean equals(Object o) { return true; } - //=============================PRIVATE METHODS=======================// - /** * Prints out the probability table for each variable. * @@ -1221,21 +1171,30 @@ public String toString() { * @see #initializeNode * @see #randomizeRow */ - private void initialize(BayesIm oldBayesIm, int initializationMethod) { + private void initialize(BayesIm oldBayesIm, InitializationMethod initializationMethod) { this.parents = new int[this.nodes.length][]; this.parentDims = new int[this.nodes.length][]; this.probs = new double[this.nodes.length][][]; + this.probMatrices = new CptMapProbs[this.nodes.length]; + + this.cptMapType = CptMapType.PROB_MAP; for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); } } + //=============================PRIVATE METHODS=======================// + /** * This method initializes the node indicated. */ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, - int initializationMethod) { + InitializationMethod initializationMethod) { + if (cptMapType != CptMapType.PROB_MAP) { + throw new IllegalArgumentException("Cannot initialize a Bayes IM randomly without a probability map."); + } + Node node = this.nodes[nodeIndex]; // Set up parents array. Should store the parents of @@ -1265,15 +1224,6 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, int numRows = 1; for (int dim : dims) { - if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { - throw new IllegalArgumentException( - "The number of rows in the " - + "conditional probability table for " - + this.nodes[nodeIndex] - + " is greater than 1,000,000 and cannot be " - + "represented."); - } - numRows *= dim; } @@ -1281,9 +1231,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, this.parentDims[nodeIndex] = dims; this.probs[nodeIndex] = new double[numRows][numCols]; + this.probMatrices[nodeIndex] = new CptMapProbs(numRows, numCols); // Initialize each row. - if (initializationMethod == MlBayesIm.RANDOM) { + if (initializationMethod == InitializationMethod.RANDOM) { randomizeTable(nodeIndex); } else { for (int rowIndex = 0; rowIndex < numRows; rowIndex++) { @@ -1298,10 +1249,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, } private void overwriteRow(int nodeIndex, int rowIndex, - int initializationMethod) { - if (initializationMethod == MlBayesIm.RANDOM) { + InitializationMethod initializationMethod) { + if (initializationMethod == InitializationMethod.RANDOM) { randomizeRow(nodeIndex, rowIndex); - } else if (initializationMethod == MlBayesIm.MANUAL) { + } else if (initializationMethod == InitializationMethod.MANUAL) { initializeRowAsUnknowns(nodeIndex, rowIndex); } else { throw new IllegalArgumentException("Unrecognized state."); @@ -1312,14 +1263,21 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { int size = getNumColumns(nodeIndex); double[] row = new double[size]; Arrays.fill(row, Double.NaN); - this.probs[nodeIndex][rowIndex] = row; + + if (cptMapType == CptMapType.PROB_MAP) { + ((CptMapProbs) probMatrices[nodeIndex]).assignRow(rowIndex, new Vector(row)); + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot initialize a row as unknowns in an estimated Bayes IM."); + } else { + this.probs[nodeIndex][rowIndex] = row; + } } /** * This method initializes the node indicated. */ private void retainOldRowIfPossible(int nodeIndex, int rowIndex, - BayesIm oldBayesIm, int initializationMethod) { + BayesIm oldBayesIm, InitializationMethod initializationMethod) { int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); @@ -1396,42 +1354,6 @@ private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, } } -// // Go through each parent of the node in the new BayesIm. -// for (int i = 0; i < oldBayesIm.getNumParents(oldNodeIndex); i++) { -// -// // Get the index of the parent in the new graph and in the old -// // graph. If it's no longer in the new graph, skip to the next -// // parent. -// int oldParentNodeIndex = oldBayesIm.getParent(oldNodeIndex, i); -// int parentNodeIndex = -// oldBayesIm.getCorrespondingNodeIndex(oldParentNodeIndex, this); -// int parentIndex = -1; -// -// for (int j = 0; j < this.getNumParents(nodeIndex); j++) { -// if (parentNodeIndex == this.getParent(nodeIndex, j)) { -// parentIndex = j; -// break; -// } -// } -// -// if (parentIndex == -1 || -// parentIndex >= this.getNumParents(nodeIndex)) { -// continue; -// } -// -// // Look up that value index for the new BayesIm for that parent. -// // If it was a valid value index in the old BayesIm, record -// // that value in oldParentValues. Otherwise return -1. -// int parentValue = oldParentValues[i]; -// int parentDim = -// this.getParentDim(nodeIndex, parentIndex); -// -// if (parentValue < parentDim) { -// oldParentValues[parentIndex] = oldParentValue; -// } else { -// return -1; -// } -// } // If there are any -1's in the combination at this point, return -1. for (int oldParentValue : oldParentValues) { if (oldParentValue == -1) { @@ -1492,8 +1414,42 @@ private void readObject(ObjectInputStream s) throw new NullPointerException(); } - if (this.probs == null) { - throw new NullPointerException(); + copyDataToProbMatrices(); + } + + /** + * Copies data from the `probs` array to the `probMatrices` array. If the lengths of both arrays are equal, the + * `probMatrices` array is initialized with `ProbMap` objects, each containing the corresponding `probs` element. + * The `probs` array is then set to null and the `useProbMatrices` flag is set to true. + *

+ * Note: This method should only be called after the `probs` array has been properly initialized. + */ + private void copyDataToProbMatrices() { + if (cptMapType == null && this.probs != null && this.probs.length == this.nodes.length) { + this.probMatrices = new CptMapProbs[this.probs.length]; + + for (int i = 0; i < this.nodes.length; i++) { + probMatrices[i] = new CptMapProbs(this.probs[i]); + } + + this.probs = null; + this.cptMapType = CptMapType.PROB_MAP; } } + + /** + * A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used. + * The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility. + */ + public CptMapType getCptMapType() { + return cptMapType; + } + + public enum CptMapType { + PROB_MAP, COUNT_MAP + } + + public enum InitializationMethod { + MANUAL, RANDOM + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java index bd330396bf..35e0872e22 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java @@ -346,7 +346,7 @@ private void updateAll() { private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { // Switching this to MANUAL since the initial values don't matter. - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java index c0a8a5c9be..7e8916e1fa 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java @@ -228,7 +228,7 @@ public static Graph randomGraphRandomForwardEdges(List nodes, int numLaten * negative or exceeds the number of nodes */ public static Graph randomGraphRandomForwardEdges(List nodes, int numLatentConfounders, int numEdges, int maxDegree, int maxIndegree, int maxOutdegree, boolean connected, boolean layoutAsCircle) { - if (nodes.size() == 0) { + if (nodes.isEmpty()) { throw new IllegalArgumentException("NumNodes most be > 0"); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java index 7ff2a15a0a..e1d94a9839 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java @@ -65,7 +65,7 @@ public static List run(int numVars, double edgesPerNode, int numCases, Graph odag = RandomGraph.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true); BayesPm bayesPm = new BayesPm(odag, 2, 2); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); //oData is the original data set, and odag is the original dag. DataSet oData = bayesIm.simulateData(numCases, false); //System.out.println(oData); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java index 1e638dcb6a..28fb2e581d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java @@ -112,7 +112,7 @@ public static ComparisonResult compare(ComparisonParameters params) { } BayesPm pm = new BayesPm(trueDag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); dataSet = im.simulateData(params.getSampleSize(), false, tiers); } else { throw new IllegalArgumentException("Unrecognized data type."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java index 938c38e0ce..aa054ecdef 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java @@ -303,7 +303,7 @@ public static ComparisonResult compare(ComparisonParameters params) { } BayesPm pm = new BayesPm(trueDag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); dataSet = im.simulateData(params.getSampleSize(), false, tiers); } else { throw new IllegalArgumentException("Unrecognized data type."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java index c831818858..66d4e42b74 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java @@ -1026,7 +1026,7 @@ private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, } else { BayesPm pm = new BayesPm(dag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(numCases, false, tiers); @@ -1258,7 +1258,7 @@ private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRun this.out.println(new Date()); BayesPm pm = new BayesPm(dag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(numCases, false, tiers); diff --git a/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java b/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java index 6588f2ccf9..06e0341818 100644 --- a/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java +++ b/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java @@ -72,7 +72,7 @@ public static void main(String[] args) throws Exception { Graph graph = RandomGraph.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true); BayesPm pm = new BayesPm(graph); - BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(rows, false); // This implementation uses a DataTable to represent the data diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java index a24ce17cbd..b0cde7abc8 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java @@ -53,7 +53,7 @@ public void testPValue() { final int numCategories = 8; BayesPm pm = new BayesPm(graph, numCategories, numCategories); - BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(1000, false); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java index 88f949f05d..6439600537 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java @@ -66,7 +66,7 @@ public void testCopyConstructor() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); BayesIm bayesIm2 = new MlBayesIm(bayesIm); assertEquals(bayesIm, bayesIm2); } @@ -102,9 +102,9 @@ public void testAddRemoveParent() { dag.addDirectedEdge(a, b); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); - BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL); + BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.InitializationMethod.MANUAL); assertEquals(bayesIm, bayesIm2); @@ -113,7 +113,7 @@ public void testAddRemoveParent() { dag.addDirectedEdge(c, b); BayesPm bayesPm3 = new BayesPm(dag, bayesPm); - BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.MANUAL); + BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.InitializationMethod.MANUAL); // Make sure the rows got repeated. // assertTrue(rowsEqual(bayesIm3, bayesIm3.getNodeIndex(b), 0, 1)); @@ -125,7 +125,7 @@ public void testAddRemoveParent() { dag.removeNode(c); BayesPm bayesPm4 = new BayesPm(dag, bayesPm3); - BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.MANUAL); + BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.InitializationMethod.MANUAL); // Make sure the 'b' node has 2 rows of '?'s'. assertEquals(2, bayesIm4.getNumRows(bayesIm4.getNodeIndex(b))); @@ -157,17 +157,17 @@ public void testAddRemoveValues() { assertTrue(Edges.isDirectedEdge(dag.getEdge(a, b))); BayesPm bayesPm = new BayesPm(dag, 3, 3); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); bayesPm.setNumCategories(a, 4); bayesPm.setNumCategories(c, 4); - BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL); + BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.InitializationMethod.MANUAL); bayesPm.setNumCategories(a, 2); - BayesIm bayesIm3 = new MlBayesIm(bayesPm, bayesIm2, MlBayesIm.MANUAL); + BayesIm bayesIm3 = new MlBayesIm(bayesPm, bayesIm2, MlBayesIm.InitializationMethod.MANUAL); bayesPm.setNumCategories(b, 2); - BayesIm bayesIm4 = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm4 = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); // At this point, a has 2 categories, b has 2 categories, and c has 4 categories. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java index aa1e746a92..979f547792 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java @@ -112,7 +112,7 @@ private static BayesIm sampleBayesIm2() { BayesPm bayesPm = new BayesPm(graph); bayesPm.setNumCategories(b, 3); - return new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + return new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); } private static BayesIm sampleBayesIm3() { @@ -138,7 +138,7 @@ private static BayesIm sampleBayesIm3() { BayesPm bayesPm = new BayesPm(graph); bayesPm.setNumCategories(b, 3); - return new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + return new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); } /** diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java index 66aa240bae..afebcb340e 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java @@ -73,7 +73,7 @@ public void testCreateUsingBayesIm() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java index 19d0b64e12..fde6221b59 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java @@ -176,7 +176,7 @@ public void testUpdate4() { graph.addDirectedEdge(x2Node, x3Node); BayesPm bayesPm = new BayesPm(graph, 2, 2); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x2 = bayesIm.getNodeIndex(x2Node); int x3 = bayesIm.getNodeIndex(x3Node); @@ -220,7 +220,7 @@ public void testUpdate5() { graph.addDirectedEdge(x4Node, x2Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x1 = bayesIm.getNodeIndex(x1Node); int x2 = bayesIm.getNodeIndex(x2Node); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java index 3542df3d69..c4ec3b3ddc 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java @@ -41,7 +41,7 @@ public void testCreateUsingBayesIm() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); BayesImProbs bayesImProbs = new BayesImProbs(bayesIm); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java index 5b470d4272..e160b8d476 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java @@ -79,7 +79,7 @@ public void testCreateUsingBayesIm() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java index e3e6d25416..5730004833 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java @@ -227,7 +227,7 @@ public void explore2() { 30, 15, 15, false, true); BayesPm pm = new BayesPm(dag, 2, 3); - BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(numCases, false); BdeScore score = new BdeScore(data); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java index 445afbc6c3..4653051cec 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java @@ -247,7 +247,7 @@ public void testRandomDiscreteData() { Graph g = GraphUtils.convert("X1-->X2,X1-->X3,X1-->X4,X2-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(g); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = bayesIm.simulateData(sampleSize, false); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java index ee3403d4a5..f36f1ed4a1 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java @@ -90,7 +90,7 @@ public void testHistogram() { // Discrete BayesPm bayesPm = new BayesPm(trueGraph); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); DataSet data2 = bayesIm.simulateData(sampleSize, false); // For some reason these are giving different diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java index 70ba06d08a..7b10b446cc 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java @@ -81,7 +81,7 @@ public void testRandomDiscreteData() { // set a number of latent variables BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); RandomUtil.getInstance().setSeed(seed); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java index fd9ec8294f..ac2a354859 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java @@ -197,7 +197,7 @@ public void testUpdate4() { graph.addDirectedEdge(x2Node, x3Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x2 = bayesIm.getNodeIndex(x2Node); int x3 = bayesIm.getNodeIndex(x3Node); @@ -240,7 +240,7 @@ public void testUpdate5() { graph.addDirectedEdge(x4Node, x2Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x1 = bayesIm.getNodeIndex(x1Node); int x2 = bayesIm.getNodeIndex(x2Node); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java index 74a221e161..dbda137037 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java @@ -59,7 +59,7 @@ public void testCompound() { graph.addDirectedEdge(x4Node, x2Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); UpdatedBayesIm updatedIm1 = new UpdatedBayesIm(bayesIm); assertEquals(bayesIm, updatedIm1);