Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made a version of MlBayesim that doesn't store NaNs in the tables so that huge models can be estimated. #1750

Merged
merged 24 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6e16cc0
Add MlBayesImOld.java for representing directed acyclic graphs
jdramsey Mar 28, 2024
681f3aa
Introduce ProbMap and implement it in MlBayesIm
jdramsey Mar 28, 2024
dc4e455
Implement ProbMap for efficient probability storage
jdramsey Mar 28, 2024
a9da8ce
Add Javadoc comments and update matrix data managing method
jdramsey Mar 28, 2024
14ba1fe
Rename class ProbMap to CptMap and update methods
jdramsey Mar 28, 2024
9833529
Rename class ProbMap to CptMap and update methods
jdramsey Mar 28, 2024
5618508
Rename class ProbMap to CptMap and update methods
jdramsey Mar 28, 2024
acd6268
Correct grammar in MlBayesIm.java comments
jdramsey Mar 28, 2024
a6552a5
Refactor storage method in MlBayesIm.java and update comments
jdramsey Mar 28, 2024
27a6334
Enhance storage strategy and comments in CptMap.java
jdramsey Mar 28, 2024
39a4afd
Simplify CptMap class and its comments
jdramsey Mar 28, 2024
a7b9166
Refactor MlBayesIm storage method explanation
jdramsey Mar 28, 2024
701d15e
Remove simulation task from MlBayesIm
jdramsey Mar 28, 2024
3a1b0dd
Replace probability matrices with CptMaps in MlBayesIm
jdramsey Mar 28, 2024
279b5f8
Update code for efficient probability storage in MlBayesIm
jdramsey Mar 28, 2024
27a564d
Removing the original MlBayesIm code.
jdramsey Mar 28, 2024
86447cb
Refactor getVariableNames method in MlBayesIm
jdramsey Mar 28, 2024
ed5622c
Add null check in existsParameterizedConstructor method
jdramsey Mar 30, 2024
e9f274c
Refactor conditional checks and parameter names
jdramsey Mar 30, 2024
c10116c
Fix typographical error in Statistic.java comment
jdramsey Mar 31, 2024
2cd7ebb
Refactor initialization methods in Bayes' methods
jdramsey Mar 31, 2024
8085291
Refactor initialization methods in Bayes' methods
jdramsey Mar 31, 2024
3abb9c9
Refactor initialization methods in Bayes' methods
jdramsey Mar 31, 2024
854381d
Merge branch 'development' into joe_bayes_im
jdramsey Mar 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.JOptionUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
Expand Down Expand Up @@ -536,8 +537,9 @@ public Object getValueAt(int tableRow, int tableCol) {
int colIndex = tableCol - parentVals.length;

if (colIndex < getBayesIm().getNumColumns(getNodeIndex())) {
return getBayesIm().getProbability(getNodeIndex(), tableRow,
double probability = getBayesIm().getProbability(getNodeIndex(), tableRow,
colIndex);
return probability;
}

return "null";
Expand All @@ -555,10 +557,14 @@ public boolean isCellEditable(int row, int col) {
* Sets the value of the cell at (row, col) to 'aValue'.
*/
public void setValueAt(Object aValue, int row, int col) {
if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.COUNT_MAP) {
return;
}

int numParents = getBayesIm().getNumParents(getNodeIndex());
int colIndex = col - numParents;

if ("".equals(aValue) || aValue == null) {
if (getBayesIm().getCptMapType() == MlBayesIm.CptMapType.PROB_MAP && ("".equals(aValue) || aValue == null)) {
getBayesIm().setProbability(getNodeIndex(), row, colIndex,
Double.NaN);
fireTableRowsUpdated(row, row);
Expand Down Expand Up @@ -752,6 +758,8 @@ public void resetFailedRow() {
public void resetFailedCol() {
this.failedCol = -1;
}


}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ public void setEnabled(boolean enabled) {
* @return a boolean
*/
public boolean isRandomForward() {
return this.parameters.getBoolean("graphRandomFoward", true);
return this.parameters.getBoolean("graphRandomForward", true);
}

private void setRandomForward(boolean randomFoward) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.io.Serial;
import java.util.ArrayList;
import java.util.List;
import java.util.function.IntBinaryOperator;

/**
* Wraps a Bayes Pm for use in the Tetrad application.
Expand Down Expand Up @@ -90,11 +91,11 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, BayesImWrapper oldBayesImwr
BayesIm oldBayesIm = oldBayesImwrapper.getBayesIm();

if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) {
setBayesIm(bayesPm, oldBayesIm, MlBayesIm.MANUAL);
setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.MANUAL);
} else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) {
setBayesIm(bayesPm, oldBayesIm, MlBayesIm.RANDOM);
setBayesIm(bayesPm, oldBayesIm, MlBayesIm.InitializationMethod.RANDOM);
} else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) {
setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM));
setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM));
}
}

Expand Down Expand Up @@ -193,9 +194,9 @@ public BayesImWrapper(BayesPmWrapper bayesPmWrapper, Parameters params) {
if (params.getString("initializationMode", "manualRetain").equals("manualRetain")) {
setBayesIm(new MlBayesIm(bayesPm));
} else if (params.getString("initializationMode", "manualRetain").equals("randomRetain")) {
setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM));
setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM));
} else if (params.getString("initializationMode", "manualRetain").equals("randomOverwrite")) {
setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.RANDOM));
setBayesIm(new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM));
}
}

Expand Down Expand Up @@ -342,9 +343,9 @@ public String getModelSourceName() {

//============================== private methods ============================//

private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, int manual) {
private void setBayesIm(BayesPm bayesPm, BayesIm oldBayesIm, MlBayesIm.InitializationMethod initializationMethod) {
this.bayesIms = new ArrayList<>();
this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, manual));
this.bayesIms.add(new MlBayesIm(bayesPm, oldBayesIm, initializationMethod));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ public BayesPmWrapper(GraphWrapper graphWrapper, Parameters params) {
lowerBound = upperBound = 3;
setBayesPm(graph, lowerBound, upperBound);
} else if (params.getString("bayesPmInitializationMode", "range").equals("range")) {
lowerBound = params.getInt("minCategories", 2);
upperBound = params.getInt("maxCategories", 4);
lowerBound = params.getInt("lowerBoundNumVals", 2);
upperBound = params.getInt("upperBoundNumVals", 4);
setBayesPm(graph, lowerBound, upperBound);
} else {
throw new IllegalStateException("Unrecognized type.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,13 @@ public int compareTo(Node node) {
* @return a boolean
*/
public boolean existsParameterizedConstructor(Class modelClass) {
if (modelClass == null) {

// If the model class is null, then there is no constructor, so display a dialog to the users by
// return false here.
return false;
}

Object param = getParam(modelClass);
List parentModels = listParentModels();
parentModels.add(param);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ private DataSet simulate(Graph graph, Parameters parameters) {
int minCategories = parameters.getInt(Params.MIN_CATEGORIES);
int maxCategories = parameters.getInt(Params.MAX_CATEGORIES);
pm = new BayesPm(graph, minCategories, maxCategories);
im = new MlBayesIm(pm, MlBayesIm.RANDOM);
im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
} else {
im = new MlBayesIm(pm, MlBayesIm.RANDOM);
im = new MlBayesIm(pm, MlBayesIm.InitializationMethod.RANDOM);
this.im = im;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ private DataSet simulate(Graph G, Parameters parameters) {
}

BayesPm bayesPm = new BayesPm(AG);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);
BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.InitializationMethod.RANDOM);

SemPm semPm = new SemPm(XG);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public interface Statistic extends Serializable {

/**
* Returns a mapping of the statistic to the interval [0, 1], with higher being better. This is used for a
* calculation of a utility for an algorithm.If the statistic is already between 0 and 1, you can just return the
* calculation of a utility for an algorithm. If the statistic is already between 0 and 1, you can just return the
* statistic.
*
* @param value The value of the statistic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ private void doUpdate() {
}

private BayesIm createdUpdatedBayesIm(BayesPm updatedBayesPm) {
return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.RANDOM);
return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.RANDOM);
}

private BayesPm createUpdatedBayesPm(Dag updatedGraph) {
Expand Down
4 changes: 4 additions & 0 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/BayesIm.java
Original file line number Diff line number Diff line change
Expand Up @@ -393,4 +393,8 @@ public interface BayesIm extends VariableSource, Im, Simulator {
* @return a string representation for this Bayes net.
*/
String toString();

default MlBayesIm.CptMapType getCptMapType() {
return MlBayesIm.CptMapType.PROB_MAP;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ public String toString() {
//==============================PRIVATE METHODS=======================//

private BayesIm createdManipulatedBayesIm(BayesPm updatedBayesPm) {
return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.MANUAL);
return new MlBayesIm(updatedBayesPm, this.bayesIm, MlBayesIm.InitializationMethod.MANUAL);
}

private BayesPm createManipulatedBayesPm(Dag updatedGraph) {
Expand Down
41 changes: 41 additions & 0 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/bayes/CptMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.util.TetradSerializable;

/**
* An interface representing a map of probabilities or counts for nodes in a Bayesian network. Implementations of this
* interface should provide methods to get the probability or count for a node at a given row and column, as well as
* methods to retrieve the number of rows and columns in the map.
* <p>
* This interface extends the TetradSerializable interface, indicating that implementations should be serializable and
* follow certain guidelines for compatibility across different versions of Tetrad.
*
* @author josephramsey
* @see CptMapProbs
* @see CptMapCounts
*/
public interface CptMap extends TetradSerializable {

/**
* Retrieves the value at the specified row and column in the CptMap.
*
* @param row the row index of the value to retrieve.
* @param column the column index of the value to retrieve.
* @return the value at the specified row and column in the CptMap.
*/
double get(int row, int column);

/**
* Retrieves the number of rows in the CptMap.
*
* @return the number of rows in the CptMap.
*/
int getNumRows();

/**
* Retrieves the number of columns in the CptMap.
*
* @return the number of columns in the CptMap.
*/
int getNumColumns();
}
Loading