Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes Apache Commons Math #241

Merged
merged 51 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b3531d3
Adding Cholesky factorization & Multivariate normal for sampling. Mig…
Craigacp May 6, 2022
954c820
Roughing out LU factorization.
Craigacp May 9, 2022
7b2257c
Implementing LU factorization and solver methods.
Craigacp May 10, 2022
ae7689b
Adding a test helper for comparing top feature maps and tidied up the…
Craigacp May 11, 2022
5026ab7
Removing commons math from Regression/SLM. Due to some additional ref…
Craigacp May 11, 2022
56dce9a
Stubbing out eigen decomposition.
Craigacp May 11, 2022
323e83c
Initial eigen decomposition, still buggy.
Craigacp May 13, 2022
60a13e5
Fixing the eigen decomposition.
Craigacp May 14, 2022
43e24e2
Commons math begone!
Craigacp May 14, 2022
accae2e
Adding CholeskyFactorization.solve implementations, improving Cholesk…
Craigacp May 16, 2022
af83956
Promoting DenseMatrix.getColumn to Matrix, implementing it on DenseSp…
Craigacp May 16, 2022
b6d359e
Renaming org.tribuo.math.rng to org.tribuo.math.distributions
Craigacp May 23, 2022
8ba663b
Improving the docs for the la factorization methods.
Craigacp May 23, 2022
d2be6a2
Removing commons math from the third party licenses.
Craigacp May 23, 2022
1a36117
Cleanups.
Craigacp May 23, 2022
799e7ee
Switching to Arrays.fill to zero part of a matrix.
Craigacp Jun 13, 2022
68f1eb1
Removing a dependent load.
Craigacp Jun 13, 2022
c04eda5
Adding an interface for factorizations, migrating the factorizations …
Craigacp Jul 23, 2022
556134f
initial commit of GammaTest
pogren Jul 25, 2022
3cd120f
added testReductionBiFunction and testMeanVariance to DenseVectorTest
pogren Jul 25, 2022
af33f87
added some addition tests/assertions for cholesky and lu factorization.
pogren Jul 25, 2022
63ee5dd
fixes issue with test matrix for cholensky factorization.
pogren Jul 25, 2022
d30a604
adding some python code to help generate a unit test for
pogren Jul 27, 2022
b3ea277
added printMatrixPythonFriendly to DenseMatrix
pogren Jul 27, 2022
72b87c9
added unit testing for eigendecomposition, setColumn, and selectColumns
pogren Jul 27, 2022
6d6dc3f
cleaning up, standardizing, filling out factorization/decomp tests
pogren Jul 27, 2022
ed6aa12
added test for createIdentity and createDiagonal
pogren Jul 27, 2022
9eb0033
added simple test for DenseSparseMatrix.getColumn
pogren Jul 27, 2022
b7c0abb
ClusteringMetrics.adjustedMI produces same values as sklearn
pogren Jul 28, 2022
f8b2966
add delta to unit test for mi
pogren Jul 28, 2022
9595871
comments demonstrating generating test values in numpy/scypy/sklearn
pogren Jul 28, 2022
10818c1
added comment showing how to generate test
pogren Jul 28, 2022
742e304
initial commit of GammaTest
pogren Jul 25, 2022
afc81a6
added testReductionBiFunction and testMeanVariance to DenseVectorTest
pogren Jul 25, 2022
6a6e464
added some addition tests/assertions for cholesky and lu factorization.
pogren Jul 25, 2022
25c7146
fixes issue with test matrix for cholensky factorization.
pogren Jul 25, 2022
d1c6f9c
adding some python code to help generate a unit test for
pogren Jul 27, 2022
5f77d74
added printMatrixPythonFriendly to DenseMatrix
pogren Jul 27, 2022
4c76b61
added unit testing for eigendecomposition, setColumn, and selectColumns
pogren Jul 27, 2022
5a1bd8a
cleaning up, standardizing, filling out factorization/decomp tests
pogren Jul 27, 2022
89addb3
added test for createIdentity and createDiagonal
pogren Jul 27, 2022
c93cdd3
added simple test for DenseSparseMatrix.getColumn
pogren Jul 27, 2022
b0e472b
ClusteringMetrics.adjustedMI produces same values as sklearn
pogren Jul 28, 2022
70e51f9
add delta to unit test for mi
pogren Jul 28, 2022
d91004d
comments demonstrating generating test values in numpy/scypy/sklearn
pogren Jul 28, 2022
c228de1
added comment showing how to generate test
pogren Jul 28, 2022
6fabb6f
fixes compile errors
pogren Jul 28, 2022
4d2aa61
resolves merge conflict
pogren Jul 28, 2022
7108a8b
reverted adjustedMI to using 'min' approach for calculating denominator
pogren Jul 28, 2022
560384f
Merge pull request #2 from pogren/commons-math-removal-pvo-review
Craigacp Jul 28, 2022
717ff81
Fixing licensing information.
Craigacp Jul 28, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,11 @@

package org.tribuo.clustering.evaluation;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import org.tribuo.clustering.ClusterID;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.util.infotheory.InformationTheory;
import org.tribuo.util.infotheory.impl.PairDistribution;
import org.apache.commons.math3.special.Gamma;

import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

/**
Expand Down Expand Up @@ -70,14 +66,26 @@ public ClusteringMetric forTarget(MetricTarget<ClusterID> tgt) {
* @return The adjusted normalized mutual information.
*/
public static double adjustedMI(ClusteringMetric.Context context) {
double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
double predEntropy = InformationTheory.entropy(context.getPredictedIDs());
double trueEntropy = InformationTheory.entropy(context.getTrueIDs());
double expectedMI = expectedMI(context.getPredictedIDs(), context.getTrueIDs());
return adjustedMI(context.getPredictedIDs(), context.getTrueIDs());
}

public static double adjustedMI(List<Integer> predictedIDs, List<Integer> trueIDs) {
double mi = InformationTheory.mi(predictedIDs, trueIDs);
double predEntropy = InformationTheory.entropy(predictedIDs);
double trueEntropy = InformationTheory.entropy(trueIDs);
double expectedMI = InformationTheory.expectedMI(trueIDs, predictedIDs);

double minEntropy = Math.min(predEntropy, trueEntropy);
double denominator = minEntropy - expectedMI;

if (denominator < 0) {
denominator = Math.min(denominator, -2.220446049250313e-16);
} else {
denominator = Math.max(denominator, 2.220446049250313e-16);
}


return (mi - expectedMI) / (minEntropy - expectedMI);
return (mi - expectedMI) / (denominator);
}

/**
Expand All @@ -93,44 +101,4 @@ public static double normalizedMI(ClusteringMetric.Context context) {
return predEntropy < trueEntropy ? mi / predEntropy : mi / trueEntropy;
}

private static double expectedMI(List<Integer> first, List<Integer> second) {
PairDistribution<Integer,Integer> pd = PairDistribution.constructFromLists(first,second);

Map<Integer, MutableLong> firstCount = pd.firstCount;
Map<Integer,MutableLong> secondCount = pd.secondCount;
long count = pd.count;

double output = 0.0;

for (Map.Entry<Integer,MutableLong> f : firstCount.entrySet()) {
for (Map.Entry<Integer,MutableLong> s : secondCount.entrySet()) {
long fVal = f.getValue().longValue();
long sVal = s.getValue().longValue();
long minCount = Math.min(fVal, sVal);

long threshold = fVal + sVal - count;
long start = threshold > 1 ? threshold : 1;

for (long nij = start; nij < minCount; nij++) {
double acc = ((double) nij) / count;
acc *= Math.log(((double) (count * nij)) / (fVal * sVal));
//numerator
double logSpace = Gamma.logGamma(fVal + 1);
logSpace += Gamma.logGamma(sVal + 1);
logSpace += Gamma.logGamma(count - fVal + 1);
logSpace += Gamma.logGamma(count - sVal + 1);
//denominator
logSpace -= Gamma.logGamma(count + 1);
logSpace -= Gamma.logGamma(nij + 1);
logSpace -= Gamma.logGamma(fVal - nij + 1);
logSpace -= Gamma.logGamma(sVal - nij + 1);
logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1);
acc *= Math.exp(logSpace);
output += acc;
}
}
}
return output;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,15 +17,14 @@
package org.tribuo.clustering.example;

import com.oracle.labs.mlrg.olcut.util.Pair;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.impl.ArrayExample;
import org.tribuo.math.distributions.MultivariateNormalDistribution;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.util.Util;

Expand Down Expand Up @@ -63,27 +62,27 @@ public static Dataset<ClusterID> gaussianClusters(long size, long seed) {
String[] featureNames = new String[]{"A","B"};
double[] mixingPMF = new double[]{0.1,0.35,0.05,0.25,0.25};
double[] mixingCDF = Util.generateCDF(mixingPMF);
MultivariateNormalDistribution first = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
new double[]{0.0,0.0}, new double[][]{{1.0,0.0},{0.0,1.0}}
MultivariateNormalDistribution first = new MultivariateNormalDistribution(
new double[]{0.0,0.0}, new double[][]{{1.0,0.0},{0.0,1.0}}, rng.nextInt(), true
);
MultivariateNormalDistribution second = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
new double[]{5.0,5.0}, new double[][]{{1.0,0.0},{0.0,1.0}}
MultivariateNormalDistribution second = new MultivariateNormalDistribution(
new double[]{5.0,5.0}, new double[][]{{1.0,0.0},{0.0,1.0}}, rng.nextInt(), true
);
MultivariateNormalDistribution third = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
new double[]{2.5,2.5}, new double[][]{{1.0,0.5},{0.5,1.0}}
MultivariateNormalDistribution third = new MultivariateNormalDistribution(
new double[]{2.5,2.5}, new double[][]{{1.0,0.5},{0.5,1.0}}, rng.nextInt(), true
);
MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
new double[]{10.0,0.0}, new double[][]{{0.1,0.0},{0.0,0.1}}
MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(
new double[]{10.0,0.0}, new double[][]{{0.1,0.0},{0.0,0.1}}, rng.nextInt(), true
);
MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
new double[]{-1.0,0.0}, new double[][]{{1.0,0.0},{0.0,0.1}}
MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(
new double[]{-1.0,0.0}, new double[][]{{1.0,0.0},{0.0,0.1}}, rng.nextInt(), true
);
MultivariateNormalDistribution[] gaussians = new MultivariateNormalDistribution[]{first,second,third,fourth,fifth};

List<Example<ClusterID>> trainingData = new ArrayList<>();
for (int i = 0; i < size; i++) {
int centroid = Util.sampleFromCDF(mixingCDF,rng);
double[] sample = gaussians[centroid].sample();
double[] sample = gaussians[centroid].sampleArray();
trainingData.add(new ArrayExample<>(new ClusterID(centroid),featureNames,sample));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,17 +22,15 @@
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.math.distributions.MultivariateNormalDistribution;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.util.Util;
Expand Down Expand Up @@ -235,26 +233,26 @@ public void postConfig() {
double[] mixingCDF = Util.generateCDF(mixingDistribution);
String[] featureNames = Arrays.copyOf(allFeatureNames, firstMean.length);
Random rng = new Random(seed);
MultivariateNormalDistribution first = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
firstMean, reshapeAndValidate(firstVariance, "firstVariance")
MultivariateNormalDistribution first = new MultivariateNormalDistribution(
firstMean, reshapeAndValidate(firstVariance, "firstVariance"), rng.nextInt(), true
);
MultivariateNormalDistribution second = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
secondMean, reshapeAndValidate(secondVariance, "secondVariance")
MultivariateNormalDistribution second = new MultivariateNormalDistribution(
secondMean, reshapeAndValidate(secondVariance, "secondVariance"), rng.nextInt(), true
);
MultivariateNormalDistribution third = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
thirdMean, reshapeAndValidate(thirdVariance, "thirdVariance")
MultivariateNormalDistribution third = new MultivariateNormalDistribution(
thirdMean, reshapeAndValidate(thirdVariance, "thirdVariance"), rng.nextInt(), true
);
MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
fourthMean, reshapeAndValidate(fourthVariance, "fourthVariance")
MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(
fourthMean, reshapeAndValidate(fourthVariance, "fourthVariance"), rng.nextInt(), true
);
MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()),
fifthMean, reshapeAndValidate(fifthVariance, "fifthVariance")
MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(
fifthMean, reshapeAndValidate(fifthVariance, "fifthVariance"), rng.nextInt(), true
);
MultivariateNormalDistribution[] Gaussians = new MultivariateNormalDistribution[]{first, second, third, fourth, fifth};
List<Example<ClusterID>> examples = new ArrayList<>(numSamples);
for (int i = 0; i < numSamples; i++) {
int centroid = Util.sampleFromCDF(mixingCDF, rng);
double[] sample = Gaussians[centroid].sample();
double[] sample = Gaussians[centroid].sampleArray();
examples.add(new ArrayExample<>(new ClusterID(centroid), featureNames, sample));
}
this.examples = Collections.unmodifiableList(examples);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.clustering.evaluation;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.clustering.evaluation.ClusteringMetrics.adjustedMI;

import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.tribuo.util.infotheory.InformationTheory;

public class ClusteringMetricsTest {

/*
* import numpy as np
* from sklearn.metrics import adjusted_mutual_info_score
* score = adjusted_mutual_info_score([0,0,1,1], [1,0,1,1])
*
* a = np.random.randint(0,2,500)
* #see printArrayAsJavaDoubles in /tribuo-math/src/test/resources/eigendecomposition-test.py
* print(printArrayAsJavaDoubles(a))
* b = np.random.randint(0,2,500)
* print(printArrayAsJavaDoubles(b))
* score = adjusted_mutual_info_score(a, b)
*/
@Test
void testAdjustedMI() throws Exception {
double logBase = InformationTheory.LOG_BASE;
InformationTheory.LOG_BASE = InformationTheory.LOG_E;
List<Integer> a = Arrays.asList(0, 3, 2, 3, 4, 4, 4, 1, 3, 3, 4, 3, 2, 3, 2, 4, 2, 2, 1, 4, 1,
2, 0, 4, 4, 4, 3, 3, 2, 2, 0, 4, 0, 1, 3, 0, 4, 0, 0, 4, 0, 0, 2, 2, 2, 2, 0, 3, 0, 2, 2, 3,
1, 0, 1, 0, 3, 4, 4, 4, 0, 1, 1, 3, 3, 1, 3, 4, 0, 3, 4, 1, 0, 3, 2, 2, 2, 1, 1, 2, 3, 2, 1,
3, 0, 4, 4, 0, 4, 0, 2, 1, 4, 0, 3, 0, 1, 1, 1, 0);
List<Integer> b = Arrays.asList(4, 2, 4, 0, 4, 4, 3, 3, 3, 2, 2, 0, 1, 3, 2, 1, 2, 0, 0, 4, 3,
3, 0, 1, 1, 1, 1, 4, 4, 4, 3, 1, 0, 0, 0, 1, 4, 1, 1, 1, 3, 3, 1, 2, 3, 0, 4, 0, 2, 3, 4, 2,
3, 2, 1, 0, 2, 4, 2, 2, 4, 1, 2, 4, 3, 1, 1, 1, 3, 0, 2, 3, 2, 0, 1, 0, 0, 4, 0, 3, 0, 0, 0,
1, 3, 2, 3, 4, 2, 4, 1, 0, 3, 3, 0, 2, 1, 0, 4, 1);
assertEquals(0.01454420034676734, adjustedMI(a, b), 1e-14);

a = Arrays.asList(1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0,
1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0,
1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1,
1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0,
1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1,
1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1,
0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0,
1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1,
0, 0, 1, 1, 1, 1, 0, 0, 1);
b = Arrays.asList(1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0,
1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1,
1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1,
0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1,
0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0,
0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,
1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1,
0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
1, 0, 0, 1, 1, 0, 0, 1, 0);
assertEquals(-0.0014006748276587267, adjustedMI(a, b), 1e-14);

//used to create third example
//Random rng = new Random();
//a = new ArrayList<>();
//for(int i=0; i<100; i++) {
// int v = rng.nextDouble()*i < 20 ? 0 : i < 50 ? 1 : 2;
// a.add(v);
//}
//System.out.println(a);
a = Arrays.asList(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 0, 2, 0, 0, 0, 0,
2, 0, 0, 2, 0, 0, 2, 2, 2, 2, 2, 0, 0, 2, 0, 0, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2);
b = Arrays.asList(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 2, 0, 2, 0, 2, 0, 0,
2, 0, 2, 2, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0,
2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0);
assertEquals(0.31766625364399165, adjustedMI(a, b), 1e-14);

assertEquals(1.0, adjustedMI(Arrays.asList(0, 0, 1, 1), Arrays.asList(0, 0, 1, 1)));
assertEquals(1.0, adjustedMI(Arrays.asList(0, 0, 1, 1), Arrays.asList(1, 1, 0, 0)));
assertEquals(0.0, adjustedMI(Arrays.asList(0, 0, 0, 0), Arrays.asList(1, 2, 3, 4)));
assertEquals(0.0, adjustedMI(Arrays.asList(0, 0, 1, 1), Arrays.asList(1, 1, 1, 1)));
assertEquals(0.0834628172282441,
adjustedMI(Arrays.asList(0, 0, 0, 1, 0, 1, 1, 1), Arrays.asList(0, 0, 0, 0, 1, 1, 1, 1)),
1e-15);
assertEquals(0, adjustedMI(Arrays.asList(1, 0, 1, 1), Arrays.asList(0, 0, 1, 1)), 1e-14);

InformationTheory.LOG_BASE = logBase;
}
}
Loading