diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java
index 2a6f684960..bfe8d1dbd8 100755
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java
@@ -23,6 +23,7 @@
import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
+import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.JOptionUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
@@ -536,8 +537,9 @@ public Object getValueAt(int tableRow, int tableCol) {
int colIndex = tableCol - parentVals.length;
if (colIndex < getBayesIm().getNumColumns(getNodeIndex())) {
- return getBayesIm().getProbability(getNodeIndex(), tableRow,
+ double probability = getBayesIm().getProbability(getNodeIndex(), tableRow,
colIndex);
+ return probability;
}
return "null";
@@ -555,10 +557,14 @@ public boolean isCellEditable(int row, int col) {
* Sets the value of the cell at (row, col) to 'aValue'.
*/
public void setValueAt(Object aValue, int row, int col) {
+ if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.COUNT_MAP) {
+ return;
+ }
+
int numParents = getBayesIm().getNumParents(getNodeIndex());
int colIndex = col - numParents;
- if ("".equals(aValue) || aValue == null) {
+ if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.PROB_MAP && ("".equals(aValue) || aValue == null)) {
getBayesIm().setProbability(getNodeIndex(), row, colIndex,
Double.NaN);
fireTableRowsUpdated(row, row);
@@ -752,6 +758,8 @@ public void resetFailedRow() {
public void resetFailedCol() {
this.failedCol = -1;
}
+
+
}
}
diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java
index c4e5e55146..da15ea8743 100644
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java
@@ -481,7 +481,7 @@ public void setEnabled(boolean enabled) {
* @return a boolean
*/
public boolean isRandomForward() {
- return this.parameters.getBoolean("graphRandomFoward", true);
+ return this.parameters.getBoolean("graphRandomForward", true);
}
private void setRandomForward(boolean randomFoward) {
diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java
index 08aa40b5ee..f8f0d8b793 100644
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java
@@ -36,6 +36,7 @@
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
+import java.util.function.IntBinaryOperator;
/**
* Wraps a Bayes Pm for use in the Tetrad application.
@@ -90,11 +91,11 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, BayesImWrapper oldBayesImwr
BayesIm oldBayesIm = oldBayesImwrapper.getBayesIm();
if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) {
- setBayesIm(bayesPm, oldBayesIm, MlBayesIm.MANUAL);
+ setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.MANUAL);
} else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) {
- setBayesIm(bayesPm, oldBayesIm, MlBayesIm.RANDOM);
+ setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.RANDOM);
} else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) {
- setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM));
+ setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM));
}
}
@@ -193,9 +194,9 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, Parameters params) {
if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) {
setBayesIm(new MlBayesIm(bayesPm));
} else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) {
- setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM));
+ setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM));
} else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) {
- setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM));
+ setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM));
}
}
@@ -342,9 +343,9 @@ public String getModelSourceName() {
//============================== private methods ============================//
- private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int manual) {
+ private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, MlBayesIm.InitializationMethod initializationMethod) {
this.bayesIms = new ArrayList<>();
- this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, manual));
+ this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, initializationMethod));
}
/**
diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java
index 2c7ee8d4b3..6511a3bcc3 100644
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java
@@ -206,8 +206,8 @@ public BayesPmWrapper(GraphWrapper graphWrapper, Parameters params) {
lowerBound = upperBound = 3;
setBayesPm(graph, lowerBound, upperBound);
} else if (params.getString("bayesPmInitializationMode", "range").equals("range")) {
- lowerBound = params.getInt("minCategories", 2);
- upperBound = params.getInt("maxCategories", 4);
+ lowerBound = params.getInt("lowerBoundNumVals", 2);
+ upperBound = params.getInt("upperBoundNumVals", 4);
setBayesPm(graph, lowerBound, upperBound);
} else {
throw new IllegalStateException("Unrecognized type.");
diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java
index 42681673ae..4ccd0fc806 100644
--- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java
+++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java
@@ -1173,6 +1173,13 @@ public int compareTo(Node node) {
* @return a boolean
*/
public boolean existsParameterizedConstructor(Class modelClass) {
+ if (modelClass == null) {
+
+ // If the model class is null, then there is no constructor, so display a dialog to the users by
+ // return false here.
+ return false;
+ }
+
Object param = getParam(modelClass);
List parentModels = listParentModels();
parentModels.add(param);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
index f973cf76de..2a0718b19f 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java
@@ -238,9 +238,9 @@ private DataSet simulate(Graph graph, Parameters parameters) {
int minCategories = parameters.getInt(Params.MIN_CATEGORIES);
int maxCategories = parameters.getInt(Params.MAX_CATEGORIES);
pm = new BayesPm(graph, minCategories, maxCategories);
- im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
} else {
- im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
this.im = im;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
index 0af14dfd20..676915f1c8 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java
@@ -320,7 +320,7 @@ private DataSet simulate(Graph G, Parameters parameters) {
}
BayesPm bayesPm = new BayesPm(AG);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
SemPm semPm = new SemPm(XG);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java
index 147b4c8d73..248d9ad374 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java
@@ -43,7 +43,7 @@ public interface Statistic extends Serializable {
/**
* Returns a mapping of the statistic to the interval [0, 1], with higher being better. This is used for a
- * calculation of a utility for an algorithm.If the statistic is already between 0 and 1, you can just return the
+ * calculation of a utility for an algorithm. If the statistic is already between 0 and 1, you can just return the
* statistic.
*
* @param value The value of the statistic.
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java
index 5b69d7c030..375cec1846 100755
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java
@@ -344,7 +344,7 @@ private void doUpdate() {
}
private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) {
- return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM);
+ return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM);
}
private BayesPm createUpdatedBayesPm(Dag updatedGraph) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java
index 7b1f1f7b62..a768c73798 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java
@@ -393,4 +393,8 @@ public interface BayesIm extends VariableSource, Im, Simulator {
* @return a string representation for this Bayes net.
*/
String toString();
+
+ default MlBayesIm.CptMapType getCptMapType() {
+ return MlBayesIm.CptMapType.PROB_MAP;
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java
index c3f9183d1a..312a6ba9d6 100755
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java
@@ -266,7 +266,7 @@ public String toString() {
//==============================PRIVATE METHODS=======================//
private BayesIm createdManipulatedBayesIm(BayesPm updatedBayesPm) {
- return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL);
+ return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL);
}
private BayesPm createManipulatedBayesPm(Dag updatedGraph) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java
new file mode 100644
index 0000000000..d8d271c5e2
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java
@@ -0,0 +1,41 @@
+package edu.cmu.tetrad.bayes;
+
+import edu.cmu.tetrad.util.TetradSerializable;
+
+/**
+ * An interface representing a map of probabilities or counts for nodes in a Bayesian network. Implementations of this
+ * interface should provide methods to get the probability or count for a node at a given row and column, as well as
+ * methods to retrieve the number of rows and columns in the map.
+ *
+ * This interface extends the TetradSerializable interface, indicating that implementations should be serializable and
+ * follow certain guidelines for compatibility across different versions of Tetrad.
+ *
+ * @author josephramsey
+ * @see CptMapProbs
+ * @see CptMapCounts
+ */
+public interface CptMap extends TetradSerializable {
+
+ /**
+ * Retrieves the value at the specified row and column in the CptMap.
+ *
+ * @param row the row index of the value to retrieve.
+ * @param column the column index of the value to retrieve.
+ * @return the value at the specified row and column in the CptMap.
+ */
+ double get(int row, int column);
+
+ /**
+ * Retrieves the number of rows in the CptMap.
+ *
+ * @return the number of rows in the CptMap.
+ */
+ int getNumRows();
+
+ /**
+ * Retrieves the number of columns in the CptMap.
+ *
+ * @return the number of columns in the CptMap.
+ */
+ int getNumColumns();
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java
new file mode 100644
index 0000000000..61f936b2ec
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java
@@ -0,0 +1,181 @@
+package edu.cmu.tetrad.bayes;
+
+import edu.cmu.tetrad.data.DataSet;
+
+import java.io.Serial;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique
+ * integer index for a particular node to the cell count for that node, where 0's are not stored. Row counts are also
+ * stored, so that the probability of a cell can be calculated. A prior cell count of 0 is assumed for all cells,
+ * but this may be set by the user to any non-negative count. (A prior count of 0 is equivalent to a maximum
+ * likelihood estimate.)
+ *
+ * @author josephramsey
+ */
+public class CptMapCounts implements CptMap {
+ @Serial
+ private static final long serialVersionUID = 23L;
+ /**
+ * Constructs a new count map, a map from a unique integer index for a particular node to the count for that value,
+ * where 0's are not stored.
+ */
+ private final Map cellCounts = new HashMap<>();
+ /**
+ * Constructs a new row count map, a map from a unique integer index for a particular node to the count for that
+ * row, where 0's are not stored.
+ */
+ private final Map rowCounts = new HashMap<>();
+ /**
+ * The number of rows in the table.
+ */
+ private final int numRows;
+ /**
+ * The number of columns in the table.
+ */
+ private final int numColumns;
+ /**
+ * The prior count for all cells.
+ */
+ private int priorCount = 1;
+
+ /**
+ * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of
+ * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain
+ * number of rows and a certain number of columns in the table.
+ *
+ * @param numRows the number of rows in the table
+ * @param numColumns the number of columns in the table
+ */
+ public CptMapCounts(int numRows, int numColumns) {
+ if (numRows < 1 || numColumns < 1) {
+ throw new IllegalArgumentException("Number of rows and columns must be at least 1.");
+ }
+
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructs a new CptMap based on counts from a given dataset.
+ *
+ * @param data the DataSet object representing the probability matrix
+ * @throws IllegalArgumentException if the data set is null or not discrete
+ */
+ public CptMapCounts(DataSet data) {
+ if (data == null) {
+ throw new IllegalArgumentException("Probability matrix must have at least one row and one column.");
+ }
+
+ if (!data.isDiscrete()) {
+ throw new IllegalArgumentException("Data set must be discrete.");
+
+ }
+
+ numRows = data.getNumRows();
+ numColumns = data.getNumColumns();
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numColumns; j++) {
+ int key = i * numColumns + j;
+
+ if (data.getInt(i, j) == -1) {
+ continue;
+ }
+
+ if (data.getInt(i, j) == 0) {
+ continue;
+ }
+
+ if (!cellCounts.containsKey(key)) {
+ cellCounts.put(key, 1);
+ } else {
+ cellCounts.put(key, cellCounts.get(key) + 1);
+ }
+
+ if (!rowCounts.containsKey(i)) {
+ rowCounts.put(i, 1);
+ } else {
+ rowCounts.put(i, rowCounts.get(i) + 1);
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the probability of the node taking on the value specified by the given row and column.
+ *
+ * @param row the row of the node
+ * @param column the column of the node
+ * @return the probability of the node taking on the value specified by the given row and column
+ */
+ @Override
+ public double get(int row, int column) {
+ if (row < 0 || row >= numRows || column < 0 || column >= numColumns) {
+ throw new IllegalArgumentException("Row and column must be within bounds.");
+ }
+
+ int key = row * numColumns + column;
+ int rowCount = rowCounts.getOrDefault(row, 0);
+ int cellCount = cellCounts.getOrDefault(key, 0);
+ rowCount += priorCount * numColumns;
+ cellCount += priorCount;
+ return cellCount / (double) rowCount;
+ }
+
+ public void addCounts(int row, int column, int count) {
+ if (row < 0 || row >= numRows || column < 0 || column >= numColumns) {
+ throw new IllegalArgumentException("Row and column must be within bounds.");
+ }
+
+ int key = row * numColumns + column;
+
+ if (!cellCounts.containsKey(key)) {
+ cellCounts.put(key, count);
+ } else {
+ cellCounts.put(key, cellCounts.get(key) + count);
+ }
+
+ if (!rowCounts.containsKey(row)) {
+ rowCounts.put(row, count);
+ } else {
+ rowCounts.put(row, rowCounts.get(row) + count);
+ }
+ }
+
+ /**
+ * Returns the number of rows in the probability map.
+ *
+ * @return the number of rows in the probability map.
+ */
+ @Override
+ public int getNumRows() {
+ return numRows;
+ }
+
+ /**
+ * Returns the number of columns in the probability map.
+ *
+ * @return the number of columns in the probability map.
+ */
+ @Override
+ public int getNumColumns() {
+ return numColumns;
+ }
+
+ /**
+ * The prior count for all cells.
+ */
+ public int getPriorCount() {
+ return priorCount;
+ }
+
+ /**
+ * The prior count for all cells.
+ */
+ public void setPriorCount(int priorCount) {
+ this.priorCount = priorCount;
+ }
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java
new file mode 100644
index 0000000000..f2ccf95c53
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java
@@ -0,0 +1,160 @@
+package edu.cmu.tetrad.bayes;
+
+import edu.cmu.tetrad.util.TetradSerializable;
+import edu.cmu.tetrad.util.Vector;
+
+import java.io.Serial;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique
+ * integer index for a particular node to the probability of that node taking on that value, where NaN's are not
+ * stored.
+ *
+ * @author josephramsey
+ */
+public class CptMapProbs implements CptMap {
+ @Serial
+ private static final long serialVersionUID = 23L;
+ /**
+ * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of
+ * that node taking on that value, where NaN's are not stored.
+ */
+ private final Map map = new HashMap<>();
+ /**
+ * The number of rows in the table.
+ */
+ private final int numRows;
+ /**
+ * The number of columns in the table.
+ */
+ private final int numColumns;
+
+ /**
+ * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of
+ * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain
+ * number of rows and a certain number of columns in the table.
+ *
+ * @param numRows the number of rows in the table
+ * @param numColumns the number of columns in the table
+ */
+ public CptMapProbs(int numRows, int numColumns) {
+ if (numRows < 1 || numColumns < 1) {
+ throw new IllegalArgumentException("Number of rows and columns must be at least 1.");
+ }
+
+ this.numRows = numRows;
+ this.numColumns = numColumns;
+ }
+
+ /**
+ * Constructs a new probability map based on the given 2-dimensional array.
+ *
+ * @param probMatrix the 2-dimensional array representing the probability matrix
+ * @throws IllegalArgumentException if the number of columns in any row is different
+ */
+ public CptMapProbs(double[][] probMatrix) {
+ if (probMatrix == null || probMatrix.length == 0 || probMatrix[0].length == 0) {
+ throw new IllegalArgumentException("Probability matrix must have at least one row and one column.");
+ }
+
+ numRows = probMatrix.length;
+ numColumns = probMatrix[0].length;
+
+ for (int i = 0; i < numRows; i++) {
+ if (probMatrix[i].length != numColumns) {
+ throw new IllegalArgumentException("All rows must have the same number of columns.");
+ }
+ }
+
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numColumns; j++) {
+ map.put(i * numColumns + j, probMatrix[i][j]);
+ }
+ }
+ }
+
+ /**
+ * Returns the probability of the node taking on the value specified by the given row and column.
+ *
+ * @param row the row of the node
+ * @param column the column of the node
+ * @return the probability of the node taking on the value specified by the given row and column
+ */
+ @Override
+ public double get(int row, int column) {
+ if (row < 0 || row >= numRows || column < 0 || column >= numColumns) {
+ throw new IllegalArgumentException("Row and column must be within bounds.");
+ }
+
+ int key = row * numColumns + column;
+
+ if (!map.containsKey(key)) {
+ return Double.NaN;
+ }
+
+ return map.get(key);
+ }
+
+ /**
+ * Sets the probability of the node taking on the value specified by the given row and column to the given value.
+ *
+ * @param row the row of the node
+ * @param column the column of the node
+ * @param value the probability of the node taking on the value specified by the given row and column (NaN to
+ * remove the value)
+ */
+ public void set(int row, int column, double value) {
+ if (row < 0 || row >= numRows || column < 0 || column >= numColumns) {
+ throw new IllegalArgumentException("Row and column must be within bounds.");
+ }
+
+ int key = row * numColumns + column;
+
+ if (Double.isNaN(value)) {
+ map.remove(key);
+ return;
+ }
+
+ map.put(key, value);
+ }
+
+ /**
+ * Returns the number of rows in the probability map.
+ *
+ * @return the number of rows in the probability map.
+ */
+ @Override
+ public int getNumRows() {
+ return numRows;
+ }
+
+ /**
+ * Returns the number of columns in the probability map.
+ *
+ * @return the number of columns in the probability map.
+ */
+ @Override
+ public int getNumColumns() {
+ return numColumns;
+ }
+
+ /**
+ * Assigns the values in the provided vector to a specific row in the probability map.
+ *
+ * @param rowIndex the index of the row to be assigned
+ * @param vector the vector containing the values to be assigned to the row
+ * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the
+ * probability map
+ */
+ public void assignRow(int rowIndex, Vector vector) {
+ if (vector.size() != numColumns) {
+ throw new IllegalArgumentException("Vector must have the same number of columns as the probability map.");
+ }
+
+ for (int i = 0; i < numColumns; i++) {
+ set(rowIndex, i, vector.get(i));
+ }
+ }
+}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java
index ba63797b5b..75636a8fbd 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java
@@ -490,7 +490,7 @@ private void estimateIM(BayesPm bayesPm, DataSet dataSet) {
BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
// Create a new Bayes IM to store the estimated values.
- this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
int numNodes = this.estimatedIm.getNumNodes();
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java
index dfc75e9d5f..0bf7efecd7 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java
@@ -441,7 +441,7 @@ public double getJointMarginal(int[] sVariables, int[] sValues) {
Dag gD = new Dag(this.bayesIm.getDag().subgraph(dNodes));
BayesPm bayesPmD = new BayesPm(gD, this.bayesIm.getBayesPm());
- BayesIm bayesImD = new MlBayesIm(bayesPmD, this.bayesIm, MlBayesIm.RANDOM);
+ BayesIm bayesImD = new MlBayesIm(bayesPmD, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM);
if (this.debug) {
System.out.println("------ bayeIm.getDag() -------------");
@@ -919,7 +919,7 @@ else if (nodesA.containsAll(nodesT)) {
// construct an IM with the dag graphA
BayesPm bayesPmA = new BayesPm(graphA, this.bayesIm.getBayesPm());
- BayesIm bayesImA = new MlBayesIm(bayesPmA, this.bayesIm, MlBayesIm.RANDOM);
+ BayesIm bayesImA = new MlBayesIm(bayesPmA, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM);
// get c-components of graphA
int[] cComponentsA = getCComponents(bayesImA);
@@ -967,7 +967,7 @@ else if (nodesA.containsAll(nodesT)) {
/////////////////////////////////////////////////////////////////
private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) {
- return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM);
+ return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM);
}
private BayesPm createUpdatedBayesPm(Dag updatedGraph) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java
index 225861cf40..508e7b3f3a 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java
@@ -300,7 +300,7 @@ private void updateAll() {
}
private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) {
- return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL);
+ return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL);
}
private BayesPm createUpdatedBayesPm(Dag updatedGraph) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java
index a24cd8ed4b..e19d76b85c 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java
@@ -22,8 +22,11 @@
package edu.cmu.tetrad.bayes;
import edu.cmu.tetrad.data.DataSet;
+import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
/**
@@ -58,75 +61,63 @@ public BayesIm estimate(BayesPm bayesPm, DataSet dataSet) {
throw new NullPointerException();
}
-// if (DataUtils.containsMissingValue(dataSet)) {
-// throw new IllegalArgumentException("Please remove or impute missing values.");
-// }
-
- // Make sure all of the variables in the PM are in the data set;
- // otherwise, estimation is impossible.
- BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
-
- // Create a new Bayes IM to store the estimated values.
- BayesIm estimatedIm = new MlBayesIm(bayesPm);
-
- // Create a subset of the data set with the variables of the IM, in
- // the order of the IM.
- List variables = estimatedIm.getVariables();
- DataSet columnDataSet2 = dataSet.subsetColumns(variables);
- DiscreteProbs discreteProbs = new DataSetProbs(columnDataSet2);
-
- // We will use the same estimation methods as the updaters, to ensure
- // compatibility.
- Proposition assertion = Proposition.tautology(estimatedIm);
- Proposition condition = Proposition.tautology(estimatedIm);
- Evidence evidence2 = Evidence.tautology(estimatedIm);
-
- int numNodes = estimatedIm.getNumNodes();
-
- for (int node = 0; node < numNodes; node++) {
- int numRows = estimatedIm.getNumRows(node);
- int numCols = estimatedIm.getNumColumns(node);
- int[] parents = estimatedIm.getParents(node);
-
- for (int row = 0; row < numRows; row++) {
- int[] parentValues = estimatedIm.getParentValues(node, row);
-
- for (int col = 0; col < numCols; col++) {
-
- // Remove values from the proposition in various ways; if
- // a combination exists in the end, calculate a conditional
- // probability.
- assertion.setToTautology();
- condition.setToTautology();
-
- for (int i = 0; i < numNodes; i++) {
- for (int j = 0; j < evidence2.getNumCategories(i); j++) {
- if (!evidence2.getProposition().isAllowed(i, j)) {
- condition.removeCategory(i, j);
- }
- }
- }
-
- assertion.disallowComplement(node, col);
-
- for (int k = 0; k < parents.length; k++) {
- condition.disallowComplement(parents[k], parentValues[k]);
- }
-
- if (condition.existsCombination()) {
- double p = discreteProbs.getConditionalProb(assertion, condition);
-// if (Double.isNaN(p)) p = 1.0 / numCols;
- estimatedIm.setProbability(node, row, col, p);
- } else {
- estimatedIm.setProbability(node, row, col, Double.NaN);
- }
+ MlBayesIm im = new MlBayesIm(bayesPm, true);
+
+ // Get the nodes from the BayesPm. This fixes the order of the nodes
+ // in the BayesIm, independently of any change to the BayesPm.
+ // (This order must be maintained.)
+ Graph graph = bayesPm.getDag();
+
+ for (int nodeIndex = 0; nodeIndex < im.getNumNodes(); nodeIndex++) {
+ Node node = im.getNode(nodeIndex);
+
+ // Set up parents array. Should store the parents of
+ // each node as ints in a particular order.
+ List parentList = new ArrayList<>(graph.getParents(node));
+ int[] parentArray = new int[parentList.size()];
+
+ for (int i = 0; i < parentList.size(); i++) {
+ parentArray[i] = im.getNodeIndex(parentList.get(i));
+ }
+
+ // Sort the parent array.
+ Arrays.sort(parentArray);
+
+ // Setup dimensions array for parents.
+ int[] dims = new int[parentArray.length];
+
+ for (int i = 0; i < dims.length; i++) {
+ Node parNode = im.getNode(parentArray[i]);
+ dims[i] = bayesPm.getNumCategories(parNode);
+ }
+
+ // Calculate dimensions of table.
+ int numRows = 1;
+
+ for (int dim : dims) {
+ numRows *= dim;
+ }
+
+ int numCols = bayesPm.getNumCategories(node);
+
+ CptMapCounts counts = new CptMapCounts(numRows, numCols);
+
+ for (int row = 0; row < dataSet.getNumRows(); row++) {
+ int[] parentValues = new int[parentArray.length];
+
+ for (int i = 0; i < parentValues.length; i++) {
+ parentValues[i] = dataSet.getInt(row, parentArray[i]);
}
+
+ int value = dataSet.getInt(row, nodeIndex);
+
+ counts.addCounts(im.getRowIndex(nodeIndex, parentValues), value, 1);
}
- }
-// System.out.println(estimatedIm);
+ im.setCountMap(nodeIndex, counts);
+ }
- return estimatedIm;
+ return im;
}
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java
new file mode 100644
index 0000000000..0717b628a2
--- /dev/null
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java
@@ -0,0 +1,136 @@
+///////////////////////////////////////////////////////////////////////////////
+// For information as to what this class does, see the Javadoc, below. //
+// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
+// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
+// Scheines, Joseph Ramsey, and Clark Glymour. //
+// //
+// This program is free software; you can redistribute it and/or modify //
+// it under the terms of the GNU General Public License as published by //
+// the Free Software Foundation; either version 2 of the License, or //
+// (at your option) any later version. //
+// //
+// This program is distributed in the hope that it will be useful, //
+// but WITHOUT ANY WARRANTY; without even the implied warranty of //
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
+// GNU General Public License for more details. //
+// //
+// You should have received a copy of the GNU General Public License //
+// along with this program; if not, write to the Free Software //
+// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
+///////////////////////////////////////////////////////////////////////////////
+
+package edu.cmu.tetrad.bayes;
+
+import edu.cmu.tetrad.data.DataSet;
+import edu.cmu.tetrad.graph.Node;
+
+import java.util.List;
+
+/**
+ * Estimates parameters of the given Bayes net from the given data using maximum likelihood method.
+ *
+ * @author Shane Harwood, Joseph Ramsey
+ * @version $Id: $Id
+ */
+public final class MlBayesEstimatorOld {
+
+ /**
+ * Constructor for MlBayesEstimator.
+ */
+ public MlBayesEstimatorOld() {
+
+ }
+
+ /**
+ * 33 Estimates a Bayes IM using the variables, graph, and parameters in the given Bayes PM and the data columns in
+ * the given data set. Each variable in the given Bayes PM must be equal to a variable in the given data set.
+ *
+ * @param bayesPm a {@link BayesPm} object
+ * @param dataSet a {@link DataSet} object
+ * @return a {@link BayesIm} object
+ */
+ public BayesIm estimate(BayesPm bayesPm, DataSet dataSet) {
+ if (bayesPm == null) {
+ throw new NullPointerException();
+ }
+
+ if (dataSet == null) {
+ throw new NullPointerException();
+ }
+
+// if (DataUtils.containsMissingValue(dataSet)) {
+// throw new IllegalArgumentException("Please remove or impute missing values.");
+// }
+
+ // Make sure all of the variables in the PM are in the data set;
+ // otherwise, estimation is impossible.
+ BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
+
+ // Create a new Bayes IM to store the estimated values.
+ BayesIm estimatedIm = new MlBayesIm(bayesPm);
+
+ // Create a subset of the data set with the variables of the IM, in
+ // the order of the IM.
+ List variables = estimatedIm.getVariables();
+ DataSet columnDataSet2 = dataSet.subsetColumns(variables);
+ DiscreteProbs discreteProbs = new DataSetProbs(columnDataSet2);
+
+ // We will use the same estimation methods as the updaters, to ensure
+ // compatibility.
+ Proposition assertion = Proposition.tautology(estimatedIm);
+ Proposition condition = Proposition.tautology(estimatedIm);
+ Evidence evidence2 = Evidence.tautology(estimatedIm);
+
+ int numNodes = estimatedIm.getNumNodes();
+
+ for (int node = 0; node < numNodes; node++) {
+ int numRows = estimatedIm.getNumRows(node);
+ int numCols = estimatedIm.getNumColumns(node);
+ int[] parents = estimatedIm.getParents(node);
+
+ for (int row = 0; row < numRows; row++) {
+ int[] parentValues = estimatedIm.getParentValues(node, row);
+
+ for (int col = 0; col < numCols; col++) {
+
+ // Remove values from the proposition in various ways; if
+ // a combination exists in the end, calculate a conditional
+ // probability.
+ assertion.setToTautology();
+ condition.setToTautology();
+
+ for (int i = 0; i < numNodes; i++) {
+ for (int j = 0; j < evidence2.getNumCategories(i); j++) {
+ if (!evidence2.getProposition().isAllowed(i, j)) {
+ condition.removeCategory(i, j);
+ }
+ }
+ }
+
+ assertion.disallowComplement(node, col);
+
+ for (int k = 0; k < parents.length; k++) {
+ condition.disallowComplement(parents[k], parentValues[k]);
+ }
+
+ if (condition.existsCombination()) {
+ double p = discreteProbs.getConditionalProb(assertion, condition);
+// if (Double.isNaN(p)) p = 1.0 / numCols;
+ estimatedIm.setProbability(node, row, col, p);
+ } else {
+ estimatedIm.setProbability(node, row, col, Double.NaN);
+ }
+ }
+ }
+ }
+
+// System.out.println(estimatedIm);
+
+ return estimatedIm;
+ }
+}
+
+
+
+
+
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java
index 90514d7f2a..f828e231af 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java
@@ -27,7 +27,7 @@
import edu.cmu.tetrad.graph.TimeLagGraph;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
-import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+import edu.cmu.tetrad.util.Vector;
import java.io.IOException;
import java.io.ObjectInputStream;
@@ -36,12 +36,11 @@
import java.util.*;
import static org.apache.commons.math3.util.FastMath.abs;
-import static org.apache.commons.math3.util.FastMath.pow;
/**
* Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate
* this table. The division of labor is as follows. The Dag is responsible for manipulating the basic graphical
- * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there are no method
+ * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there is no method
* in either BayesPm or BayesIm to do this. BayesPm stores and manipulates the *categories* of each node in a DAG,
* considered as a variable in a Bayes net. The number of categories for a variable can be changed there as well as the
* names for those categories. This class, BayesIm, stores the actual probability tables which are implied by the
@@ -50,47 +49,35 @@
* probabilities is organized in this class as a three-dimensional table of double values. The first dimension
* corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of
* combinations of parent categories for that node. The third dimension corresponds to the list of categories for that
- * node itself. Two methods allow these values to be set and retrieved: - getWordRatio(int nodeIndex, int
- * rowIndex, int colIndex); and,
- setProbability(int nodeIndex, int rowIndex, int colIndex, int probability).
- * To determine the index of the node in question, use the method To determine
- * the index of the row in question, use the method
- * - getRowIndex(int[] parentVals).
To determine the order of the
- * parent values for a given node so that you can build the parentVals[] array,
- * use the method - getParents(int nodeIndex)
To determine the
- * index of a category, use the method - getCategoryIndex(Node node)
- *
in BayesPm. The rest of the methods in this class are easily understood
- * as variants of the methods above.
+ * node itself. Two methods allow these values to be set and retrieved: getWordRatio(int nodeIndex, int rowIndex, int
+ * colIndex); and setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of
+ * the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, use the
+ * method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that you can
+ * build the parentVals[] array, use the method getParents(int nodeIndex). To determine the index of a category, use the
+ * method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood as
+ * variants of the methods above.
*
- * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for
- * advice and earlier versions.
+ * This version uses a sparse method for storing the probabilities, where NaNs are not stored. This allows BayesPms with
+ * many categories per variable to be estimated from small samples without overflowing memory. The old method of storing
+ * probabilities is kept here for backward compatibility, with an internal code flag to indicate which should be used.
+ *
+ * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions.
*
* @author josephramsey
* @version $Id: $Id
*/
public final class MlBayesIm implements BayesIm {
-
- /**
- * Inidicates that new rows in this BayesIm should be initialized as unknowns, forcing them to be specified
- * manually. This is the default.
- */
- public static final int MANUAL = 0;
- /**
- * Indicates that new rows in this BayesIm should be initialized randomly.
- */
public static final int RANDOM = 1;
@Serial
private static final long serialVersionUID = 23L;
-
/**
* Tolerance.
*/
private static final double ALLOWABLE_DIFFERENCE = 1.0e-3;
-
/**
* Random number generator.
*/
static private final Random random = new Random();
-
/**
* The associated Bayes PM model.
*/
@@ -99,6 +86,7 @@ public final class MlBayesIm implements BayesIm {
* The array of nodes from the graph. Order is important.
*/
private final Node[] nodes;
+ private CptMapType cptMapType = null;
/**
* The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes',
* and order in subarrays is important.
@@ -108,8 +96,6 @@ public final class MlBayesIm implements BayesIm {
* The array of dimensionality (number of categories for each node) for each of the subarrays of 'parents'.
*/
private int[][] parentDims;
-
- //===============================CONSTRUCTORS=========================//
/**
* The main data structure; stores the values of all of the conditional probabilities for the Bayes net of the form
* P(N=v0 | P1=v1, P2=v2,...). The first dimension is the node N, in the order of 'nodes'. The second dimension is
@@ -118,10 +104,16 @@ public final class MlBayesIm implements BayesIm {
* for each of the parent values; the order of the values in this array is the same as the order of node in
* 'parents'; the value indices are obtained from the Bayes PM for each node. The column is the index of the value
* of N, where this index is obtained from the Bayes PM.
- *
- * @serial
+ *
+ * This is kept here for backward compatibility. The new way of storing the probabilities is in the probMatrices
+ * array.
*/
private double[][][] probs;
+ /**
+ * The array of CPT maps for each node. The index of the node corresponds to the index of the probability map in
+ * this array. Replaces the probs array.
+ */
+ private CptMap[] probMatrices;
/**
* Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?").
@@ -131,7 +123,7 @@ public final class MlBayesIm implements BayesIm {
* contained in the bayes parametric model provided.
*/
public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException {
- this(bayesPm, null, MlBayesIm.MANUAL);
+ this(bayesPm, null, InitializationMethod.MANUAL);
}
/**
@@ -144,16 +136,90 @@ public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException {
* @throws java.lang.IllegalArgumentException if the array of nodes provided is not a permutation of the nodes
* contained in the bayes parametric model provided.
*/
- public MlBayesIm(BayesPm bayesPm, int initializationMethod)
+ public MlBayesIm(BayesPm bayesPm, InitializationMethod initializationMethod)
throws IllegalArgumentException {
this(bayesPm, null, initializationMethod);
}
+ /**
+ * Constructs an instance of MlBayesIm.
+ *
+ * @param bayesPm the BayesPm object that represents the Bayesian network.
+ * @param countsOnly should be set to true for this constructor.
+ * @throws IllegalArgumentException if countsOnly is false.
+ * @throws NullPointerException if bayesPm is null.
+ */
+ public MlBayesIm(BayesPm bayesPm, boolean countsOnly) {
+ if (!countsOnly) {
+ throw new IllegalArgumentException("countsOnly must be true for this constructor.");
+ }
+
+ if (bayesPm == null) {
+ throw new NullPointerException("BayesPm must not be null.");
+ }
+
+
+
+ this.bayesPm = new BayesPm(bayesPm);
+
+ // Get the nodes from the BayesPm. This fixes the order of the nodes
+ // in the BayesIm, independently of any change to the BayesPm.
+ // (This order must be maintained.)
+ Graph graph = bayesPm.getDag();
+ this.nodes = graph.getNodes().toArray(new Node[0]);
+
+ this.cptMapType = CptMapType.COUNT_MAP;
+
+ this.parents = new int[this.nodes.length][];
+ this.parentDims = new int[this.nodes.length][];
+ this.probs = new double[this.nodes.length][][];
+ this.probMatrices = new CptMapCounts[this.nodes.length];
+
+ for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) {
+ Node node = this.nodes[nodeIndex];
+
+ // Set up parents array. Should store the parents of
+ // each node as ints in a particular order.
+ List parentList = new ArrayList<>(graph.getParents(node));
+ int[] parentArray = new int[parentList.size()];
+
+ for (int i = 0; i < parentList.size(); i++) {
+ parentArray[i] = getNodeIndex(parentList.get(i));
+ }
+
+ // Sort parent array.
+ Arrays.sort(parentArray);
+
+ this.parents[nodeIndex] = parentArray;
+
+ // Setup dimensions array for parents.
+ int[] dims = new int[parentArray.length];
+
+ for (int i = 0; i < dims.length; i++) {
+ Node parNode = this.nodes[parentArray[i]];
+ dims[i] = getBayesPm().getNumCategories(parNode);
+ }
+
+ // Calculate dimensions of table.
+ int numRows = 1;
+
+ for (int dim : dims) {
+ numRows *= dim;
+ }
+
+ int numCols = getBayesPm().getNumCategories(node);
+
+ this.parentDims[nodeIndex] = dims;
+ this.probs[nodeIndex] = new double[numRows][numCols];
+ this.probMatrices[nodeIndex] = new CptMapCounts(numRows, numCols);
+ }
+ }
+
/**
* Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values
* from the old BayesIm provided where posssible. If initialized manually, all values that cannot be retrieved from
* oldBayesIm will be set to Double.NaN ("?") in each such row; if initialized randomly, all values that cannot be
- * retrieved from oldBayesIm will distributed randomly in each such row.
+ * retrieved from oldBayesIm will be distributed randomly in each such row.
*
* @param bayesPm the given Bayes PM. Carries with it the underlying graph model.
* @param oldBayesIm an already-constructed BayesIm whose values may be used where possible to initialize
@@ -163,7 +229,7 @@ public MlBayesIm(BayesPm bayesPm, int initializationMethod)
* contained in the bayes parametric model provided.
*/
public MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm,
- int initializationMethod) throws IllegalArgumentException {
+ InitializationMethod initializationMethod) throws IllegalArgumentException {
if (bayesPm == null) {
throw new NullPointerException("BayesPm must not be null.");
}
@@ -203,7 +269,7 @@ public MlBayesIm(BayesIm bayesIm) throws IllegalArgumentException {
}
// Copy all the old values over.
- initialize(bayesIm, MlBayesIm.MANUAL);
+ initialize(bayesIm, InitializationMethod.MANUAL);
}
/**
@@ -215,8 +281,6 @@ public static MlBayesIm serializableInstance() {
return new MlBayesIm(BayesPm.serializableInstance());
}
- //===============================PUBLIC METHODS========================//
-
/**
* getParameterNames.
*
@@ -251,6 +315,8 @@ private static double[] getRandomWeights(int size) {
return row;
}
+ //===============================PUBLIC METHODS========================//
+
/**
* Getter for the field bayesPm
.
*
@@ -345,11 +411,11 @@ public List getMeasuredNodes() {
* @return a {@link java.util.List} object
*/
public List getVariableNames() {
+ List nodes = getVariables();
List variableNames = new LinkedList<>();
- for (int i = 0; i < getNumNodes(); i++) {
- Node node = getNode(i);
- variableNames.add(this.bayesPm.getVariable(node).getName());
+ for (Node node : nodes) {
+ variableNames.add(node.getName());
}
return variableNames;
@@ -362,7 +428,11 @@ public List getVariableNames() {
* @return the number of columns.
*/
public int getNumColumns(int nodeIndex) {
- return this.probs[nodeIndex][0].length;
+ if (cptMapType != null) {
+ return probMatrices[nodeIndex].getNumColumns();
+ } else {
+ return this.probs[nodeIndex][0].length;
+ }
}
/**
@@ -372,7 +442,11 @@ public int getNumColumns(int nodeIndex) {
* @return the number of rows in the node.
*/
public int getNumRows(int nodeIndex) {
- return this.probs[nodeIndex].length;
+ if (cptMapType != null) {
+ return probMatrices[nodeIndex].getNumRows();
+ } else {
+ return this.probs[nodeIndex].length;
+ }
}
/**
@@ -474,7 +548,11 @@ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) {
* @return the probability value for the given node.
*/
public double getProbability(int nodeIndex, int rowIndex, int colIndex) {
- return this.probs[nodeIndex][rowIndex][colIndex];
+ if (cptMapType != null) {
+ return probMatrices[nodeIndex].get(rowIndex, colIndex);
+ } else {
+ return this.probs[nodeIndex][rowIndex][colIndex];
+ }
}
/**
@@ -556,8 +634,22 @@ public void normalizeRow(int nodeIndex, int rowIndex) {
*/
@Override
public void setProbability(int nodeIndex, double[][] probMatrix) {
- for (int i = 0; i < probMatrix.length; i++) {
- System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length);
+ if (cptMapType == CptMapType.PROB_MAP) {
+ probMatrices[nodeIndex] = new CptMapProbs(probMatrix);
+ } else if (cptMapType == CptMapType.COUNT_MAP) {
+ throw new IllegalArgumentException("Cannot set probability matrix for an estimated Bayes IM.");
+ } else {
+ for (int i = 0; i < probMatrix.length; i++) {
+ System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length);
+ }
+ }
+ }
+
+ public void setCountMap(int nodeIndex, CptMapCounts countMap) {
+ if (cptMapType == CptMapType.COUNT_MAP) {
+ probMatrices[nodeIndex] = countMap;
+ } else {
+ throw new IllegalArgumentException("Cannot set count map for a non-estimated Bayes IM.");
}
}
@@ -584,7 +676,13 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex,
+ "between 0.0 and 1.0 or Double.NaN.");
}
- this.probs[nodeIndex][rowIndex][colIndex] = value;
+ if (cptMapType == CptMapType.PROB_MAP) {
+ ((CptMapProbs) probMatrices[nodeIndex]).set(rowIndex, colIndex, value);
+ } else if (cptMapType == CptMapType.COUNT_MAP) {
+ throw new IllegalArgumentException("Cannot set probability value for an estimated Bayes IM.");
+ } else {
+ this.probs[nodeIndex][rowIndex][colIndex] = value;
+ }
}
/**
@@ -620,7 +718,19 @@ public void clearRow(int nodeIndex, int rowIndex) {
*/
public void randomizeRow(int nodeIndex, int rowIndex) {
int size = getNumColumns(nodeIndex);
- this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size);
+ double[] row = getRandomWeights(size);
+
+ if (cptMapType == CptMapType.PROB_MAP) {
+ for (int colIndex = 0; colIndex < size; colIndex++) {
+ ((CptMapProbs) probMatrices[nodeIndex]).set(rowIndex, colIndex, row[colIndex]);
+ }
+ } else if (cptMapType == CptMapType.COUNT_MAP) {
+ throw new IllegalArgumentException("Cannot randomize row for an estimated Bayes IM.");
+ } else {
+ this.probs[nodeIndex][rowIndex] = row;
+ }
+
+// this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size);
}
/**
@@ -647,95 +757,6 @@ public void randomizeTable(int nodeIndex) {
}
}
- private int score(int nodeIndex) {
- double[][] p = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)];
- copy(this.probs[nodeIndex], p);
- int num = 0;
-
- int numRows = getNumRows(nodeIndex);
-
- for (int r = 0; r < p.length; r++) {
- for (int c = 0; c < p[0].length; c++) {
- p[r][c] /= numRows;
- }
- }
-
- int[] parents = getParents(nodeIndex);
-
- for (int t = 0; t < parents.length; t++) {
- int numParentValues = getParentDim(nodeIndex, t);
- int numColumns = getNumColumns(nodeIndex);
-
- double[][] table = new double[numParentValues][numColumns];
-
- for (int childCol = 0; childCol < numColumns; childCol++) {
- for (int parentValue = 0; parentValue < numParentValues; parentValue++) {
- for (int row = 0; row < numRows; row++) {
- if (getParentValues(nodeIndex, row)[t] == parentValue) {
- table[parentValue][childCol] += p[row][childCol];
- }
- }
- }
- }
-
- final double N = 1000.0;
-
- for (int r = 0; r < table.length; r++) {
- for (int c = 0; c < table[0].length; c++) {
- table[r][c] *= N;
- }
- }
-
- double chisq = 0.0;
-
- for (int r = 0; r < table.length; r++) {
- for (int c = 0; c < table[0].length; c++) {
- double _sumRow = sumRow(table, r);
- double _sumCol = sumCol(table, c);
- double exp = (_sumRow / N) * (_sumCol / N) * N;
- double obs = table[r][c];
- chisq += pow(obs - exp, 2) / exp;
- }
- }
-
- int dof = (table.length - 1) * (table[0].length - 1);
-
- ChiSquaredDistribution distribution = new ChiSquaredDistribution(dof);
- double prob = 1 - distribution.cumulativeProbability(chisq);
-
- num += prob < 0.0001 ? 1 : 0;
- }
-
-// return num == parents.length ? -score : 0;
- return num;
- }
-
- private double sumCol(double[][] marginals, int j) {
- double sum = 0.0;
-
- for (double[] marginal : marginals) {
- sum += marginal[j];
- }
-
- return sum;
- }
-
- private double sumRow(double[][] marginals, int i) {
- double sum = 0.0;
-
- for (int h = 0; h < marginals[i].length; h++) {
- sum += marginals[i][h];
- }
-
- return sum;
- }
-
- private void copy(double[][] a, double[][] b) {
- for (int r = 0; r < a.length; r++) {
- System.arraycopy(a[r], 0, b[r], 0, a[r].length);
- }
- }
-
/**
* Clears the table by clearing all rows for the given node.
*
@@ -1016,75 +1037,6 @@ private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved, int
private void constructSample(int sampleSize, DataSet dataSet, int[] map, int[] tiers) {
-// //Do the simulation.
-// class SimulationTask extends RecursiveTask {
-// private int chunk;
-// private int from;
-// private int to;
-// private int[] tiers;
-// private DataSet dataSet;
-// private int[] map;
-//
-// public SimulationTask(int chunk, int from, int to, int[] tiers, DataSet dataSet, int[] map) {
-// this.chunk = chunk;
-// this.from = from;
-// this.to = to;
-// this.tiers = tiers;
-// this.dataSet = dataSet;
-// this.map = map;
-// }
-//
-// @Override
-// protected Boolean compute() {
-// if (to - from <= chunk) {
-// RandomGenerator randomGenerator = new Well1024a(++seed[0]);
-//
-// for (int row = from; row < to; row++) {
-// for (int t : tiers) {
-// int[] parentValues = new int[parents[t].length];
-//
-// for (int k = 0; k < parentValues.length; k++) {
-// parentValues[k] = dataSet.getInt(row, parents[t][k]);
-// }
-//
-// int rowIndex = getRowIndex(t, parentValues);
-// double sum = 0.0;
-// double r;
-//
-// r = randomGenerator.nextDouble();
-//
-// for (int k = 0; k < getNumColumns(t); k++) {
-// double probability = getProbability(t, rowIndex, k);
-// sum += probability;
-//
-// if (sum >= r) {
-// dataSet.setInt(row, map[t], k);
-// break;
-// }
-// }
-// }
-// }
-//
-// return true;
-// } else {
-// int mid = (to + from) / 2;
-// SimulationTask left = new SimulationTask(chunk, from, mid, tiers, dataSet, map);
-// SimulationTask right = new SimulationTask(chunk, mid, to, tiers, dataSet, map);
-//
-// left.fork();
-// right.compute();
-// left.join();
-//
-// return true;
-// }
-// }
-// }
-//
-// int chunk = 25;
-//
-// ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
-// SimulationTask task = new SimulationTask(chunk, 0, sampleSize, tiers, dataSet, map);
-// pool.invoke(task);
// Construct the sample.
for (int i = 0; i < sampleSize; i++) {
for (int t : tiers) {
@@ -1167,8 +1119,6 @@ public boolean equals(Object o) {
return true;
}
- //=============================PRIVATE METHODS=======================//
-
/**
* Prints out the probability table for each variable.
*
@@ -1221,21 +1171,30 @@ public String toString() {
* @see #initializeNode
* @see #randomizeRow
*/
- private void initialize(BayesIm oldBayesIm, int initializationMethod) {
+ private void initialize(BayesIm oldBayesIm, InitializationMethod initializationMethod) {
this.parents = new int[this.nodes.length][];
this.parentDims = new int[this.nodes.length][];
this.probs = new double[this.nodes.length][][];
+ this.probMatrices = new CptMapProbs[this.nodes.length];
+
+ this.cptMapType = CptMapType.PROB_MAP;
for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) {
initializeNode(nodeIndex, oldBayesIm, initializationMethod);
}
}
+ //=============================PRIVATE METHODS=======================//
+
/**
* This method initializes the node indicated.
*/
private void initializeNode(int nodeIndex, BayesIm oldBayesIm,
- int initializationMethod) {
+ InitializationMethod initializationMethod) {
+ if (cptMapType != CptMapType.PROB_MAP) {
+ throw new IllegalArgumentException("Cannot initialize a Bayes IM randomly without a probability map.");
+ }
+
Node node = this.nodes[nodeIndex];
// Set up parents array. Should store the parents of
@@ -1265,15 +1224,6 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm,
int numRows = 1;
for (int dim : dims) {
- if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) {
- throw new IllegalArgumentException(
- "The number of rows in the "
- + "conditional probability table for "
- + this.nodes[nodeIndex]
- + " is greater than 1,000,000 and cannot be "
- + "represented.");
- }
-
numRows *= dim;
}
@@ -1281,9 +1231,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm,
this.parentDims[nodeIndex] = dims;
this.probs[nodeIndex] = new double[numRows][numCols];
+ this.probMatrices[nodeIndex] = new CptMapProbs(numRows, numCols);
// Initialize each row.
- if (initializationMethod == MlBayesIm.RANDOM) {
+ if (initializationMethod == InitializationMethod.RANDOM) {
randomizeTable(nodeIndex);
} else {
for (int rowIndex = 0; rowIndex < numRows; rowIndex++) {
@@ -1298,10 +1249,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm,
}
private void overwriteRow(int nodeIndex, int rowIndex,
- int initializationMethod) {
- if (initializationMethod == MlBayesIm.RANDOM) {
+ InitializationMethod initializationMethod) {
+ if (initializationMethod == InitializationMethod.RANDOM) {
randomizeRow(nodeIndex, rowIndex);
- } else if (initializationMethod == MlBayesIm.MANUAL) {
+ } else if (initializationMethod == InitializationMethod.MANUAL) {
initializeRowAsUnknowns(nodeIndex, rowIndex);
} else {
throw new IllegalArgumentException("Unrecognized state.");
@@ -1312,14 +1263,21 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) {
int size = getNumColumns(nodeIndex);
double[] row = new double[size];
Arrays.fill(row, Double.NaN);
- this.probs[nodeIndex][rowIndex] = row;
+
+ if (cptMapType == CptMapType.PROB_MAP) {
+ ((CptMapProbs) probMatrices[nodeIndex]).assignRow(rowIndex, new Vector(row));
+ } else if (cptMapType == CptMapType.COUNT_MAP) {
+ throw new IllegalArgumentException("Cannot initialize a row as unknowns in an estimated Bayes IM.");
+ } else {
+ this.probs[nodeIndex][rowIndex] = row;
+ }
}
/**
* This method initializes the node indicated.
*/
private void retainOldRowIfPossible(int nodeIndex, int rowIndex,
- BayesIm oldBayesIm, int initializationMethod) {
+ BayesIm oldBayesIm, InitializationMethod initializationMethod) {
int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm);
@@ -1396,42 +1354,6 @@ private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex,
}
}
-// // Go through each parent of the node in the new BayesIm.
-// for (int i = 0; i < oldBayesIm.getNumParents(oldNodeIndex); i++) {
-//
-// // Get the index of the parent in the new graph and in the old
-// // graph. If it's no longer in the new graph, skip to the next
-// // parent.
-// int oldParentNodeIndex = oldBayesIm.getParent(oldNodeIndex, i);
-// int parentNodeIndex =
-// oldBayesIm.getCorrespondingNodeIndex(oldParentNodeIndex, this);
-// int parentIndex = -1;
-//
-// for (int j = 0; j < this.getNumParents(nodeIndex); j++) {
-// if (parentNodeIndex == this.getParent(nodeIndex, j)) {
-// parentIndex = j;
-// break;
-// }
-// }
-//
-// if (parentIndex == -1 ||
-// parentIndex >= this.getNumParents(nodeIndex)) {
-// continue;
-// }
-//
-// // Look up that value index for the new BayesIm for that parent.
-// // If it was a valid value index in the old BayesIm, record
-// // that value in oldParentValues. Otherwise return -1.
-// int parentValue = oldParentValues[i];
-// int parentDim =
-// this.getParentDim(nodeIndex, parentIndex);
-//
-// if (parentValue < parentDim) {
-// oldParentValues[parentIndex] = oldParentValue;
-// } else {
-// return -1;
-// }
-// }
// If there are any -1's in the combination at this point, return -1.
for (int oldParentValue : oldParentValues) {
if (oldParentValue == -1) {
@@ -1492,8 +1414,42 @@ private void readObject(ObjectInputStream s)
throw new NullPointerException();
}
- if (this.probs == null) {
- throw new NullPointerException();
+ copyDataToProbMatrices();
+ }
+
+ /**
+ * Copies data from the `probs` array to the `probMatrices` array. If the lengths of both arrays are equal, the
+ * `probMatrices` array is initialized with `ProbMap` objects, each containing the corresponding `probs` element.
+ * The `probs` array is then set to null and the `useProbMatrices` flag is set to true.
+ *
+ * Note: This method should only be called after the `probs` array has been properly initialized.
+ */
+ private void copyDataToProbMatrices() {
+ if (cptMapType == null && this.probs != null && this.probs.length == this.nodes.length) {
+ this.probMatrices = new CptMapProbs[this.probs.length];
+
+ for (int i = 0; i < this.nodes.length; i++) {
+ probMatrices[i] = new CptMapProbs(this.probs[i]);
+ }
+
+ this.probs = null;
+ this.cptMapType = CptMapType.PROB_MAP;
}
}
+
+ /**
+ * A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used.
+ * The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility.
+ */
+ public CptMapType getCptMapType() {
+ return cptMapType;
+ }
+
+ public enum CptMapType {
+ PROB_MAP, COUNT_MAP
+ }
+
+ public enum InitializationMethod {
+ MANUAL, RANDOM
+ }
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java
index bd330396bf..35e0872e22 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java
@@ -346,7 +346,7 @@ private void updateAll() {
private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) {
// Switching this to MANUAL since the initial values don't matter.
- return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL);
+ return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL);
}
private BayesPm createUpdatedBayesPm(Dag updatedGraph) {
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java
index c0a8a5c9be..7e8916e1fa 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java
@@ -228,7 +228,7 @@ public static Graph randomGraphRandomForwardEdges(List nodes, int numLaten
* negative or exceeds the number of nodes
*/
public static Graph randomGraphRandomForwardEdges(List nodes, int numLatentConfounders, int numEdges, int maxDegree, int maxIndegree, int maxOutdegree, boolean connected, boolean layoutAsCircle) {
- if (nodes.size() == 0) {
+ if (nodes.isEmpty()) {
throw new IllegalArgumentException("NumNodes most be > 0");
}
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java
index 7ff2a15a0a..e1d94a9839 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java
@@ -65,7 +65,7 @@ public static List run(int numVars, double edgesPerNode, int numCases,
Graph odag = RandomGraph.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true);
BayesPm bayesPm = new BayesPm(odag, 2, 2);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
//oData is the original data set, and odag is the original dag.
DataSet oData = bayesIm.simulateData(numCases, false);
//System.out.println(oData);
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java
index 1e638dcb6a..28fb2e581d 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java
@@ -112,7 +112,7 @@ public static ComparisonResult compare(ComparisonParameters params) {
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
- MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java
index 938c38e0ce..aa054ecdef 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java
@@ -303,7 +303,7 @@ public static ComparisonResult compare(ComparisonParameters params) {
}
BayesPm pm = new BayesPm(trueDag, 3, 3);
- MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
dataSet = im.simulateData(params.getSampleSize(), false, tiers);
} else {
throw new IllegalArgumentException("Unrecognized data type.");
diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java
index c831818858..66d4e42b74 100644
--- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java
+++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java
@@ -1026,7 +1026,7 @@ private void testFges(int numVars, double edgeFactor, int numCases, int numRuns,
} else {
BayesPm pm = new BayesPm(dag, 3, 3);
- MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data = im.simulateData(numCases, false, tiers);
@@ -1258,7 +1258,7 @@ private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRun
this.out.println(new Date());
BayesPm pm = new BayesPm(dag, 3, 3);
- MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data = im.simulateData(numCases, false, tiers);
diff --git a/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java b/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java
index 6588f2ccf9..06e0341818 100644
--- a/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java
+++ b/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java
@@ -72,7 +72,7 @@ public static void main(String[] args) throws Exception {
Graph graph = RandomGraph.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true);
BayesPm pm = new BayesPm(graph);
- BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data = im.simulateData(rows, false);
// This implementation uses a DataTable to represent the data
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java
index a24ce17cbd..b0cde7abc8 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java
@@ -53,7 +53,7 @@ public void testPValue() {
final int numCategories = 8;
BayesPm pm = new BayesPm(graph, numCategories, numCategories);
- BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data = im.simulateData(1000, false);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java
index 88f949f05d..6439600537 100755
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java
@@ -66,7 +66,7 @@ public void testCopyConstructor() {
Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
BayesIm bayesIm2 = new MlBayesIm(bayesIm);
assertEquals(bayesIm, bayesIm2);
}
@@ -102,9 +102,9 @@ public void testAddRemoveParent() {
dag.addDirectedEdge(a, b);
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
- BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL);
+ BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.InitializationMethod.MANUAL);
assertEquals(bayesIm, bayesIm2);
@@ -113,7 +113,7 @@ public void testAddRemoveParent() {
dag.addDirectedEdge(c, b);
BayesPm bayesPm3 = new BayesPm(dag, bayesPm);
- BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.MANUAL);
+ BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.InitializationMethod.MANUAL);
// Make sure the rows got repeated.
// assertTrue(rowsEqual(bayesIm3, bayesIm3.getNodeIndex(b), 0, 1));
@@ -125,7 +125,7 @@ public void testAddRemoveParent() {
dag.removeNode(c);
BayesPm bayesPm4 = new BayesPm(dag, bayesPm3);
- BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.MANUAL);
+ BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.InitializationMethod.MANUAL);
// Make sure the 'b' node has 2 rows of '?'s'.
assertEquals(2, bayesIm4.getNumRows(bayesIm4.getNodeIndex(b)));
@@ -157,17 +157,17 @@ public void testAddRemoveValues() {
assertTrue(Edges.isDirectedEdge(dag.getEdge(a, b)));
BayesPm bayesPm = new BayesPm(dag, 3, 3);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
bayesPm.setNumCategories(a, 4);
bayesPm.setNumCategories(c, 4);
- BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL);
+ BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.InitializationMethod.MANUAL);
bayesPm.setNumCategories(a, 2);
- BayesIm bayesIm3 = new MlBayesIm(bayesPm, bayesIm2, MlBayesIm.MANUAL);
+ BayesIm bayesIm3 = new MlBayesIm(bayesPm, bayesIm2, MlBayesIm.InitializationMethod.MANUAL);
bayesPm.setNumCategories(b, 2);
- BayesIm bayesIm4 = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm4 = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
// At this point, a has 2 categories, b has 2 categories, and c has 4 categories.
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java
index aa1e746a92..979f547792 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java
@@ -112,7 +112,7 @@ private static BayesIm sampleBayesIm2() {
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
- return new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ return new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
}
private static BayesIm sampleBayesIm3() {
@@ -138,7 +138,7 @@ private static BayesIm sampleBayesIm3() {
BayesPm bayesPm = new BayesPm(graph);
bayesPm.setNumCategories(b, 3);
- return new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ return new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
}
/**
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java
index 66aa240bae..afebcb340e 100755
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java
@@ -73,7 +73,7 @@ public void testCreateUsingBayesIm() {
Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java
index 19d0b64e12..fde6221b59 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java
@@ -176,7 +176,7 @@ public void testUpdate4() {
graph.addDirectedEdge(x2Node, x3Node);
BayesPm bayesPm = new BayesPm(graph, 2, 2);
- MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
int x2 = bayesIm.getNodeIndex(x2Node);
int x3 = bayesIm.getNodeIndex(x3Node);
@@ -220,7 +220,7 @@ public void testUpdate5() {
graph.addDirectedEdge(x4Node, x2Node);
BayesPm bayesPm = new BayesPm(graph);
- MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
int x1 = bayesIm.getNodeIndex(x1Node);
int x2 = bayesIm.getNodeIndex(x2Node);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java
index 3542df3d69..c4ec3b3ddc 100755
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java
@@ -41,7 +41,7 @@ public void testCreateUsingBayesIm() {
Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
BayesImProbs bayesImProbs = new BayesImProbs(bayesIm);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java
index 5b470d4272..e160b8d476 100755
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java
@@ -79,7 +79,7 @@ public void testCreateUsingBayesIm() {
Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(graph);
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java
index e3e6d25416..5730004833 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java
@@ -227,7 +227,7 @@ public void explore2() {
30, 15, 15, false, true);
BayesPm pm = new BayesPm(dag, 2, 3);
- BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM);
+ BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data = im.simulateData(numCases, false);
BdeScore score = new BdeScore(data);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java
index 445afbc6c3..4653051cec 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java
@@ -247,7 +247,7 @@ public void testRandomDiscreteData() {
Graph g = GraphUtils.convert("X1-->X2,X1-->X3,X1-->X4,X2-->X3,X2-->X4,X3-->X4");
Dag dag = new Dag(g);
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data = bayesIm.simulateData(sampleSize, false);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java
index ee3403d4a5..f36f1ed4a1 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java
@@ -90,7 +90,7 @@ public void testHistogram() {
// Discrete
BayesPm bayesPm = new BayesPm(trueGraph);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
DataSet data2 = bayesIm.simulateData(sampleSize, false);
// For some reason these are giving different
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java
index 70ba06d08a..7b10b446cc 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java
@@ -81,7 +81,7 @@ public void testRandomDiscreteData() {
// set a number of latent variables
BayesPm bayesPm = new BayesPm(dag);
- BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
RandomUtil.getInstance().setSeed(seed);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java
index fd9ec8294f..ac2a354859 100644
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java
@@ -197,7 +197,7 @@ public void testUpdate4() {
graph.addDirectedEdge(x2Node, x3Node);
BayesPm bayesPm = new BayesPm(graph);
- MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
int x2 = bayesIm.getNodeIndex(x2Node);
int x3 = bayesIm.getNodeIndex(x3Node);
@@ -240,7 +240,7 @@ public void testUpdate5() {
graph.addDirectedEdge(x4Node, x2Node);
BayesPm bayesPm = new BayesPm(graph);
- MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
int x1 = bayesIm.getNodeIndex(x1Node);
int x2 = bayesIm.getNodeIndex(x2Node);
diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java
index 74a221e161..dbda137037 100755
--- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java
+++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java
@@ -59,7 +59,7 @@ public void testCompound() {
graph.addDirectedEdge(x4Node, x2Node);
BayesPm bayesPm = new BayesPm(graph);
- MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
+ MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);
UpdatedBayesIm updatedIm1 = new UpdatedBayesIm(bayesIm);
assertEquals(bayesIm, updatedIm1);