Skip to content

Commit

Permalink
Add RowsSettable interface to IndTest classes
Browse files Browse the repository at this point in the history
The IndTestConditionalGaussianLrt and IndTestDegenerateGaussianLrt classes have been updated to implement the RowsSettable interface. This change allows users to set which rows are used in the test. Also, redundant spaces were removed from IndTestFisherZ, and parameter changes were made in DegenerateGaussianBic class.
  • Loading branch information
jdramsey committed Jun 12, 2024
1 parent 92b0c22 commit 7cda6c2
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ public DegenerateGaussianBicScore() {
public Score getScore(DataModel dataSet, Parameters parameters) {
this.dataSet = dataSet;
boolean precomputeCovariances = parameters.getBoolean(Params.PRECOMPUTE_COVARIANCES);
// DegenerateGaussianScoreOld degenerateGaussianScore = new DegenerateGaussianScoreOld(DataUtils.getMixedDataSet(dataSet));
DegenerateGaussianScore degenerateGaussianScore = new DegenerateGaussianScore(SimpleDataLoader.getMixedDataSet(dataSet), precomputeCovariances);
degenerateGaussianScore.setPenaltyDiscount(parameters.getDouble("penaltyDiscount"));
degenerateGaussianScore.setPenaltyDiscount(parameters.getDouble(Params.PENALTY_DISCOUNT));
degenerateGaussianScore.setUsePseudoInverse(parameters.getBoolean(Params.USE_PSEUDOINVERSE));
return degenerateGaussianScore;
}

Expand Down Expand Up @@ -102,6 +102,7 @@ public List<String> getParameters() {
parameters.add(Params.PENALTY_DISCOUNT);
parameters.add(Params.STRUCTURE_PRIOR);
parameters.add(Params.PRECOMPUTE_COVARIANCES);
parameters.add(Params.USE_PSEUDOINVERSE);
return parameters;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
* @author josephramsey
* @version $Id: $Id
*/
public class IndTestConditionalGaussianLrt implements IndependenceTest {
public class IndTestConditionalGaussianLrt implements IndependenceTest, RowsSettable {
/**
* The data set.
*/
Expand Down Expand Up @@ -79,6 +79,10 @@ public class IndTestConditionalGaussianLrt implements IndependenceTest {
* The minimum sample size per cell for discretization.
*/
private int minSampleSizePerCell = 4;
/**
* The rows used in the test.
*/
private List<Integer> rows = new ArrayList<>();

/**
* Constructor.
Expand Down Expand Up @@ -282,6 +286,10 @@ public void setNumCategoriesToDiscretize(int numCategoriesToDiscretize) {
* @return A list of row indices.
*/
private List<Integer> getRows(List<Node> allVars, Map<Node, Integer> nodeHash) {
if (this.rows != null) {
return this.rows;
}

List<Integer> rows = new ArrayList<>();

K:
Expand All @@ -299,6 +307,42 @@ private List<Integer> getRows(List<Node> allVars, Map<Node, Integer> nodeHash) {
return rows;
}

/**
* Returns the rows used in the test.
*
* @return The rows used in the test.
*/
public List<Integer> getRows() {
return rows;
}

/**
* Allows the user to set which rows are used in the test. Otherwise, all rows are used, except those with missing
* values.
*/
public void setRows(List<Integer> rows) {
if (data == null) {
return;
}

List<Integer> all = new ArrayList<>();
for (int i = 0; i < data.getNumRows(); i++) all.add(i);
Collections.shuffle(all);

List<Integer> _rows = new ArrayList<>();
for (int i = 0; i < data.getNumRows() / 2; i++) {
_rows.add(all.get(i));
}

for (Integer row : _rows) {
if (row < 0 || row >= data.getNumRows()) {
throw new IllegalArgumentException("Row index out of bounds.");
}
}

this.rows = _rows;
}

/**
* Sets the minimum sample size per cell for the independence test.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
* @author Bryan Andrews
* @version $Id: $Id
*/
public class IndTestDegenerateGaussianLrt implements IndependenceTest {
public class IndTestDegenerateGaussianLrt implements IndependenceTest, RowsSettable {

/**
* A constant.
Expand Down Expand Up @@ -96,6 +96,10 @@ public class IndTestDegenerateGaussianLrt implements IndependenceTest {
* True if verbose output should be printed.
*/
private boolean verbose;
/**
* The rows used in the test.
*/
private List<Integer> rows = new ArrayList<>();

/**
* Constructs the score using a covariance matrix.
Expand Down Expand Up @@ -403,6 +407,10 @@ private Ret getlldof(List<Integer> rows, int i, int... parents) {
* @return A list of integers representing the row indices that satisfy the conditions.
*/
private List<Integer> getRows(List<Node> allVars, Map<Node, Integer> nodesHash) {
if (this.rows != null) {
return this.rows;
}

List<Integer> rows = new ArrayList<>();

K:
Expand Down Expand Up @@ -459,6 +467,42 @@ private Matrix getCov(List<Integer> rows, int[] cols) {
return cov;
}

/**
* Returns the rows used in the test.
*
* @return The rows used in the test.
*/
public List<Integer> getRows() {
return rows;
}

/**
* Allows the user to set which rows are used in the test. Otherwise, all rows are used, except those with missing
* values.
*/
public void setRows(List<Integer> rows) {
if (dataSet == null) {
return;
}

List<Integer> all = new ArrayList<>();
for (int i = 0; i < dataSet.getNumRows(); i++) all.add(i);
Collections.shuffle(all);

List<Integer> _rows = new ArrayList<>();
for (int i = 0; i < dataSet.getNumRows() / 2; i++) {
_rows.add(all.get(i));
}

for (Integer row : _rows) {
if (row < 0 || row >= dataSet.getNumRows()) {
throw new IllegalArgumentException("Row index out of bounds.");
}
}

this.rows = _rows;
}

/**
* Stores a return value for a likelihood--i.e., a likelihood value and the degrees of freedom for it.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,8 +841,6 @@ public void setRows(List<Integer> rows) {
_rows.add(all.get(i));
}



for (Integer row : _rows) {
if (row < 0 || row >= sampleSize()) {
throw new IllegalArgumentException("Row index out of bounds.");
Expand Down

0 comments on commit 7cda6c2

Please sign in to comment.