From 6e16cc06ec49858c4fe9805cc544b4eba3f018b0 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 01:11:12 -0400 Subject: [PATCH 01/23] Add MlBayesImOld.java for representing directed acyclic graphs Added a new implementation of the BayesIm interface, MlBayesImOld.java, for representing directed acyclic graphs in Bayes nets. This class also supports operations for manipulating the node tables including setting and retrieving probabilities, normalizing node tables, and checking if table rows are incomplete. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 225 ++- .../edu/cmu/tetrad/bayes/MlBayesImOld.java | 1499 +++++++++++++++++ 2 files changed, 1601 insertions(+), 123 deletions(-) create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 71e22172dc..4853819ddb 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -25,8 +25,10 @@ import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.graph.Paths; import edu.cmu.tetrad.graph.TimeLagGraph; +import edu.cmu.tetrad.util.Matrix; import edu.cmu.tetrad.util.NumberFormatUtil; import edu.cmu.tetrad.util.RandomUtil; +import edu.cmu.tetrad.util.Vector; import org.apache.commons.math3.distribution.ChiSquaredDistribution; import java.io.IOException; @@ -50,19 +52,22 @@ * probabilities is organized in this class as a three-dimensional table of double values. The first dimension * corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of * combinations of parent categories for that node. The third dimension corresponds to the list of categories for that - * node itself. Two methods allow these values to be set and retrieved: - * To determine the index of the node in question, use the method To determine - * the index of the row in question, use the method - * To determine the order of the - * parent values for a given node so that you can build the parentVals[] array, - * use the method To determine the - * index of a category, use the method in BayesPm. The rest of the methods in this class are easily understood + * node itself. Two methods allow these values to be set and retrieved: getWordRatio(int nodeIndex, int rowIndex, int + * colIndex); and,
  • setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the + * index of the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, + * use the method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that + * you can build the parentVals[] array, use the method getParents(int nodeIndex) To determine the index of a category, + * use the method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood * as variants of the methods above. *

    - * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for - * advice and earlier versions. + * This version uses a different method for storing the probabilities. The previous version stored the probabilities in + * a three-dimensional array, where the first dimension was the node, the second dimension was the row index, and the + * third dimension was the column index. This version stores and array of Matrix objects, where each Matrix object + * represents the conditional probability table for a node. This will allow us in the future to represent this as an of + * Maps from Integers to Doubles and store only the non-NaN values. This will save space and time in the case of sparse + * tables. + *

    + * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions. * * @author josephramsey * @version $Id: $Id @@ -123,6 +128,13 @@ public final class MlBayesIm implements BayesIm { */ private double[][][] probs; + /** + * The array of matrices that store the probabilities for each node. + */ + private Matrix[] probMatrices; + + boolean useProbMatrices = true; + /** * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). * @@ -362,7 +374,14 @@ public List getVariableNames() { * @return the number of columns. */ public int getNumColumns(int nodeIndex) { - return this.probs[nodeIndex][0].length; + if (useProbMatrices) { + return probMatrices[nodeIndex].getNumColumns(); + } else { + return this.probs[nodeIndex][0].length; + } + +// return this.probs[nodeIndex][0].length; + } /** @@ -372,7 +391,13 @@ public int getNumColumns(int nodeIndex) { * @return the number of rows in the node. */ public int getNumRows(int nodeIndex) { - return this.probs[nodeIndex].length; + if (useProbMatrices) { + return probMatrices[nodeIndex].getNumRows(); + } else { + return this.probs[nodeIndex].length; + } + +// return this.probs[nodeIndex].length; } /** @@ -474,7 +499,13 @@ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { * @return the probability value for the given node. */ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { - return this.probs[nodeIndex][rowIndex][colIndex]; + if (useProbMatrices) { + return probMatrices[nodeIndex].get(rowIndex, colIndex); + } else { + return this.probs[nodeIndex][rowIndex][colIndex]; + } + +// return this.probs[nodeIndex][rowIndex][colIndex]; } /** @@ -556,9 +587,17 @@ public void normalizeRow(int nodeIndex, int rowIndex) { */ @Override public void setProbability(int nodeIndex, double[][] probMatrix) { - for (int i = 0; i < probMatrix.length; i++) { - System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); + if (useProbMatrices) { + probMatrices[nodeIndex] = new Matrix(probMatrix); + } else { + for (int i = 0; i < probMatrix.length; i++) { + System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); + } } + +// for (int i = 0; i < probMatrix.length; i++) { +// System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); +// } } /** @@ -576,15 +615,21 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex, double value) { if (colIndex >= getNumColumns(nodeIndex)) { throw new IllegalArgumentException("Column out of range: " - + colIndex + " >= " + getNumColumns(nodeIndex)); + + colIndex + " >= " + getNumColumns(nodeIndex)); } if (!(0.0 <= value && value <= 1.0) && !Double.isNaN(value)) { throw new IllegalArgumentException("Probability value must be " - + "between 0.0 and 1.0 or Double.NaN."); + + "between 0.0 and 1.0 or Double.NaN."); } - this.probs[nodeIndex][rowIndex][colIndex] = value; + if (useProbMatrices) { + probMatrices[nodeIndex].set(rowIndex, colIndex, value); + } else { + this.probs[nodeIndex][rowIndex][colIndex] = value; + } + +// this.probs[nodeIndex][rowIndex][colIndex] = value; } /** @@ -620,7 +665,17 @@ public void clearRow(int nodeIndex, int rowIndex) { */ public void randomizeRow(int nodeIndex, int rowIndex) { int size = getNumColumns(nodeIndex); - this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size); + double[] row = getRandomWeights(size); + + if (useProbMatrices) { + for (int colIndex = 0; colIndex < size; colIndex++) { + probMatrices[nodeIndex].set(rowIndex, colIndex, row[colIndex]); + } + } else { + this.probs[nodeIndex][rowIndex] = row; + } + +// this.probs[nodeIndex][rowIndex] = MlBayesIm.getRandomWeights(size); } /** @@ -647,95 +702,6 @@ public void randomizeTable(int nodeIndex) { } } - private int score(int nodeIndex) { - double[][] p = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; - copy(this.probs[nodeIndex], p); - int num = 0; - - int numRows = getNumRows(nodeIndex); - - for (int r = 0; r < p.length; r++) { - for (int c = 0; c < p[0].length; c++) { - p[r][c] /= numRows; - } - } - - int[] parents = getParents(nodeIndex); - - for (int t = 0; t < parents.length; t++) { - int numParentValues = getParentDim(nodeIndex, t); - int numColumns = getNumColumns(nodeIndex); - - double[][] table = new double[numParentValues][numColumns]; - - for (int childCol = 0; childCol < numColumns; childCol++) { - for (int parentValue = 0; parentValue < numParentValues; parentValue++) { - for (int row = 0; row < numRows; row++) { - if (getParentValues(nodeIndex, row)[t] == parentValue) { - table[parentValue][childCol] += p[row][childCol]; - } - } - } - } - - final double N = 1000.0; - - for (int r = 0; r < table.length; r++) { - for (int c = 0; c < table[0].length; c++) { - table[r][c] *= N; - } - } - - double chisq = 0.0; - - for (int r = 0; r < table.length; r++) { - for (int c = 0; c < table[0].length; c++) { - double _sumRow = sumRow(table, r); - double _sumCol = sumCol(table, c); - double exp = (_sumRow / N) * (_sumCol / N) * N; - double obs = table[r][c]; - chisq += pow(obs - exp, 2) / exp; - } - } - - int dof = (table.length - 1) * (table[0].length - 1); - - ChiSquaredDistribution distribution = new ChiSquaredDistribution(dof); - double prob = 1 - distribution.cumulativeProbability(chisq); - - num += prob < 0.0001 ? 1 : 0; - } - -// return num == parents.length ? -score : 0; - return num; - } - - private double sumCol(double[][] marginals, int j) { - double sum = 0.0; - - for (double[] marginal : marginals) { - sum += marginal[j]; - } - - return sum; - } - - private double sumRow(double[][] marginals, int i) { - double sum = 0.0; - - for (int h = 0; h < marginals[i].length; h++) { - sum += marginals[i][h]; - } - - return sum; - } - - private void copy(double[][] a, double[][] b) { - for (int r = 0; r < a.length; r++) { - System.arraycopy(a[r], 0, b[r], 0, a[r].length); - } - } - /** * Clears the table by clearing all rows for the given node. * @@ -911,8 +877,8 @@ private DataSet simulateTimeSeries(int sampleSize) { if (Double.isNaN(probability)) { throw new IllegalStateException("Some probability " - + "values in the BayesIm are not filled in; " - + "cannot simulate data."); + + "values in the BayesIm are not filled in; " + + "cannot simulate data."); } sum += probability; @@ -972,8 +938,8 @@ private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved, int[ private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved, int[] tiers) { if (dataSet.getNumColumns() != this.nodes.length) { throw new IllegalArgumentException("When rewriting the old data set, " - + "number of variables in data set must equal number of variables " - + "in Bayes net."); + + "number of variables in data set must equal number of variables " + + "in Bayes net."); } int sampleSize = dataSet.getNumRows(); @@ -1225,6 +1191,7 @@ private void initialize(BayesIm oldBayesIm, int initializationMethod) { this.parents = new int[this.nodes.length][]; this.parentDims = new int[this.nodes.length][]; this.probs = new double[this.nodes.length][][]; + this.probMatrices = new Matrix[this.nodes.length]; for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); @@ -1268,10 +1235,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { throw new IllegalArgumentException( "The number of rows in the " - + "conditional probability table for " - + this.nodes[nodeIndex] - + " is greater than 1,000,000 and cannot be " - + "represented."); + + "conditional probability table for " + + this.nodes[nodeIndex] + + " is greater than 1,000,000 and cannot be " + + "represented."); } numRows *= dim; @@ -1281,6 +1248,7 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, this.parentDims[nodeIndex] = dims; this.probs[nodeIndex] = new double[numRows][numCols]; + this.probMatrices[nodeIndex] = new Matrix(numRows, numCols); // Initialize each row. if (initializationMethod == MlBayesIm.RANDOM) { @@ -1312,7 +1280,14 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { int size = getNumColumns(nodeIndex); double[] row = new double[size]; Arrays.fill(row, Double.NaN); - this.probs[nodeIndex][rowIndex] = row; + + if (useProbMatrices) { + probMatrices[nodeIndex].assignRow(rowIndex, new Vector(row)); + } else { + this.probs[nodeIndex][rowIndex] = row; + } +// this.probs[nodeIndex][rowIndex] = row; +// this.probMatrices[nodeIndex].assignRow(rowIndex, new Vector(row)); } /** @@ -1378,7 +1353,7 @@ private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, } if (oldParentIndex == -1 - || oldParentIndex >= oldBayesIm.getNumParents(oldNodeIndex)) { + || oldParentIndex >= oldBayesIm.getNumParents(oldNodeIndex)) { return -1; } @@ -1448,8 +1423,8 @@ private void copyValuesFromOldToNew(int oldNodeIndex, int oldRowIndex, int nodeIndex, int rowIndex, BayesIm oldBayesIm) { if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { throw new IllegalArgumentException("It's only possible to copy " - + "one row of probability values to another in a Bayes IM " - + "if the number of columns in the table are the same."); + + "one row of probability values to another in a Bayes IM " + + "if the number of columns in the table are the same."); } for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { @@ -1492,8 +1467,12 @@ private void readObject(ObjectInputStream s) throw new NullPointerException(); } - if (this.probs == null) { - throw new NullPointerException(); + if (this.probs != null) { + for (int i = 0; i < this.nodes.length; i++) { + if (useProbMatrices) { + probMatrices[i] = new Matrix(probs[i]); + } + } } } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java new file mode 100644 index 0000000000..8dcc09abcb --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java @@ -0,0 +1,1499 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.data.*; +import edu.cmu.tetrad.graph.Graph; +import edu.cmu.tetrad.graph.Node; +import edu.cmu.tetrad.graph.Paths; +import edu.cmu.tetrad.graph.TimeLagGraph; +import edu.cmu.tetrad.util.NumberFormatUtil; +import edu.cmu.tetrad.util.RandomUtil; +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serial; +import java.text.NumberFormat; +import java.util.*; + +import static org.apache.commons.math3.util.FastMath.abs; +import static org.apache.commons.math3.util.FastMath.pow; + +/** + * Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate + * this table. The division of labor is as follows. The Dag is responsible for manipulating the basic graphical + * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there are no method + * in either BayesPm or BayesIm to do this. BayesPm stores and manipulates the *categories* of each node in a DAG, + * considered as a variable in a Bayes net. The number of categories for a variable can be changed there as well as the + * names for those categories. This class, BayesIm, stores the actual probability tables which are implied by the + * structures in the other two classes. The implied parameters take the form of conditional probabilities--e.g., + * P(N=v0|P1=v1, P2=v2, ...), for all nodes and all combinations of their parent categories. The set of all such + * probabilities is organized in this class as a three-dimensional table of double values. The first dimension + * corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of + * combinations of parent categories for that node. The third dimension corresponds to the list of categories for that + * node itself. Two methods allow these values to be set and retrieved:

    + * To determine the index of the node in question, use the method To determine + * the index of the row in question, use the method + * To determine the order of the + * parent values for a given node so that you can build the parentVals[] array, + * use the method To determine the + * index of a category, use the method in BayesPm. The rest of the methods in this class are easily understood + * as variants of the methods above. + *

    + * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for + * advice and earlier versions. + * + * @author josephramsey + * @version $Id: $Id + */ +public final class MlBayesImOld implements BayesIm { + + /** + * Inidicates that new rows in this BayesIm should be initialized as unknowns, forcing them to be specified + * manually. This is the default. + */ + public static final int MANUAL = 0; + /** + * Indicates that new rows in this BayesIm should be initialized randomly. + */ + public static final int RANDOM = 1; + @Serial + private static final long serialVersionUID = 23L; + + /** + * Tolerance. + */ + private static final double ALLOWABLE_DIFFERENCE = 1.0e-3; + + /** + * Random number generator. + */ + static private final Random random = new Random(); + + /** + * The associated Bayes PM model. + */ + private final BayesPm bayesPm; + /** + * The array of nodes from the graph. Order is important. + */ + private final Node[] nodes; + /** + * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', + * and order in subarrays is important. + */ + private int[][] parents; + /** + * The array of dimensionality (number of categories for each node) for each of the subarrays of 'parents'. + */ + private int[][] parentDims; + + //===============================CONSTRUCTORS=========================// + /** + * The main data structure; stores the values of all of the conditional probabilities for the Bayes net of the form + * P(N=v0 | P1=v1, P2=v2,...). The first dimension is the node N, in the order of 'nodes'. The second dimension is + * the row index for the table of parameters associated with node N; the third dimension is the column index. The + * row index is calculated by the function getRowIndex(int[] values) where 'values' is an array of numerical indices + * for each of the parent values; the order of the values in this array is the same as the order of node in + * 'parents'; the value indices are obtained from the Bayes PM for each node. The column is the index of the value + * of N, where this index is obtained from the Bayes PM. + * + * @serial + */ + private double[][][] probs; + + /** + * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). + * + * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. + * @throws IllegalArgumentException if the array of nodes provided is not a permutation of the nodes + * contained in the bayes parametric model provided. + */ + public MlBayesImOld(BayesPm bayesPm) throws IllegalArgumentException { + this(bayesPm, null, MlBayesImOld.MANUAL); + } + + /** + * Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM. If initialized + * manually, all values will be set to Double.NaN ("?") in each row; if initialized randomly, all values will + * distribute randomly in each row. + * + * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. + * @param initializationMethod either MANUAL or RANDOM. + * @throws IllegalArgumentException if the array of nodes provided is not a permutation of the nodes + * contained in the bayes parametric model provided. + */ + public MlBayesImOld(BayesPm bayesPm, int initializationMethod) + throws IllegalArgumentException { + this(bayesPm, null, initializationMethod); + } + + /** + * Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values + * from the old BayesIm provided where posssible. If initialized manually, all values that cannot be retrieved from + * oldBayesIm will be set to Double.NaN ("?") in each such row; if initialized randomly, all values that cannot be + * retrieved from oldBayesIm will distributed randomly in each such row. + * + * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. + * @param oldBayesIm an already-constructed BayesIm whose values may be used where possible to initialize + * this BayesIm. May be null. + * @param initializationMethod either MANUAL or RANDOM. + * @throws IllegalArgumentException if the array of nodes provided is not a permutation of the nodes + * contained in the bayes parametric model provided. + */ + public MlBayesImOld(BayesPm bayesPm, BayesIm oldBayesIm, + int initializationMethod) throws IllegalArgumentException { + if (bayesPm == null) { + throw new NullPointerException("BayesPm must not be null."); + } + + this.bayesPm = new BayesPm(bayesPm); + + // Get the nodes from the BayesPm. This fixes the order of the nodes + // in the BayesIm, independently of any change to the BayesPm. + // (This order must be maintained.) + Graph graph = bayesPm.getDag(); + this.nodes = graph.getNodes().toArray(new Node[0]); + + // Initialize. + initialize(oldBayesIm, initializationMethod); + } + + /** + * Copy constructor. + * + * @param bayesIm a {@link BayesIm} object + * @throws IllegalArgumentException if any. + */ + public MlBayesImOld(BayesIm bayesIm) throws IllegalArgumentException { + if (bayesIm == null) { + throw new NullPointerException("BayesIm must not be null."); + } + + this.bayesPm = bayesIm.getBayesPm(); + + // Get the nodes from the BayesPm, fixing on an order. (This is + // important; the nodes must always be in the same order for this + // BayesIm.) + this.nodes = new Node[bayesIm.getNumNodes()]; + + for (int i = 0; i < bayesIm.getNumNodes(); i++) { + this.nodes[i] = bayesIm.getNode(i); + } + + // Copy all the old values over. + initialize(bayesIm, MlBayesImOld.MANUAL); + } + + /** + * Generates a simple exemplar of this class to test serialization. + * + * @return a {@link MlBayesImOld} object + */ + public static MlBayesImOld serializableInstance() { + return new MlBayesImOld(BayesPm.serializableInstance()); + } + + //===============================PUBLIC METHODS========================// + + /** + *

    getParameterNames.

    + * + * @return a {@link List} object + */ + public static List getParameterNames() { + return new ArrayList<>(); + } + + private static double[] getRandomWeights(int size) { + assert size > 0; + + double[] row = new double[size]; + double sum = 0.0; + + int strong = (int) Math.floor(random.nextDouble() * size); + + for (int i = 0; i < size; i++) { + if (i == strong) { + row[i] = 1.0; + } else { + row[i] = RandomUtil.getInstance().nextDouble() * 0.3; + } + + sum += row[i]; + } + + for (int i = 0; i < size; i++) { + row[i] /= sum; + } + + return row; + } + + /** + *

    Getter for the field bayesPm.

    + * + * @return this PM. + */ + public BayesPm getBayesPm() { + return this.bayesPm; + } + + /** + *

    getDag.

    + * + * @return the DAG. + */ + public Graph getDag() { + return this.bayesPm.getDag(); + } + + /** + *

    getNumNodes.

    + * + * @return the number of nodes in the model. + */ + public int getNumNodes() { + return this.nodes.length; + } + + /** + * Retrieves the node at the specified index. + * + * @param nodeIndex the index of the node. + * @return the node at the specified index. + */ + public Node getNode(int nodeIndex) { + return this.nodes[nodeIndex]; + } + + /** + *

    getNode.

    + * + * @param name the name of the node. + * @return the node. + */ + public Node getNode(String name) { + return getDag().getNode(name); + } + + /** + * Returns the index of the given node in the nodes array. + * + * @param node the given node. + * @return the index of the node in the nodes array, or -1 if the node is not found. + */ + public int getNodeIndex(Node node) { + for (int i = 0; i < this.nodes.length; i++) { + if (node == this.nodes[i]) { + return i; + } + } + + return -1; + } + + /** + *

    getVariables.

    + * + * @return a {@link List} object + */ + public List getVariables() { + List variables = new LinkedList<>(); + + for (int i = 0; i < getNumNodes(); i++) { + Node node = getNode(i); + variables.add(this.bayesPm.getVariable(node)); + } + + return variables; + } + + /** + *

    getMeasuredNodes.

    + * + * @return the list of measured variableNodes. + */ + public List getMeasuredNodes() { + return this.bayesPm.getMeasuredNodes(); + } + + /** + *

    getVariableNames.

    + * + * @return a {@link List} object + */ + public List getVariableNames() { + List variableNames = new LinkedList<>(); + + for (int i = 0; i < getNumNodes(); i++) { + Node node = getNode(i); + variableNames.add(this.bayesPm.getVariable(node).getName()); + } + + return variableNames; + } + + /** + * Returns the number of columns in the specified node. + * + * @param nodeIndex the index of the node. + * @return the number of columns. + */ + public int getNumColumns(int nodeIndex) { + return this.probs[nodeIndex][0].length; + } + + /** + * Retrieves the number of rows in the specified node. + * + * @param nodeIndex the index of the node. + * @return the number of rows in the node. + */ + public int getNumRows(int nodeIndex) { + return this.probs[nodeIndex].length; + } + + /** + * Returns the number of parents for the given node. + * + * @param nodeIndex the index of the node. + * @return the number of parents. + */ + public int getNumParents(int nodeIndex) { + return this.parents[nodeIndex].length; + } + + /** + * Retrieves the parent of a node at the specified index. + * + * @param nodeIndex the index of the node. + * @param parentIndex the index of the parent. + * @return the parent of the node. + */ + public int getParent(int nodeIndex, int parentIndex) { + return this.parents[nodeIndex][parentIndex]; + } + + /** + * Retrieves the value of the parent dimension for a given node and parent index. + * + * @param nodeIndex the index of the node. + * @param parentIndex the index of the parent. + * @return the parent dimension value. + */ + public int getParentDim(int nodeIndex, int parentIndex) { + return this.parentDims[nodeIndex][parentIndex]; + } + + /** + * Returns a copy of the dimensions of the parent node at the specified index. + * + * @param nodeIndex the index of the node. + * @return an array containing the dimensions of the parent node. + */ + public int[] getParentDims(int nodeIndex) { + int[] dims = this.parentDims[nodeIndex]; + int[] copy = new int[dims.length]; + System.arraycopy(dims, 0, copy, 0, dims.length); + return copy; + } + + /** + * Returns an array containing the parents of the specified node. + * + * @param nodeIndex the index of the node. + * @return an array of integers representing the parents of the specified node. + */ + public int[] getParents(int nodeIndex) { + int[] nodeParents = this.parents[nodeIndex]; + int[] copy = new int[nodeParents.length]; + System.arraycopy(nodeParents, 0, copy, 0, nodeParents.length); + return copy; + } + + /** + * Returns an integer array containing the parent values for a given node index and row index. + * + * @param nodeIndex the index of the node. + * @param rowIndex the index of the row in question. + * @return an integer array containing the parent values. + */ + public int[] getParentValues(int nodeIndex, int rowIndex) { + int[] dims = getParentDims(nodeIndex); + int[] values = new int[dims.length]; + + for (int i = dims.length - 1; i >= 0; i--) { + values[i] = rowIndex % dims[i]; + rowIndex /= dims[i]; + } + + return values; + } + + /** + * Retrieves the value of the parent node at the specified row and column index. + * + * @param nodeIndex the index of the node. + * @param rowIndex the index of the row in question. + * @param colIndex the index of the column in question. + * @return the value of the parent node at the specified row and column index. + */ + public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { + return getParentValues(nodeIndex, rowIndex)[colIndex]; + } + + /** + * Returns the probability for a given node in the table. + * + * @param nodeIndex the index of the node in question. + * @param rowIndex the row in the table for this node which represents the combination of parent values in + * question. + * @param colIndex the column in the table for this node which represents the value of the node in question. + * @return the probability value for the given node. + */ + public double getProbability(int nodeIndex, int rowIndex, int colIndex) { + return this.probs[nodeIndex][rowIndex][colIndex]; + } + + /** + * Returns the row index corresponding to the given node index and combination of parent values. + * + * @param nodeIndex the index of the node in question. + * @param values the combination of parent values in question. + * @return the row index corresponding to the given node index and combination of parent values. + */ + public int getRowIndex(int nodeIndex, int[] values) { + int[] dim = getParentDims(nodeIndex); + int rowIndex = 0; + + for (int i = 0; i < dim.length; i++) { + rowIndex *= dim[i]; + rowIndex += values[i]; + } + + return rowIndex; + } + + /** + * Normalizes all rows in the tables associated with each of node in turn. + */ + public void normalizeAll() { + for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { + normalizeNode(nodeIndex); + } + } + + /** + * Normalizes the specified node by invoking the {@link #normalizeRow(int, int)} method on each row of the node. + * + * @param nodeIndex the index of the node to be normalized. + */ + public void normalizeNode(int nodeIndex) { + for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { + normalizeRow(nodeIndex, rowIndex); + } + } + + /** + * Normalizes the probabilities of a given row in a node. + * + * @param nodeIndex the index of the node in question. + * @param rowIndex the index of the row in question. + */ + public void normalizeRow(int nodeIndex, int rowIndex) { + int numColumns = getNumColumns(nodeIndex); + double total = 0.0; + + for (int colIndex = 0; colIndex < numColumns; colIndex++) { + total += getProbability(nodeIndex, rowIndex, colIndex); + } + + if (total != 0.0) { + for (int colIndex = 0; colIndex < numColumns; colIndex++) { + double probability + = getProbability(nodeIndex, rowIndex, colIndex); + double prob = probability / total; + setProbability(nodeIndex, rowIndex, colIndex, prob); + } + } else { + double prob = 1.0 / numColumns; + + for (int colIndex = 0; colIndex < numColumns; colIndex++) { + setProbability(nodeIndex, rowIndex, colIndex, prob); + } + } + } + + /** + * Sets the probability for the given node. The matrix row represent row index, the row in the table for this for + * node which represents the combination of parent values in question. of the CPT. The matrix column represent + * column index, the column in the table for this node which represents the value of the node in question. + * + * @param nodeIndex The index of the node. + * @param probMatrix The matrix of probabilities. + */ + @Override + public void setProbability(int nodeIndex, double[][] probMatrix) { + for (int i = 0; i < probMatrix.length; i++) { + System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); + } + } + + /** + * Sets the probability value for a specific node, row, and column in the probability table. + * + * @param nodeIndex the index of the node in question. + * @param rowIndex the row in the table for this node which represents the combination of parent values in + * question. + * @param colIndex the column in the table for this node which represents the value of the node in question. + * @param value the desired probability to be set. Must be between 0.0 and 1.0, or Double.NaN. + * @throws IllegalArgumentException if the column index is out of range for the given node, or if the probability + * value is not between 0.0 and 1.0 or Double.NaN. + */ + public void setProbability(int nodeIndex, int rowIndex, int colIndex, + double value) { + if (colIndex >= getNumColumns(nodeIndex)) { + throw new IllegalArgumentException("Column out of range: " + + colIndex + " >= " + getNumColumns(nodeIndex)); + } + + if (!(0.0 <= value && value <= 1.0) && !Double.isNaN(value)) { + throw new IllegalArgumentException("Probability value must be " + + "between 0.0 and 1.0 or Double.NaN."); + } + + this.probs[nodeIndex][rowIndex][colIndex] = value; + } + + /** + * Returns the corresponding node index in the given BayesIm based on the node index in this BayesIm. + * + * @param nodeIndex the index of the node in this BayesIm. + * @param otherBayesIm the BayesIm in which the node is to be found. + * @return the corresponding node index in the given BayesIm. + */ + public int getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm) { + String nodeName = getNode(nodeIndex).getName(); + Node oldNode = otherBayesIm.getNode(nodeName); + return otherBayesIm.getNodeIndex(oldNode); + } + + /** + * Clears all values in the specified row of a table. + * + * @param nodeIndex the index of the node for the table that this row belongs to + * @param rowIndex the index of the row to be cleared + */ + public void clearRow(int nodeIndex, int rowIndex) { + for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { + setProbability(nodeIndex, rowIndex, colIndex, Double.NaN); + } + } + + /** + * Randomizes the values of a row in a table for a given node. + * + * @param nodeIndex the index of the node for the table that this row belongs to. + * @param rowIndex the index of the row to be randomized. + */ + public void randomizeRow(int nodeIndex, int rowIndex) { + int size = getNumColumns(nodeIndex); + this.probs[nodeIndex][rowIndex] = MlBayesImOld.getRandomWeights(size); + } + + /** + * Randomizes the incomplete rows in the specified node's table. + * + * @param nodeIndex the index of the node for the table whose incomplete rows are to be randomized + */ + public void randomizeIncompleteRows(int nodeIndex) { + for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { + if (isIncomplete(nodeIndex, rowIndex)) { + randomizeRow(nodeIndex, rowIndex); + } + } + } + + /** + * Randomizes the table for a given node. + * + * @param nodeIndex the index of the node for the table to be randomized + */ + public void randomizeTable(int nodeIndex) { + for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { + randomizeRow(nodeIndex, rowIndex); + } + } + + private int score(int nodeIndex) { + double[][] p = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; + copy(this.probs[nodeIndex], p); + int num = 0; + + int numRows = getNumRows(nodeIndex); + + for (int r = 0; r < p.length; r++) { + for (int c = 0; c < p[0].length; c++) { + p[r][c] /= numRows; + } + } + + int[] parents = getParents(nodeIndex); + + for (int t = 0; t < parents.length; t++) { + int numParentValues = getParentDim(nodeIndex, t); + int numColumns = getNumColumns(nodeIndex); + + double[][] table = new double[numParentValues][numColumns]; + + for (int childCol = 0; childCol < numColumns; childCol++) { + for (int parentValue = 0; parentValue < numParentValues; parentValue++) { + for (int row = 0; row < numRows; row++) { + if (getParentValues(nodeIndex, row)[t] == parentValue) { + table[parentValue][childCol] += p[row][childCol]; + } + } + } + } + + final double N = 1000.0; + + for (int r = 0; r < table.length; r++) { + for (int c = 0; c < table[0].length; c++) { + table[r][c] *= N; + } + } + + double chisq = 0.0; + + for (int r = 0; r < table.length; r++) { + for (int c = 0; c < table[0].length; c++) { + double _sumRow = sumRow(table, r); + double _sumCol = sumCol(table, c); + double exp = (_sumRow / N) * (_sumCol / N) * N; + double obs = table[r][c]; + chisq += pow(obs - exp, 2) / exp; + } + } + + int dof = (table.length - 1) * (table[0].length - 1); + + ChiSquaredDistribution distribution = new ChiSquaredDistribution(dof); + double prob = 1 - distribution.cumulativeProbability(chisq); + + num += prob < 0.0001 ? 1 : 0; + } + +// return num == parents.length ? -score : 0; + return num; + } + + private double sumCol(double[][] marginals, int j) { + double sum = 0.0; + + for (double[] marginal : marginals) { + sum += marginal[j]; + } + + return sum; + } + + private double sumRow(double[][] marginals, int i) { + double sum = 0.0; + + for (int h = 0; h < marginals[i].length; h++) { + sum += marginals[i][h]; + } + + return sum; + } + + private void copy(double[][] a, double[][] b) { + for (int r = 0; r < a.length; r++) { + System.arraycopy(a[r], 0, b[r], 0, a[r].length); + } + } + + /** + * Clears the table by clearing all rows for the given node. + * + * @param nodeIndex The index of the node for the table to be cleared. + */ + public void clearTable(int nodeIndex) { + for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { + clearRow(nodeIndex, rowIndex); + } + } + + /** + * Checks if the specified row of a table is incomplete, i.e., if any of the columns have a NaN value. + * + * @param nodeIndex the index of the table node to check. + * @param rowIndex the index of the row to check. + * @return true if the row is incomplete, false otherwise. + */ + public boolean isIncomplete(int nodeIndex, int rowIndex) { + for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { + double p = getProbability(nodeIndex, rowIndex, colIndex); + + if (Double.isNaN(p)) { + return true; + } + } + + return false; + } + + /** + * Checks if the specified table has any incomplete rows. + * + * @param nodeIndex the index of the node for the table + * @return true if the table has any incomplete rows, false otherwise + */ + public boolean isIncomplete(int nodeIndex) { + for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { + if (isIncomplete(nodeIndex, rowIndex)) { + return true; + } + } + + return false; + } + + /** + * Simulates a sample with the given sample size. + * + * @param sampleSize the sample size. + * @param latentDataSaved a boolean + * @param tiers an array of {@link int} objects + * @return the simulated sample as a DataSet. + */ + public DataSet simulateData(int sampleSize, boolean latentDataSaved, int[] tiers) { + if (getBayesPm().getDag().isTimeLagModel()) { + return simulateTimeSeries(sampleSize); + } + + return simulateDataHelper(sampleSize, latentDataSaved, tiers); + } + + /** + * Simulates a data set. + * + * @param sampleSize The number of rows to simulate. + * @param latentDataSaved If set to true, latent variables are saved in the data set. + * @return The simulated data set. + * @throws IllegalArgumentException If the graph contains a directed cycle. + */ + public DataSet simulateData(int sampleSize, boolean latentDataSaved) { + if (getBayesPm().getDag().isTimeLagModel()) { + return simulateTimeSeries(sampleSize); + } + + // Get a tier ordering and convert it to an int array. + Graph graph = getBayesPm().getDag(); + + if (graph.paths().existsDirectedCycle()) { + throw new IllegalArgumentException("Graph must be acyclic to simulate from discrete Bayes net."); + } + + Paths paths = graph.paths(); + List initialOrder = graph.getNodes(); + List tierOrdering = paths.getValidOrder(initialOrder, true); + int[] tiers = new int[tierOrdering.size()]; + + for (int i = 0; i < tierOrdering.size(); i++) { + tiers[i] = getNodeIndex(tierOrdering.get(i)); + } + + return simulateDataHelper(sampleSize, latentDataSaved, tiers); + } + + /** + *

    simulateData.

    + * + * @param dataSet a {@link DataSet} object + * @param latentDataSaved a boolean + * @param tiers an array of {@link int} objects + * @return a {@link DataSet} object + */ + public DataSet simulateData(DataSet dataSet, boolean latentDataSaved, int[] tiers) { + return simulateDataHelper(dataSet, latentDataSaved, tiers); + } + + /** + * Simulates data for the given data set. + * + * @param dataSet The data set to simulate data for. + * @param latentDataSaved Indicates whether latent data should be saved during simulation. + * @return The modified data set after simulating the data. + */ + public DataSet simulateData(DataSet dataSet, boolean latentDataSaved) { + // Get a tier ordering and convert it to an int array. + Graph graph = getBayesPm().getDag(); + Paths paths = graph.paths(); + List initialOrder = graph.getNodes(); + List tierOrdering = paths.getValidOrder(initialOrder, true); + int[] tiers = new int[tierOrdering.size()]; + + for (int i = 0; i < tierOrdering.size(); i++) { + tiers[i] = getNodeIndex(tierOrdering.get(i)); + } + + return simulateDataHelper(dataSet, latentDataSaved, tiers); + } + + private DataSet simulateTimeSeries(int sampleSize) { + TimeLagGraph timeSeriesGraph = getBayesPm().getDag().getTimeLagGraph(); + + List variables = new ArrayList<>(); + + for (Node node : timeSeriesGraph.getLag0Nodes()) { + DiscreteVariable e = new DiscreteVariable(timeSeriesGraph.getNodeId(node).getName()); + e.setNodeType(node.getNodeType()); + variables.add(e); + } + + List lag0Nodes = timeSeriesGraph.getLag0Nodes(); + +// DataSet fullData = new ColtDataSet(sampleSize, variables); + DataSet fullData = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables); + + Graph contemporaneousDag = timeSeriesGraph.subgraph(lag0Nodes); + Paths paths = contemporaneousDag.paths(); + List initialOrder = contemporaneousDag.getNodes(); + List tierOrdering = paths.getValidOrder(initialOrder, true); + int[] tiers = new int[tierOrdering.size()]; + + for (int i = 0; i < tierOrdering.size(); i++) { + tiers[i] = getNodeIndex(tierOrdering.get(i)); + } + + // Construct the sample. + int[] combination = new int[tierOrdering.size()]; + + for (int i = 0; i < sampleSize; i++) { + int[] point = new int[this.nodes.length]; + + for (int nodeIndex : tiers) { + double cutoff = RandomUtil.getInstance().nextDouble(); + + for (int k = 0; k < getNumParents(nodeIndex); k++) { + combination[k] = point[getParent(nodeIndex, k)]; + } + + int rowIndex = getRowIndex(nodeIndex, combination); + double sum = 0.0; + + for (int k = 0; k < getNumColumns(nodeIndex); k++) { + double probability = getProbability(nodeIndex, rowIndex, k); + + if (Double.isNaN(probability)) { + throw new IllegalStateException("Some probability " + + "values in the BayesIm are not filled in; " + + "cannot simulate data."); + } + + sum += probability; + + if (sum >= cutoff) { + point[nodeIndex] = k; + break; + } + } + } + } + + return fullData; + } + + /** + * Simulates a sample with the given sample size. + * + * @param sampleSize the sample size. + * @return the simulated sample as a DataSet. + */ + private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved, int[] tiers) { + int numMeasured = 0; + int[] map = new int[this.nodes.length]; + List variables = new LinkedList<>(); + + for (int j = 0; j < this.nodes.length; j++) { + + int numCategories = this.bayesPm.getNumCategories(this.nodes[j]); + List categories = new LinkedList<>(); + + for (int k = 0; k < numCategories; k++) { + categories.add(this.bayesPm.getCategory(this.nodes[j], k)); + } + + DiscreteVariable var + = new DiscreteVariable(this.nodes[j].getName(), categories); + var.setNodeType(this.nodes[j].getNodeType()); + variables.add(var); + int index = ++numMeasured - 1; + map[index] = j; + } + + DataSet dataSet = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables); + constructSample(sampleSize, dataSet, map, tiers); + + if (!latentDataSaved) { + dataSet = DataTransforms.restrictToMeasured(dataSet); + } + + return dataSet; + } + + /** + * Constructs a random sample using the given already allocated data set, to avoid allocating more memory. + */ + private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved, int[] tiers) { + if (dataSet.getNumColumns() != this.nodes.length) { + throw new IllegalArgumentException("When rewriting the old data set, " + + "number of variables in data set must equal number of variables " + + "in Bayes net."); + } + + int sampleSize = dataSet.getNumRows(); + + int numVars = 0; + int[] map = new int[this.nodes.length]; + List variables = new LinkedList<>(); + + for (int j = 0; j < this.nodes.length; j++) { + + int numCategories = this.bayesPm.getNumCategories(this.nodes[j]); + List categories = new LinkedList<>(); + + for (int k = 0; k < numCategories; k++) { + categories.add(this.bayesPm.getCategory(this.nodes[j], k)); + } + + DiscreteVariable var + = new DiscreteVariable(this.nodes[j].getName(), categories); + var.setNodeType(this.nodes[j].getNodeType()); + variables.add(var); + int index = ++numVars - 1; + map[index] = j; + } + + for (int i = 0; i < variables.size(); i++) { + Node node = dataSet.getVariable(i); + Node _node = variables.get(i); + dataSet.changeVariable(node, _node); + } + + constructSample(sampleSize, dataSet, map, tiers); + + if (latentDataSaved) { + return dataSet; + } else { + return DataTransforms.restrictToMeasured(dataSet); + } + } + + private void constructSample(int sampleSize, DataSet dataSet, int[] map, int[] tiers) { + +// //Do the simulation. +// class SimulationTask extends RecursiveTask { +// private int chunk; +// private int from; +// private int to; +// private int[] tiers; +// private DataSet dataSet; +// private int[] map; +// +// public SimulationTask(int chunk, int from, int to, int[] tiers, DataSet dataSet, int[] map) { +// this.chunk = chunk; +// this.from = from; +// this.to = to; +// this.tiers = tiers; +// this.dataSet = dataSet; +// this.map = map; +// } +// +// @Override +// protected Boolean compute() { +// if (to - from <= chunk) { +// RandomGenerator randomGenerator = new Well1024a(++seed[0]); +// +// for (int row = from; row < to; row++) { +// for (int t : tiers) { +// int[] parentValues = new int[parents[t].length]; +// +// for (int k = 0; k < parentValues.length; k++) { +// parentValues[k] = dataSet.getInt(row, parents[t][k]); +// } +// +// int rowIndex = getRowIndex(t, parentValues); +// double sum = 0.0; +// double r; +// +// r = randomGenerator.nextDouble(); +// +// for (int k = 0; k < getNumColumns(t); k++) { +// double probability = getProbability(t, rowIndex, k); +// sum += probability; +// +// if (sum >= r) { +// dataSet.setInt(row, map[t], k); +// break; +// } +// } +// } +// } +// +// return true; +// } else { +// int mid = (to + from) / 2; +// SimulationTask left = new SimulationTask(chunk, from, mid, tiers, dataSet, map); +// SimulationTask right = new SimulationTask(chunk, mid, to, tiers, dataSet, map); +// +// left.fork(); +// right.compute(); +// left.join(); +// +// return true; +// } +// } +// } +// +// int chunk = 25; +// +// ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); +// SimulationTask task = new SimulationTask(chunk, 0, sampleSize, tiers, dataSet, map); +// pool.invoke(task); + // Construct the sample. + for (int i = 0; i < sampleSize; i++) { + for (int t : tiers) { + int[] parentValues = new int[this.parents[t].length]; + + for (int k = 0; k < parentValues.length; k++) { + parentValues[k] = dataSet.getInt(i, this.parents[t][k]); + } + + int rowIndex = getRowIndex(t, parentValues); + double sum = 0.0; + + double r = RandomUtil.getInstance().nextDouble(); + + for (int k = 0; k < getNumColumns(t); k++) { + double probability = getProbability(t, rowIndex, k); + sum += probability; + + if (sum >= r) { + dataSet.setInt(i, map[t], k); + break; + } + } + } + } + +// System.out.println(dataSet); + } + + /** + * Determines whether the specified object is equal to this Bayes net. + * + * @param o the object to be compared to this Bayes net + * @return true if the specified object is equal to this Bayes net, false otherwise + */ + public boolean equals(Object o) { + if (o == this) { + return true; + } + + if (!(o instanceof BayesIm otherIm)) { + return false; + } + + if (getNumNodes() != otherIm.getNumNodes()) { + return false; + } + + for (int i = 0; i < getNumNodes(); i++) { + int otherIndex = otherIm.getCorrespondingNodeIndex(i, otherIm); + + if (otherIndex == -1) { + return false; + } + + if (getNumColumns(i) != otherIm.getNumColumns(otherIndex)) { + return false; + } + + if (getNumRows(i) != otherIm.getNumRows(otherIndex)) { + return false; + } + + for (int j = 0; j < getNumRows(i); j++) { + for (int k = 0; k < getNumColumns(i); k++) { + double prob = getProbability(i, j, k); + double otherProb = otherIm.getProbability(i, j, k); + + if (Double.isNaN(prob) && Double.isNaN(otherProb)) { + continue; + } + + if (abs(prob - otherProb) > MlBayesImOld.ALLOWABLE_DIFFERENCE) { + return false; + } + } + } + } + + return true; + } + + //=============================PRIVATE METHODS=======================// + + /** + * Prints out the probability table for each variable. + * + * @return a {@link String} object + */ + public String toString() { + StringBuilder buf = new StringBuilder(); + NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat(); + + for (int i = 0; i < getNumNodes(); i++) { + buf.append("\n\nNode: ").append(getNode(i)); + + if (getNumParents(i) == 0) { + buf.append("\n"); + } else { + buf.append("\n\n"); + for (int k = 0; k < getNumParents(i); k++) { + buf.append(getNode(getParent(i, k))).append("\t"); + } + } + + for (int j = 0; j < getNumRows(i); j++) { + buf.append("\n"); + for (int k = 0; k < getNumParents(i); k++) { + buf.append(getParentValue(i, j, k)); + + if (k < getNumParents(i) - 1) { + buf.append("\t"); + } + } + + if (getNumParents(i) > 0) { + buf.append("\t"); + } + + for (int k = 0; k < getNumColumns(i); k++) { + buf.append(nf.format(getProbability(i, j, k))).append("\t"); + } + } + } + + buf.append("\n"); + + return buf.toString(); + } + + /** + * This method initializes the probability tables for all of the nodes in the Bayes net. + * + * @see #initializeNode + * @see #randomizeRow + */ + private void initialize(BayesIm oldBayesIm, int initializationMethod) { + this.parents = new int[this.nodes.length][]; + this.parentDims = new int[this.nodes.length][]; + this.probs = new double[this.nodes.length][][]; + + for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { + initializeNode(nodeIndex, oldBayesIm, initializationMethod); + } + } + + /** + * This method initializes the node indicated. + */ + private void initializeNode(int nodeIndex, BayesIm oldBayesIm, + int initializationMethod) { + Node node = this.nodes[nodeIndex]; + + // Set up parents array. Should store the parents of + // each node as ints in a particular order. + Graph graph = getBayesPm().getDag(); + List parentList = new ArrayList<>(graph.getParents(node)); + int[] parentArray = new int[parentList.size()]; + + for (int i = 0; i < parentList.size(); i++) { + parentArray[i] = getNodeIndex(parentList.get(i)); + } + + // Sort parent array. + Arrays.sort(parentArray); + + this.parents[nodeIndex] = parentArray; + + // Setup dimensions array for parents. + int[] dims = new int[parentArray.length]; + + for (int i = 0; i < dims.length; i++) { + Node parNode = this.nodes[parentArray[i]]; + dims[i] = getBayesPm().getNumCategories(parNode); + } + + // Calculate dimensions of table. + int numRows = 1; + + for (int dim : dims) { + if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { + throw new IllegalArgumentException( + "The number of rows in the " + + "conditional probability table for " + + this.nodes[nodeIndex] + + " is greater than 1,000,000 and cannot be " + + "represented."); + } + + numRows *= dim; + } + + int numCols = getBayesPm().getNumCategories(node); + + this.parentDims[nodeIndex] = dims; + this.probs[nodeIndex] = new double[numRows][numCols]; + + // Initialize each row. + if (initializationMethod == MlBayesImOld.RANDOM) { + randomizeTable(nodeIndex); + } else { + for (int rowIndex = 0; rowIndex < numRows; rowIndex++) { + if (oldBayesIm == null) { + overwriteRow(nodeIndex, rowIndex, initializationMethod); + } else { + retainOldRowIfPossible(nodeIndex, rowIndex, oldBayesIm, + initializationMethod); + } + } + } + } + + private void overwriteRow(int nodeIndex, int rowIndex, + int initializationMethod) { + if (initializationMethod == MlBayesImOld.RANDOM) { + randomizeRow(nodeIndex, rowIndex); + } else if (initializationMethod == MlBayesImOld.MANUAL) { + initializeRowAsUnknowns(nodeIndex, rowIndex); + } else { + throw new IllegalArgumentException("Unrecognized state."); + } + } + + private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { + int size = getNumColumns(nodeIndex); + double[] row = new double[size]; + Arrays.fill(row, Double.NaN); + this.probs[nodeIndex][rowIndex] = row; + } + + /** + * This method initializes the node indicated. + */ + private void retainOldRowIfPossible(int nodeIndex, int rowIndex, + BayesIm oldBayesIm, int initializationMethod) { + + int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); + + if (oldNodeIndex == -1) { + overwriteRow(nodeIndex, rowIndex, initializationMethod); + } else if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { + overwriteRow(nodeIndex, rowIndex, initializationMethod); +// } else if (parentsChanged(nodeIndex, this, oldBayesIm)) { +// overwriteRow(nodeIndex, rowIndex, initializationMethod); + } else { + int oldRowIndex = getUniqueCompatibleOldRow(nodeIndex, rowIndex, oldBayesIm); + + if (oldRowIndex >= 0) { + copyValuesFromOldToNew(oldNodeIndex, oldRowIndex, nodeIndex, + rowIndex, oldBayesIm); + } else { + overwriteRow(nodeIndex, rowIndex, initializationMethod); + } + } + } + + /** + * @return the unique rowIndex in the old BayesIm for the given node that is compatible with the given rowIndex in + * the new BayesIm for that node, if one exists. Otherwise, returns -1. A compatible rowIndex is one in which all + * the parents that the given node has in common between the old BayesIm and the new BayesIm are assigned the values + * they have in the new rowIndex. If a parent node is removed in the new BayesIm, there may be more than one such + * compatible rowIndex in the old BayesIm, in which case -1 is returned. Likewise, there may be no compatible rows, + * in which case -1 is returned. + */ + private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, + BayesIm oldBayesIm) { + int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); + int oldNumParents = oldBayesIm.getNumParents(oldNodeIndex); + + int[] oldParentValues = new int[oldNumParents]; + Arrays.fill(oldParentValues, -1); + + int[] parentValues = getParentValues(nodeIndex, rowIndex); + + // Go through each parent of the node in the new BayesIm. + for (int i = 0; i < getNumParents(nodeIndex); i++) { + + // Get the index of the parent in the new graph and in the old + // graph. If it's no longer in the new graph, skip to the next + // parent. + int parentNodeIndex = getParent(nodeIndex, i); + int oldParentNodeIndex + = getCorrespondingNodeIndex(parentNodeIndex, oldBayesIm); + int oldParentIndex = -1; + + for (int j = 0; j < oldBayesIm.getNumParents(oldNodeIndex); j++) { + if (oldParentNodeIndex == oldBayesIm.getParent(oldNodeIndex, j)) { + oldParentIndex = j; + break; + } + } + + if (oldParentIndex == -1 + || oldParentIndex >= oldBayesIm.getNumParents(oldNodeIndex)) { + return -1; + } + + // Look up that value index for the new BayesIm for that parent. + // If it was a valid value index in the old BayesIm, record + // that value in oldParentValues. Otherwise return -1. + int newParentValue = parentValues[i]; + int oldParentDim + = oldBayesIm.getParentDim(oldNodeIndex, oldParentIndex); + + if (newParentValue < oldParentDim) { + oldParentValues[oldParentIndex] = newParentValue; + } else { + return -1; + } + } + +// // Go through each parent of the node in the new BayesIm. +// for (int i = 0; i < oldBayesIm.getNumParents(oldNodeIndex); i++) { +// +// // Get the index of the parent in the new graph and in the old +// // graph. If it's no longer in the new graph, skip to the next +// // parent. +// int oldParentNodeIndex = oldBayesIm.getParent(oldNodeIndex, i); +// int parentNodeIndex = +// oldBayesIm.getCorrespondingNodeIndex(oldParentNodeIndex, this); +// int parentIndex = -1; +// +// for (int j = 0; j < this.getNumParents(nodeIndex); j++) { +// if (parentNodeIndex == this.getParent(nodeIndex, j)) { +// parentIndex = j; +// break; +// } +// } +// +// if (parentIndex == -1 || +// parentIndex >= this.getNumParents(nodeIndex)) { +// continue; +// } +// +// // Look up that value index for the new BayesIm for that parent. +// // If it was a valid value index in the old BayesIm, record +// // that value in oldParentValues. Otherwise return -1. +// int parentValue = oldParentValues[i]; +// int parentDim = +// this.getParentDim(nodeIndex, parentIndex); +// +// if (parentValue < parentDim) { +// oldParentValues[parentIndex] = oldParentValue; +// } else { +// return -1; +// } +// } + // If there are any -1's in the combination at this point, return -1. + for (int oldParentValue : oldParentValues) { + if (oldParentValue == -1) { + return -1; + } + } + + // Otherwise, return the combination, which will be a row in the + // old BayesIm. + return oldBayesIm.getRowIndex(oldNodeIndex, oldParentValues); + } + + private void copyValuesFromOldToNew(int oldNodeIndex, int oldRowIndex, + int nodeIndex, int rowIndex, BayesIm oldBayesIm) { + if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { + throw new IllegalArgumentException("It's only possible to copy " + + "one row of probability values to another in a Bayes IM " + + "if the number of columns in the table are the same."); + } + + for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { + double prob = oldBayesIm.getProbability(oldNodeIndex, oldRowIndex, + colIndex); + setProbability(nodeIndex, rowIndex, colIndex, prob); + } + } + + /** + * Adds semantic checks to the default deserialization method. This method must have the standard signature for a + * readObject method, and the body of the method must begin with "s.defaultReadObject();". Other than that, any + * semantic checks can be specified and do not need to stay the same from version to version. A readObject method of + * this form may be added to any class, even if Tetrad sessions were previously saved out using a version of the + * class that didn't include it. (That's what the "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for + * help. + * + * @param s The object input stream. + * @throws IOException If any. + * @throws ClassNotFoundException If any. * + */ + @Serial + private void readObject(ObjectInputStream s) + throws IOException, ClassNotFoundException { + s.defaultReadObject(); + + if (this.bayesPm == null) { + throw new NullPointerException(); + } + + if (this.nodes == null) { + throw new NullPointerException(); + } + + if (this.parents == null) { + throw new NullPointerException(); + } + + if (this.parentDims == null) { + throw new NullPointerException(); + } + + if (this.probs == null) { + throw new NullPointerException(); + } + } +} From 681f3aa768847cf34896d733ec96993436f661a5 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 01:26:34 -0400 Subject: [PATCH 02/23] Introduce ProbMap and implement it in MlBayesIm Implemented a new class ProbMap.java in bayes package, as a map between a unique integer index for a particular node and its probability. This mapping omits NaN values and it's implemented in MlBayesIm class. MlBayesIm uses either traditional probability matrices or these new probability maps, based on a flag setting. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 28 ++--- .../java/edu/cmu/tetrad/bayes/ProbMap.java | 102 ++++++++++++++++++ 2 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 4853819ddb..2ba4009c84 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -25,11 +25,9 @@ import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.graph.Paths; import edu.cmu.tetrad.graph.TimeLagGraph; -import edu.cmu.tetrad.util.Matrix; import edu.cmu.tetrad.util.NumberFormatUtil; import edu.cmu.tetrad.util.RandomUtil; import edu.cmu.tetrad.util.Vector; -import org.apache.commons.math3.distribution.ChiSquaredDistribution; import java.io.IOException; import java.io.ObjectInputStream; @@ -38,7 +36,6 @@ import java.util.*; import static org.apache.commons.math3.util.FastMath.abs; -import static org.apache.commons.math3.util.FastMath.pow; /** * Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate @@ -104,17 +101,21 @@ public final class MlBayesIm implements BayesIm { * The array of nodes from the graph. Order is important. */ private final Node[] nodes; + /** + * A flag indicating whether to use probability matrices or not. + */ + boolean useProbMatrices = true; /** * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', * and order in subarrays is important. */ private int[][] parents; + + //===============================CONSTRUCTORS=========================// /** * The array of dimensionality (number of categories for each node) for each of the subarrays of 'parents'. */ private int[][] parentDims; - - //===============================CONSTRUCTORS=========================// /** * The main data structure; stores the values of all of the conditional probabilities for the Bayes net of the form * P(N=v0 | P1=v1, P2=v2,...). The first dimension is the node N, in the order of 'nodes'. The second dimension is @@ -127,13 +128,12 @@ public final class MlBayesIm implements BayesIm { * @serial */ private double[][][] probs; - /** - * The array of matrices that store the probabilities for each node. + * The array of probability maps for each node. The index of the node corresponds to the index of the probability + * map in this array. The probability map is a map from a unique integer index for a particular node to the + * probability of that node taking on that value, where NaN's are not stored. */ - private Matrix[] probMatrices; - - boolean useProbMatrices = true; + private ProbMap[] probMatrices; /** * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). @@ -588,7 +588,7 @@ public void normalizeRow(int nodeIndex, int rowIndex) { @Override public void setProbability(int nodeIndex, double[][] probMatrix) { if (useProbMatrices) { - probMatrices[nodeIndex] = new Matrix(probMatrix); + probMatrices[nodeIndex] = new ProbMap(probMatrix); } else { for (int i = 0; i < probMatrix.length; i++) { System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); @@ -1191,7 +1191,7 @@ private void initialize(BayesIm oldBayesIm, int initializationMethod) { this.parents = new int[this.nodes.length][]; this.parentDims = new int[this.nodes.length][]; this.probs = new double[this.nodes.length][][]; - this.probMatrices = new Matrix[this.nodes.length]; + this.probMatrices = new ProbMap[this.nodes.length]; for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); @@ -1248,7 +1248,7 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, this.parentDims[nodeIndex] = dims; this.probs[nodeIndex] = new double[numRows][numCols]; - this.probMatrices[nodeIndex] = new Matrix(numRows, numCols); + this.probMatrices[nodeIndex] = new ProbMap(numRows, numCols); // Initialize each row. if (initializationMethod == MlBayesIm.RANDOM) { @@ -1470,7 +1470,7 @@ private void readObject(ObjectInputStream s) if (this.probs != null) { for (int i = 0; i < this.nodes.length; i++) { if (useProbMatrices) { - probMatrices[i] = new Matrix(probs[i]); + probMatrices[i] = new ProbMap(probs[i]); } } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java new file mode 100644 index 0000000000..7d4c4394fd --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java @@ -0,0 +1,102 @@ +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.util.Vector; + +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a probability map. A probability map is a map from a unique integer index for a particular node to the * + * probability of that node taking on that value, where NaN's are not stored. + */ +public class ProbMap { + + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. + */ + private final Map map = new HashMap<>(); + private final int numRows; + private final int numColumns; + + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain + * number of rows and a certain number of columns in the table. + */ + public ProbMap(int numRows, int numColumns) { + if (numRows < 1 || numColumns < 1) { + throw new IllegalArgumentException("Number of rows and columns must be at least 1."); + } + + this.numRows = numRows; + this.numColumns = numColumns; + } + + public ProbMap(double[][] probMatrix) { + numRows = probMatrix.length; + numColumns = probMatrix[0].length; + + for (int i = 0; i < numRows; i++) { + if (probMatrix[i].length != numColumns) { + throw new IllegalArgumentException("All rows must have the same number of columns."); + } + } + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + map.put(i * numColumns + j, probMatrix[i][j]); + } + } + } + + /** + * Returns the probability of the node taking on the value specified by the given row and column. + */ + public double get(int row, int column) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + return map.get(row * numColumns + column); + } + + /** + * Sets the probability of the node taking on the value specified by the given row and column to the given value. + */ + public void set(int row, int column, double value) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + map.put(row * numColumns + column, value); + } + + /** + * Returns the number of rows in the probability map. + * + * @return the number of rows in the probability map. + */ + public int getNumRows() { + return numRows; + } + + /** + * Returns the number of columns in the probability map. + * + * @return the number of columns in the probability map. + */ + public int getNumColumns() { + return numColumns; + } + + public void assignRow(int rowIndex, Vector vector) { + if (vector.size() != numColumns) { + throw new IllegalArgumentException("Vector must have the same number of columns as the probability map."); + } + + for (int i = 0; i < numColumns; i++) { + set(rowIndex, i, vector.get(i)); + } + } +} From dc4e455257a8d816047ba12e07fd083a735a8edb Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 02:52:34 -0400 Subject: [PATCH 03/23] Implement ProbMap for efficient probability storage Introduced ProbMap class, representing a efficient mapping between a unique index and node probability, neglecting NaN values. Integrated this into MlBayesIm replacing old matrix-based approach, governed by useProbMatrices flag. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 66 ++++++++----------- .../java/edu/cmu/tetrad/bayes/ProbMap.java | 47 +++++++++++-- 2 files changed, 69 insertions(+), 44 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 2ba4009c84..e0907b5c03 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -50,12 +50,12 @@ * corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of * combinations of parent categories for that node. The third dimension corresponds to the list of categories for that * node itself. Two methods allow these values to be set and retrieved: getWordRatio(int nodeIndex, int rowIndex, int - * colIndex); and,
  • setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the - * index of the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, - * use the method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that - * you can build the parentVals[] array, use the method getParents(int nodeIndex) To determine the index of a category, - * use the method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood - * as variants of the methods above. + * colIndex); and, setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of + * the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, use the + * method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that you can + * build the parentVals[] array, use the method getParents(int nodeIndex) To determine the index of a category, use the + * method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood as + * variants of the methods above. *

    * This version uses a different method for storing the probabilities. The previous version stored the probabilities in * a three-dimensional array, where the first dimension was the node, the second dimension was the row index, and the @@ -70,7 +70,6 @@ * @version $Id: $Id */ public final class MlBayesIm implements BayesIm { - /** * Inidicates that new rows in this BayesIm should be initialized as unknowns, forcing them to be specified * manually. This is the default. @@ -82,17 +81,14 @@ public final class MlBayesIm implements BayesIm { public static final int RANDOM = 1; @Serial private static final long serialVersionUID = 23L; - /** * Tolerance. */ private static final double ALLOWABLE_DIFFERENCE = 1.0e-3; - /** * Random number generator. */ static private final Random random = new Random(); - /** * The associated Bayes PM model. */ @@ -110,8 +106,6 @@ public final class MlBayesIm implements BayesIm { * and order in subarrays is important. */ private int[][] parents; - - //===============================CONSTRUCTORS=========================// /** * The array of dimensionality (number of categories for each node) for each of the subarrays of 'parents'. */ @@ -131,7 +125,7 @@ public final class MlBayesIm implements BayesIm { /** * The array of probability maps for each node. The index of the node corresponds to the index of the probability * map in this array. The probability map is a map from a unique integer index for a particular node to the - * probability of that node taking on that value, where NaN's are not stored. + * probability of that node taking on that value, where NaN's are not stored. Replaces the probs array. */ private ProbMap[] probMatrices; @@ -379,9 +373,6 @@ public int getNumColumns(int nodeIndex) { } else { return this.probs[nodeIndex][0].length; } - -// return this.probs[nodeIndex][0].length; - } /** @@ -396,8 +387,6 @@ public int getNumRows(int nodeIndex) { } else { return this.probs[nodeIndex].length; } - -// return this.probs[nodeIndex].length; } /** @@ -504,8 +493,6 @@ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { } else { return this.probs[nodeIndex][rowIndex][colIndex]; } - -// return this.probs[nodeIndex][rowIndex][colIndex]; } /** @@ -594,10 +581,6 @@ public void setProbability(int nodeIndex, double[][] probMatrix) { System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); } } - -// for (int i = 0; i < probMatrix.length; i++) { -// System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); -// } } /** @@ -628,8 +611,6 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex, } else { this.probs[nodeIndex][rowIndex][colIndex] = value; } - -// this.probs[nodeIndex][rowIndex][colIndex] = value; } /** @@ -1232,14 +1213,14 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, int numRows = 1; for (int dim : dims) { - if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { - throw new IllegalArgumentException( - "The number of rows in the " - + "conditional probability table for " - + this.nodes[nodeIndex] - + " is greater than 1,000,000 and cannot be " - + "represented."); - } +// if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { +// throw new IllegalArgumentException( +// "The number of rows in the " +// + "conditional probability table for " +// + this.nodes[nodeIndex] +// + " is greater than 1,000,000 and cannot be " +// + "represented."); +// } numRows *= dim; } @@ -1286,8 +1267,6 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { } else { this.probs[nodeIndex][rowIndex] = row; } -// this.probs[nodeIndex][rowIndex] = row; -// this.probMatrices[nodeIndex].assignRow(rowIndex, new Vector(row)); } /** @@ -1467,12 +1446,19 @@ private void readObject(ObjectInputStream s) throw new NullPointerException(); } - if (this.probs != null) { + copyDataToProbMatrices(); + } + + private void copyDataToProbMatrices() { + if (this.probs != null && this.probs.length == this.nodes.length) { + this.probMatrices = new ProbMap[this.probs.length]; + for (int i = 0; i < this.nodes.length; i++) { - if (useProbMatrices) { - probMatrices[i] = new ProbMap(probs[i]); - } + probMatrices[i] = new ProbMap(this.probs[i]); } + + this.probs = null; + this.useProbMatrices = true; } } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java index 7d4c4394fd..5742fdf9a7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java @@ -1,7 +1,9 @@ package edu.cmu.tetrad.bayes; +import edu.cmu.tetrad.util.TetradSerializable; import edu.cmu.tetrad.util.Vector; +import java.io.Serial; import java.util.HashMap; import java.util.Map; @@ -9,14 +11,21 @@ * Represents a probability map. A probability map is a map from a unique integer index for a particular node to the * * probability of that node taking on that value, where NaN's are not stored. */ -public class ProbMap { - +public class ProbMap implements TetradSerializable { + @Serial + private static final long serialVersionUID = 23L; /** * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of * that node taking on that value, where NaN's are not stored. */ private final Map map = new HashMap<>(); + /** + * The number of rows in the table. + */ private final int numRows; + /** + * The number of columns in the table. + */ private final int numColumns; /** @@ -33,7 +42,17 @@ public ProbMap(int numRows, int numColumns) { this.numColumns = numColumns; } + /** + * Constructs a new probability map based on the given 2-dimensional array. + * + * @param probMatrix the 2-dimensional array representing the probability matrix + * @throws IllegalArgumentException if the number of columns in any row is different + */ public ProbMap(double[][] probMatrix) { + if (probMatrix == null || probMatrix.length == 0 || probMatrix[0].length == 0) { + throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); + } + numRows = probMatrix.length; numColumns = probMatrix[0].length; @@ -58,7 +77,13 @@ public double get(int row, int column) { throw new IllegalArgumentException("Row and column must be within bounds."); } - return map.get(row * numColumns + column); + int key = row * numColumns + column; + + if (!map.containsKey(key)) { + return Double.NaN; + } + + return map.get(key); } /** @@ -69,7 +94,14 @@ public void set(int row, int column, double value) { throw new IllegalArgumentException("Row and column must be within bounds."); } - map.put(row * numColumns + column, value); + int key = row * numColumns + column; + + if (Double.isNaN(value)) { + map.remove(key); + return; + } + + map.put(key, value); } /** @@ -90,6 +122,13 @@ public int getNumColumns() { return numColumns; } + /** + * Assigns the values in the provided vector to a specific row in the probability map. + * + * @param rowIndex the index of the row to be assigned + * @param vector the vector containing the values to be assigned to the row + * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the probability map + */ public void assignRow(int rowIndex, Vector vector) { if (vector.size() != numColumns) { throw new IllegalArgumentException("Vector must have the same number of columns as the probability map."); From a9da8ce62826a24a942e2f76fa7dc4f440270d85 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:07:48 -0400 Subject: [PATCH 04/23] Add Javadoc comments and update matrix data managing method Updated ProbMap.java and MlBayesIm.java with detailed Javadoc comments for better understanding of the code. Also, refined the copyDataToProbMatrices() method in MlBayesIm.java to consider the 'useProbMatrices' flag before copying data from 'probs' array to 'probMatrices' array. --- .../main/java/edu/cmu/tetrad/bayes/MlBayesIm.java | 12 +++++++++++- .../src/main/java/edu/cmu/tetrad/bayes/ProbMap.java | 12 ++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index e0907b5c03..bb74846582 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -1449,8 +1449,18 @@ private void readObject(ObjectInputStream s) copyDataToProbMatrices(); } + /** + * Copies data from the `probs` array to the `probMatrices` array. + * If the lengths of both arrays are equal, the `probMatrices` array is + * initialized with `ProbMap` objects, each containing the corresponding + * `probs` element. The `probs` array is then set to null and the + * `useProbMatrices` flag is set to true. + * + * Note: This method should only be called after the `probs` array has + * been properly initialized. + */ private void copyDataToProbMatrices() { - if (this.probs != null && this.probs.length == this.nodes.length) { + if (!this.useProbMatrices && this.probs != null && this.probs.length == this.nodes.length) { this.probMatrices = new ProbMap[this.probs.length]; for (int i = 0; i < this.nodes.length; i++) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java index 5742fdf9a7..724a5cc51a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java @@ -32,6 +32,9 @@ public class ProbMap implements TetradSerializable { * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain * number of rows and a certain number of columns in the table. + * + * @param numRows the number of rows in the table + * @param numColumns the number of columns in the table */ public ProbMap(int numRows, int numColumns) { if (numRows < 1 || numColumns < 1) { @@ -71,6 +74,10 @@ public ProbMap(double[][] probMatrix) { /** * Returns the probability of the node taking on the value specified by the given row and column. + * + * @param row the row of the node + * @param column the column of the node + * @return the probability of the node taking on the value specified by the given row and column */ public double get(int row, int column) { if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { @@ -88,6 +95,11 @@ public double get(int row, int column) { /** * Sets the probability of the node taking on the value specified by the given row and column to the given value. + * + * @param row the row of the node + * @param column the column of the node + * @param value the probability of the node taking on the value specified by the given row and column + * (NaN to remove the value) */ public void set(int row, int column, double value) { if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { From 14ba1fe50646f80fff3fcf93511b6a580956f0a2 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:21:35 -0400 Subject: [PATCH 05/23] Rename class ProbMap to CptMap and update methods The ProbMap class and its instances have been renamed to CptMap to better reflect their purpose of handling Conditional Probability Tables (CPTs). Changes were also executed in MlBayesIm class to reflect this modification. Added more detailed Javadoc comments for improved readability and understanding. --- .../bayes/{ProbMap.java => CptMap.java} | 28 +++++++++++-------- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 12 ++++---- 2 files changed, 22 insertions(+), 18 deletions(-) rename tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/{ProbMap.java => CptMap.java} (82%) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java similarity index 82% rename from tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java rename to tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java index 724a5cc51a..a40efa224a 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ProbMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java @@ -8,10 +8,13 @@ import java.util.Map; /** - * Represents a probability map. A probability map is a map from a unique integer index for a particular node to the * - * probability of that node taking on that value, where NaN's are not stored. + * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique + * integer index for a particular node to the probability of that node taking on that value, where NaN's are not + * stored. + *

    + * The goal of this is to allow huge conditional probability tables to be stored in a compact way. */ -public class ProbMap implements TetradSerializable { +public class CptMap implements TetradSerializable { @Serial private static final long serialVersionUID = 23L; /** @@ -33,10 +36,10 @@ public class ProbMap implements TetradSerializable { * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain * number of rows and a certain number of columns in the table. * - * @param numRows the number of rows in the table + * @param numRows the number of rows in the table * @param numColumns the number of columns in the table */ - public ProbMap(int numRows, int numColumns) { + public CptMap(int numRows, int numColumns) { if (numRows < 1 || numColumns < 1) { throw new IllegalArgumentException("Number of rows and columns must be at least 1."); } @@ -51,7 +54,7 @@ public ProbMap(int numRows, int numColumns) { * @param probMatrix the 2-dimensional array representing the probability matrix * @throws IllegalArgumentException if the number of columns in any row is different */ - public ProbMap(double[][] probMatrix) { + public CptMap(double[][] probMatrix) { if (probMatrix == null || probMatrix.length == 0 || probMatrix[0].length == 0) { throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); } @@ -75,7 +78,7 @@ public ProbMap(double[][] probMatrix) { /** * Returns the probability of the node taking on the value specified by the given row and column. * - * @param row the row of the node + * @param row the row of the node * @param column the column of the node * @return the probability of the node taking on the value specified by the given row and column */ @@ -96,10 +99,10 @@ public double get(int row, int column) { /** * Sets the probability of the node taking on the value specified by the given row and column to the given value. * - * @param row the row of the node + * @param row the row of the node * @param column the column of the node - * @param value the probability of the node taking on the value specified by the given row and column - * (NaN to remove the value) + * @param value the probability of the node taking on the value specified by the given row and column (NaN to + * remove the value) */ public void set(int row, int column, double value) { if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { @@ -138,8 +141,9 @@ public int getNumColumns() { * Assigns the values in the provided vector to a specific row in the probability map. * * @param rowIndex the index of the row to be assigned - * @param vector the vector containing the values to be assigned to the row - * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the probability map + * @param vector the vector containing the values to be assigned to the row + * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the + * probability map */ public void assignRow(int rowIndex, Vector vector) { if (vector.size() != numColumns) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index bb74846582..0fa3919f98 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -127,7 +127,7 @@ public final class MlBayesIm implements BayesIm { * map in this array. The probability map is a map from a unique integer index for a particular node to the * probability of that node taking on that value, where NaN's are not stored. Replaces the probs array. */ - private ProbMap[] probMatrices; + private CptMap[] probMatrices; /** * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). @@ -575,7 +575,7 @@ public void normalizeRow(int nodeIndex, int rowIndex) { @Override public void setProbability(int nodeIndex, double[][] probMatrix) { if (useProbMatrices) { - probMatrices[nodeIndex] = new ProbMap(probMatrix); + probMatrices[nodeIndex] = new CptMap(probMatrix); } else { for (int i = 0; i < probMatrix.length; i++) { System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); @@ -1172,7 +1172,7 @@ private void initialize(BayesIm oldBayesIm, int initializationMethod) { this.parents = new int[this.nodes.length][]; this.parentDims = new int[this.nodes.length][]; this.probs = new double[this.nodes.length][][]; - this.probMatrices = new ProbMap[this.nodes.length]; + this.probMatrices = new CptMap[this.nodes.length]; for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); @@ -1229,7 +1229,7 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, this.parentDims[nodeIndex] = dims; this.probs[nodeIndex] = new double[numRows][numCols]; - this.probMatrices[nodeIndex] = new ProbMap(numRows, numCols); + this.probMatrices[nodeIndex] = new CptMap(numRows, numCols); // Initialize each row. if (initializationMethod == MlBayesIm.RANDOM) { @@ -1461,10 +1461,10 @@ private void readObject(ObjectInputStream s) */ private void copyDataToProbMatrices() { if (!this.useProbMatrices && this.probs != null && this.probs.length == this.nodes.length) { - this.probMatrices = new ProbMap[this.probs.length]; + this.probMatrices = new CptMap[this.probs.length]; for (int i = 0; i < this.nodes.length; i++) { - probMatrices[i] = new ProbMap(this.probs[i]); + probMatrices[i] = new CptMap(this.probs[i]); } this.probs = null; From 9833529808d0c33ba991e2c753f6b16f3e3f233f Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:23:38 -0400 Subject: [PATCH 06/23] Rename class ProbMap to CptMap and update methods The ProbMap class and its instances have been renamed to CptMap to better reflect their purpose of handling Conditional Probability Tables (CPTs). Changes were also executed in MlBayesIm class to reflect this modification. Added more detailed Javadoc comments for improved readability and understanding. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 51 +++---------------- 1 file changed, 6 insertions(+), 45 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 0fa3919f98..9041245b66 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -53,7 +53,7 @@ * colIndex); and, setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of * the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, use the * method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that you can - * build the parentVals[] array, use the method getParents(int nodeIndex) To determine the index of a category, use the + * build the parentVals[] array, use the method getParents(int nodeIndex). To determine the index of a category, use the * method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood as * variants of the methods above. *

    @@ -1350,42 +1350,6 @@ private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, } } -// // Go through each parent of the node in the new BayesIm. -// for (int i = 0; i < oldBayesIm.getNumParents(oldNodeIndex); i++) { -// -// // Get the index of the parent in the new graph and in the old -// // graph. If it's no longer in the new graph, skip to the next -// // parent. -// int oldParentNodeIndex = oldBayesIm.getParent(oldNodeIndex, i); -// int parentNodeIndex = -// oldBayesIm.getCorrespondingNodeIndex(oldParentNodeIndex, this); -// int parentIndex = -1; -// -// for (int j = 0; j < this.getNumParents(nodeIndex); j++) { -// if (parentNodeIndex == this.getParent(nodeIndex, j)) { -// parentIndex = j; -// break; -// } -// } -// -// if (parentIndex == -1 || -// parentIndex >= this.getNumParents(nodeIndex)) { -// continue; -// } -// -// // Look up that value index for the new BayesIm for that parent. -// // If it was a valid value index in the old BayesIm, record -// // that value in oldParentValues. Otherwise return -1. -// int parentValue = oldParentValues[i]; -// int parentDim = -// this.getParentDim(nodeIndex, parentIndex); -// -// if (parentValue < parentDim) { -// oldParentValues[parentIndex] = oldParentValue; -// } else { -// return -1; -// } -// } // If there are any -1's in the combination at this point, return -1. for (int oldParentValue : oldParentValues) { if (oldParentValue == -1) { @@ -1450,14 +1414,11 @@ private void readObject(ObjectInputStream s) } /** - * Copies data from the `probs` array to the `probMatrices` array. - * If the lengths of both arrays are equal, the `probMatrices` array is - * initialized with `ProbMap` objects, each containing the corresponding - * `probs` element. The `probs` array is then set to null and the - * `useProbMatrices` flag is set to true. - * - * Note: This method should only be called after the `probs` array has - * been properly initialized. + * Copies data from the `probs` array to the `probMatrices` array. If the lengths of both arrays are equal, the + * `probMatrices` array is initialized with `ProbMap` objects, each containing the corresponding `probs` element. + * The `probs` array is then set to null and the `useProbMatrices` flag is set to true. + *

    + * Note: This method should only be called after the `probs` array has been properly initialized. */ private void copyDataToProbMatrices() { if (!this.useProbMatrices && this.probs != null && this.probs.length == this.nodes.length) { From 56185087d90ba0de375cef072db3ec39999893c8 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:25:09 -0400 Subject: [PATCH 07/23] Rename class ProbMap to CptMap and update methods The ProbMap class and its instances have been renamed to CptMap to better reflect their purpose of handling Conditional Probability Tables (CPTs). Changes were also executed in MlBayesIm class to reflect this modification. Added more detailed Javadoc comments for improved readability and understanding. --- .../src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 9041245b66..fbd3d1d99b 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -98,7 +98,9 @@ public final class MlBayesIm implements BayesIm { */ private final Node[] nodes; /** - * A flag indicating whether to use probability matrices or not. + * A flag indicating whether to use probability matrices or not. If true, the probMatrices array is used; if false, + * the probs array is used. The probMatrices array is the new way of storing the probabilities; the probs array is + * kept here for backward compatibility. */ boolean useProbMatrices = true; /** @@ -118,8 +120,9 @@ public final class MlBayesIm implements BayesIm { * for each of the parent values; the order of the values in this array is the same as the order of node in * 'parents'; the value indices are obtained from the Bayes PM for each node. The column is the index of the value * of N, where this index is obtained from the Bayes PM. - * - * @serial + *

    + * This is kept here for backward compatibility. The new way of storing the probabilities is in the probMatrices + * array. */ private double[][][] probs; /** From acd6268c4694625a84c60c4f5ddea24aae4f0247 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:26:41 -0400 Subject: [PATCH 08/23] Correct grammar in MlBayesIm.java comments Fixed grammar in the explanatory comments of MlBayesIm.java, mainly focusing on the description of the division of labour among different classes and the purposes of the different dimensions in the data arrays. Also, improved sentence flow in a few places to enhance readability. --- tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index fbd3d1d99b..3d67fd3180 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -40,7 +40,7 @@ /** * Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate * this table. The division of labor is as follows. The Dag is responsible for manipulating the basic graphical - * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there are no method + * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there is no method * in either BayesPm or BayesIm to do this. BayesPm stores and manipulates the *categories* of each node in a DAG, * considered as a variable in a Bayes net. The number of categories for a variable can be changed there as well as the * names for those categories. This class, BayesIm, stores the actual probability tables which are implied by the @@ -50,7 +50,7 @@ * corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of * combinations of parent categories for that node. The third dimension corresponds to the list of categories for that * node itself. Two methods allow these values to be set and retrieved: getWordRatio(int nodeIndex, int rowIndex, int - * colIndex); and, setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of + * colIndex); and setProbability(int nodeIndex, int rowIndex, int colIndex, int probability). To determine the index of * the node in question, use the method getNodeIndex(Node node). To determine the index of the row in question, use the * method getRowIndex(int[] parentVals). To determine the order of the parent values for a given node so that you can * build the parentVals[] array, use the method getParents(int nodeIndex). To determine the index of a category, use the From a6552a597845509f0aa3259c84cfa4f52a5a6337 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:55:27 -0400 Subject: [PATCH 09/23] Refactor storage method in MlBayesIm.java and update comments Updated MlBayesIm storage strategy from a three-dimensional array to an array of CptMap objects, which offer a more efficient representation for large conditional probability tables. Enhanced code readability by removing unnecessary checks and revising comments for clarity. This new method omits the storage of NaNs allowing space savings for sparse tables. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index 3d67fd3180..c0c7e76fe7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -59,10 +59,9 @@ *

    * This version uses a different method for storing the probabilities. The previous version stored the probabilities in * a three-dimensional array, where the first dimension was the node, the second dimension was the row index, and the - * third dimension was the column index. This version stores and array of Matrix objects, where each Matrix object - * represents the conditional probability table for a node. This will allow us in the future to represent this as an of - * Maps from Integers to Doubles and store only the non-NaN values. This will save space and time in the case of sparse - * tables. + * third dimension was the column index. This version stores and array of CptMap objects, where each CptMap object + * represents the conditional probability table for a node. NaNs in these maps are not stored, allowing for a more + * compact representation so that huge conditional probability tables can be estimated from finite samples. *

    * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions. * @@ -1216,15 +1215,6 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, int numRows = 1; for (int dim : dims) { -// if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { -// throw new IllegalArgumentException( -// "The number of rows in the " -// + "conditional probability table for " -// + this.nodes[nodeIndex] -// + " is greater than 1,000,000 and cannot be " -// + "represented."); -// } - numRows *= dim; } From 27a6334de93f6681316c1fb010a8b1ffdf3050ee Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 03:56:29 -0400 Subject: [PATCH 10/23] Enhance storage strategy and comments in CptMap.java Transitioned to an efficient storage method for large conditional probability tables in the CptMap class. This change --- tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java index a40efa224a..29be162b38 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java @@ -12,7 +12,12 @@ * integer index for a particular node to the probability of that node taking on that value, where NaN's are not * stored. *

    - * The goal of this is to allow huge conditional probability tables to be stored in a compact way. + * The goal of this is to allow huge conditional probability tables to be stored in a compact way when estimated from + * finite samples. The idea is that the CPT is stored as a map from a unique integer index for a particular node to the + * probability of that node taking on that value, where NaN's are not stored. This is useful because the CPTs can be + * huge and sparse, and this allows them to be stored in a compact way. The unique integer index for a particular node + * is calculated as follows: row * numColumns + column, where row is the row of the node and column is the column of the + * node. */ public class CptMap implements TetradSerializable { @Serial From 39a4afd42942d16113fc4ba97ce953a16092ad32 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:21:47 -0400 Subject: [PATCH 11/23] Simplify CptMap class and its comments Removed duplicative comments and unnecessary explanations about unique integer index calculation in CptMap class. Refocused on the key functionality: storage of large conditional probability tables in a compact form excluding NaN values. --- .../src/main/java/edu/cmu/tetrad/bayes/CptMap.java | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java index 29be162b38..ab3cc3a0e8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java @@ -8,16 +8,9 @@ import java.util.Map; /** - * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique + * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique * integer index for a particular node to the probability of that node taking on that value, where NaN's are not * stored. - *

    - * The goal of this is to allow huge conditional probability tables to be stored in a compact way when estimated from - * finite samples. The idea is that the CPT is stored as a map from a unique integer index for a particular node to the - * probability of that node taking on that value, where NaN's are not stored. This is useful because the CPTs can be - * huge and sparse, and this allows them to be stored in a compact way. The unique integer index for a particular node - * is calculated as follows: row * numColumns + column, where row is the row of the node and column is the column of the - * node. */ public class CptMap implements TetradSerializable { @Serial From a7b9166f76dbbbf06e6b88587b5cc971eb07e1f9 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:25:29 -0400 Subject: [PATCH 12/23] Refactor MlBayesIm storage method explanation Modified the explanation of MlBayesIm's probability storage method from a three-dimensional array to a sparse method that does not store NaNs. This change highlights the method's efficiency when working with large Bayesian probabilistic models. The old storage method's description remains for backward compatibility. --- .../src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index c0c7e76fe7..bdad051965 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -57,11 +57,9 @@ * method getCategoryIndex(Node node) in BayesPm. The rest of the methods in this class are easily understood as * variants of the methods above. *

    - * This version uses a different method for storing the probabilities. The previous version stored the probabilities in - * a three-dimensional array, where the first dimension was the node, the second dimension was the row index, and the - * third dimension was the column index. This version stores and array of CptMap objects, where each CptMap object - * represents the conditional probability table for a node. NaNs in these maps are not stored, allowing for a more - * compact representation so that huge conditional probability tables can be estimated from finite samples. + * This version uses a sparse method for storing the probabilities, where NaNs are not stored. This allows BayesPms with + * many categories per variable to be estimated from small samples without overflowing memory. The old method of storing + * probabilities is kept here for backward compatibility. *

    * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions. * From 701d15e27e7f5118013d9f1b5cdb3ccfdb9130e7 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:26:51 -0400 Subject: [PATCH 13/23] Remove simulation task from MlBayesIm Removed an embedded class called SimulationTask along with related code lines in MlBayesIm.java file. The modification simplifies the constructSample function, as the simulation now utilizes a sequential approach rather than a fork/join parallel computation, making it more straightforward. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 69 ------------------- 1 file changed, 69 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index bdad051965..fdc7aaf8b9 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -963,75 +963,6 @@ private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved, int private void constructSample(int sampleSize, DataSet dataSet, int[] map, int[] tiers) { -// //Do the simulation. -// class SimulationTask extends RecursiveTask { -// private int chunk; -// private int from; -// private int to; -// private int[] tiers; -// private DataSet dataSet; -// private int[] map; -// -// public SimulationTask(int chunk, int from, int to, int[] tiers, DataSet dataSet, int[] map) { -// this.chunk = chunk; -// this.from = from; -// this.to = to; -// this.tiers = tiers; -// this.dataSet = dataSet; -// this.map = map; -// } -// -// @Override -// protected Boolean compute() { -// if (to - from <= chunk) { -// RandomGenerator randomGenerator = new Well1024a(++seed[0]); -// -// for (int row = from; row < to; row++) { -// for (int t : tiers) { -// int[] parentValues = new int[parents[t].length]; -// -// for (int k = 0; k < parentValues.length; k++) { -// parentValues[k] = dataSet.getInt(row, parents[t][k]); -// } -// -// int rowIndex = getRowIndex(t, parentValues); -// double sum = 0.0; -// double r; -// -// r = randomGenerator.nextDouble(); -// -// for (int k = 0; k < getNumColumns(t); k++) { -// double probability = getProbability(t, rowIndex, k); -// sum += probability; -// -// if (sum >= r) { -// dataSet.setInt(row, map[t], k); -// break; -// } -// } -// } -// } -// -// return true; -// } else { -// int mid = (to + from) / 2; -// SimulationTask left = new SimulationTask(chunk, from, mid, tiers, dataSet, map); -// SimulationTask right = new SimulationTask(chunk, mid, to, tiers, dataSet, map); -// -// left.fork(); -// right.compute(); -// left.join(); -// -// return true; -// } -// } -// } -// -// int chunk = 25; -// -// ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); -// SimulationTask task = new SimulationTask(chunk, 0, sampleSize, tiers, dataSet, map); -// pool.invoke(task); // Construct the sample. for (int i = 0; i < sampleSize; i++) { for (int t : tiers) { From 3a1b0ddb5d67b692bb400d41582658bba92422c7 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:30:05 -0400 Subject: [PATCH 14/23] Replace probability matrices with CptMaps in MlBayesIm Refactored the MlBayesIm.java file to replace the use of probability matrices with CptMaps for storing probabilities. This change introduces more efficient data structures and methods for maintaining and accessing probabilities. Backward compatibility is maintained through the fallback to a "probs" array if CptMaps are not used. --- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index fdc7aaf8b9..fc454801d3 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -95,11 +95,10 @@ public final class MlBayesIm implements BayesIm { */ private final Node[] nodes; /** - * A flag indicating whether to use probability matrices or not. If true, the probMatrices array is used; if false, - * the probs array is used. The probMatrices array is the new way of storing the probabilities; the probs array is - * kept here for backward compatibility. + * A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used. + * The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility. */ - boolean useProbMatrices = true; + boolean useCptMaps = true; /** * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', * and order in subarrays is important. @@ -123,9 +122,8 @@ public final class MlBayesIm implements BayesIm { */ private double[][][] probs; /** - * The array of probability maps for each node. The index of the node corresponds to the index of the probability - * map in this array. The probability map is a map from a unique integer index for a particular node to the - * probability of that node taking on that value, where NaN's are not stored. Replaces the probs array. + * The array of CPT maps for each node. The index of the node corresponds to the index of the probability map in + * this array. Replaces the probs array. */ private CptMap[] probMatrices; @@ -368,7 +366,7 @@ public List getVariableNames() { * @return the number of columns. */ public int getNumColumns(int nodeIndex) { - if (useProbMatrices) { + if (useCptMaps) { return probMatrices[nodeIndex].getNumColumns(); } else { return this.probs[nodeIndex][0].length; @@ -382,7 +380,7 @@ public int getNumColumns(int nodeIndex) { * @return the number of rows in the node. */ public int getNumRows(int nodeIndex) { - if (useProbMatrices) { + if (useCptMaps) { return probMatrices[nodeIndex].getNumRows(); } else { return this.probs[nodeIndex].length; @@ -488,7 +486,7 @@ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { * @return the probability value for the given node. */ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { - if (useProbMatrices) { + if (useCptMaps) { return probMatrices[nodeIndex].get(rowIndex, colIndex); } else { return this.probs[nodeIndex][rowIndex][colIndex]; @@ -574,7 +572,7 @@ public void normalizeRow(int nodeIndex, int rowIndex) { */ @Override public void setProbability(int nodeIndex, double[][] probMatrix) { - if (useProbMatrices) { + if (useCptMaps) { probMatrices[nodeIndex] = new CptMap(probMatrix); } else { for (int i = 0; i < probMatrix.length; i++) { @@ -606,7 +604,7 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex, + "between 0.0 and 1.0 or Double.NaN."); } - if (useProbMatrices) { + if (useCptMaps) { probMatrices[nodeIndex].set(rowIndex, colIndex, value); } else { this.probs[nodeIndex][rowIndex][colIndex] = value; @@ -648,7 +646,7 @@ public void randomizeRow(int nodeIndex, int rowIndex) { int size = getNumColumns(nodeIndex); double[] row = getRandomWeights(size); - if (useProbMatrices) { + if (useCptMaps) { for (int colIndex = 0; colIndex < size; colIndex++) { probMatrices[nodeIndex].set(rowIndex, colIndex, row[colIndex]); } @@ -1184,7 +1182,7 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { double[] row = new double[size]; Arrays.fill(row, Double.NaN); - if (useProbMatrices) { + if (useCptMaps) { probMatrices[nodeIndex].assignRow(rowIndex, new Vector(row)); } else { this.probs[nodeIndex][rowIndex] = row; @@ -1343,7 +1341,7 @@ private void readObject(ObjectInputStream s) * Note: This method should only be called after the `probs` array has been properly initialized. */ private void copyDataToProbMatrices() { - if (!this.useProbMatrices && this.probs != null && this.probs.length == this.nodes.length) { + if (!this.useCptMaps && this.probs != null && this.probs.length == this.nodes.length) { this.probMatrices = new CptMap[this.probs.length]; for (int i = 0; i < this.nodes.length; i++) { @@ -1351,7 +1349,7 @@ private void copyDataToProbMatrices() { } this.probs = null; - this.useProbMatrices = true; + this.useCptMaps = true; } } } From 279b5f8841716a4ce6f51cd9f2b38cdd6d2614f7 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:31:15 -0400 Subject: [PATCH 15/23] Update code for efficient probability storage in MlBayesIm Modified MlBayesIm.java to implement CptMaps instead of probability matrices for better memory management. While this update optimizes probability storage and access methods, it also ensures backward compatibility by keeping the old method and introducing a flag to select the method. --- tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index fc454801d3..c86c463aa1 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -59,7 +59,7 @@ *

    * This version uses a sparse method for storing the probabilities, where NaNs are not stored. This allows BayesPms with * many categories per variable to be estimated from small samples without overflowing memory. The old method of storing - * probabilities is kept here for backward compatibility. + * probabilities is kept here for backward compatibility, with an internal code flag to indicate which should be used. *

    * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for advice and earlier versions. * From 27a564dbe4db53012803a30ecd505527e23218fd Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:33:44 -0400 Subject: [PATCH 16/23] Removing the original MlBayesIm code. --- .../edu/cmu/tetrad/bayes/MlBayesImOld.java | 1499 ----------------- 1 file changed, 1499 deletions(-) delete mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java deleted file mode 100644 index 8dcc09abcb..0000000000 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesImOld.java +++ /dev/null @@ -1,1499 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// For information as to what this class does, see the Javadoc, below. // -// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // -// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // -// Scheines, Joseph Ramsey, and Clark Glymour. // -// // -// This program is free software; you can redistribute it and/or modify // -// it under the terms of the GNU General Public License as published by // -// the Free Software Foundation; either version 2 of the License, or // -// (at your option) any later version. // -// // -// This program is distributed in the hope that it will be useful, // -// but WITHOUT ANY WARRANTY; without even the implied warranty of // -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // -// GNU General Public License for more details. // -// // -// You should have received a copy of the GNU General Public License // -// along with this program; if not, write to the Free Software // -// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // -/////////////////////////////////////////////////////////////////////////////// -package edu.cmu.tetrad.bayes; - -import edu.cmu.tetrad.data.*; -import edu.cmu.tetrad.graph.Graph; -import edu.cmu.tetrad.graph.Node; -import edu.cmu.tetrad.graph.Paths; -import edu.cmu.tetrad.graph.TimeLagGraph; -import edu.cmu.tetrad.util.NumberFormatUtil; -import edu.cmu.tetrad.util.RandomUtil; -import org.apache.commons.math3.distribution.ChiSquaredDistribution; - -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.Serial; -import java.text.NumberFormat; -import java.util.*; - -import static org.apache.commons.math3.util.FastMath.abs; -import static org.apache.commons.math3.util.FastMath.pow; - -/** - * Stores a table of probabilities for a Bayes net and, together with BayesPm and Dag, provides methods to manipulate - * this table. The division of labor is as follows. The Dag is responsible for manipulating the basic graphical - * structure of the Bayes net. Dag also stores and manipulates the names of the nodes in the graph; there are no method - * in either BayesPm or BayesIm to do this. BayesPm stores and manipulates the *categories* of each node in a DAG, - * considered as a variable in a Bayes net. The number of categories for a variable can be changed there as well as the - * names for those categories. This class, BayesIm, stores the actual probability tables which are implied by the - * structures in the other two classes. The implied parameters take the form of conditional probabilities--e.g., - * P(N=v0|P1=v1, P2=v2, ...), for all nodes and all combinations of their parent categories. The set of all such - * probabilities is organized in this class as a three-dimensional table of double values. The first dimension - * corresponds to the nodes in the Bayes net. For each such node, the second dimension corresponds to a flat list of - * combinations of parent categories for that node. The third dimension corresponds to the list of categories for that - * node itself. Two methods allow these values to be set and retrieved:

    - * To determine the index of the node in question, use the method To determine - * the index of the row in question, use the method - * To determine the order of the - * parent values for a given node so that you can build the parentVals[] array, - * use the method To determine the - * index of a category, use the method in BayesPm. The rest of the methods in this class are easily understood - * as variants of the methods above. - *

    - * Thanks to Pucktada Treeratpituk, Frank Wimberly, and Willie Wheeler for - * advice and earlier versions. - * - * @author josephramsey - * @version $Id: $Id - */ -public final class MlBayesImOld implements BayesIm { - - /** - * Inidicates that new rows in this BayesIm should be initialized as unknowns, forcing them to be specified - * manually. This is the default. - */ - public static final int MANUAL = 0; - /** - * Indicates that new rows in this BayesIm should be initialized randomly. - */ - public static final int RANDOM = 1; - @Serial - private static final long serialVersionUID = 23L; - - /** - * Tolerance. - */ - private static final double ALLOWABLE_DIFFERENCE = 1.0e-3; - - /** - * Random number generator. - */ - static private final Random random = new Random(); - - /** - * The associated Bayes PM model. - */ - private final BayesPm bayesPm; - /** - * The array of nodes from the graph. Order is important. - */ - private final Node[] nodes; - /** - * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', - * and order in subarrays is important. - */ - private int[][] parents; - /** - * The array of dimensionality (number of categories for each node) for each of the subarrays of 'parents'. - */ - private int[][] parentDims; - - //===============================CONSTRUCTORS=========================// - /** - * The main data structure; stores the values of all of the conditional probabilities for the Bayes net of the form - * P(N=v0 | P1=v1, P2=v2,...). The first dimension is the node N, in the order of 'nodes'. The second dimension is - * the row index for the table of parameters associated with node N; the third dimension is the column index. The - * row index is calculated by the function getRowIndex(int[] values) where 'values' is an array of numerical indices - * for each of the parent values; the order of the values in this array is the same as the order of node in - * 'parents'; the value indices are obtained from the Bayes PM for each node. The column is the index of the value - * of N, where this index is obtained from the Bayes PM. - * - * @serial - */ - private double[][][] probs; - - /** - * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). - * - * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. - * @throws IllegalArgumentException if the array of nodes provided is not a permutation of the nodes - * contained in the bayes parametric model provided. - */ - public MlBayesImOld(BayesPm bayesPm) throws IllegalArgumentException { - this(bayesPm, null, MlBayesImOld.MANUAL); - } - - /** - * Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM. If initialized - * manually, all values will be set to Double.NaN ("?") in each row; if initialized randomly, all values will - * distribute randomly in each row. - * - * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. - * @param initializationMethod either MANUAL or RANDOM. - * @throws IllegalArgumentException if the array of nodes provided is not a permutation of the nodes - * contained in the bayes parametric model provided. - */ - public MlBayesImOld(BayesPm bayesPm, int initializationMethod) - throws IllegalArgumentException { - this(bayesPm, null, initializationMethod); - } - - /** - * Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values - * from the old BayesIm provided where posssible. If initialized manually, all values that cannot be retrieved from - * oldBayesIm will be set to Double.NaN ("?") in each such row; if initialized randomly, all values that cannot be - * retrieved from oldBayesIm will distributed randomly in each such row. - * - * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. - * @param oldBayesIm an already-constructed BayesIm whose values may be used where possible to initialize - * this BayesIm. May be null. - * @param initializationMethod either MANUAL or RANDOM. - * @throws IllegalArgumentException if the array of nodes provided is not a permutation of the nodes - * contained in the bayes parametric model provided. - */ - public MlBayesImOld(BayesPm bayesPm, BayesIm oldBayesIm, - int initializationMethod) throws IllegalArgumentException { - if (bayesPm == null) { - throw new NullPointerException("BayesPm must not be null."); - } - - this.bayesPm = new BayesPm(bayesPm); - - // Get the nodes from the BayesPm. This fixes the order of the nodes - // in the BayesIm, independently of any change to the BayesPm. - // (This order must be maintained.) - Graph graph = bayesPm.getDag(); - this.nodes = graph.getNodes().toArray(new Node[0]); - - // Initialize. - initialize(oldBayesIm, initializationMethod); - } - - /** - * Copy constructor. - * - * @param bayesIm a {@link BayesIm} object - * @throws IllegalArgumentException if any. - */ - public MlBayesImOld(BayesIm bayesIm) throws IllegalArgumentException { - if (bayesIm == null) { - throw new NullPointerException("BayesIm must not be null."); - } - - this.bayesPm = bayesIm.getBayesPm(); - - // Get the nodes from the BayesPm, fixing on an order. (This is - // important; the nodes must always be in the same order for this - // BayesIm.) - this.nodes = new Node[bayesIm.getNumNodes()]; - - for (int i = 0; i < bayesIm.getNumNodes(); i++) { - this.nodes[i] = bayesIm.getNode(i); - } - - // Copy all the old values over. - initialize(bayesIm, MlBayesImOld.MANUAL); - } - - /** - * Generates a simple exemplar of this class to test serialization. - * - * @return a {@link MlBayesImOld} object - */ - public static MlBayesImOld serializableInstance() { - return new MlBayesImOld(BayesPm.serializableInstance()); - } - - //===============================PUBLIC METHODS========================// - - /** - *

    getParameterNames.

    - * - * @return a {@link List} object - */ - public static List getParameterNames() { - return new ArrayList<>(); - } - - private static double[] getRandomWeights(int size) { - assert size > 0; - - double[] row = new double[size]; - double sum = 0.0; - - int strong = (int) Math.floor(random.nextDouble() * size); - - for (int i = 0; i < size; i++) { - if (i == strong) { - row[i] = 1.0; - } else { - row[i] = RandomUtil.getInstance().nextDouble() * 0.3; - } - - sum += row[i]; - } - - for (int i = 0; i < size; i++) { - row[i] /= sum; - } - - return row; - } - - /** - *

    Getter for the field bayesPm.

    - * - * @return this PM. - */ - public BayesPm getBayesPm() { - return this.bayesPm; - } - - /** - *

    getDag.

    - * - * @return the DAG. - */ - public Graph getDag() { - return this.bayesPm.getDag(); - } - - /** - *

    getNumNodes.

    - * - * @return the number of nodes in the model. - */ - public int getNumNodes() { - return this.nodes.length; - } - - /** - * Retrieves the node at the specified index. - * - * @param nodeIndex the index of the node. - * @return the node at the specified index. - */ - public Node getNode(int nodeIndex) { - return this.nodes[nodeIndex]; - } - - /** - *

    getNode.

    - * - * @param name the name of the node. - * @return the node. - */ - public Node getNode(String name) { - return getDag().getNode(name); - } - - /** - * Returns the index of the given node in the nodes array. - * - * @param node the given node. - * @return the index of the node in the nodes array, or -1 if the node is not found. - */ - public int getNodeIndex(Node node) { - for (int i = 0; i < this.nodes.length; i++) { - if (node == this.nodes[i]) { - return i; - } - } - - return -1; - } - - /** - *

    getVariables.

    - * - * @return a {@link List} object - */ - public List getVariables() { - List variables = new LinkedList<>(); - - for (int i = 0; i < getNumNodes(); i++) { - Node node = getNode(i); - variables.add(this.bayesPm.getVariable(node)); - } - - return variables; - } - - /** - *

    getMeasuredNodes.

    - * - * @return the list of measured variableNodes. - */ - public List getMeasuredNodes() { - return this.bayesPm.getMeasuredNodes(); - } - - /** - *

    getVariableNames.

    - * - * @return a {@link List} object - */ - public List getVariableNames() { - List variableNames = new LinkedList<>(); - - for (int i = 0; i < getNumNodes(); i++) { - Node node = getNode(i); - variableNames.add(this.bayesPm.getVariable(node).getName()); - } - - return variableNames; - } - - /** - * Returns the number of columns in the specified node. - * - * @param nodeIndex the index of the node. - * @return the number of columns. - */ - public int getNumColumns(int nodeIndex) { - return this.probs[nodeIndex][0].length; - } - - /** - * Retrieves the number of rows in the specified node. - * - * @param nodeIndex the index of the node. - * @return the number of rows in the node. - */ - public int getNumRows(int nodeIndex) { - return this.probs[nodeIndex].length; - } - - /** - * Returns the number of parents for the given node. - * - * @param nodeIndex the index of the node. - * @return the number of parents. - */ - public int getNumParents(int nodeIndex) { - return this.parents[nodeIndex].length; - } - - /** - * Retrieves the parent of a node at the specified index. - * - * @param nodeIndex the index of the node. - * @param parentIndex the index of the parent. - * @return the parent of the node. - */ - public int getParent(int nodeIndex, int parentIndex) { - return this.parents[nodeIndex][parentIndex]; - } - - /** - * Retrieves the value of the parent dimension for a given node and parent index. - * - * @param nodeIndex the index of the node. - * @param parentIndex the index of the parent. - * @return the parent dimension value. - */ - public int getParentDim(int nodeIndex, int parentIndex) { - return this.parentDims[nodeIndex][parentIndex]; - } - - /** - * Returns a copy of the dimensions of the parent node at the specified index. - * - * @param nodeIndex the index of the node. - * @return an array containing the dimensions of the parent node. - */ - public int[] getParentDims(int nodeIndex) { - int[] dims = this.parentDims[nodeIndex]; - int[] copy = new int[dims.length]; - System.arraycopy(dims, 0, copy, 0, dims.length); - return copy; - } - - /** - * Returns an array containing the parents of the specified node. - * - * @param nodeIndex the index of the node. - * @return an array of integers representing the parents of the specified node. - */ - public int[] getParents(int nodeIndex) { - int[] nodeParents = this.parents[nodeIndex]; - int[] copy = new int[nodeParents.length]; - System.arraycopy(nodeParents, 0, copy, 0, nodeParents.length); - return copy; - } - - /** - * Returns an integer array containing the parent values for a given node index and row index. - * - * @param nodeIndex the index of the node. - * @param rowIndex the index of the row in question. - * @return an integer array containing the parent values. - */ - public int[] getParentValues(int nodeIndex, int rowIndex) { - int[] dims = getParentDims(nodeIndex); - int[] values = new int[dims.length]; - - for (int i = dims.length - 1; i >= 0; i--) { - values[i] = rowIndex % dims[i]; - rowIndex /= dims[i]; - } - - return values; - } - - /** - * Retrieves the value of the parent node at the specified row and column index. - * - * @param nodeIndex the index of the node. - * @param rowIndex the index of the row in question. - * @param colIndex the index of the column in question. - * @return the value of the parent node at the specified row and column index. - */ - public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { - return getParentValues(nodeIndex, rowIndex)[colIndex]; - } - - /** - * Returns the probability for a given node in the table. - * - * @param nodeIndex the index of the node in question. - * @param rowIndex the row in the table for this node which represents the combination of parent values in - * question. - * @param colIndex the column in the table for this node which represents the value of the node in question. - * @return the probability value for the given node. - */ - public double getProbability(int nodeIndex, int rowIndex, int colIndex) { - return this.probs[nodeIndex][rowIndex][colIndex]; - } - - /** - * Returns the row index corresponding to the given node index and combination of parent values. - * - * @param nodeIndex the index of the node in question. - * @param values the combination of parent values in question. - * @return the row index corresponding to the given node index and combination of parent values. - */ - public int getRowIndex(int nodeIndex, int[] values) { - int[] dim = getParentDims(nodeIndex); - int rowIndex = 0; - - for (int i = 0; i < dim.length; i++) { - rowIndex *= dim[i]; - rowIndex += values[i]; - } - - return rowIndex; - } - - /** - * Normalizes all rows in the tables associated with each of node in turn. - */ - public void normalizeAll() { - for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { - normalizeNode(nodeIndex); - } - } - - /** - * Normalizes the specified node by invoking the {@link #normalizeRow(int, int)} method on each row of the node. - * - * @param nodeIndex the index of the node to be normalized. - */ - public void normalizeNode(int nodeIndex) { - for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { - normalizeRow(nodeIndex, rowIndex); - } - } - - /** - * Normalizes the probabilities of a given row in a node. - * - * @param nodeIndex the index of the node in question. - * @param rowIndex the index of the row in question. - */ - public void normalizeRow(int nodeIndex, int rowIndex) { - int numColumns = getNumColumns(nodeIndex); - double total = 0.0; - - for (int colIndex = 0; colIndex < numColumns; colIndex++) { - total += getProbability(nodeIndex, rowIndex, colIndex); - } - - if (total != 0.0) { - for (int colIndex = 0; colIndex < numColumns; colIndex++) { - double probability - = getProbability(nodeIndex, rowIndex, colIndex); - double prob = probability / total; - setProbability(nodeIndex, rowIndex, colIndex, prob); - } - } else { - double prob = 1.0 / numColumns; - - for (int colIndex = 0; colIndex < numColumns; colIndex++) { - setProbability(nodeIndex, rowIndex, colIndex, prob); - } - } - } - - /** - * Sets the probability for the given node. The matrix row represent row index, the row in the table for this for - * node which represents the combination of parent values in question. of the CPT. The matrix column represent - * column index, the column in the table for this node which represents the value of the node in question. - * - * @param nodeIndex The index of the node. - * @param probMatrix The matrix of probabilities. - */ - @Override - public void setProbability(int nodeIndex, double[][] probMatrix) { - for (int i = 0; i < probMatrix.length; i++) { - System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); - } - } - - /** - * Sets the probability value for a specific node, row, and column in the probability table. - * - * @param nodeIndex the index of the node in question. - * @param rowIndex the row in the table for this node which represents the combination of parent values in - * question. - * @param colIndex the column in the table for this node which represents the value of the node in question. - * @param value the desired probability to be set. Must be between 0.0 and 1.0, or Double.NaN. - * @throws IllegalArgumentException if the column index is out of range for the given node, or if the probability - * value is not between 0.0 and 1.0 or Double.NaN. - */ - public void setProbability(int nodeIndex, int rowIndex, int colIndex, - double value) { - if (colIndex >= getNumColumns(nodeIndex)) { - throw new IllegalArgumentException("Column out of range: " - + colIndex + " >= " + getNumColumns(nodeIndex)); - } - - if (!(0.0 <= value && value <= 1.0) && !Double.isNaN(value)) { - throw new IllegalArgumentException("Probability value must be " - + "between 0.0 and 1.0 or Double.NaN."); - } - - this.probs[nodeIndex][rowIndex][colIndex] = value; - } - - /** - * Returns the corresponding node index in the given BayesIm based on the node index in this BayesIm. - * - * @param nodeIndex the index of the node in this BayesIm. - * @param otherBayesIm the BayesIm in which the node is to be found. - * @return the corresponding node index in the given BayesIm. - */ - public int getCorrespondingNodeIndex(int nodeIndex, BayesIm otherBayesIm) { - String nodeName = getNode(nodeIndex).getName(); - Node oldNode = otherBayesIm.getNode(nodeName); - return otherBayesIm.getNodeIndex(oldNode); - } - - /** - * Clears all values in the specified row of a table. - * - * @param nodeIndex the index of the node for the table that this row belongs to - * @param rowIndex the index of the row to be cleared - */ - public void clearRow(int nodeIndex, int rowIndex) { - for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { - setProbability(nodeIndex, rowIndex, colIndex, Double.NaN); - } - } - - /** - * Randomizes the values of a row in a table for a given node. - * - * @param nodeIndex the index of the node for the table that this row belongs to. - * @param rowIndex the index of the row to be randomized. - */ - public void randomizeRow(int nodeIndex, int rowIndex) { - int size = getNumColumns(nodeIndex); - this.probs[nodeIndex][rowIndex] = MlBayesImOld.getRandomWeights(size); - } - - /** - * Randomizes the incomplete rows in the specified node's table. - * - * @param nodeIndex the index of the node for the table whose incomplete rows are to be randomized - */ - public void randomizeIncompleteRows(int nodeIndex) { - for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { - if (isIncomplete(nodeIndex, rowIndex)) { - randomizeRow(nodeIndex, rowIndex); - } - } - } - - /** - * Randomizes the table for a given node. - * - * @param nodeIndex the index of the node for the table to be randomized - */ - public void randomizeTable(int nodeIndex) { - for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { - randomizeRow(nodeIndex, rowIndex); - } - } - - private int score(int nodeIndex) { - double[][] p = new double[getNumRows(nodeIndex)][getNumColumns(nodeIndex)]; - copy(this.probs[nodeIndex], p); - int num = 0; - - int numRows = getNumRows(nodeIndex); - - for (int r = 0; r < p.length; r++) { - for (int c = 0; c < p[0].length; c++) { - p[r][c] /= numRows; - } - } - - int[] parents = getParents(nodeIndex); - - for (int t = 0; t < parents.length; t++) { - int numParentValues = getParentDim(nodeIndex, t); - int numColumns = getNumColumns(nodeIndex); - - double[][] table = new double[numParentValues][numColumns]; - - for (int childCol = 0; childCol < numColumns; childCol++) { - for (int parentValue = 0; parentValue < numParentValues; parentValue++) { - for (int row = 0; row < numRows; row++) { - if (getParentValues(nodeIndex, row)[t] == parentValue) { - table[parentValue][childCol] += p[row][childCol]; - } - } - } - } - - final double N = 1000.0; - - for (int r = 0; r < table.length; r++) { - for (int c = 0; c < table[0].length; c++) { - table[r][c] *= N; - } - } - - double chisq = 0.0; - - for (int r = 0; r < table.length; r++) { - for (int c = 0; c < table[0].length; c++) { - double _sumRow = sumRow(table, r); - double _sumCol = sumCol(table, c); - double exp = (_sumRow / N) * (_sumCol / N) * N; - double obs = table[r][c]; - chisq += pow(obs - exp, 2) / exp; - } - } - - int dof = (table.length - 1) * (table[0].length - 1); - - ChiSquaredDistribution distribution = new ChiSquaredDistribution(dof); - double prob = 1 - distribution.cumulativeProbability(chisq); - - num += prob < 0.0001 ? 1 : 0; - } - -// return num == parents.length ? -score : 0; - return num; - } - - private double sumCol(double[][] marginals, int j) { - double sum = 0.0; - - for (double[] marginal : marginals) { - sum += marginal[j]; - } - - return sum; - } - - private double sumRow(double[][] marginals, int i) { - double sum = 0.0; - - for (int h = 0; h < marginals[i].length; h++) { - sum += marginals[i][h]; - } - - return sum; - } - - private void copy(double[][] a, double[][] b) { - for (int r = 0; r < a.length; r++) { - System.arraycopy(a[r], 0, b[r], 0, a[r].length); - } - } - - /** - * Clears the table by clearing all rows for the given node. - * - * @param nodeIndex The index of the node for the table to be cleared. - */ - public void clearTable(int nodeIndex) { - for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { - clearRow(nodeIndex, rowIndex); - } - } - - /** - * Checks if the specified row of a table is incomplete, i.e., if any of the columns have a NaN value. - * - * @param nodeIndex the index of the table node to check. - * @param rowIndex the index of the row to check. - * @return true if the row is incomplete, false otherwise. - */ - public boolean isIncomplete(int nodeIndex, int rowIndex) { - for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { - double p = getProbability(nodeIndex, rowIndex, colIndex); - - if (Double.isNaN(p)) { - return true; - } - } - - return false; - } - - /** - * Checks if the specified table has any incomplete rows. - * - * @param nodeIndex the index of the node for the table - * @return true if the table has any incomplete rows, false otherwise - */ - public boolean isIncomplete(int nodeIndex) { - for (int rowIndex = 0; rowIndex < getNumRows(nodeIndex); rowIndex++) { - if (isIncomplete(nodeIndex, rowIndex)) { - return true; - } - } - - return false; - } - - /** - * Simulates a sample with the given sample size. - * - * @param sampleSize the sample size. - * @param latentDataSaved a boolean - * @param tiers an array of {@link int} objects - * @return the simulated sample as a DataSet. - */ - public DataSet simulateData(int sampleSize, boolean latentDataSaved, int[] tiers) { - if (getBayesPm().getDag().isTimeLagModel()) { - return simulateTimeSeries(sampleSize); - } - - return simulateDataHelper(sampleSize, latentDataSaved, tiers); - } - - /** - * Simulates a data set. - * - * @param sampleSize The number of rows to simulate. - * @param latentDataSaved If set to true, latent variables are saved in the data set. - * @return The simulated data set. - * @throws IllegalArgumentException If the graph contains a directed cycle. - */ - public DataSet simulateData(int sampleSize, boolean latentDataSaved) { - if (getBayesPm().getDag().isTimeLagModel()) { - return simulateTimeSeries(sampleSize); - } - - // Get a tier ordering and convert it to an int array. - Graph graph = getBayesPm().getDag(); - - if (graph.paths().existsDirectedCycle()) { - throw new IllegalArgumentException("Graph must be acyclic to simulate from discrete Bayes net."); - } - - Paths paths = graph.paths(); - List initialOrder = graph.getNodes(); - List tierOrdering = paths.getValidOrder(initialOrder, true); - int[] tiers = new int[tierOrdering.size()]; - - for (int i = 0; i < tierOrdering.size(); i++) { - tiers[i] = getNodeIndex(tierOrdering.get(i)); - } - - return simulateDataHelper(sampleSize, latentDataSaved, tiers); - } - - /** - *

    simulateData.

    - * - * @param dataSet a {@link DataSet} object - * @param latentDataSaved a boolean - * @param tiers an array of {@link int} objects - * @return a {@link DataSet} object - */ - public DataSet simulateData(DataSet dataSet, boolean latentDataSaved, int[] tiers) { - return simulateDataHelper(dataSet, latentDataSaved, tiers); - } - - /** - * Simulates data for the given data set. - * - * @param dataSet The data set to simulate data for. - * @param latentDataSaved Indicates whether latent data should be saved during simulation. - * @return The modified data set after simulating the data. - */ - public DataSet simulateData(DataSet dataSet, boolean latentDataSaved) { - // Get a tier ordering and convert it to an int array. - Graph graph = getBayesPm().getDag(); - Paths paths = graph.paths(); - List initialOrder = graph.getNodes(); - List tierOrdering = paths.getValidOrder(initialOrder, true); - int[] tiers = new int[tierOrdering.size()]; - - for (int i = 0; i < tierOrdering.size(); i++) { - tiers[i] = getNodeIndex(tierOrdering.get(i)); - } - - return simulateDataHelper(dataSet, latentDataSaved, tiers); - } - - private DataSet simulateTimeSeries(int sampleSize) { - TimeLagGraph timeSeriesGraph = getBayesPm().getDag().getTimeLagGraph(); - - List variables = new ArrayList<>(); - - for (Node node : timeSeriesGraph.getLag0Nodes()) { - DiscreteVariable e = new DiscreteVariable(timeSeriesGraph.getNodeId(node).getName()); - e.setNodeType(node.getNodeType()); - variables.add(e); - } - - List lag0Nodes = timeSeriesGraph.getLag0Nodes(); - -// DataSet fullData = new ColtDataSet(sampleSize, variables); - DataSet fullData = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables); - - Graph contemporaneousDag = timeSeriesGraph.subgraph(lag0Nodes); - Paths paths = contemporaneousDag.paths(); - List initialOrder = contemporaneousDag.getNodes(); - List tierOrdering = paths.getValidOrder(initialOrder, true); - int[] tiers = new int[tierOrdering.size()]; - - for (int i = 0; i < tierOrdering.size(); i++) { - tiers[i] = getNodeIndex(tierOrdering.get(i)); - } - - // Construct the sample. - int[] combination = new int[tierOrdering.size()]; - - for (int i = 0; i < sampleSize; i++) { - int[] point = new int[this.nodes.length]; - - for (int nodeIndex : tiers) { - double cutoff = RandomUtil.getInstance().nextDouble(); - - for (int k = 0; k < getNumParents(nodeIndex); k++) { - combination[k] = point[getParent(nodeIndex, k)]; - } - - int rowIndex = getRowIndex(nodeIndex, combination); - double sum = 0.0; - - for (int k = 0; k < getNumColumns(nodeIndex); k++) { - double probability = getProbability(nodeIndex, rowIndex, k); - - if (Double.isNaN(probability)) { - throw new IllegalStateException("Some probability " - + "values in the BayesIm are not filled in; " - + "cannot simulate data."); - } - - sum += probability; - - if (sum >= cutoff) { - point[nodeIndex] = k; - break; - } - } - } - } - - return fullData; - } - - /** - * Simulates a sample with the given sample size. - * - * @param sampleSize the sample size. - * @return the simulated sample as a DataSet. - */ - private DataSet simulateDataHelper(int sampleSize, boolean latentDataSaved, int[] tiers) { - int numMeasured = 0; - int[] map = new int[this.nodes.length]; - List variables = new LinkedList<>(); - - for (int j = 0; j < this.nodes.length; j++) { - - int numCategories = this.bayesPm.getNumCategories(this.nodes[j]); - List categories = new LinkedList<>(); - - for (int k = 0; k < numCategories; k++) { - categories.add(this.bayesPm.getCategory(this.nodes[j], k)); - } - - DiscreteVariable var - = new DiscreteVariable(this.nodes[j].getName(), categories); - var.setNodeType(this.nodes[j].getNodeType()); - variables.add(var); - int index = ++numMeasured - 1; - map[index] = j; - } - - DataSet dataSet = new BoxDataSet(new VerticalIntDataBox(sampleSize, variables.size()), variables); - constructSample(sampleSize, dataSet, map, tiers); - - if (!latentDataSaved) { - dataSet = DataTransforms.restrictToMeasured(dataSet); - } - - return dataSet; - } - - /** - * Constructs a random sample using the given already allocated data set, to avoid allocating more memory. - */ - private DataSet simulateDataHelper(DataSet dataSet, boolean latentDataSaved, int[] tiers) { - if (dataSet.getNumColumns() != this.nodes.length) { - throw new IllegalArgumentException("When rewriting the old data set, " - + "number of variables in data set must equal number of variables " - + "in Bayes net."); - } - - int sampleSize = dataSet.getNumRows(); - - int numVars = 0; - int[] map = new int[this.nodes.length]; - List variables = new LinkedList<>(); - - for (int j = 0; j < this.nodes.length; j++) { - - int numCategories = this.bayesPm.getNumCategories(this.nodes[j]); - List categories = new LinkedList<>(); - - for (int k = 0; k < numCategories; k++) { - categories.add(this.bayesPm.getCategory(this.nodes[j], k)); - } - - DiscreteVariable var - = new DiscreteVariable(this.nodes[j].getName(), categories); - var.setNodeType(this.nodes[j].getNodeType()); - variables.add(var); - int index = ++numVars - 1; - map[index] = j; - } - - for (int i = 0; i < variables.size(); i++) { - Node node = dataSet.getVariable(i); - Node _node = variables.get(i); - dataSet.changeVariable(node, _node); - } - - constructSample(sampleSize, dataSet, map, tiers); - - if (latentDataSaved) { - return dataSet; - } else { - return DataTransforms.restrictToMeasured(dataSet); - } - } - - private void constructSample(int sampleSize, DataSet dataSet, int[] map, int[] tiers) { - -// //Do the simulation. -// class SimulationTask extends RecursiveTask { -// private int chunk; -// private int from; -// private int to; -// private int[] tiers; -// private DataSet dataSet; -// private int[] map; -// -// public SimulationTask(int chunk, int from, int to, int[] tiers, DataSet dataSet, int[] map) { -// this.chunk = chunk; -// this.from = from; -// this.to = to; -// this.tiers = tiers; -// this.dataSet = dataSet; -// this.map = map; -// } -// -// @Override -// protected Boolean compute() { -// if (to - from <= chunk) { -// RandomGenerator randomGenerator = new Well1024a(++seed[0]); -// -// for (int row = from; row < to; row++) { -// for (int t : tiers) { -// int[] parentValues = new int[parents[t].length]; -// -// for (int k = 0; k < parentValues.length; k++) { -// parentValues[k] = dataSet.getInt(row, parents[t][k]); -// } -// -// int rowIndex = getRowIndex(t, parentValues); -// double sum = 0.0; -// double r; -// -// r = randomGenerator.nextDouble(); -// -// for (int k = 0; k < getNumColumns(t); k++) { -// double probability = getProbability(t, rowIndex, k); -// sum += probability; -// -// if (sum >= r) { -// dataSet.setInt(row, map[t], k); -// break; -// } -// } -// } -// } -// -// return true; -// } else { -// int mid = (to + from) / 2; -// SimulationTask left = new SimulationTask(chunk, from, mid, tiers, dataSet, map); -// SimulationTask right = new SimulationTask(chunk, mid, to, tiers, dataSet, map); -// -// left.fork(); -// right.compute(); -// left.join(); -// -// return true; -// } -// } -// } -// -// int chunk = 25; -// -// ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); -// SimulationTask task = new SimulationTask(chunk, 0, sampleSize, tiers, dataSet, map); -// pool.invoke(task); - // Construct the sample. - for (int i = 0; i < sampleSize; i++) { - for (int t : tiers) { - int[] parentValues = new int[this.parents[t].length]; - - for (int k = 0; k < parentValues.length; k++) { - parentValues[k] = dataSet.getInt(i, this.parents[t][k]); - } - - int rowIndex = getRowIndex(t, parentValues); - double sum = 0.0; - - double r = RandomUtil.getInstance().nextDouble(); - - for (int k = 0; k < getNumColumns(t); k++) { - double probability = getProbability(t, rowIndex, k); - sum += probability; - - if (sum >= r) { - dataSet.setInt(i, map[t], k); - break; - } - } - } - } - -// System.out.println(dataSet); - } - - /** - * Determines whether the specified object is equal to this Bayes net. - * - * @param o the object to be compared to this Bayes net - * @return true if the specified object is equal to this Bayes net, false otherwise - */ - public boolean equals(Object o) { - if (o == this) { - return true; - } - - if (!(o instanceof BayesIm otherIm)) { - return false; - } - - if (getNumNodes() != otherIm.getNumNodes()) { - return false; - } - - for (int i = 0; i < getNumNodes(); i++) { - int otherIndex = otherIm.getCorrespondingNodeIndex(i, otherIm); - - if (otherIndex == -1) { - return false; - } - - if (getNumColumns(i) != otherIm.getNumColumns(otherIndex)) { - return false; - } - - if (getNumRows(i) != otherIm.getNumRows(otherIndex)) { - return false; - } - - for (int j = 0; j < getNumRows(i); j++) { - for (int k = 0; k < getNumColumns(i); k++) { - double prob = getProbability(i, j, k); - double otherProb = otherIm.getProbability(i, j, k); - - if (Double.isNaN(prob) && Double.isNaN(otherProb)) { - continue; - } - - if (abs(prob - otherProb) > MlBayesImOld.ALLOWABLE_DIFFERENCE) { - return false; - } - } - } - } - - return true; - } - - //=============================PRIVATE METHODS=======================// - - /** - * Prints out the probability table for each variable. - * - * @return a {@link String} object - */ - public String toString() { - StringBuilder buf = new StringBuilder(); - NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat(); - - for (int i = 0; i < getNumNodes(); i++) { - buf.append("\n\nNode: ").append(getNode(i)); - - if (getNumParents(i) == 0) { - buf.append("\n"); - } else { - buf.append("\n\n"); - for (int k = 0; k < getNumParents(i); k++) { - buf.append(getNode(getParent(i, k))).append("\t"); - } - } - - for (int j = 0; j < getNumRows(i); j++) { - buf.append("\n"); - for (int k = 0; k < getNumParents(i); k++) { - buf.append(getParentValue(i, j, k)); - - if (k < getNumParents(i) - 1) { - buf.append("\t"); - } - } - - if (getNumParents(i) > 0) { - buf.append("\t"); - } - - for (int k = 0; k < getNumColumns(i); k++) { - buf.append(nf.format(getProbability(i, j, k))).append("\t"); - } - } - } - - buf.append("\n"); - - return buf.toString(); - } - - /** - * This method initializes the probability tables for all of the nodes in the Bayes net. - * - * @see #initializeNode - * @see #randomizeRow - */ - private void initialize(BayesIm oldBayesIm, int initializationMethod) { - this.parents = new int[this.nodes.length][]; - this.parentDims = new int[this.nodes.length][]; - this.probs = new double[this.nodes.length][][]; - - for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { - initializeNode(nodeIndex, oldBayesIm, initializationMethod); - } - } - - /** - * This method initializes the node indicated. - */ - private void initializeNode(int nodeIndex, BayesIm oldBayesIm, - int initializationMethod) { - Node node = this.nodes[nodeIndex]; - - // Set up parents array. Should store the parents of - // each node as ints in a particular order. - Graph graph = getBayesPm().getDag(); - List parentList = new ArrayList<>(graph.getParents(node)); - int[] parentArray = new int[parentList.size()]; - - for (int i = 0; i < parentList.size(); i++) { - parentArray[i] = getNodeIndex(parentList.get(i)); - } - - // Sort parent array. - Arrays.sort(parentArray); - - this.parents[nodeIndex] = parentArray; - - // Setup dimensions array for parents. - int[] dims = new int[parentArray.length]; - - for (int i = 0; i < dims.length; i++) { - Node parNode = this.nodes[parentArray[i]]; - dims[i] = getBayesPm().getNumCategories(parNode); - } - - // Calculate dimensions of table. - int numRows = 1; - - for (int dim : dims) { - if (numRows > 1000000 /* Integer.MAX_VALUE / dim*/) { - throw new IllegalArgumentException( - "The number of rows in the " - + "conditional probability table for " - + this.nodes[nodeIndex] - + " is greater than 1,000,000 and cannot be " - + "represented."); - } - - numRows *= dim; - } - - int numCols = getBayesPm().getNumCategories(node); - - this.parentDims[nodeIndex] = dims; - this.probs[nodeIndex] = new double[numRows][numCols]; - - // Initialize each row. - if (initializationMethod == MlBayesImOld.RANDOM) { - randomizeTable(nodeIndex); - } else { - for (int rowIndex = 0; rowIndex < numRows; rowIndex++) { - if (oldBayesIm == null) { - overwriteRow(nodeIndex, rowIndex, initializationMethod); - } else { - retainOldRowIfPossible(nodeIndex, rowIndex, oldBayesIm, - initializationMethod); - } - } - } - } - - private void overwriteRow(int nodeIndex, int rowIndex, - int initializationMethod) { - if (initializationMethod == MlBayesImOld.RANDOM) { - randomizeRow(nodeIndex, rowIndex); - } else if (initializationMethod == MlBayesImOld.MANUAL) { - initializeRowAsUnknowns(nodeIndex, rowIndex); - } else { - throw new IllegalArgumentException("Unrecognized state."); - } - } - - private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { - int size = getNumColumns(nodeIndex); - double[] row = new double[size]; - Arrays.fill(row, Double.NaN); - this.probs[nodeIndex][rowIndex] = row; - } - - /** - * This method initializes the node indicated. - */ - private void retainOldRowIfPossible(int nodeIndex, int rowIndex, - BayesIm oldBayesIm, int initializationMethod) { - - int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); - - if (oldNodeIndex == -1) { - overwriteRow(nodeIndex, rowIndex, initializationMethod); - } else if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { - overwriteRow(nodeIndex, rowIndex, initializationMethod); -// } else if (parentsChanged(nodeIndex, this, oldBayesIm)) { -// overwriteRow(nodeIndex, rowIndex, initializationMethod); - } else { - int oldRowIndex = getUniqueCompatibleOldRow(nodeIndex, rowIndex, oldBayesIm); - - if (oldRowIndex >= 0) { - copyValuesFromOldToNew(oldNodeIndex, oldRowIndex, nodeIndex, - rowIndex, oldBayesIm); - } else { - overwriteRow(nodeIndex, rowIndex, initializationMethod); - } - } - } - - /** - * @return the unique rowIndex in the old BayesIm for the given node that is compatible with the given rowIndex in - * the new BayesIm for that node, if one exists. Otherwise, returns -1. A compatible rowIndex is one in which all - * the parents that the given node has in common between the old BayesIm and the new BayesIm are assigned the values - * they have in the new rowIndex. If a parent node is removed in the new BayesIm, there may be more than one such - * compatible rowIndex in the old BayesIm, in which case -1 is returned. Likewise, there may be no compatible rows, - * in which case -1 is returned. - */ - private int getUniqueCompatibleOldRow(int nodeIndex, int rowIndex, - BayesIm oldBayesIm) { - int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); - int oldNumParents = oldBayesIm.getNumParents(oldNodeIndex); - - int[] oldParentValues = new int[oldNumParents]; - Arrays.fill(oldParentValues, -1); - - int[] parentValues = getParentValues(nodeIndex, rowIndex); - - // Go through each parent of the node in the new BayesIm. - for (int i = 0; i < getNumParents(nodeIndex); i++) { - - // Get the index of the parent in the new graph and in the old - // graph. If it's no longer in the new graph, skip to the next - // parent. - int parentNodeIndex = getParent(nodeIndex, i); - int oldParentNodeIndex - = getCorrespondingNodeIndex(parentNodeIndex, oldBayesIm); - int oldParentIndex = -1; - - for (int j = 0; j < oldBayesIm.getNumParents(oldNodeIndex); j++) { - if (oldParentNodeIndex == oldBayesIm.getParent(oldNodeIndex, j)) { - oldParentIndex = j; - break; - } - } - - if (oldParentIndex == -1 - || oldParentIndex >= oldBayesIm.getNumParents(oldNodeIndex)) { - return -1; - } - - // Look up that value index for the new BayesIm for that parent. - // If it was a valid value index in the old BayesIm, record - // that value in oldParentValues. Otherwise return -1. - int newParentValue = parentValues[i]; - int oldParentDim - = oldBayesIm.getParentDim(oldNodeIndex, oldParentIndex); - - if (newParentValue < oldParentDim) { - oldParentValues[oldParentIndex] = newParentValue; - } else { - return -1; - } - } - -// // Go through each parent of the node in the new BayesIm. -// for (int i = 0; i < oldBayesIm.getNumParents(oldNodeIndex); i++) { -// -// // Get the index of the parent in the new graph and in the old -// // graph. If it's no longer in the new graph, skip to the next -// // parent. -// int oldParentNodeIndex = oldBayesIm.getParent(oldNodeIndex, i); -// int parentNodeIndex = -// oldBayesIm.getCorrespondingNodeIndex(oldParentNodeIndex, this); -// int parentIndex = -1; -// -// for (int j = 0; j < this.getNumParents(nodeIndex); j++) { -// if (parentNodeIndex == this.getParent(nodeIndex, j)) { -// parentIndex = j; -// break; -// } -// } -// -// if (parentIndex == -1 || -// parentIndex >= this.getNumParents(nodeIndex)) { -// continue; -// } -// -// // Look up that value index for the new BayesIm for that parent. -// // If it was a valid value index in the old BayesIm, record -// // that value in oldParentValues. Otherwise return -1. -// int parentValue = oldParentValues[i]; -// int parentDim = -// this.getParentDim(nodeIndex, parentIndex); -// -// if (parentValue < parentDim) { -// oldParentValues[parentIndex] = oldParentValue; -// } else { -// return -1; -// } -// } - // If there are any -1's in the combination at this point, return -1. - for (int oldParentValue : oldParentValues) { - if (oldParentValue == -1) { - return -1; - } - } - - // Otherwise, return the combination, which will be a row in the - // old BayesIm. - return oldBayesIm.getRowIndex(oldNodeIndex, oldParentValues); - } - - private void copyValuesFromOldToNew(int oldNodeIndex, int oldRowIndex, - int nodeIndex, int rowIndex, BayesIm oldBayesIm) { - if (getNumColumns(nodeIndex) != oldBayesIm.getNumColumns(oldNodeIndex)) { - throw new IllegalArgumentException("It's only possible to copy " - + "one row of probability values to another in a Bayes IM " - + "if the number of columns in the table are the same."); - } - - for (int colIndex = 0; colIndex < getNumColumns(nodeIndex); colIndex++) { - double prob = oldBayesIm.getProbability(oldNodeIndex, oldRowIndex, - colIndex); - setProbability(nodeIndex, rowIndex, colIndex, prob); - } - } - - /** - * Adds semantic checks to the default deserialization method. This method must have the standard signature for a - * readObject method, and the body of the method must begin with "s.defaultReadObject();". Other than that, any - * semantic checks can be specified and do not need to stay the same from version to version. A readObject method of - * this form may be added to any class, even if Tetrad sessions were previously saved out using a version of the - * class that didn't include it. (That's what the "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for - * help. - * - * @param s The object input stream. - * @throws IOException If any. - * @throws ClassNotFoundException If any. * - */ - @Serial - private void readObject(ObjectInputStream s) - throws IOException, ClassNotFoundException { - s.defaultReadObject(); - - if (this.bayesPm == null) { - throw new NullPointerException(); - } - - if (this.nodes == null) { - throw new NullPointerException(); - } - - if (this.parents == null) { - throw new NullPointerException(); - } - - if (this.parentDims == null) { - throw new NullPointerException(); - } - - if (this.probs == null) { - throw new NullPointerException(); - } - } -} From 86447cbbbf909ecf3f665fd122dcf12ea83578f0 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Thu, 28 Mar 2024 12:42:18 -0400 Subject: [PATCH 17/23] Refactor getVariableNames method in MlBayesIm This commit simplifies the getVariableNames method in the MlBayesIm class. It replaces the for loop that used indices with a more straightforward enhanced for loop that directly iterates over the nodes. Consequently, this makes the code cleaner and easier to understand without affecting functionality. --- .../src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index c86c463aa1..c49473f068 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -349,11 +349,11 @@ public List getMeasuredNodes() { * @return a {@link java.util.List} object */ public List getVariableNames() { + List nodes = getVariables(); List variableNames = new LinkedList<>(); - for (int i = 0; i < getNumNodes(); i++) { - Node node = getNode(i); - variableNames.add(this.bayesPm.getVariable(node).getName()); + for (Node node : nodes) { + variableNames.add(node.getName()); } return variableNames; From ed5622c20eec99346ec11bc11861a9eafb64b2d4 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sat, 30 Mar 2024 13:02:48 -0400 Subject: [PATCH 18/23] Add null check in existsParameterizedConstructor method Null check has been added in the existsParameterizedConstructor method in SessionNode class. If the model class is null, then the method will return false indicating that no constructor exists. This added validation is intended to prevent potential issues related to null object handling. --- .../main/java/edu/cmu/tetradapp/session/SessionNode.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java index 6b17851e62..d560f21321 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/session/SessionNode.java @@ -1173,6 +1173,13 @@ public int compareTo(Node node) { * @return a boolean */ public boolean existsParameterizedConstructor(Class modelClass) { + if (modelClass == null) { + + // If the model class is null, then there is no constructor, so display a dialog to the users by + // return false here. + return false; + } + Object param = getParam(modelClass); List parentModels = listParentModels(); parentModels.add(param); From e9f274c62f492e666864e37e0fa12b111660cf1b Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sat, 30 Mar 2024 13:47:28 -0400 Subject: [PATCH 19/23] Refactor conditional checks and parameter names Refactored the conditional check in `RandomGraph.java` to use `isEmpty()` function instead of `size() == 0`. Adjusted parameter names in `BayesPmWrapper.java` for improved clarity. Corrected a typographical error in `RandomGraphEditor.java` from "graphRandomFoward" to "graphRandomForward". These changes improve code readability and maintainability. --- .../main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java | 2 +- .../src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java | 4 ++-- .../src/main/java/edu/cmu/tetrad/graph/RandomGraph.java | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java index 712da8f520..b10d03f8b2 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/RandomGraphEditor.java @@ -481,7 +481,7 @@ public void setEnabled(boolean enabled) { * @return a boolean */ public boolean isRandomForward() { - return this.parameters.getBoolean("graphRandomFoward", true); + return this.parameters.getBoolean("graphRandomForward", true); } private void setRandomForward(boolean randomFoward) { diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java index 2c7ee8d4b3..6511a3bcc3 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesPmWrapper.java @@ -206,8 +206,8 @@ public BayesPmWrapper(GraphWrapper graphWrapper, Parameters params) { lowerBound = upperBound = 3; setBayesPm(graph, lowerBound, upperBound); } else if (params.getString("bayesPmInitializationMode", "range").equals("range")) { - lowerBound = params.getInt("minCategories", 2); - upperBound = params.getInt("maxCategories", 4); + lowerBound = params.getInt("lowerBoundNumVals", 2); + upperBound = params.getInt("upperBoundNumVals", 4); setBayesPm(graph, lowerBound, upperBound); } else { throw new IllegalStateException("Unrecognized type."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java index 1d0322d5ba..c056d6f148 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/graph/RandomGraph.java @@ -228,7 +228,7 @@ public static Graph randomGraphRandomForwardEdges(List nodes, int numLaten * negative or exceeds the number of nodes */ public static Graph randomGraphRandomForwardEdges(List nodes, int numLatentConfounders, int numEdges, int maxDegree, int maxIndegree, int maxOutdegree, boolean connected, boolean layoutAsCircle) { - if (nodes.size() == 0) { + if (nodes.isEmpty()) { throw new IllegalArgumentException("NumNodes most be > 0"); } From c10116ccca114452fccbfb0cf73feec9c16f41f7 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sun, 31 Mar 2024 09:04:09 -0400 Subject: [PATCH 20/23] Fix typographical error in Statistic.java comment Corrected a typographical error in a comment within the `Statistic.java` file. This edit improves the readability of the codebase for future developers by ensuring comments accurately describe the intended functionality and behavior of the code. --- .../java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java index 147b4c8d73..248d9ad374 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/Statistic.java @@ -43,7 +43,7 @@ public interface Statistic extends Serializable { /** * Returns a mapping of the statistic to the interval [0, 1], with higher being better. This is used for a - * calculation of a utility for an algorithm.If the statistic is already between 0 and 1, you can just return the + * calculation of a utility for an algorithm. If the statistic is already between 0 and 1, you can just return the * statistic. * * @param value The value of the statistic. From 2cd7ebb46b022cc4c2ae6458b0ef8c8c380b66e8 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sun, 31 Mar 2024 10:43:13 -0400 Subject: [PATCH 21/23] Refactor initialization methods in Bayes' methods Updated the usage of initialization methods in various Bayes related classes, specifically changing from "MlBayesIm.RANDOM" and "MlBayesIm.MANUAL" to "MlBayesIm.InitializationMethod.RANDOM" and "MlBayesIm.InitializationMethod.MANUAL". Created new files for CptMapProbs and CptMapCounts as part of the refactoring. This update helps make the code more readable and consistent across different parts of the application. --- .../cmu/tetradapp/model/BayesImWrapper.java | 15 +- .../simulation/BayesNetSimulation.java | 4 +- .../ConditionalGaussianSimulation.java | 2 +- .../cmu/tetrad/bayes/ApproximateUpdater.java | 2 +- .../cmu/tetrad/bayes/CptInvariantUpdater.java | 2 +- .../java/edu/cmu/tetrad/bayes/CptMap.java | 158 +++-------------- .../edu/cmu/tetrad/bayes/CptMapCounts.java | 159 +++++++++++++++++ .../edu/cmu/tetrad/bayes/CptMapProbs.java | 160 ++++++++++++++++++ .../cmu/tetrad/bayes/EmBayesEstimator.java | 2 +- .../edu/cmu/tetrad/bayes/Identifiability.java | 6 +- .../cmu/tetrad/bayes/JunctionTreeUpdater.java | 2 +- .../cmu/tetrad/bayes/MlBayesEstimatorOld.java | 136 +++++++++++++++ .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 68 ++++---- .../tetrad/bayes/RowSummingExactUpdater.java | 2 +- .../tetrad/simulation/HsimRobustCompare.java | 2 +- .../tetrad/study/performance/Comparison.java | 2 +- .../tetrad/study/performance/Comparison2.java | 2 +- .../study/performance/PerformanceTests.java | 4 +- .../pitt/isp/sverchkov/data/AdTreeTest.java | 2 +- .../test/TestBayesDiscreteBicScorer.java | 2 +- .../java/edu/cmu/tetrad/test/TestBayesIm.java | 18 +- .../edu/cmu/tetrad/test/TestBayesXml.java | 4 +- .../tetrad/test/TestCellProbabilities.java | 2 +- .../tetrad/test/TestCptInvariantUpdater.java | 4 +- .../cmu/tetrad/test/TestDataSetCellProbs.java | 2 +- .../cmu/tetrad/test/TestDiscreteProbs.java | 2 +- .../java/edu/cmu/tetrad/test/TestFges.java | 2 +- .../java/edu/cmu/tetrad/test/TestGFci.java | 2 +- .../edu/cmu/tetrad/test/TestHistogram.java | 2 +- .../java/edu/cmu/tetrad/test/TestRfciBsc.java | 2 +- .../tetrad/test/TestRowSummingUpdater.java | 4 +- .../cmu/tetrad/test/TestUpdatedBayesIm.java | 2 +- 32 files changed, 559 insertions(+), 219 deletions(-) create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java create mode 100644 tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java index 08aa40b5ee..f8f0d8b793 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/BayesImWrapper.java @@ -36,6 +36,7 @@ import java.io.Serial; import java.util.ArrayList; import java.util.List; +import java.util.function.IntBinaryOperator; /** * Wraps a Bayes Pm for use in the Tetrad application. @@ -90,11 +91,11 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, BayesImWrapper oldBayesImwr BayesIm oldBayesIm = oldBayesImwrapper.getBayesIm(); if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) { - setBayesIm(bayesPm, oldBayesIm, MlBayesIm.MANUAL); + setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.MANUAL); } else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) { - setBayesIm(bayesPm, oldBayesIm, MlBayesIm.RANDOM); + setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.RANDOM); } else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) { - setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM)); + setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM)); } } @@ -193,9 +194,9 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, Parameters params) { if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) { setBayesIm(new MlBayesIm(bayesPm)); } else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) { - setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM)); + setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM)); } else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) { - setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM)); + setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM)); } } @@ -342,9 +343,9 @@ public String getModelSourceName() { //============================== private methods ============================// - private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int manual) { + private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, MlBayesIm.InitializationMethod initializationMethod) { this.bayesIms = new ArrayList<>(); - this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, manual)); + this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, initializationMethod)); } /** diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java index 46c1831baa..7ce861dbaf 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/BayesNetSimulation.java @@ -238,9 +238,9 @@ private DataSet simulate(Graph graph, Parameters parameters) { int minCategories = parameters.getInt(Params.MIN_CATEGORIES); int maxCategories = parameters.getInt(Params.MAX_CATEGORIES); pm = new BayesPm(graph, minCategories, maxCategories); - im = new MlBayesIm(pm, MlBayesIm.RANDOM); + im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); } else { - im = new MlBayesIm(pm, MlBayesIm.RANDOM); + im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); this.im = im; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java index 0af14dfd20..676915f1c8 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/simulation/ConditionalGaussianSimulation.java @@ -320,7 +320,7 @@ private DataSet simulate(Graph G, Parameters parameters) { } BayesPm bayesPm = new BayesPm(AG); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); SemPm semPm = new SemPm(XG); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java index 479746f5f0..da17e8c887 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/ApproximateUpdater.java @@ -344,7 +344,7 @@ private void doUpdate() { } private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java index 171f92d69e..a47ca9d293 100755 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptInvariantUpdater.java @@ -266,7 +266,7 @@ public String toString() { //==============================PRIVATE METHODS=======================// private BayesIm createdManipulatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createManipulatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java index ab3cc3a0e8..d8d271c5e2 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java @@ -1,155 +1,41 @@ package edu.cmu.tetrad.bayes; import edu.cmu.tetrad.util.TetradSerializable; -import edu.cmu.tetrad.util.Vector; - -import java.io.Serial; -import java.util.HashMap; -import java.util.Map; /** - * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique - * integer index for a particular node to the probability of that node taking on that value, where NaN's are not - * stored. + * An interface representing a map of probabilities or counts for nodes in a Bayesian network. Implementations of this + * interface should provide methods to get the probability or count for a node at a given row and column, as well as + * methods to retrieve the number of rows and columns in the map. + *

    + * This interface extends the TetradSerializable interface, indicating that implementations should be serializable and + * follow certain guidelines for compatibility across different versions of Tetrad. + * + * @author josephramsey + * @see CptMapProbs + * @see CptMapCounts */ -public class CptMap implements TetradSerializable { - @Serial - private static final long serialVersionUID = 23L; - /** - * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of - * that node taking on that value, where NaN's are not stored. - */ - private final Map map = new HashMap<>(); - /** - * The number of rows in the table. - */ - private final int numRows; - /** - * The number of columns in the table. - */ - private final int numColumns; - - /** - * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of - * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain - * number of rows and a certain number of columns in the table. - * - * @param numRows the number of rows in the table - * @param numColumns the number of columns in the table - */ - public CptMap(int numRows, int numColumns) { - if (numRows < 1 || numColumns < 1) { - throw new IllegalArgumentException("Number of rows and columns must be at least 1."); - } - - this.numRows = numRows; - this.numColumns = numColumns; - } - - /** - * Constructs a new probability map based on the given 2-dimensional array. - * - * @param probMatrix the 2-dimensional array representing the probability matrix - * @throws IllegalArgumentException if the number of columns in any row is different - */ - public CptMap(double[][] probMatrix) { - if (probMatrix == null || probMatrix.length == 0 || probMatrix[0].length == 0) { - throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); - } - - numRows = probMatrix.length; - numColumns = probMatrix[0].length; - - for (int i = 0; i < numRows; i++) { - if (probMatrix[i].length != numColumns) { - throw new IllegalArgumentException("All rows must have the same number of columns."); - } - } - - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numColumns; j++) { - map.put(i * numColumns + j, probMatrix[i][j]); - } - } - } +public interface CptMap extends TetradSerializable { /** - * Returns the probability of the node taking on the value specified by the given row and column. + * Retrieves the value at the specified row and column in the CptMap. * - * @param row the row of the node - * @param column the column of the node - * @return the probability of the node taking on the value specified by the given row and column + * @param row the row index of the value to retrieve. + * @param column the column index of the value to retrieve. + * @return the value at the specified row and column in the CptMap. */ - public double get(int row, int column) { - if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { - throw new IllegalArgumentException("Row and column must be within bounds."); - } - - int key = row * numColumns + column; - - if (!map.containsKey(key)) { - return Double.NaN; - } - - return map.get(key); - } + double get(int row, int column); /** - * Sets the probability of the node taking on the value specified by the given row and column to the given value. + * Retrieves the number of rows in the CptMap. * - * @param row the row of the node - * @param column the column of the node - * @param value the probability of the node taking on the value specified by the given row and column (NaN to - * remove the value) + * @return the number of rows in the CptMap. */ - public void set(int row, int column, double value) { - if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { - throw new IllegalArgumentException("Row and column must be within bounds."); - } - - int key = row * numColumns + column; - - if (Double.isNaN(value)) { - map.remove(key); - return; - } - - map.put(key, value); - } + int getNumRows(); /** - * Returns the number of rows in the probability map. + * Retrieves the number of columns in the CptMap. * - * @return the number of rows in the probability map. + * @return the number of columns in the CptMap. */ - public int getNumRows() { - return numRows; - } - - /** - * Returns the number of columns in the probability map. - * - * @return the number of columns in the probability map. - */ - public int getNumColumns() { - return numColumns; - } - - /** - * Assigns the values in the provided vector to a specific row in the probability map. - * - * @param rowIndex the index of the row to be assigned - * @param vector the vector containing the values to be assigned to the row - * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the - * probability map - */ - public void assignRow(int rowIndex, Vector vector) { - if (vector.size() != numColumns) { - throw new IllegalArgumentException("Vector must have the same number of columns as the probability map."); - } - - for (int i = 0; i < numColumns; i++) { - set(rowIndex, i, vector.get(i)); - } - } + int getNumColumns(); } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java new file mode 100644 index 0000000000..ec5f81f431 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java @@ -0,0 +1,159 @@ +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.data.DataSet; + +import java.io.Serial; +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique + * integer index for a particular node to the cell count for that node, where 0's are not stored. Row counts are also + * stored, so that the probability of a cell can be calculated. A prior cell count of 0 is assumed for all cells, + * but this may be set by the user to any non-negative count. (A prior count of 0 is equivalent to a maximum + * likelihood estimate.) + * + * @author josephramsey + */ +public class CptMapCounts implements CptMap { + @Serial + private static final long serialVersionUID = 23L; + /** + * Constructs a new count map, a map from a unique integer index for a particular node to the count for that value, + * where 0's are not stored. + */ + private final Map cellCounts = new HashMap<>(); + /** + * Constructs a new row count map, a map from a unique integer index for a particular node to the count for that + * row, where 0's are not stored. + */ + private final Map rowCounts = new HashMap<>(); + /** + * The number of rows in the table. + */ + private final int numRows; + /** + * The number of columns in the table. + */ + private final int numColumns; + /** + * The prior count for all cells. + */ + private int priorCount = 0; + + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain + * number of rows and a certain number of columns in the table. + * + * @param numRows the number of rows in the table + * @param numColumns the number of columns in the table + */ + public CptMapCounts(int numRows, int numColumns) { + if (numRows < 1 || numColumns < 1) { + throw new IllegalArgumentException("Number of rows and columns must be at least 1."); + } + + this.numRows = numRows; + this.numColumns = numColumns; + } + + /** + * Constructs a new CptMap based on counts from a given dataset. + * + * @param data the DataSet object representing the probability matrix + * @throws IllegalArgumentException if the data set is null or not discrete + */ + public CptMapCounts(DataSet data) { + if (data == null) { + throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); + } + + if (!data.isDiscrete()) { + throw new IllegalArgumentException("Data set must be discrete."); + + } + + numRows = data.getNumRows(); + numColumns = data.getNumColumns(); + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + int key = i * numColumns + j; + + if (data.getInt(i, j) == -1) { + continue; + } + + if (data.getInt(i, j) == 0) { + continue; + } + + if (!cellCounts.containsKey(key)) { + cellCounts.put(key, 1); + } else { + cellCounts.put(key, cellCounts.get(key) + 1); + } + + if (!rowCounts.containsKey(i)) { + rowCounts.put(i, 1); + } else { + rowCounts.put(i, rowCounts.get(i) + 1); + } + } + } + } + + /** + * Returns the probability of the node taking on the value specified by the given row and column. + * + * @param row the row of the node + * @param column the column of the node + * @return the probability of the node taking on the value specified by the given row and column + */ + @Override + public double get(int row, int column) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + int rowCount = rowCounts.getOrDefault(row, 0); + int cellCount = cellCounts.getOrDefault(key, 0); + return cellCount / (double) rowCount; + } + + /** + * Returns the number of rows in the probability map. + * + * @return the number of rows in the probability map. + */ + @Override + public int getNumRows() { + return numRows; + } + + /** + * Returns the number of columns in the probability map. + * + * @return the number of columns in the probability map. + */ + @Override + public int getNumColumns() { + return numColumns; + } + + /** + * The prior count for all cells. + */ + public int getPriorCount() { + return priorCount; + } + + /** + * The prior count for all cells. + */ + public void setPriorCount(int priorCount) { + this.priorCount = priorCount; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java new file mode 100644 index 0000000000..f2ccf95c53 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapProbs.java @@ -0,0 +1,160 @@ +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.util.TetradSerializable; +import edu.cmu.tetrad.util.Vector; + +import java.io.Serial; +import java.util.HashMap; +import java.util.Map; + +/** + * Represents a conditional probability table (CPT) in a Bayes net. This represents the CPT as a map from a unique + * integer index for a particular node to the probability of that node taking on that value, where NaN's are not + * stored. + * + * @author josephramsey + */ +public class CptMapProbs implements CptMap { + @Serial + private static final long serialVersionUID = 23L; + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. + */ + private final Map map = new HashMap<>(); + /** + * The number of rows in the table. + */ + private final int numRows; + /** + * The number of columns in the table. + */ + private final int numColumns; + + /** + * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of + * that node taking on that value, where NaN's are not stored. This probability map assumes that there is a certain + * number of rows and a certain number of columns in the table. + * + * @param numRows the number of rows in the table + * @param numColumns the number of columns in the table + */ + public CptMapProbs(int numRows, int numColumns) { + if (numRows < 1 || numColumns < 1) { + throw new IllegalArgumentException("Number of rows and columns must be at least 1."); + } + + this.numRows = numRows; + this.numColumns = numColumns; + } + + /** + * Constructs a new probability map based on the given 2-dimensional array. + * + * @param probMatrix the 2-dimensional array representing the probability matrix + * @throws IllegalArgumentException if the number of columns in any row is different + */ + public CptMapProbs(double[][] probMatrix) { + if (probMatrix == null || probMatrix.length == 0 || probMatrix[0].length == 0) { + throw new IllegalArgumentException("Probability matrix must have at least one row and one column."); + } + + numRows = probMatrix.length; + numColumns = probMatrix[0].length; + + for (int i = 0; i < numRows; i++) { + if (probMatrix[i].length != numColumns) { + throw new IllegalArgumentException("All rows must have the same number of columns."); + } + } + + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numColumns; j++) { + map.put(i * numColumns + j, probMatrix[i][j]); + } + } + } + + /** + * Returns the probability of the node taking on the value specified by the given row and column. + * + * @param row the row of the node + * @param column the column of the node + * @return the probability of the node taking on the value specified by the given row and column + */ + @Override + public double get(int row, int column) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + + if (!map.containsKey(key)) { + return Double.NaN; + } + + return map.get(key); + } + + /** + * Sets the probability of the node taking on the value specified by the given row and column to the given value. + * + * @param row the row of the node + * @param column the column of the node + * @param value the probability of the node taking on the value specified by the given row and column (NaN to + * remove the value) + */ + public void set(int row, int column, double value) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + + if (Double.isNaN(value)) { + map.remove(key); + return; + } + + map.put(key, value); + } + + /** + * Returns the number of rows in the probability map. + * + * @return the number of rows in the probability map. + */ + @Override + public int getNumRows() { + return numRows; + } + + /** + * Returns the number of columns in the probability map. + * + * @return the number of columns in the probability map. + */ + @Override + public int getNumColumns() { + return numColumns; + } + + /** + * Assigns the values in the provided vector to a specific row in the probability map. + * + * @param rowIndex the index of the row to be assigned + * @param vector the vector containing the values to be assigned to the row + * @throws IllegalArgumentException if the size of the vector is not equal to the number of columns in the + * probability map + */ + public void assignRow(int rowIndex, Vector vector) { + if (vector.size() != numColumns) { + throw new IllegalArgumentException("Vector must have the same number of columns as the probability map."); + } + + for (int i = 0; i < numColumns; i++) { + set(rowIndex, i, vector.get(i)); + } + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java index 9a915923f9..d29e0eacb7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/EmBayesEstimator.java @@ -490,7 +490,7 @@ private void estimateIM(BayesPm bayesPm, DataSet dataSet) { BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // Create a new Bayes IM to store the estimated values. - this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int numNodes = this.estimatedIm.getNumNodes(); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java index 4d709ebf6b..f57f06348f 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/Identifiability.java @@ -441,7 +441,7 @@ public double getJointMarginal(int[] sVariables, int[] sValues) { Dag gD = new Dag(this.bayesIm.getDag().subgraph(dNodes)); BayesPm bayesPmD = new BayesPm(gD, this.bayesIm.getBayesPm()); - BayesIm bayesImD = new MlBayesIm(bayesPmD, this.bayesIm, MlBayesIm.RANDOM); + BayesIm bayesImD = new MlBayesIm(bayesPmD, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); if (this.debug) { System.out.println("------ bayeIm.getDag() -------------"); @@ -919,7 +919,7 @@ else if (nodesA.containsAll(nodesT)) { // construct an IM with the dag graphA BayesPm bayesPmA = new BayesPm(graphA, this.bayesIm.getBayesPm()); - BayesIm bayesImA = new MlBayesIm(bayesPmA, this.bayesIm, MlBayesIm.RANDOM); + BayesIm bayesImA = new MlBayesIm(bayesPmA, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); // get c-components of graphA int[] cComponentsA = getCComponents(bayesImA); @@ -967,7 +967,7 @@ else if (nodesA.containsAll(nodesT)) { ///////////////////////////////////////////////////////////////// private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java index d1911649f6..be22787be6 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/JunctionTreeUpdater.java @@ -300,7 +300,7 @@ private void updateAll() { } private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java new file mode 100644 index 0000000000..0717b628a2 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimatorOld.java @@ -0,0 +1,136 @@ +/////////////////////////////////////////////////////////////////////////////// +// For information as to what this class does, see the Javadoc, below. // +// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, // +// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard // +// Scheines, Joseph Ramsey, and Clark Glymour. // +// // +// This program is free software; you can redistribute it and/or modify // +// it under the terms of the GNU General Public License as published by // +// the Free Software Foundation; either version 2 of the License, or // +// (at your option) any later version. // +// // +// This program is distributed in the hope that it will be useful, // +// but WITHOUT ANY WARRANTY; without even the implied warranty of // +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // +// GNU General Public License for more details. // +// // +// You should have received a copy of the GNU General Public License // +// along with this program; if not, write to the Free Software // +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA // +/////////////////////////////////////////////////////////////////////////////// + +package edu.cmu.tetrad.bayes; + +import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.graph.Node; + +import java.util.List; + +/** + * Estimates parameters of the given Bayes net from the given data using maximum likelihood method. + * + * @author Shane Harwood, Joseph Ramsey + * @version $Id: $Id + */ +public final class MlBayesEstimatorOld { + + /** + *

    Constructor for MlBayesEstimator.

    + */ + public MlBayesEstimatorOld() { + + } + + /** + * 33 Estimates a Bayes IM using the variables, graph, and parameters in the given Bayes PM and the data columns in + * the given data set. Each variable in the given Bayes PM must be equal to a variable in the given data set. + * + * @param bayesPm a {@link BayesPm} object + * @param dataSet a {@link DataSet} object + * @return a {@link BayesIm} object + */ + public BayesIm estimate(BayesPm bayesPm, DataSet dataSet) { + if (bayesPm == null) { + throw new NullPointerException(); + } + + if (dataSet == null) { + throw new NullPointerException(); + } + +// if (DataUtils.containsMissingValue(dataSet)) { +// throw new IllegalArgumentException("Please remove or impute missing values."); +// } + + // Make sure all of the variables in the PM are in the data set; + // otherwise, estimation is impossible. + BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); + + // Create a new Bayes IM to store the estimated values. + BayesIm estimatedIm = new MlBayesIm(bayesPm); + + // Create a subset of the data set with the variables of the IM, in + // the order of the IM. + List variables = estimatedIm.getVariables(); + DataSet columnDataSet2 = dataSet.subsetColumns(variables); + DiscreteProbs discreteProbs = new DataSetProbs(columnDataSet2); + + // We will use the same estimation methods as the updaters, to ensure + // compatibility. + Proposition assertion = Proposition.tautology(estimatedIm); + Proposition condition = Proposition.tautology(estimatedIm); + Evidence evidence2 = Evidence.tautology(estimatedIm); + + int numNodes = estimatedIm.getNumNodes(); + + for (int node = 0; node < numNodes; node++) { + int numRows = estimatedIm.getNumRows(node); + int numCols = estimatedIm.getNumColumns(node); + int[] parents = estimatedIm.getParents(node); + + for (int row = 0; row < numRows; row++) { + int[] parentValues = estimatedIm.getParentValues(node, row); + + for (int col = 0; col < numCols; col++) { + + // Remove values from the proposition in various ways; if + // a combination exists in the end, calculate a conditional + // probability. + assertion.setToTautology(); + condition.setToTautology(); + + for (int i = 0; i < numNodes; i++) { + for (int j = 0; j < evidence2.getNumCategories(i); j++) { + if (!evidence2.getProposition().isAllowed(i, j)) { + condition.removeCategory(i, j); + } + } + } + + assertion.disallowComplement(node, col); + + for (int k = 0; k < parents.length; k++) { + condition.disallowComplement(parents[k], parentValues[k]); + } + + if (condition.existsCombination()) { + double p = discreteProbs.getConditionalProb(assertion, condition); +// if (Double.isNaN(p)) p = 1.0 / numCols; + estimatedIm.setProbability(node, row, col, p); + } else { + estimatedIm.setProbability(node, row, col, Double.NaN); + } + } + } + } + +// System.out.println(estimatedIm); + + return estimatedIm; + } +} + + + + + diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index c49473f068..b662778ded 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -67,14 +67,12 @@ * @version $Id: $Id */ public final class MlBayesIm implements BayesIm { - /** - * Inidicates that new rows in this BayesIm should be initialized as unknowns, forcing them to be specified - * manually. This is the default. - */ - public static final int MANUAL = 0; - /** - * Indicates that new rows in this BayesIm should be initialized randomly. - */ + private enum CptMapType { + COUNT_MAP, PROB_MAP + } + public enum InitializationMethod { + MANUAL, RANDOM + } public static final int RANDOM = 1; @Serial private static final long serialVersionUID = 23L; @@ -98,7 +96,7 @@ public final class MlBayesIm implements BayesIm { * A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used. * The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility. */ - boolean useCptMaps = true; + private CptMapType cptMapType = null; /** * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', * and order in subarrays is important. @@ -125,7 +123,7 @@ public final class MlBayesIm implements BayesIm { * The array of CPT maps for each node. The index of the node corresponds to the index of the probability map in * this array. Replaces the probs array. */ - private CptMap[] probMatrices; + private CptMapProbs[] probMatrices; /** * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). @@ -135,7 +133,7 @@ public final class MlBayesIm implements BayesIm { * contained in the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException { - this(bayesPm, null, MlBayesIm.MANUAL); + this(bayesPm, null, InitializationMethod.MANUAL); } /** @@ -148,7 +146,7 @@ public MlBayesIm(BayesPm bayesPm) throws IllegalArgumentException { * @throws java.lang.IllegalArgumentException if the array of nodes provided is not a permutation of the nodes * contained in the bayes parametric model provided. */ - public MlBayesIm(BayesPm bayesPm, int initializationMethod) + public MlBayesIm(BayesPm bayesPm, InitializationMethod initializationMethod) throws IllegalArgumentException { this(bayesPm, null, initializationMethod); } @@ -167,7 +165,7 @@ public MlBayesIm(BayesPm bayesPm, int initializationMethod) * contained in the bayes parametric model provided. */ public MlBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, - int initializationMethod) throws IllegalArgumentException { + InitializationMethod initializationMethod) throws IllegalArgumentException { if (bayesPm == null) { throw new NullPointerException("BayesPm must not be null."); } @@ -207,7 +205,7 @@ public MlBayesIm(BayesIm bayesIm) throws IllegalArgumentException { } // Copy all the old values over. - initialize(bayesIm, MlBayesIm.MANUAL); + initialize(bayesIm, InitializationMethod.MANUAL); } /** @@ -366,7 +364,7 @@ public List getVariableNames() { * @return the number of columns. */ public int getNumColumns(int nodeIndex) { - if (useCptMaps) { + if (cptMapType == CptMapType.PROB_MAP) { return probMatrices[nodeIndex].getNumColumns(); } else { return this.probs[nodeIndex][0].length; @@ -380,7 +378,7 @@ public int getNumColumns(int nodeIndex) { * @return the number of rows in the node. */ public int getNumRows(int nodeIndex) { - if (useCptMaps) { + if (cptMapType == CptMapType.PROB_MAP) { return probMatrices[nodeIndex].getNumRows(); } else { return this.probs[nodeIndex].length; @@ -486,7 +484,7 @@ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { * @return the probability value for the given node. */ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { - if (useCptMaps) { + if (cptMapType == CptMapType.PROB_MAP) { return probMatrices[nodeIndex].get(rowIndex, colIndex); } else { return this.probs[nodeIndex][rowIndex][colIndex]; @@ -572,8 +570,8 @@ public void normalizeRow(int nodeIndex, int rowIndex) { */ @Override public void setProbability(int nodeIndex, double[][] probMatrix) { - if (useCptMaps) { - probMatrices[nodeIndex] = new CptMap(probMatrix); + if (cptMapType == CptMapType.PROB_MAP) { + probMatrices[nodeIndex] = new CptMapProbs(probMatrix); } else { for (int i = 0; i < probMatrix.length; i++) { System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); @@ -604,7 +602,7 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex, + "between 0.0 and 1.0 or Double.NaN."); } - if (useCptMaps) { + if (cptMapType == CptMapType.PROB_MAP) { probMatrices[nodeIndex].set(rowIndex, colIndex, value); } else { this.probs[nodeIndex][rowIndex][colIndex] = value; @@ -646,7 +644,7 @@ public void randomizeRow(int nodeIndex, int rowIndex) { int size = getNumColumns(nodeIndex); double[] row = getRandomWeights(size); - if (useCptMaps) { + if (cptMapType == CptMapType.PROB_MAP) { for (int colIndex = 0; colIndex < size; colIndex++) { probMatrices[nodeIndex].set(rowIndex, colIndex, row[colIndex]); } @@ -1097,11 +1095,11 @@ public String toString() { * @see #initializeNode * @see #randomizeRow */ - private void initialize(BayesIm oldBayesIm, int initializationMethod) { + private void initialize(BayesIm oldBayesIm, InitializationMethod initializationMethod) { this.parents = new int[this.nodes.length][]; this.parentDims = new int[this.nodes.length][]; this.probs = new double[this.nodes.length][][]; - this.probMatrices = new CptMap[this.nodes.length]; + this.probMatrices = new CptMapProbs[this.nodes.length]; for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); @@ -1112,7 +1110,7 @@ private void initialize(BayesIm oldBayesIm, int initializationMethod) { * This method initializes the node indicated. */ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, - int initializationMethod) { + InitializationMethod initializationMethod) { Node node = this.nodes[nodeIndex]; // Set up parents array. Should store the parents of @@ -1149,10 +1147,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, this.parentDims[nodeIndex] = dims; this.probs[nodeIndex] = new double[numRows][numCols]; - this.probMatrices[nodeIndex] = new CptMap(numRows, numCols); + this.probMatrices[nodeIndex] = new CptMapProbs(numRows, numCols); // Initialize each row. - if (initializationMethod == MlBayesIm.RANDOM) { + if (initializationMethod == InitializationMethod.RANDOM) { randomizeTable(nodeIndex); } else { for (int rowIndex = 0; rowIndex < numRows; rowIndex++) { @@ -1167,10 +1165,10 @@ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, } private void overwriteRow(int nodeIndex, int rowIndex, - int initializationMethod) { - if (initializationMethod == MlBayesIm.RANDOM) { + InitializationMethod initializationMethod) { + if (initializationMethod == InitializationMethod.RANDOM) { randomizeRow(nodeIndex, rowIndex); - } else if (initializationMethod == MlBayesIm.MANUAL) { + } else if (initializationMethod == InitializationMethod.MANUAL) { initializeRowAsUnknowns(nodeIndex, rowIndex); } else { throw new IllegalArgumentException("Unrecognized state."); @@ -1182,7 +1180,7 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { double[] row = new double[size]; Arrays.fill(row, Double.NaN); - if (useCptMaps) { + if (cptMapType == CptMapType.PROB_MAP) { probMatrices[nodeIndex].assignRow(rowIndex, new Vector(row)); } else { this.probs[nodeIndex][rowIndex] = row; @@ -1193,7 +1191,7 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { * This method initializes the node indicated. */ private void retainOldRowIfPossible(int nodeIndex, int rowIndex, - BayesIm oldBayesIm, int initializationMethod) { + BayesIm oldBayesIm, InitializationMethod initializationMethod) { int oldNodeIndex = getCorrespondingNodeIndex(nodeIndex, oldBayesIm); @@ -1341,15 +1339,15 @@ private void readObject(ObjectInputStream s) * Note: This method should only be called after the `probs` array has been properly initialized. */ private void copyDataToProbMatrices() { - if (!this.useCptMaps && this.probs != null && this.probs.length == this.nodes.length) { - this.probMatrices = new CptMap[this.probs.length]; + if (cptMapType == null && this.probs != null && this.probs.length == this.nodes.length) { + this.probMatrices = new CptMapProbs[this.probs.length]; for (int i = 0; i < this.nodes.length; i++) { - probMatrices[i] = new CptMap(this.probs[i]); + probMatrices[i] = new CptMapProbs(this.probs[i]); } this.probs = null; - this.useCptMaps = true; + this.cptMapType = CptMapType.PROB_MAP; } } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java index 1f0bf27624..87862b3101 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/RowSummingExactUpdater.java @@ -346,7 +346,7 @@ private void updateAll() { private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) { // Switching this to MANUAL since the initial values don't matter. - return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL); + return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL); } private BayesPm createUpdatedBayesPm(Dag updatedGraph) { diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java index eb6a390159..3a875f084d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/simulation/HsimRobustCompare.java @@ -65,7 +65,7 @@ public static List run(int numVars, double edgesPerNode, int numCases, Graph odag = RandomGraph.randomGraphRandomForwardEdges(vars, 0, numEdges, 30, 15, 15, false, true); BayesPm bayesPm = new BayesPm(odag, 2, 2); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); //oData is the original data set, and odag is the original dag. DataSet oData = bayesIm.simulateData(numCases, false); //System.out.println(oData); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java index 1e638dcb6a..28fb2e581d 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison.java @@ -112,7 +112,7 @@ public static ComparisonResult compare(ComparisonParameters params) { } BayesPm pm = new BayesPm(trueDag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); dataSet = im.simulateData(params.getSampleSize(), false, tiers); } else { throw new IllegalArgumentException("Unrecognized data type."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java index a56c2b9f14..6d2e122dd2 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/Comparison2.java @@ -303,7 +303,7 @@ public static ComparisonResult compare(ComparisonParameters params) { } BayesPm pm = new BayesPm(trueDag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); dataSet = im.simulateData(params.getSampleSize(), false, tiers); } else { throw new IllegalArgumentException("Unrecognized data type."); diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java index 5ff07697bd..d353a1e4cb 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/study/performance/PerformanceTests.java @@ -1026,7 +1026,7 @@ private void testFges(int numVars, double edgeFactor, int numCases, int numRuns, } else { BayesPm pm = new BayesPm(dag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(numCases, false, tiers); @@ -1258,7 +1258,7 @@ private void testFgesMb(int numVars, double edgeFactor, int numCases, int numRun this.out.println(new Date()); BayesPm pm = new BayesPm(dag, 3, 3); - MlBayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + MlBayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(numCases, false, tiers); diff --git a/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java b/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java index 6588f2ccf9..06e0341818 100644 --- a/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java +++ b/tetrad-lib/src/main/java/edu/pitt/isp/sverchkov/data/AdTreeTest.java @@ -72,7 +72,7 @@ public static void main(String[] args) throws Exception { Graph graph = RandomGraph.randomGraphRandomForwardEdges(variables, 0, numEdges, 30, 15, 15, false, true); BayesPm pm = new BayesPm(graph); - BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(rows, false); // This implementation uses a DataTable to represent the data diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java index a24ce17cbd..b0cde7abc8 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesDiscreteBicScorer.java @@ -53,7 +53,7 @@ public void testPValue() { final int numCategories = 8; BayesPm pm = new BayesPm(graph, numCategories, numCategories); - BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(1000, false); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java index 699ddfe42d..fa8f6c7778 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesIm.java @@ -66,7 +66,7 @@ public void testCopyConstructor() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); BayesIm bayesIm2 = new MlBayesIm(bayesIm); assertEquals(bayesIm, bayesIm2); } @@ -102,9 +102,9 @@ public void testAddRemoveParent() { dag.addDirectedEdge(a, b); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); - BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL); + BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.InitializationMethod.MANUAL); assertEquals(bayesIm, bayesIm2); @@ -113,7 +113,7 @@ public void testAddRemoveParent() { dag.addDirectedEdge(c, b); BayesPm bayesPm3 = new BayesPm(dag, bayesPm); - BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.MANUAL); + BayesIm bayesIm3 = new MlBayesIm(bayesPm3, bayesIm2, MlBayesIm.InitializationMethod.MANUAL); // Make sure the rows got repeated. // assertTrue(rowsEqual(bayesIm3, bayesIm3.getNodeIndex(b), 0, 1)); @@ -125,7 +125,7 @@ public void testAddRemoveParent() { dag.removeNode(c); BayesPm bayesPm4 = new BayesPm(dag, bayesPm3); - BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.MANUAL); + BayesIm bayesIm4 = new MlBayesIm(bayesPm4, bayesIm3, MlBayesIm.InitializationMethod.MANUAL); // Make sure the 'b' node has 2 rows of '?'s'. assertTrue(bayesIm4.getNumRows(bayesIm4.getNodeIndex(b)) == 2); @@ -157,17 +157,17 @@ public void testAddRemoveValues() { assertTrue(Edges.isDirectedEdge(dag.getEdge(a, b))); BayesPm bayesPm = new BayesPm(dag, 3, 3); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); bayesPm.setNumCategories(a, 4); bayesPm.setNumCategories(c, 4); - BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.MANUAL); + BayesIm bayesIm2 = new MlBayesIm(bayesPm, bayesIm, MlBayesIm.InitializationMethod.MANUAL); bayesPm.setNumCategories(a, 2); - BayesIm bayesIm3 = new MlBayesIm(bayesPm, bayesIm2, MlBayesIm.MANUAL); + BayesIm bayesIm3 = new MlBayesIm(bayesPm, bayesIm2, MlBayesIm.InitializationMethod.MANUAL); bayesPm.setNumCategories(b, 2); - BayesIm bayesIm4 = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm4 = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); // At this point, a has 2 categories, b has 2 categories, and c has 4 categories. diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java index 1a43189e17..e1155ced9d 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestBayesXml.java @@ -112,7 +112,7 @@ private static BayesIm sampleBayesIm2() { BayesPm bayesPm = new BayesPm(graph); bayesPm.setNumCategories(b, 3); - return new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + return new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); } private static BayesIm sampleBayesIm3() { @@ -138,7 +138,7 @@ private static BayesIm sampleBayesIm3() { BayesPm bayesPm = new BayesPm(graph); bayesPm.setNumCategories(b, 3); - return new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + return new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); } /** diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java index 66aa240bae..afebcb340e 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCellProbabilities.java @@ -73,7 +73,7 @@ public void testCreateUsingBayesIm() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java index 19d0b64e12..fde6221b59 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCptInvariantUpdater.java @@ -176,7 +176,7 @@ public void testUpdate4() { graph.addDirectedEdge(x2Node, x3Node); BayesPm bayesPm = new BayesPm(graph, 2, 2); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x2 = bayesIm.getNodeIndex(x2Node); int x3 = bayesIm.getNodeIndex(x3Node); @@ -220,7 +220,7 @@ public void testUpdate5() { graph.addDirectedEdge(x4Node, x2Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x1 = bayesIm.getNodeIndex(x1Node); int x2 = bayesIm.getNodeIndex(x2Node); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java index 3542df3d69..c4ec3b3ddc 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDataSetCellProbs.java @@ -41,7 +41,7 @@ public void testCreateUsingBayesIm() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); BayesImProbs bayesImProbs = new BayesImProbs(bayesIm); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java index 5b470d4272..e160b8d476 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestDiscreteProbs.java @@ -79,7 +79,7 @@ public void testCreateUsingBayesIm() { Graph graph = GraphUtils.convert("X1-->X2,X1-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(graph); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); StoredCellProbs cellProbs = StoredCellProbs.createCellTable(bayesIm); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java index e3e6d25416..5730004833 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFges.java @@ -227,7 +227,7 @@ public void explore2() { 30, 15, 15, false, true); BayesPm pm = new BayesPm(dag, 2, 3); - BayesIm im = new MlBayesIm(pm, MlBayesIm.RANDOM); + BayesIm im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = im.simulateData(numCases, false); BdeScore score = new BdeScore(data); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java index 46ce808a42..848af1aa3d 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestGFci.java @@ -248,7 +248,7 @@ public void testRandomDiscreteData() { Graph g = GraphUtils.convert("X1-->X2,X1-->X3,X1-->X4,X2-->X3,X2-->X4,X3-->X4"); Dag dag = new Dag(g); BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); DataSet data = bayesIm.simulateData(sampleSize, false); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java index ee3403d4a5..f36f1ed4a1 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestHistogram.java @@ -90,7 +90,7 @@ public void testHistogram() { // Discrete BayesPm bayesPm = new BayesPm(trueGraph); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); DataSet data2 = bayesIm.simulateData(sampleSize, false); // For some reason these are giving different diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java index 70ba06d08a..7b10b446cc 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRfciBsc.java @@ -81,7 +81,7 @@ public void testRandomDiscreteData() { // set a number of latent variables BayesPm bayesPm = new BayesPm(dag); - BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); RandomUtil.getInstance().setSeed(seed); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java index fd9ec8294f..ac2a354859 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestRowSummingUpdater.java @@ -197,7 +197,7 @@ public void testUpdate4() { graph.addDirectedEdge(x2Node, x3Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x2 = bayesIm.getNodeIndex(x2Node); int x3 = bayesIm.getNodeIndex(x3Node); @@ -240,7 +240,7 @@ public void testUpdate5() { graph.addDirectedEdge(x4Node, x2Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); int x1 = bayesIm.getNodeIndex(x1Node); int x2 = bayesIm.getNodeIndex(x2Node); diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java index 74a221e161..dbda137037 100755 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestUpdatedBayesIm.java @@ -59,7 +59,7 @@ public void testCompound() { graph.addDirectedEdge(x4Node, x2Node); BayesPm bayesPm = new BayesPm(graph); - MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); + MlBayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM); UpdatedBayesIm updatedIm1 = new UpdatedBayesIm(bayesIm); assertEquals(bayesIm, updatedIm1); From 8085291dc489992a6201eb001d2364f3e28b6404 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sun, 31 Mar 2024 12:45:53 -0400 Subject: [PATCH 22/23] Refactor initialization methods in Bayes' methods Updated the usage of initialization methods in various Bayes related classes, specifically changing from "MlBayesIm.RANDOM" and "MlBayesIm.MANUAL" to "MlBayesIm.InitializationMethod.RANDOM" and "MlBayesIm.InitializationMethod.MANUAL". Created new files for CptMapProbs and CptMapCounts as part of the refactoring. This update helps make the code more readable and consistent across different parts of the application. --- .../editor/BayesImNodeEditingTable.java | 12 +- .../java/edu/cmu/tetrad/bayes/BayesIm.java | 4 + .../edu/cmu/tetrad/bayes/CptMapCounts.java | 20 +++ .../cmu/tetrad/bayes/MlBayesEstimator.java | 121 +++++++-------- .../java/edu/cmu/tetrad/bayes/MlBayesIm.java | 146 +++++++++++++++--- 5 files changed, 214 insertions(+), 89 deletions(-) diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java index 10825b1c7b..5999e8f784 100755 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/BayesImNodeEditingTable.java @@ -23,6 +23,7 @@ import edu.cmu.tetrad.bayes.BayesIm; import edu.cmu.tetrad.bayes.BayesPm; +import edu.cmu.tetrad.bayes.MlBayesIm; import edu.cmu.tetrad.graph.Node; import edu.cmu.tetrad.util.JOptionUtils; import edu.cmu.tetrad.util.NumberFormatUtil; @@ -537,8 +538,9 @@ public Object getValueAt(int tableRow, int tableCol) { int colIndex = tableCol - parentVals.length; if (colIndex < getBayesIm().getNumColumns(getNodeIndex())) { - return getBayesIm().getProbability(getNodeIndex(), tableRow, + double probability = getBayesIm().getProbability(getNodeIndex(), tableRow, colIndex); + return probability; } return "null"; @@ -556,10 +558,14 @@ public boolean isCellEditable(int row, int col) { * Sets the value of the cell at (row, col) to 'aValue'. */ public void setValueAt(Object aValue, int row, int col) { + if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.COUNT_MAP) { + return; + } + int numParents = getBayesIm().getNumParents(getNodeIndex()); int colIndex = col - numParents; - if ("".equals(aValue) || aValue == null) { + if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.PROB_MAP && ("".equals(aValue) || aValue == null)) { getBayesIm().setProbability(getNodeIndex(), row, colIndex, Double.NaN); fireTableRowsUpdated(row, row); @@ -753,6 +759,8 @@ public void resetFailedRow() { public void resetFailedCol() { this.failedCol = -1; } + + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java index 7b1f1f7b62..a768c73798 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java @@ -393,4 +393,8 @@ public interface BayesIm extends VariableSource, Im, Simulator { * @return a string representation for this Bayes net. */ String toString(); + + default MlBayesIm.CptMapType getCptMapType() { + return MlBayesIm.CptMapType.PROB_MAP; + } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java index ec5f81f431..9b4fa5dca7 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java @@ -123,6 +123,26 @@ public double get(int row, int column) { return cellCount / (double) rowCount; } + public void addCounts(int row, int column, int count) { + if (row < 0 || row >= numRows || column < 0 || column >= numColumns) { + throw new IllegalArgumentException("Row and column must be within bounds."); + } + + int key = row * numColumns + column; + + if (!cellCounts.containsKey(key)) { + cellCounts.put(key, count); + } else { + cellCounts.put(key, cellCounts.get(key) + count); + } + + if (!rowCounts.containsKey(row)) { + rowCounts.put(row, count); + } else { + rowCounts.put(row, rowCounts.get(row) + count); + } + } + /** * Returns the number of rows in the probability map. * diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java index a24cd8ed4b..e19d76b85c 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesEstimator.java @@ -22,8 +22,11 @@ package edu.cmu.tetrad.bayes; import edu.cmu.tetrad.data.DataSet; +import edu.cmu.tetrad.graph.Graph; import edu.cmu.tetrad.graph.Node; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -58,75 +61,63 @@ public BayesIm estimate(BayesPm bayesPm, DataSet dataSet) { throw new NullPointerException(); } -// if (DataUtils.containsMissingValue(dataSet)) { -// throw new IllegalArgumentException("Please remove or impute missing values."); -// } - - // Make sure all of the variables in the PM are in the data set; - // otherwise, estimation is impossible. - BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); - - // Create a new Bayes IM to store the estimated values. - BayesIm estimatedIm = new MlBayesIm(bayesPm); - - // Create a subset of the data set with the variables of the IM, in - // the order of the IM. - List variables = estimatedIm.getVariables(); - DataSet columnDataSet2 = dataSet.subsetColumns(variables); - DiscreteProbs discreteProbs = new DataSetProbs(columnDataSet2); - - // We will use the same estimation methods as the updaters, to ensure - // compatibility. - Proposition assertion = Proposition.tautology(estimatedIm); - Proposition condition = Proposition.tautology(estimatedIm); - Evidence evidence2 = Evidence.tautology(estimatedIm); - - int numNodes = estimatedIm.getNumNodes(); - - for (int node = 0; node < numNodes; node++) { - int numRows = estimatedIm.getNumRows(node); - int numCols = estimatedIm.getNumColumns(node); - int[] parents = estimatedIm.getParents(node); - - for (int row = 0; row < numRows; row++) { - int[] parentValues = estimatedIm.getParentValues(node, row); - - for (int col = 0; col < numCols; col++) { - - // Remove values from the proposition in various ways; if - // a combination exists in the end, calculate a conditional - // probability. - assertion.setToTautology(); - condition.setToTautology(); - - for (int i = 0; i < numNodes; i++) { - for (int j = 0; j < evidence2.getNumCategories(i); j++) { - if (!evidence2.getProposition().isAllowed(i, j)) { - condition.removeCategory(i, j); - } - } - } - - assertion.disallowComplement(node, col); - - for (int k = 0; k < parents.length; k++) { - condition.disallowComplement(parents[k], parentValues[k]); - } - - if (condition.existsCombination()) { - double p = discreteProbs.getConditionalProb(assertion, condition); -// if (Double.isNaN(p)) p = 1.0 / numCols; - estimatedIm.setProbability(node, row, col, p); - } else { - estimatedIm.setProbability(node, row, col, Double.NaN); - } + MlBayesIm im = new MlBayesIm(bayesPm, true); + + // Get the nodes from the BayesPm. This fixes the order of the nodes + // in the BayesIm, independently of any change to the BayesPm. + // (This order must be maintained.) + Graph graph = bayesPm.getDag(); + + for (int nodeIndex = 0; nodeIndex < im.getNumNodes(); nodeIndex++) { + Node node = im.getNode(nodeIndex); + + // Set up parents array. Should store the parents of + // each node as ints in a particular order. + List parentList = new ArrayList<>(graph.getParents(node)); + int[] parentArray = new int[parentList.size()]; + + for (int i = 0; i < parentList.size(); i++) { + parentArray[i] = im.getNodeIndex(parentList.get(i)); + } + + // Sort the parent array. + Arrays.sort(parentArray); + + // Setup dimensions array for parents. + int[] dims = new int[parentArray.length]; + + for (int i = 0; i < dims.length; i++) { + Node parNode = im.getNode(parentArray[i]); + dims[i] = bayesPm.getNumCategories(parNode); + } + + // Calculate dimensions of table. + int numRows = 1; + + for (int dim : dims) { + numRows *= dim; + } + + int numCols = bayesPm.getNumCategories(node); + + CptMapCounts counts = new CptMapCounts(numRows, numCols); + + for (int row = 0; row < dataSet.getNumRows(); row++) { + int[] parentValues = new int[parentArray.length]; + + for (int i = 0; i < parentValues.length; i++) { + parentValues[i] = dataSet.getInt(row, parentArray[i]); } + + int value = dataSet.getInt(row, nodeIndex); + + counts.addCounts(im.getRowIndex(nodeIndex, parentValues), value, 1); } - } -// System.out.println(estimatedIm); + im.setCountMap(nodeIndex, counts); + } - return estimatedIm; + return im; } } diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java index b662778ded..f828e231af 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/MlBayesIm.java @@ -67,12 +67,6 @@ * @version $Id: $Id */ public final class MlBayesIm implements BayesIm { - private enum CptMapType { - COUNT_MAP, PROB_MAP - } - public enum InitializationMethod { - MANUAL, RANDOM - } public static final int RANDOM = 1; @Serial private static final long serialVersionUID = 23L; @@ -92,10 +86,6 @@ public enum InitializationMethod { * The array of nodes from the graph. Order is important. */ private final Node[] nodes; - /** - * A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used. - * The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility. - */ private CptMapType cptMapType = null; /** * The list of parents for each node from the graph. Order or nodes corresponds to the order of nodes in 'nodes', @@ -123,7 +113,7 @@ public enum InitializationMethod { * The array of CPT maps for each node. The index of the node corresponds to the index of the probability map in * this array. Replaces the probs array. */ - private CptMapProbs[] probMatrices; + private CptMap[] probMatrices; /** * Constructs a new BayesIm from the given BayesPm, initializing all values as Double.NaN ("?"). @@ -151,11 +141,85 @@ public MlBayesIm(BayesPm bayesPm, InitializationMethod initializationMethod) this(bayesPm, null, initializationMethod); } + /** + * Constructs an instance of MlBayesIm. + * + * @param bayesPm the BayesPm object that represents the Bayesian network. + * @param countsOnly should be set to true for this constructor. + * @throws IllegalArgumentException if countsOnly is false. + * @throws NullPointerException if bayesPm is null. + */ + public MlBayesIm(BayesPm bayesPm, boolean countsOnly) { + if (!countsOnly) { + throw new IllegalArgumentException("countsOnly must be true for this constructor."); + } + + if (bayesPm == null) { + throw new NullPointerException("BayesPm must not be null."); + } + + + + this.bayesPm = new BayesPm(bayesPm); + + // Get the nodes from the BayesPm. This fixes the order of the nodes + // in the BayesIm, independently of any change to the BayesPm. + // (This order must be maintained.) + Graph graph = bayesPm.getDag(); + this.nodes = graph.getNodes().toArray(new Node[0]); + + this.cptMapType = CptMapType.COUNT_MAP; + + this.parents = new int[this.nodes.length][]; + this.parentDims = new int[this.nodes.length][]; + this.probs = new double[this.nodes.length][][]; + this.probMatrices = new CptMapCounts[this.nodes.length]; + + for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { + Node node = this.nodes[nodeIndex]; + + // Set up parents array. Should store the parents of + // each node as ints in a particular order. + List parentList = new ArrayList<>(graph.getParents(node)); + int[] parentArray = new int[parentList.size()]; + + for (int i = 0; i < parentList.size(); i++) { + parentArray[i] = getNodeIndex(parentList.get(i)); + } + + // Sort parent array. + Arrays.sort(parentArray); + + this.parents[nodeIndex] = parentArray; + + // Setup dimensions array for parents. + int[] dims = new int[parentArray.length]; + + for (int i = 0; i < dims.length; i++) { + Node parNode = this.nodes[parentArray[i]]; + dims[i] = getBayesPm().getNumCategories(parNode); + } + + // Calculate dimensions of table. + int numRows = 1; + + for (int dim : dims) { + numRows *= dim; + } + + int numCols = getBayesPm().getNumCategories(node); + + this.parentDims[nodeIndex] = dims; + this.probs[nodeIndex] = new double[numRows][numCols]; + this.probMatrices[nodeIndex] = new CptMapCounts(numRows, numCols); + } + } + /** * Constructs a new BayesIm from the given BayesPm, initializing values either as MANUAL or RANDOM, but using values * from the old BayesIm provided where posssible. If initialized manually, all values that cannot be retrieved from * oldBayesIm will be set to Double.NaN ("?") in each such row; if initialized randomly, all values that cannot be - * retrieved from oldBayesIm will distributed randomly in each such row. + * retrieved from oldBayesIm will be distributed randomly in each such row. * * @param bayesPm the given Bayes PM. Carries with it the underlying graph model. * @param oldBayesIm an already-constructed BayesIm whose values may be used where possible to initialize @@ -217,8 +281,6 @@ public static MlBayesIm serializableInstance() { return new MlBayesIm(BayesPm.serializableInstance()); } - //===============================PUBLIC METHODS========================// - /** *

    getParameterNames.

    * @@ -253,6 +315,8 @@ private static double[] getRandomWeights(int size) { return row; } + //===============================PUBLIC METHODS========================// + /** *

    Getter for the field bayesPm.

    * @@ -364,7 +428,7 @@ public List getVariableNames() { * @return the number of columns. */ public int getNumColumns(int nodeIndex) { - if (cptMapType == CptMapType.PROB_MAP) { + if (cptMapType != null) { return probMatrices[nodeIndex].getNumColumns(); } else { return this.probs[nodeIndex][0].length; @@ -378,7 +442,7 @@ public int getNumColumns(int nodeIndex) { * @return the number of rows in the node. */ public int getNumRows(int nodeIndex) { - if (cptMapType == CptMapType.PROB_MAP) { + if (cptMapType != null) { return probMatrices[nodeIndex].getNumRows(); } else { return this.probs[nodeIndex].length; @@ -484,7 +548,7 @@ public int getParentValue(int nodeIndex, int rowIndex, int colIndex) { * @return the probability value for the given node. */ public double getProbability(int nodeIndex, int rowIndex, int colIndex) { - if (cptMapType == CptMapType.PROB_MAP) { + if (cptMapType != null) { return probMatrices[nodeIndex].get(rowIndex, colIndex); } else { return this.probs[nodeIndex][rowIndex][colIndex]; @@ -572,6 +636,8 @@ public void normalizeRow(int nodeIndex, int rowIndex) { public void setProbability(int nodeIndex, double[][] probMatrix) { if (cptMapType == CptMapType.PROB_MAP) { probMatrices[nodeIndex] = new CptMapProbs(probMatrix); + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot set probability matrix for an estimated Bayes IM."); } else { for (int i = 0; i < probMatrix.length; i++) { System.arraycopy(probMatrix[i], 0, this.probs[nodeIndex][i], 0, probMatrix[i].length); @@ -579,6 +645,14 @@ public void setProbability(int nodeIndex, double[][] probMatrix) { } } + public void setCountMap(int nodeIndex, CptMapCounts countMap) { + if (cptMapType == CptMapType.COUNT_MAP) { + probMatrices[nodeIndex] = countMap; + } else { + throw new IllegalArgumentException("Cannot set count map for a non-estimated Bayes IM."); + } + } + /** * Sets the probability value for a specific node, row, and column in the probability table. * @@ -603,7 +677,9 @@ public void setProbability(int nodeIndex, int rowIndex, int colIndex, } if (cptMapType == CptMapType.PROB_MAP) { - probMatrices[nodeIndex].set(rowIndex, colIndex, value); + ((CptMapProbs) probMatrices[nodeIndex]).set(rowIndex, colIndex, value); + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot set probability value for an estimated Bayes IM."); } else { this.probs[nodeIndex][rowIndex][colIndex] = value; } @@ -646,8 +722,10 @@ public void randomizeRow(int nodeIndex, int rowIndex) { if (cptMapType == CptMapType.PROB_MAP) { for (int colIndex = 0; colIndex < size; colIndex++) { - probMatrices[nodeIndex].set(rowIndex, colIndex, row[colIndex]); + ((CptMapProbs) probMatrices[nodeIndex]).set(rowIndex, colIndex, row[colIndex]); } + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot randomize row for an estimated Bayes IM."); } else { this.probs[nodeIndex][rowIndex] = row; } @@ -1041,8 +1119,6 @@ public boolean equals(Object o) { return true; } - //=============================PRIVATE METHODS=======================// - /** * Prints out the probability table for each variable. * @@ -1101,16 +1177,24 @@ private void initialize(BayesIm oldBayesIm, InitializationMethod initializationM this.probs = new double[this.nodes.length][][]; this.probMatrices = new CptMapProbs[this.nodes.length]; + this.cptMapType = CptMapType.PROB_MAP; + for (int nodeIndex = 0; nodeIndex < this.nodes.length; nodeIndex++) { initializeNode(nodeIndex, oldBayesIm, initializationMethod); } } + //=============================PRIVATE METHODS=======================// + /** * This method initializes the node indicated. */ private void initializeNode(int nodeIndex, BayesIm oldBayesIm, InitializationMethod initializationMethod) { + if (cptMapType != CptMapType.PROB_MAP) { + throw new IllegalArgumentException("Cannot initialize a Bayes IM randomly without a probability map."); + } + Node node = this.nodes[nodeIndex]; // Set up parents array. Should store the parents of @@ -1181,7 +1265,9 @@ private void initializeRowAsUnknowns(int nodeIndex, int rowIndex) { Arrays.fill(row, Double.NaN); if (cptMapType == CptMapType.PROB_MAP) { - probMatrices[nodeIndex].assignRow(rowIndex, new Vector(row)); + ((CptMapProbs) probMatrices[nodeIndex]).assignRow(rowIndex, new Vector(row)); + } else if (cptMapType == CptMapType.COUNT_MAP) { + throw new IllegalArgumentException("Cannot initialize a row as unknowns in an estimated Bayes IM."); } else { this.probs[nodeIndex][rowIndex] = row; } @@ -1350,4 +1436,20 @@ private void copyDataToProbMatrices() { this.cptMapType = CptMapType.PROB_MAP; } } + + /** + * A flag indicating whether to use CptMaps or not. If true, CptMaps are used; if false, the probs array is used. + * The CptMap is the new way of storing the probabilities; the probs array is kept here for backward compatibility. + */ + public CptMapType getCptMapType() { + return cptMapType; + } + + public enum CptMapType { + PROB_MAP, COUNT_MAP + } + + public enum InitializationMethod { + MANUAL, RANDOM + } } From 3abb9c95f69584f91c61e8a79b2843ec2706d250 Mon Sep 17 00:00:00 2001 From: jdramsey Date: Sun, 31 Mar 2024 13:44:24 -0400 Subject: [PATCH 23/23] Refactor initialization methods in Bayes' methods Updated the usage of initialization methods in various Bayes related classes, specifically changing from "MlBayesIm.RANDOM" and "MlBayesIm.MANUAL" to "MlBayesIm.InitializationMethod.RANDOM" and "MlBayesIm.InitializationMethod.MANUAL". Created new files for CptMapProbs and CptMapCounts as part of the refactoring. This update helps make the code more readable and consistent across different parts of the application. --- .../src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java index 9b4fa5dca7..61f936b2ec 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMapCounts.java @@ -39,7 +39,7 @@ public class CptMapCounts implements CptMap { /** * The prior count for all cells. */ - private int priorCount = 0; + private int priorCount = 1; /** * Constructs a new probability map, a map from a unique integer index for a particular node to the probability of @@ -120,6 +120,8 @@ public double get(int row, int column) { int key = row * numColumns + column; int rowCount = rowCounts.getOrDefault(row, 0); int cellCount = cellCounts.getOrDefault(key, 0); + rowCount += priorCount * numColumns; + cellCount += priorCount; return cellCount / (double) rowCount; }