diff --git a/Clustering/Core/src/main/java/org/tribuo/clustering/evaluation/ClusteringMetrics.java b/Clustering/Core/src/main/java/org/tribuo/clustering/evaluation/ClusteringMetrics.java index a452dc970..e112c310e 100644 --- a/Clustering/Core/src/main/java/org/tribuo/clustering/evaluation/ClusteringMetrics.java +++ b/Clustering/Core/src/main/java/org/tribuo/clustering/evaluation/ClusteringMetrics.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,11 @@ package org.tribuo.clustering.evaluation; -import com.oracle.labs.mlrg.olcut.util.MutableLong; import org.tribuo.clustering.ClusterID; import org.tribuo.evaluation.metrics.MetricTarget; import org.tribuo.util.infotheory.InformationTheory; -import org.tribuo.util.infotheory.impl.PairDistribution; -import org.apache.commons.math3.special.Gamma; import java.util.List; -import java.util.Map; import java.util.function.BiFunction; /** @@ -70,14 +66,26 @@ public ClusteringMetric forTarget(MetricTarget tgt) { * @return The adjusted normalized mutual information. */ public static double adjustedMI(ClusteringMetric.Context context) { - double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs()); - double predEntropy = InformationTheory.entropy(context.getPredictedIDs()); - double trueEntropy = InformationTheory.entropy(context.getTrueIDs()); - double expectedMI = expectedMI(context.getPredictedIDs(), context.getTrueIDs()); + return adjustedMI(context.getPredictedIDs(), context.getTrueIDs()); + } + + public static double adjustedMI(List predictedIDs, List trueIDs) { + double mi = InformationTheory.mi(predictedIDs, trueIDs); + double predEntropy = InformationTheory.entropy(predictedIDs); + double trueEntropy = InformationTheory.entropy(trueIDs); + double expectedMI = InformationTheory.expectedMI(trueIDs, predictedIDs); double minEntropy = Math.min(predEntropy, trueEntropy); + double denominator = minEntropy - expectedMI; + + if (denominator < 0) { + denominator = Math.min(denominator, -2.220446049250313e-16); + } else { + denominator = Math.max(denominator, 2.220446049250313e-16); + } + - return (mi - expectedMI) / (minEntropy - expectedMI); + return (mi - expectedMI) / (denominator); } /** @@ -93,44 +101,4 @@ public static double normalizedMI(ClusteringMetric.Context context) { return predEntropy < trueEntropy ? mi / predEntropy : mi / trueEntropy; } - private static double expectedMI(List first, List second) { - PairDistribution pd = PairDistribution.constructFromLists(first,second); - - Map firstCount = pd.firstCount; - Map secondCount = pd.secondCount; - long count = pd.count; - - double output = 0.0; - - for (Map.Entry f : firstCount.entrySet()) { - for (Map.Entry s : secondCount.entrySet()) { - long fVal = f.getValue().longValue(); - long sVal = s.getValue().longValue(); - long minCount = Math.min(fVal, sVal); - - long threshold = fVal + sVal - count; - long start = threshold > 1 ? threshold : 1; - - for (long nij = start; nij < minCount; nij++) { - double acc = ((double) nij) / count; - acc *= Math.log(((double) (count * nij)) / (fVal * sVal)); - //numerator - double logSpace = Gamma.logGamma(fVal + 1); - logSpace += Gamma.logGamma(sVal + 1); - logSpace += Gamma.logGamma(count - fVal + 1); - logSpace += Gamma.logGamma(count - sVal + 1); - //denominator - logSpace -= Gamma.logGamma(count + 1); - logSpace -= Gamma.logGamma(nij + 1); - logSpace -= Gamma.logGamma(fVal - nij + 1); - logSpace -= Gamma.logGamma(sVal - nij + 1); - logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1); - acc *= Math.exp(logSpace); - output += acc; - } - } - } - return output; - } - } \ No newline at end of file diff --git a/Clustering/Core/src/main/java/org/tribuo/clustering/example/ClusteringDataGenerator.java b/Clustering/Core/src/main/java/org/tribuo/clustering/example/ClusteringDataGenerator.java index a6b59d4bc..40178eb9b 100644 --- a/Clustering/Core/src/main/java/org/tribuo/clustering/example/ClusteringDataGenerator.java +++ b/Clustering/Core/src/main/java/org/tribuo/clustering/example/ClusteringDataGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ package org.tribuo.clustering.example; import com.oracle.labs.mlrg.olcut.util.Pair; -import org.apache.commons.math3.distribution.MultivariateNormalDistribution; -import org.apache.commons.math3.random.JDKRandomGenerator; import org.tribuo.Dataset; import org.tribuo.Example; import org.tribuo.MutableDataset; @@ -26,6 +24,7 @@ import org.tribuo.clustering.ClusteringFactory; import org.tribuo.datasource.ListDataSource; import org.tribuo.impl.ArrayExample; +import org.tribuo.math.distributions.MultivariateNormalDistribution; import org.tribuo.provenance.SimpleDataSourceProvenance; import org.tribuo.util.Util; @@ -63,27 +62,27 @@ public static Dataset gaussianClusters(long size, long seed) { String[] featureNames = new String[]{"A","B"}; double[] mixingPMF = new double[]{0.1,0.35,0.05,0.25,0.25}; double[] mixingCDF = Util.generateCDF(mixingPMF); - MultivariateNormalDistribution first = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - new double[]{0.0,0.0}, new double[][]{{1.0,0.0},{0.0,1.0}} + MultivariateNormalDistribution first = new MultivariateNormalDistribution( + new double[]{0.0,0.0}, new double[][]{{1.0,0.0},{0.0,1.0}}, rng.nextInt(), true ); - MultivariateNormalDistribution second = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - new double[]{5.0,5.0}, new double[][]{{1.0,0.0},{0.0,1.0}} + MultivariateNormalDistribution second = new MultivariateNormalDistribution( + new double[]{5.0,5.0}, new double[][]{{1.0,0.0},{0.0,1.0}}, rng.nextInt(), true ); - MultivariateNormalDistribution third = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - new double[]{2.5,2.5}, new double[][]{{1.0,0.5},{0.5,1.0}} + MultivariateNormalDistribution third = new MultivariateNormalDistribution( + new double[]{2.5,2.5}, new double[][]{{1.0,0.5},{0.5,1.0}}, rng.nextInt(), true ); - MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - new double[]{10.0,0.0}, new double[][]{{0.1,0.0},{0.0,0.1}} + MultivariateNormalDistribution fourth = new MultivariateNormalDistribution( + new double[]{10.0,0.0}, new double[][]{{0.1,0.0},{0.0,0.1}}, rng.nextInt(), true ); - MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - new double[]{-1.0,0.0}, new double[][]{{1.0,0.0},{0.0,0.1}} + MultivariateNormalDistribution fifth = new MultivariateNormalDistribution( + new double[]{-1.0,0.0}, new double[][]{{1.0,0.0},{0.0,0.1}}, rng.nextInt(), true ); MultivariateNormalDistribution[] gaussians = new MultivariateNormalDistribution[]{first,second,third,fourth,fifth}; List> trainingData = new ArrayList<>(); for (int i = 0; i < size; i++) { int centroid = Util.sampleFromCDF(mixingCDF,rng); - double[] sample = gaussians[centroid].sample(); + double[] sample = gaussians[centroid].sampleArray(); trainingData.add(new ArrayExample<>(new ClusterID(centroid),featureNames,sample)); } diff --git a/Clustering/Core/src/main/java/org/tribuo/clustering/example/GaussianClusterDataSource.java b/Clustering/Core/src/main/java/org/tribuo/clustering/example/GaussianClusterDataSource.java index 42a1c2209..b9f793984 100644 --- a/Clustering/Core/src/main/java/org/tribuo/clustering/example/GaussianClusterDataSource.java +++ b/Clustering/Core/src/main/java/org/tribuo/clustering/example/GaussianClusterDataSource.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,10 +22,7 @@ import com.oracle.labs.mlrg.olcut.provenance.Provenance; import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; -import org.apache.commons.math3.distribution.MultivariateNormalDistribution; -import org.apache.commons.math3.random.JDKRandomGenerator; import org.tribuo.ConfigurableDataSource; -import org.tribuo.Dataset; import org.tribuo.Example; import org.tribuo.MutableDataset; import org.tribuo.OutputFactory; @@ -33,6 +30,7 @@ import org.tribuo.clustering.ClusterID; import org.tribuo.clustering.ClusteringFactory; import org.tribuo.impl.ArrayExample; +import org.tribuo.math.distributions.MultivariateNormalDistribution; import org.tribuo.provenance.ConfiguredDataSourceProvenance; import org.tribuo.provenance.DataSourceProvenance; import org.tribuo.util.Util; @@ -235,26 +233,26 @@ public void postConfig() { double[] mixingCDF = Util.generateCDF(mixingDistribution); String[] featureNames = Arrays.copyOf(allFeatureNames, firstMean.length); Random rng = new Random(seed); - MultivariateNormalDistribution first = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - firstMean, reshapeAndValidate(firstVariance, "firstVariance") + MultivariateNormalDistribution first = new MultivariateNormalDistribution( + firstMean, reshapeAndValidate(firstVariance, "firstVariance"), rng.nextInt(), true ); - MultivariateNormalDistribution second = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - secondMean, reshapeAndValidate(secondVariance, "secondVariance") + MultivariateNormalDistribution second = new MultivariateNormalDistribution( + secondMean, reshapeAndValidate(secondVariance, "secondVariance"), rng.nextInt(), true ); - MultivariateNormalDistribution third = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - thirdMean, reshapeAndValidate(thirdVariance, "thirdVariance") + MultivariateNormalDistribution third = new MultivariateNormalDistribution( + thirdMean, reshapeAndValidate(thirdVariance, "thirdVariance"), rng.nextInt(), true ); - MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - fourthMean, reshapeAndValidate(fourthVariance, "fourthVariance") + MultivariateNormalDistribution fourth = new MultivariateNormalDistribution( + fourthMean, reshapeAndValidate(fourthVariance, "fourthVariance"), rng.nextInt(), true ); - MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), - fifthMean, reshapeAndValidate(fifthVariance, "fifthVariance") + MultivariateNormalDistribution fifth = new MultivariateNormalDistribution( + fifthMean, reshapeAndValidate(fifthVariance, "fifthVariance"), rng.nextInt(), true ); MultivariateNormalDistribution[] Gaussians = new MultivariateNormalDistribution[]{first, second, third, fourth, fifth}; List> examples = new ArrayList<>(numSamples); for (int i = 0; i < numSamples; i++) { int centroid = Util.sampleFromCDF(mixingCDF, rng); - double[] sample = Gaussians[centroid].sample(); + double[] sample = Gaussians[centroid].sampleArray(); examples.add(new ArrayExample<>(new ClusterID(centroid), featureNames, sample)); } this.examples = Collections.unmodifiableList(examples); diff --git a/Clustering/Core/src/test/java/org/tribuo/clustering/evaluation/ClusteringMetricsTest.java b/Clustering/Core/src/test/java/org/tribuo/clustering/evaluation/ClusteringMetricsTest.java new file mode 100644 index 000000000..f66237422 --- /dev/null +++ b/Clustering/Core/src/test/java/org/tribuo/clustering/evaluation/ClusteringMetricsTest.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.clustering.evaluation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tribuo.clustering.evaluation.ClusteringMetrics.adjustedMI; + +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.tribuo.util.infotheory.InformationTheory; + +public class ClusteringMetricsTest { + + /* + * import numpy as np + * from sklearn.metrics import adjusted_mutual_info_score + * score = adjusted_mutual_info_score([0,0,1,1], [1,0,1,1]) + * + * a = np.random.randint(0,2,500) + * #see printArrayAsJavaDoubles in /tribuo-math/src/test/resources/eigendecomposition-test.py + * print(printArrayAsJavaDoubles(a)) + * b = np.random.randint(0,2,500) + * print(printArrayAsJavaDoubles(b)) + * score = adjusted_mutual_info_score(a, b) + */ + @Test + void testAdjustedMI() throws Exception { + double logBase = InformationTheory.LOG_BASE; + InformationTheory.LOG_BASE = InformationTheory.LOG_E; + List a = Arrays.asList(0, 3, 2, 3, 4, 4, 4, 1, 3, 3, 4, 3, 2, 3, 2, 4, 2, 2, 1, 4, 1, + 2, 0, 4, 4, 4, 3, 3, 2, 2, 0, 4, 0, 1, 3, 0, 4, 0, 0, 4, 0, 0, 2, 2, 2, 2, 0, 3, 0, 2, 2, 3, + 1, 0, 1, 0, 3, 4, 4, 4, 0, 1, 1, 3, 3, 1, 3, 4, 0, 3, 4, 1, 0, 3, 2, 2, 2, 1, 1, 2, 3, 2, 1, + 3, 0, 4, 4, 0, 4, 0, 2, 1, 4, 0, 3, 0, 1, 1, 1, 0); + List b = Arrays.asList(4, 2, 4, 0, 4, 4, 3, 3, 3, 2, 2, 0, 1, 3, 2, 1, 2, 0, 0, 4, 3, + 3, 0, 1, 1, 1, 1, 4, 4, 4, 3, 1, 0, 0, 0, 1, 4, 1, 1, 1, 3, 3, 1, 2, 3, 0, 4, 0, 2, 3, 4, 2, + 3, 2, 1, 0, 2, 4, 2, 2, 4, 1, 2, 4, 3, 1, 1, 1, 3, 0, 2, 3, 2, 0, 1, 0, 0, 4, 0, 3, 0, 0, 0, + 1, 3, 2, 3, 4, 2, 4, 1, 0, 3, 3, 0, 2, 1, 0, 4, 1); + assertEquals(0.01454420034676734, adjustedMI(a, b), 1e-14); + + a = Arrays.asList(1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, + 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, + 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, + 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, + 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, + 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, + 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, + 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, + 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, + 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, + 0, 0, 1, 1, 1, 1, 0, 0, 1); + b = Arrays.asList(1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, + 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, + 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, + 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, + 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, + 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, + 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, + 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, + 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, + 1, 0, 0, 1, 1, 0, 0, 1, 0); + assertEquals(-0.0014006748276587267, adjustedMI(a, b), 1e-14); + + //used to create third example + //Random rng = new Random(); + //a = new ArrayList<>(); + //for(int i=0; i<100; i++) { + // int v = rng.nextDouble()*i < 20 ? 0 : i < 50 ? 1 : 2; + // a.add(v); + //} + //System.out.println(a); + a = Arrays.asList(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 0, 2, 0, 0, 0, 0, + 2, 0, 0, 2, 0, 0, 2, 2, 2, 2, 2, 0, 0, 2, 0, 0, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2); + b = Arrays.asList(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 2, 0, 2, 0, 2, 0, 0, + 2, 0, 2, 2, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0, + 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0); + assertEquals(0.31766625364399165, adjustedMI(a, b), 1e-14); + + assertEquals(1.0, adjustedMI(Arrays.asList(0, 0, 1, 1), Arrays.asList(0, 0, 1, 1))); + assertEquals(1.0, adjustedMI(Arrays.asList(0, 0, 1, 1), Arrays.asList(1, 1, 0, 0))); + assertEquals(0.0, adjustedMI(Arrays.asList(0, 0, 0, 0), Arrays.asList(1, 2, 3, 4))); + assertEquals(0.0, adjustedMI(Arrays.asList(0, 0, 1, 1), Arrays.asList(1, 1, 1, 1))); + assertEquals(0.0834628172282441, + adjustedMI(Arrays.asList(0, 0, 0, 1, 0, 1, 1, 1), Arrays.asList(0, 0, 0, 0, 1, 1, 1, 1)), + 1e-15); + assertEquals(0, adjustedMI(Arrays.asList(1, 0, 1, 1), Arrays.asList(0, 0, 1, 1)), 1e-14); + + InformationTheory.LOG_BASE = logBase; + } +} diff --git a/Core/src/test/java/org/tribuo/test/Helpers.java b/Core/src/test/java/org/tribuo/test/Helpers.java index 648630b0a..3a443bebb 100644 --- a/Core/src/test/java/org/tribuo/test/Helpers.java +++ b/Core/src/test/java/org/tribuo/test/Helpers.java @@ -24,6 +24,7 @@ import com.oracle.labs.mlrg.olcut.provenance.Provenancable; import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance; +import com.oracle.labs.mlrg.olcut.util.Pair; import org.junit.jupiter.api.Assertions; import org.tribuo.Example; import org.tribuo.Feature; @@ -184,4 +185,44 @@ public static > void testSequenceModelSerialization(Sequence Assertions.fail("Failed to deserialize sequence model class " + model.getClass().toString(), ex); } } + + /** + * Compares two top feature lists according to the specified tolerances returning true when the lists have the + * same elements and the difference between the scores is within the tolerance. + *

+ * Mostly used when refactoring implementations to compare the new and old one. + * @param first The first feature list. + * @param second The second feature list. + * @param tolerance The tolerance for the scores. + * @return True if the feature lists are equal. + */ + public static boolean topFeaturesEqual(Map>> first, Map>> second, double tolerance) { + if (first.size() == second.size() && first.keySet().containsAll(second.keySet())) { + // keys the same, now check lists + for (Map.Entry>> e : first.entrySet()) { + List> firstList = e.getValue(); + List> secondList = second.get(e.getKey()); + if (firstList.size() == secondList.size()) { + // Now compare lists + for (int i = 0; i < firstList.size(); i++) { + Pair firstPair = firstList.get(i); + Pair secondPair = secondList.get(i); + if (firstPair.getA().equals(secondPair.getA())) { + double diff = Math.abs(firstPair.getB() - secondPair.getB()); + if (diff > tolerance) { + return false; + } + } else { + return false; + } + } + } else { + return false; + } + } + return true; + } else { + return false; + } + } } diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java new file mode 100644 index 000000000..91fb0fe97 --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.distributions; + +import org.tribuo.math.la.DenseMatrix; +import org.tribuo.math.la.DenseSparseMatrix; +import org.tribuo.math.la.DenseVector; + +import java.util.Arrays; +import java.util.Optional; +import java.util.Random; + +/** + * A class for sampling from multivariate normal distributions. + */ +public final class MultivariateNormalDistribution { + + private final long seed; + private final Random rng; + private final DenseVector means; + private final DenseMatrix covariance; + private final DenseMatrix samplingCovariance; + private final boolean eigenDecomposition; + + /** + * Constructs a multivariate normal distribution that can be sampled from. + *

+ * Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. + *

+ * Uses a {@link org.tribuo.math.la.DenseMatrix.CholeskyFactorization} to compute the sampling + * covariance matrix. + * @param means The mean vector. + * @param covariance The covariance matrix. + * @param seed The RNG seed. + */ + public MultivariateNormalDistribution(double[] means, double[][] covariance, long seed) { + this(DenseVector.createDenseVector(means),DenseMatrix.createDenseMatrix(covariance),seed); + } + + /** + * Constructs a multivariate normal distribution that can be sampled from. + *

+ * Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. + * @param means The mean vector. + * @param covariance The covariance matrix. + * @param seed The RNG seed. + * @param eigenDecomposition If true use an eigen decomposition to compute the sampling covariance matrix + * rather than a cholesky factorization. + */ + public MultivariateNormalDistribution(double[] means, double[][] covariance, long seed, boolean eigenDecomposition) { + this(DenseVector.createDenseVector(means),DenseMatrix.createDenseMatrix(covariance),seed, eigenDecomposition); + } + + /** + * Constructs a multivariate normal distribution that can be sampled from. + *

+ * Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. + *

+ * Uses a {@link org.tribuo.math.la.DenseMatrix.CholeskyFactorization} to compute the sampling + * covariance matrix. + * @param means The mean vector. + * @param covariance The covariance matrix. + * @param seed The RNG seed. + */ + public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, long seed) { + this(means,covariance,seed,false); + } + + /** + * Constructs a multivariate normal distribution that can be sampled from. + *

+ * Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. + * @param means The mean vector. + * @param covariance The covariance matrix. + * @param seed The RNG seed. + * @param eigenDecomposition If true use an eigen decomposition to compute the sampling covariance matrix + * rather than a cholesky factorization. + */ + public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, long seed, boolean eigenDecomposition) { + this.seed = seed; + this.rng = new Random(seed); + this.means = means.copy(); + this.covariance = covariance.copy(); + if (this.covariance.getDimension1Size() != this.means.size() || this.covariance.getDimension2Size() != this.means.size()) { + throw new IllegalArgumentException("Covariance matrix must be square and the same dimension as the mean vector. Mean vector size = " + means.size() + ", covariance size = " + Arrays.toString(this.covariance.getShape())); + } + this.eigenDecomposition = eigenDecomposition; + if (eigenDecomposition) { + Optional factorization = this.covariance.eigenDecomposition(); + if (factorization.isPresent() && factorization.get().positiveEigenvalues()) { + DenseVector eigenvalues = factorization.get().eigenvalues(); + // rows are eigenvectors + DenseMatrix eigenvectors = new DenseMatrix(factorization.get().eigenvectors()); + // scale eigenvectors by sqrt of eigenvalues + eigenvalues.foreachInPlace(Math::sqrt); + DenseSparseMatrix diagonal = DenseSparseMatrix.createDiagonal(eigenvalues);; + this.samplingCovariance = eigenvectors.matrixMultiply(diagonal).matrixMultiply(eigenvectors,false,true); + } else { + throw new IllegalArgumentException("Covariance matrix is not positive definite."); + } + } else { + Optional factorization = this.covariance.choleskyFactorization(); + if (factorization.isPresent()) { + this.samplingCovariance = factorization.get().lMatrix(); + } else { + throw new IllegalArgumentException("Covariance matrix is not positive definite."); + } + } + } + + /** + * Sample a vector from this multivariate normal distribution. + * @return A sample from this distribution. + */ + public DenseVector sampleVector() { + DenseVector sampled = new DenseVector(means.size()); + for (int i = 0; i < means.size(); i++) { + sampled.set(i,rng.nextGaussian()); + } + + sampled = samplingCovariance.leftMultiply(sampled); + + return means.add(sampled); + } + + /** + * Sample a vector from this multivariate normal distribution. + * @return A sample from this distribution. + */ + public double[] sampleArray() { + return sampleVector().toArray(); + } + + @Override + public String toString() { + return "MultivariateNormal(mean="+means+",covariance="+covariance+",seed="+seed+",useEigenDecomposition="+eigenDecomposition+")"; + } +} diff --git a/Math/src/main/java/org/tribuo/math/distributions/package-info.java b/Math/src/main/java/org/tribuo/math/distributions/package-info.java new file mode 100644 index 000000000..fb2e1dde0 --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/distributions/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * A package of statistical distributions. + */ +package org.tribuo.math.distributions; \ No newline at end of file diff --git a/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java b/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java index 9e4b306ad..b7c2a7b2a 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,18 +16,28 @@ package org.tribuo.math.la; +import com.oracle.labs.mlrg.olcut.util.SortUtil; import org.tribuo.math.util.VectorNormalizer; import java.util.Arrays; +import java.util.List; import java.util.NoSuchElementException; import java.util.Objects; +import java.util.Optional; import java.util.function.DoubleUnaryOperator; +import java.util.logging.Logger; /** * A dense matrix, backed by a primitive array. */ public class DenseMatrix implements Matrix { private static final long serialVersionUID = 1L; + private static final Logger logger = Logger.getLogger(DenseMatrix.class.getName()); + + /** + * Tolerance for non-zero diagonal values in the factorizations. + */ + public static final double FACTORIZATION_TOLERANCE = 1e-14; private static final double DELTA = 1e-10; @@ -78,8 +88,16 @@ public DenseMatrix(Matrix other) { this.dim1 = other.getDimension1Size(); this.dim2 = other.getDimension2Size(); this.values = new double[dim1][dim2]; - for (MatrixTuple t : other) { - this.values[t.i][t.j] = t.value; + if (other instanceof DenseMatrix) { + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + this.values[i][j] = other.get(i,j); + } + } + } else { + for (MatrixTuple t : other) { + this.values[t.i][t.j] = t.value; + } } this.shape = new int[]{dim1,dim2}; this.numElements = dim1*dim2; @@ -119,6 +137,30 @@ public static DenseMatrix createDenseMatrix(double[][] values) { return new DenseMatrix(newValues); } + /** + * Constructs a new DenseMatrix copying the values from the supplied vectors. + *

+ * Throws {@link IllegalArgumentException} if the supplied vectors are ragged (i.e., are not all the same size). + * @param vectors The vectors to coalesce. + * @return A new dense matrix. + */ + public static DenseMatrix createDenseMatrix(SGDVector[] vectors) { + if (vectors == null || vectors.length == 0) { + throw new IllegalArgumentException("Invalid vector array."); + } + double[][] newValues = new double[vectors.length][]; + + int size = vectors[0].size(); + for (int i = 0; i < vectors.length; i++) { + if (vectors[i].size() != size) { + throw new IllegalArgumentException("Expected size " + size + " but found size " + vectors[i].size() + " at index " + i); + } + newValues[i] = vectors[i].toArray(); + } + + return new DenseMatrix(newValues); + } + @Override public int[] getShape() { return shape; @@ -184,7 +226,7 @@ public DenseVector gatherAcrossDim1(int[] elements) { double[] outputValues = new double[dim2]; for (int i = 0; i < elements.length; i++) { - outputValues[i] = values[elements[i]][i]; + outputValues[i] = get(elements[i],i); } return new DenseVector(outputValues); @@ -202,7 +244,7 @@ public DenseVector gatherAcrossDim2(int[] elements) { double[] outputValues = new double[dim1]; for (int i = 0; i < elements.length; i++) { - outputValues[i] = values[i][elements[i]]; + outputValues[i] = get(i,elements[i]); } return new DenseVector(outputValues); @@ -246,7 +288,7 @@ public boolean equals(Object o) { @Override public int hashCode() { int result = Objects.hash(dim1, dim2, numElements); - result = 31 * result + Arrays.hashCode(values); + result = 31 * result + Arrays.deepHashCode(values); result = 31 * result + Arrays.hashCode(getShape()); return result; } @@ -281,7 +323,7 @@ public DenseVector leftMultiply(SGDVector input) { // If it's sparse we iterate the tuples for (VectorTuple tuple : input) { for (int i = 0; i < output.length; i++) { - output[i] += values[i][tuple.index] * tuple.value; + output[i] += get(i,tuple.index) * tuple.value; } } } @@ -307,7 +349,7 @@ public DenseVector rightMultiply(SGDVector input) { // If it's sparse we iterate the tuples for (VectorTuple tuple : input) { for (int i = 0; i < output.length; i++) { - output[i] += values[tuple.index][i] * tuple.value; + output[i] += get(tuple.index,i) * tuple.value; } } } @@ -501,7 +543,7 @@ public DenseVector rowSum() { for (int i = 0; i < dim1; i++) { double tmp = 0.0; for (int j = 0; j < dim2; j++) { - tmp += values[i][j]; + tmp += get(i,j); } rowSum[i] = tmp; } @@ -676,6 +718,9 @@ public int numActiveElements(int row) { @Override public DenseVector getRow(int i) { + if (i < 0 || i > dim1) { + throw new IllegalArgumentException("Invalid row index, must be [0,"+dim1+"), received " + i); + } return new DenseVector(values[i]); } @@ -684,24 +729,54 @@ public DenseVector getRow(int i) { * @param index The column index. * @return A copy of the column. */ + @Override public DenseVector getColumn(int index) { + if (index < 0 || index > dim2) { + throw new IllegalArgumentException("Invalid column index, must be [0,"+dim2+"), received " + index); + } double[] output = new double[dim1]; for (int i = 0; i < dim1; i++) { - output[i] = values[i][index]; + output[i] = get(i,index); } return new DenseVector(output); } + /** + * Sets the column to the supplied vector value. + * @param index The column to set. + * @param vector The vector to write. + */ + public void setColumn(int index, SGDVector vector) { + if (index < 0 || index > dim2) { + throw new IllegalArgumentException("Invalid column index, must be [0,"+dim2+"), received " + index); + } + if (vector.size() == dim1) { + if (vector instanceof DenseVector) { + for (int i = 0; i < dim1; i++) { + values[i][index] = vector.get(i); + } + } else { + for (VectorTuple t : vector) { + values[t.index][index] = t.value; + } + } + } else { + throw new IllegalArgumentException("Vector size mismatch, expected " + dim1 + " found " + vector.size()); + } + } + /** * Calculates the sum of the specified row. * @param rowIndex The index of the row to sum. * @return The row sum. */ public double rowSum(int rowIndex) { - double[] row = values[rowIndex]; + if (rowIndex < 0 || rowIndex > dim1) { + throw new IllegalArgumentException("Invalid row index, must be [0,"+dim1+"), received " + rowIndex); + } double sum = 0d; - for (int i = 0; i < row.length; i++) { - sum += row[i]; + for (int i = 0; i < dim2; i++) { + sum += get(rowIndex,i); } return sum; } @@ -712,9 +787,12 @@ public double rowSum(int rowIndex) { * @return The column sum. */ public double columnSum(int columnIndex) { + if (columnIndex < 0 || columnIndex > dim2) { + throw new IllegalArgumentException("Invalid column index, must be [0,"+dim2+"), received " + columnIndex); + } double sum = 0d; for (int i = 0; i < dim1; i++) { - sum += values[i][columnIndex]; + sum += get(i,columnIndex); } return sum; } @@ -731,6 +809,422 @@ public double twoNorm() { return Math.sqrt(output); } + /** + * Returns a copy of this matrix as a 2d array. + * @return A copy of this matrix. + */ + public double[][] toArray() { + double[][] copy = new double[dim1][]; + for (int i = 0; i < dim1; i++) { + copy[i] = Arrays.copyOf(values[i],dim2); + } + return copy; + } + + /** + * Is this a square matrix? + * @return True if the matrix is square. + */ + public boolean isSquare() { + return dim1 == dim2; + } + + /** + * Returns true if this matrix is square and symmetric. + * @return True if the matrix is symmetric. + */ + public boolean isSymmetric() { + if (!isSquare()) { + return false; + } else { + for (int i = 0; i < dim1; i++) { + for (int j = i + 1; j < dim1; j++) { + if (Double.compare(get(i,j),get(j,i)) != 0) { + return false; + } + } + } + return true; + } + } + + /** + * Computes the Cholesky factorization of a positive definite matrix. + *

+ * If the matrix is not symmetric or positive definite then it returns an empty optional. + * @return The Cholesky factorization or an empty optional. + */ + public Optional choleskyFactorization() { + if (!isSymmetric()) { + logger.fine("Returning empty optional as matrix is not symmetric"); + return Optional.empty(); + } else { + // Copy the matrix first + DenseMatrix chol = new DenseMatrix(this); + double[][] cholMatrix = chol.values; + + // Compute factorization + for (int i = 0; i < dim1; i++) { + for (int j = i; j < dim1; j++) { + double sum = cholMatrix[i][j]; + for (int k = i - 1; k >= 0; k--) { + sum -= cholMatrix[i][k] * cholMatrix[j][k]; + } + if (i == j) { + if (sum <= FACTORIZATION_TOLERANCE) { + // Matrix is not positive definite as it has a negative diagonal element. + logger.fine("Returning empty optional as matrix is not positive definite"); + return Optional.empty(); + } else { + cholMatrix[i][i] = Math.sqrt(sum); + } + } else { + cholMatrix[j][i] = sum / cholMatrix[i][i]; + } + } + } + + // Zero out the upper triangle + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < i; j++) { + cholMatrix[j][i] = 0.0; + } + } + + return Optional.of(new CholeskyFactorization(chol)); + } + } + + /** + * Computes the LU factorization of a square matrix. + *

+ * If the matrix is singular or not square it returns an empty optional. + * @return The LU factorization or an empty optional. + */ + public Optional luFactorization() { + if (!isSquare()) { + logger.fine("Returning empty optional as matrix is not square"); + return Optional.empty(); + } else { + // Copy the matrix first & init variables + DenseMatrix lu = new DenseMatrix(this); + double[][] luMatrix = lu.values; + int[] permutation = new int[dim1]; + boolean oddSwaps = false; + for (int i = 0; i < dim1; i++) { + permutation[i] = i; + } + + // Decompose matrix + for (int i = 0; i < dim1; i++) { + double max = 0.0; + int maxIdx = i; + + // Find max element + for (int k = i; k < dim1; k++) { + double cur = Math.abs(luMatrix[k][i]); + if (cur > max) { + max = cur; + maxIdx = k; + } + } + + if (max < FACTORIZATION_TOLERANCE) { + // zero diagonal element, matrix is singular + logger.fine("Returning empty optional as matrix is singular"); + return Optional.empty(); + } + + // Pivot matrix if necessary + if (maxIdx != i) { + // Update permutation array + int tmpIdx = permutation[maxIdx]; + permutation[maxIdx] = permutation[i]; + permutation[i] = tmpIdx; + oddSwaps = !oddSwaps; + + // Swap rows + double[] tmpRow = luMatrix[maxIdx]; + luMatrix[maxIdx] = luMatrix[i]; + luMatrix[i] = tmpRow; + } + + // Eliminate row + for (int j = i + 1; j < dim1; j++) { + // Rescale lower triangle + luMatrix[j][i] /= luMatrix[i][i]; + + for (int k = i + 1; k < dim1; k++) { + luMatrix[j][k] -= luMatrix[j][i] * luMatrix[i][k]; + } + } + } + + // Split into two matrices + DenseMatrix l = new DenseMatrix(lu); + DenseMatrix u = new DenseMatrix(lu); + + // Zero lower triangle of u + for (int i = 0; i < dim1; i++) { + Arrays.fill(u.values[i],0,i,0.0); + } + + // Zero upper triangle of l and set diagonal to 1. + for (int i = 0; i < dim1; i++) { + for (int j = 0; j <= i; j++) { + if (i == j) { + l.values[i][j] = 1.0; + } else { + l.values[j][i] = 0.0; + } + } + } + + return Optional.of(new LUFactorization(l,u,permutation,oddSwaps)); + } + } + + /** + * Eigen decomposition of a symmetric matrix. + *

+ * Non-symmetric matrices return an empty Optional as they may have complex eigenvalues, and + * any matrix which exceeds the default number of QL iterations in the decomposition also + * returns an empty Optional. + * @return The eigen decomposition of a symmetric matrix, or an empty optional if it's not symmetric. + */ + public Optional eigenDecomposition() { + if (!isSymmetric()) { + logger.fine("Returning empty optional as matrix is not symmetric"); + return Optional.empty(); + } else { + // Copy the matrix first & init variables + DenseMatrix transform = new DenseMatrix(this); + double[][] transformValues = transform.values; + + // arrays for holding the tridiagonal form. + double[] diagonal = new double[dim1]; + double[] offDiagonal = new double[dim1]; // first element is zero + + // First tridiagonalize the matrix via a Householder reduction + + // Copy last row into diagonal + System.arraycopy(transformValues[dim1 - 1], 0, diagonal, 0, dim1); + + // Iterate up the matrix, reducing it + for (int i = dim1-1; i > 0; i--) { + // Accumulate scale along current diagonal + double scale = 0.0; + for (int k = 0; k < i; k++) { + scale += Math.abs(diagonal[k]); + } + + double diagElement = 0.0; + if (scale == 0.0) { + offDiagonal[i] = 0.0; // if scale is zero then diagonal[0...i-1] = 0 + for (int j = 0; j < i; j++) { + // copy in new row + diagonal[j] = transformValues[i-1][j]; + // zero row & column + transformValues[i][j] = 0.0; + transformValues[j][i] = 0.0; + } + } else { + // Generate Householder vector + for (int k = 0; k < i; k++) { + final double tmp = diagonal[k] / scale; + diagElement += tmp * tmp; + diagonal[k] = tmp; + offDiagonal[k] = 0; + } + final double nextDiag = diagonal[i-1]; + final double offDiag = nextDiag >= 0 ? -Math.sqrt(diagElement) : Math.sqrt(diagElement); + + offDiagonal[i] = scale * offDiag; + diagElement -= offDiag * nextDiag; + diagonal[i-1] = nextDiag - offDiag; + + // Transform the remaining vectors + for (int j = 0; j < i; j++) { + final double transDiag = diagonal[j]; + // Write back to matrix + transformValues[j][i] = transDiag; + double transOffDiag = offDiagonal[j] + transformValues[j][j] * transDiag; + + // Sum remaining column and update off diagonals + for (int k = j + 1; k < i; k++) { + double tmp = transformValues[k][j]; + transOffDiag += tmp * diagonal[k]; + offDiagonal[k] += tmp * transDiag; + } + offDiagonal[j] = transOffDiag; + } + + double scaledElementSum = 0.0; + for (int j = 0; j < i; j++) { + final double tmp = offDiagonal[j] / diagElement; + offDiagonal[j] = tmp; + scaledElementSum += tmp * diagonal[j]; + } + final double offDiagScalingFactor = scaledElementSum / (diagElement + diagElement); + for (int j = 0; j < i; j++) { + offDiagonal[j] -= offDiagScalingFactor * diagonal[j]; + } + + for (int j = 0; j < i; j++) { + final double tmpDiag = diagonal[j]; + final double tmpOffDiag = offDiagonal[j]; + for (int k = j; k < i; k++) { + transformValues[k][j] -= (tmpDiag * offDiagonal[k]) + (tmpOffDiag * diagonal[k]); + } + diagonal[j] = transformValues[i-1][j]; + transformValues[i][j] = 0.0; + } + } + diagonal[i] = diagElement; + } + + // Finish transformation to tridiagonal + int dimMinusOne = dim1-1; + for (int i = 0; i < dimMinusOne; i++) { + transformValues[dimMinusOne][i] = transformValues[i][i]; + transformValues[i][i] = 1.0; + final int nextIdx = i + 1; + final double nextDiag = diagonal[nextIdx]; + if (nextDiag != 0.0) { + // Recompute diagonal and rescale matrix + for (int k = 0; k < nextIdx; k++) { + diagonal[k] = transformValues[k][nextIdx] / nextDiag; + } + for (int j = 0; j < nextIdx; j++) { + double scaleAccumulator = 0.0; + for (int k = 0; k < nextIdx; k++) { + scaleAccumulator += transformValues[k][nextIdx] * transformValues[k][j]; + } + for (int k = 0; k < nextIdx; k++) { + transformValues[k][j] -= scaleAccumulator * diagonal[k]; + } + } + // Zero lower column + for (int j = 0; j < nextIdx; j++) { + transformValues[j][nextIdx] = 0.0; + } + } + } + for (int j = 0; j < dim1; j++) { + diagonal[j] = transformValues[dimMinusOne][j]; + transformValues[dimMinusOne][j] = 0.0; + } + transformValues[dimMinusOne][dimMinusOne] = 1.0; + offDiagonal[0] = 0.0; + + // Copy to dense vector/matrix for storage in the returned object as we're going to mutate these arrays + DenseVector diagVector = DenseVector.createDenseVector(diagonal); + DenseVector offDiagVector = DenseVector.createDenseVector(offDiagonal); + DenseMatrix householderMatrix = new DenseMatrix(transform); + + // Then compute eigen vectors & values using an iterated tridiagonal QL algorithm + + // Setup constants + final int maxItr = 35; // Maximum number of QL iterations before giving up and returning empty optional. + final double eps = Double.longBitsToDouble(4372995238176751616L); // Math.pow(2,-52) + + // Copy off diagonal up for ease of use + System.arraycopy(offDiagonal, 1, offDiagonal, 0, dimMinusOne); + offDiagonal[dimMinusOne] = 0.0; + + double diagAccum = 0.0; + double largestDiagSum = 0.0; + for (int i = 0; i < dim1; i++) { + largestDiagSum = Math.max(largestDiagSum, Math.abs(diagonal[i]) + Math.abs(offDiagonal[i])); + final double testVal = largestDiagSum*eps; + // Find small value to partition the matrix + int idx = i; + while (idx < dim1) { + if (Math.abs(offDiagonal[idx]) <= testVal) { + break; + } + idx++; + } + + // if we didn't break out of the loop the diagonal value is an eigenvalue + // otherwise perform QL iterations + if (idx > i) { + int iter = 0; + do { + if (iter > maxItr) { + // Exceeded QL iteration count; + logger.fine("Exceeded QL iteration count in eigenDecomposition"); + return Optional.empty(); + } else { + iter++; + } + + // Compute shift + final double curDiag = diagonal[i]; + final double shift = (diagonal[i+1] - curDiag) / (2 * offDiagonal[i]); + final double shiftLength = shift < 0 ? -Math.hypot(shift, 1.0) : Math.hypot(shift, 1.0); + diagonal[i] = offDiagonal[i] / (shift + shiftLength); + diagonal[i+1] = offDiagonal[i] * (shift + shiftLength); + + final double nextDiag = diagonal[i+1]; + final double diagShift = curDiag - diagonal[i]; + for (int j = i + 2; j < dim1; j++) { + diagonal[j] -= diagShift; + } + diagAccum += diagShift; + + // Compute implicit QL + double partitionDiag = diagonal[idx]; + final double oldOffDiag = offDiagonal[i+1]; + double c = 1.0, c2 = 1.0, c3 = 1.0; + double s = 0.0, prevS = 0.0; + for (int j = idx-1; j >= i; j--) { + c3 = c2; + c2 = c; + prevS = s; + final double scaledOffDiag = c * offDiagonal[j]; + final double scaledDiag = c * partitionDiag; + final double dist = Math.hypot(partitionDiag, offDiagonal[j]); + offDiagonal[j+1] = s * dist; + s = offDiagonal[j] / dist; + c = partitionDiag / dist; + partitionDiag = (c * diagonal[j]) - (s * scaledOffDiag); + diagonal[j+1] = scaledDiag + s * ((c * scaledOffDiag) + (s * diagonal[j])); + + // Update eigenvectors + for (int k = 0; k < dim1; k++) { + final double[] row = transformValues[k]; + final double tmp = row[j+1]; + row[j+1] = (s * row[j]) + (c * tmp); + row[j] = (c * row[j]) - (s * tmp); + } + } + partitionDiag = -s * prevS * c3 * oldOffDiag * offDiagonal[i] / nextDiag; + offDiagonal[i] = s * partitionDiag; + diagonal[i] = c * partitionDiag; + } while (Math.abs(offDiagonal[i]) > testVal); + } + + diagonal[i] += diagAccum; + offDiagonal[i] = 0.0; + } + + // Sort eigenvalues and eigenvectors + int[] indices = SortUtil.argsort(diagonal, false); + double[] eigenValues = new double[dim1]; + double[][] eigenVectors = new double[dim1][dim1]; + + for (int i = 0; i < indices.length; i++) { + eigenValues[i] = diagonal[indices[i]]; + for (int j = 0; j < dim1; j++) { + eigenVectors[j][i] = transformValues[j][indices[i]]; + } + } + + return Optional.of(new EigenDecomposition(new DenseVector(eigenValues),new DenseMatrix(eigenVectors),diagVector,offDiagVector,householderMatrix)); + } + } + @Override public String toString() { StringBuilder buffer = new StringBuilder(); @@ -776,26 +1270,78 @@ public void normalizeRows(VectorNormalizer normalizer) { } /** - * Returns the dense vector containing each column sum. + * Returns a dense vector containing each column sum. * @return The column sums. */ public DenseVector columnSum() { double[] columnSum = new double[dim2]; for (int i = 0; i < dim1; i++) { for (int j = 0; j < dim2; j++) { - columnSum[j] += values[i][j]; + columnSum[j] += get(i,j); } } return new DenseVector(columnSum); } + /** + * Returns a new DenseMatrix containing a copy of the selected columns. + *

+ * Throws {@link IllegalArgumentException} if any column index is invalid or the array is null/empty. + * @param columnIndices The column indices + * @return The submatrix comprising the selected columns. + */ + public DenseMatrix selectColumns(int[] columnIndices) { + if (columnIndices == null || columnIndices.length == 0) { + throw new IllegalArgumentException("Invalid column indices."); + } + DenseMatrix returnVal = new DenseMatrix(dim1,columnIndices.length); + + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < columnIndices.length; j++) { + int curIdx = columnIndices[j]; + if (curIdx < 0 || curIdx >= dim2) { + throw new IllegalArgumentException("Invalid column index, expected [0, " + dim2 +"), found " + curIdx); + } + returnVal.values[i][j] = get(i,curIdx); + } + } + + return returnVal; + } + + /** + * Returns a new DenseMatrix containing a copy of the selected columns. + *

+ * Throws {@link IllegalArgumentException} if any column index is invalid or the list is null/empty. + * @param columnIndices The column indices + * @return The submatrix comprising the selected columns. + */ + public DenseMatrix selectColumns(List columnIndices) { + if (columnIndices == null || columnIndices.isEmpty()) { + throw new IllegalArgumentException("Invalid column indices."); + } + DenseMatrix returnVal = new DenseMatrix(dim1,columnIndices.size()); + + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < columnIndices.size(); j++) { + int curIdx = columnIndices.get(j); + if (curIdx < 0 || curIdx >= dim2) { + throw new IllegalArgumentException("Invalid column index, expected [0, " + dim2 +"), found " + curIdx); + } + returnVal.values[i][j] = get(i,curIdx); + } + } + + return returnVal; + } + private class DenseMatrixIterator implements MatrixIterator { private final DenseMatrix matrix; private final MatrixTuple tuple; private int i; private int j; - public DenseMatrixIterator(DenseMatrix matrix) { + DenseMatrixIterator(DenseMatrix matrix) { this.matrix = matrix; this.tuple = new MatrixTuple(); this.i = 0; @@ -831,4 +1377,481 @@ public MatrixTuple next() { } } + /** + * The output of a successful Cholesky factorization. + *

+ * Essentially wraps a {@link DenseMatrix}, but has additional + * operations which allow more efficient implementations when the + * matrix is known to be the result of a Cholesky factorization. + *

+ * Mutating the wrapped matrix will cause undefined behaviour in the methods + * of this class. + *

+ * May be refactored into a record in the future. + */ + public static final class CholeskyFactorization implements Matrix.Factorization { + private final DenseMatrix lMatrix; + + CholeskyFactorization(DenseMatrix lMatrix) { + this.lMatrix = lMatrix; + } + + /** + * The lower triangular factorized matrix. + * @return The factorization matrix. + */ + public DenseMatrix lMatrix() { + return lMatrix; + } + + @Override + public int dim1() { + return lMatrix.dim1; + } + + @Override + public int dim2() { + return lMatrix.dim2; + } + + /** + * Compute the matrix determinant of the factorized matrix. + * @return The matrix determinant. + */ + @Override + public double determinant() { + double det = 0.0; + for (int i = 0; i < lMatrix.dim1; i++) { + det *= lMatrix.values[i][i] * lMatrix.values[i][i]; + } + return det; + } + + /** + * Solves a system of linear equations A * b = y, where y is the input vector, + * A is the matrix which produced this Cholesky factorization, and b is the returned value. + * @param vector The input vector y. + * @return The vector b. + */ + @Override + public DenseVector solve(SGDVector vector) { + if (vector.size() != lMatrix.dim1) { + throw new IllegalArgumentException("Size mismatch, expected " + lMatrix.dim1 + ", received " + vector.size()); + } + final double[] vectorArr = vector.toArray(); + final double[] output = new double[lMatrix.dim1]; + + // Solve matrix . y = vector + for (int i = 0; i < lMatrix.dim1; i++) { + double sum = vectorArr[i]; + for (int j = i-1; j >= 0; j--) { + sum -= lMatrix.values[i][j] * output[j]; + } + output[i] = sum / lMatrix.values[i][i]; + } + + // Solve matrix^T . output = y + for (int i = lMatrix.dim1-1; i >= 0; i--) { + double sum = output[i]; + for (int j = i+1; j < lMatrix.dim1; j++) { + sum -= lMatrix.values[j][i] * output[j]; + } + output[i] = sum / lMatrix.values[i][i]; + } + + return new DenseVector(output); + } + + /** + * Solves the system A * X = B, where B is the input matrix, and A is the matrix which + * produced this Cholesky factorization. + * @param matrix The input matrix B. + * @return The matrix X. + */ + @Override + public DenseMatrix solve(Matrix matrix) { + if (matrix.getDimension1Size() != lMatrix.dim1) { + throw new IllegalArgumentException("Size mismatch, expected " + lMatrix.dim1 + ", received " + matrix.getDimension1Size()); + } + final int outputDim1 = lMatrix.dim1; + final int outputDim2 = matrix.getDimension2Size(); + final DenseMatrix output = new DenseMatrix(matrix); + final double[][] outputArr = output.values; + + // Solve L.Y = B + for (int i = 0; i < outputDim1; i++) { + for (int j = 0; j < outputDim2; j++) { + for (int k = 0; k < i; k++) { + outputArr[i][j] -= outputArr[k][j] * lMatrix.values[i][k]; + } + // scale by diagonal + outputArr[i][j] /= lMatrix.values[i][i]; + } + } + + // Solve L^T.X = Y + for (int i = outputDim1 - 1; i >= 0; i--) { + for (int j = 0; j < outputDim2; j++) { + for (int k = i + 1; k < outputDim2; k++) { + outputArr[i][j] -= outputArr[k][j] * lMatrix.values[k][i]; + } + // scale by diagonal + outputArr[i][j] /= lMatrix.values[i][i]; + } + } + + return output; + } + } + + /** + * The output of a successful LU factorization. + *

+ * Essentially wraps a pair of {@link DenseMatrix}, but has additional + * operations which allow more efficient implementations when the + * matrices are known to be the result of a LU factorization. + *

+ * Mutating the wrapped matrices will cause undefined behaviour in the methods + * of this class. + *

+ * May be refactored into a record in the future. + */ + public static final class LUFactorization implements Matrix.Factorization { + private final DenseMatrix lower; + private final DenseMatrix upper; + private final int[] permutationArr; + private final Matrix permutationMatrix; + private final boolean oddSwaps; + + LUFactorization(DenseMatrix lower, DenseMatrix upper, int[] permutationArr, boolean oddSwaps) { + this.lower = lower; + this.upper = upper; + this.permutationArr = permutationArr; + SparseVector[] vecs = new SparseVector[permutationArr.length]; + for (int i = 0; i < vecs.length; i++) { + vecs[i] = new SparseVector(lower.dim1,new int[]{permutationArr[i]}, new double[]{1.0}); + } + this.permutationMatrix = DenseSparseMatrix.createFromSparseVectors(vecs); + this.oddSwaps = oddSwaps; + } + + /** + * The lower triangular matrix, with ones on the diagonal. + * @return The lower triangular matrix. + */ + public DenseMatrix lower() { + return lower; + } + + /** + * The upper triangular matrix. + * @return The upper triangular matrix. + */ + public DenseMatrix upper() { + return upper; + } + + /** + * The row permutations applied to get this factorization. + * @return The permutations. + */ + public int[] permutationArr() { + return permutationArr; + } + + /** + * The row permutations stored as a sparse matrix of ones. + * @return A sparse matrix version of the permutations. + */ + public Matrix permutationMatrix() { + return permutationMatrix; + } + + /** + * Is there an odd number of row swaps (used to compute the determinant). + * @return True if there is an odd number of swaps. + */ + public boolean oddSwaps() { + return oddSwaps; + } + + @Override + public int dim1() { + return permutationArr.length; + } + + @Override + public int dim2() { + return permutationArr.length; + } + + /** + * Compute the matrix determinant of the factorized matrix. + * @return The matrix determinant. + */ + @Override + public double determinant() { + double det = 0.0; + for (int i = 0; i < upper.dim1; i++) { + det *= upper.values[i][i]; + } + if (oddSwaps) { + return -det; + } else { + return det; + } + } + + /** + * Solves a system of linear equations A * b = y, where y is the input vector, + * A is the matrix which produced this LU factorization, and b is the returned value. + * @param vector The input vector y. + * @return The vector b. + */ + @Override + public DenseVector solve(SGDVector vector) { + if (vector.size() != lower.dim1) { + throw new IllegalArgumentException("Size mismatch, expected " + lower.dim1 + ", received " + vector.size()); + } + // Apply permutation to input + final double[] vectorArr = vector.toArray(); + final double[] output = new double[vectorArr.length]; + for (int i = 0; i < permutationArr.length; i++) { + output[i] = vectorArr[permutationArr[i]]; + + // Solve L * Y = b + for (int k = 0; k < i; k++) { + output[i] -= lower.values[i][k] * output[k]; + } + } + + // Solve U * X = Y + for (int i = permutationArr.length-1; i >= 0; i--) { + for (int k = i + 1; k < permutationArr.length; k++) { + output[i] -= upper.values[i][k] * output[k]; + } + output[i] /= upper.values[i][i]; + } + + return new DenseVector(output); + } + + /** + * Solves the system A * X = Y, where Y is the input matrix, and A is the matrix which + * produced this LU factorization. + * @param matrix The input matrix Y. + * @return The matrix X. + */ + @Override + public DenseMatrix solve(Matrix matrix) { + if (matrix.getDimension1Size() != lower.dim1) { + throw new IllegalArgumentException("Size mismatch, expected " + lower.dim1 + ", received " + matrix.getDimension1Size()); + } + final int outputDim1 = lower.dim1; + final int outputDim2 = matrix.getDimension2Size(); + final double[][] output = new double[lower.dim1][]; + + // Apply permutation and copy over + for (int i = 0; i < outputDim1; i++) { + int permutedIdx = permutationArr[i]; + for (int j = 0; j < outputDim2; j++) { + output[i] = matrix.getRow(permutedIdx).toArray(); + } + } + + // Solve LY = B + for (int i = 0; i < outputDim1; i++) { + for (int j = i + 1; j < outputDim1; j++) { + for (int k = 0; k < outputDim2; k++) { + output[j][k] -= output[i][k] * lower.values[j][i]; + } + } + } + + // Solve UX = Y + for (int i = outputDim1 - 1; i >= 0; i--) { + // scale by diagonal + for (int j = 0; j < outputDim2; j++) { + output[i][j] /= upper.values[i][i]; + } + for (int j = 0; j < i; j++) { + for (int k = 0; k < outputDim2; k++) { + output[j][k] -= output[i][k] * upper.values[j][i]; + } + } + } + + return new DenseMatrix(output); + } + } + + /** + * The output of a successful eigen decomposition. + *

+ * Wraps a dense vector containing the eigenvalues and a dense matrix containing the eigenvectors as columns. + * Mutating these fields will cause undefined behaviour. + *

+ * Also has fields representing the tridiagonal form used as an intermediate step in the eigen decomposition. + *

+ * May be refactored into a record in the future. + */ + public static final class EigenDecomposition implements Matrix.Factorization { + // Eigen decomposition fields + private final DenseVector eigenvalues; + private final DenseMatrix eigenvectors; + + // Tridiagonal form fields + private final DenseVector diagonal; + private final DenseVector offDiagonal; + private final DenseMatrix householderMatrix; + + EigenDecomposition(DenseVector eigenvalues, DenseMatrix eigenvectors, DenseVector diagonal, DenseVector offDiagonal, DenseMatrix householderMatrix) { + this.eigenvalues = eigenvalues; + this.eigenvectors = eigenvectors; + this.diagonal = diagonal; + this.offDiagonal = offDiagonal; + this.householderMatrix = householderMatrix; + } + + /** + * The vector of eigenvalues, in descending order. + * @return The eigenvalues. + */ + public DenseVector eigenvalues() { + return eigenvalues; + } + + /** + * The eigenvectors for each eigenvalue, stored in the columns of the matrix. + * @return A matrix containing the eigenvalues as columns. + */ + public DenseMatrix eigenvectors() { + return eigenvectors; + } + + /** + * The diagonal vector of the tridiagonal form. + * @return The diagonal vector. + */ + public DenseVector diagonal() { + return diagonal; + } + + /** + * The off diagonal vector, with the first element set to zero. + * @return The off diagonal vector. + */ + public DenseVector offDiagonal() { + return offDiagonal; + } + + /** + * The Householder matrix produced during the tridiagonalisation. + * @return The Householder matrix. + */ + public DenseMatrix householderMatrix() { + return householderMatrix; + } + + @Override + public int dim1() { + return eigenvalues.size(); + } + + @Override + public int dim2() { + return eigenvalues.size(); + } + + /** + * Computes the determinant of the matrix which was decomposed. + *

+ * This is the product of the eigenvalues. + * @return The determinant. + */ + @Override + public double determinant() { + return eigenvalues.reduce(1.0,DoubleUnaryOperator.identity(), (a,b) -> a*b); + } + + /** + * Returns true if all the eigenvalues are positive. + * @return True if the eigenvalues are positive. + */ + public boolean positiveEigenvalues() { + return eigenvalues.reduce(true,DoubleUnaryOperator.identity(),(value, bool) -> bool && value > 0.0); + } + + /** + * Returns true if all the eigenvalues are non-zero. + * @return True if the eigenvalues are non-zero (i.e. the matrix is not singular). + */ + public boolean nonSingular() { + return eigenvalues.reduce(true,DoubleUnaryOperator.identity(),(value, bool) -> bool && value != 0.0); + } + + /** + * Returns the dense vector representing the i'th eigenvector. + * @param i The index. + * @return The i'th eigenvector. + */ + public DenseVector getEigenVector(int i) { + if (i < 0 || i > eigenvectors.dim1) { + throw new IllegalArgumentException("Invalid index, must be [0," + eigenvectors.dim1 + "), found " + i); + } + return eigenvectors.getColumn(i); + } + + /** + * Solves a system of linear equations A * b = y, where y is the input vector, + * A is the matrix which produced this eigen decomposition, and b is the returned value. + * @param vector The input vector y. + * @return The vector b. + */ + @Override + public DenseVector solve(SGDVector vector) { + if (vector.size() != eigenvectors.dim1) { + throw new IllegalArgumentException("Size mismatch, expected " + eigenvectors.dim1 + ", received " + vector.size()); + } + final double[] output = new double[vector.size()]; + for (int i = 0; i < output.length; i++) { + DenseVector eigenVector = getEigenVector(i); + double value = vector.dot(eigenVector) / eigenvalues.get(i); + for (int j = 0; j < output.length; j++) { + output[j] += value * eigenVector.get(j); + } + } + + return new DenseVector(output); + } + + /** + * Solves the system A * X = Y, where Y is the input matrix, and A is the matrix which + * produced this eigen decomposition. + * @param matrix The input matrix Y. + * @return The matrix X. + */ + @Override + public DenseMatrix solve(Matrix matrix) { + if (matrix.getDimension1Size() != eigenvectors.dim1) { + throw new IllegalArgumentException("Size mismatch, expected " + eigenvectors.dim1 + ", received " + matrix.getDimension1Size()); + } + final int outputDim1 = eigenvalues.size(); + final int outputDim2 = matrix.getDimension2Size(); + final double[][] output = new double[outputDim1][outputDim2]; + + for (int k = 0; k < outputDim2; k++) { + SGDVector column = matrix.getColumn(k); + for (int i = 0; i < outputDim1; i++) { + DenseVector eigen = getEigenVector(i); + double value = eigen.dot(column) / eigenvalues.get(i); + for (int j = 0; j < output.length; j++) { + output[j][k] += value * eigen.get(j); + } + } + } + + return new DenseMatrix(output); + } + } } diff --git a/Math/src/main/java/org/tribuo/math/la/DenseSparseMatrix.java b/Math/src/main/java/org/tribuo/math/la/DenseSparseMatrix.java index 58523a3ac..23056dae2 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseSparseMatrix.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseSparseMatrix.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.tribuo.math.la; +import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -35,6 +36,12 @@ public class DenseSparseMatrix implements Matrix { private final int dim2; private final int[] shape; + /** + * Constructs a DenseSparseMatrix from the supplied vector array. + *

+ * Does not copy the values, used internally by the la package. + * @param values The sparse vectors. + */ DenseSparseMatrix(SparseVector[] values) { this.values = values; this.dim1 = values.length; @@ -93,9 +100,7 @@ public DenseSparseMatrix(int dim1, int dim2) { this.values = new SparseVector[dim1]; this.shape = new int[]{dim1,dim2}; SparseVector emptyVector = new SparseVector(dim2); - for (int i = 0; i < values.length; i++) { - values[i] = emptyVector; - } + Arrays.fill(values, emptyVector); } /** @@ -111,6 +116,33 @@ public static DenseSparseMatrix createFromSparseVectors(SparseVector[] values) { return new DenseSparseMatrix(newValues); } + /** + * Creates an identity matrix of the specified size. + * @param dimension The matrix dimension. + * @return The identity matrix. + */ + public static DenseSparseMatrix createIdentity(int dimension) { + SparseVector[] newValues = new SparseVector[dimension]; + for (int i = 0; i < dimension; i++) { + newValues[i] = new SparseVector(dimension, new int[]{i}, new double[]{1.0}); + } + return new DenseSparseMatrix(newValues); + } + + /** + * Creates a diagonal matrix using the supplied values. + * @param diagonal The values along the diagonal. + * @return A diagonal matrix. + */ + public static DenseSparseMatrix createDiagonal(SGDVector diagonal) { + int dimension = diagonal.size(); + SparseVector[] newValues = new SparseVector[dimension]; + for (int i = 0; i < dimension; i++) { + newValues[i] = new SparseVector(dimension, new int[]{i}, new double[]{diagonal.get(i)}); + } + return new DenseSparseMatrix(newValues); + } + @Override public int[] getShape() { return shape; @@ -260,9 +292,43 @@ public int numActiveElements(int row) { @Override public SparseVector getRow(int i) { + if (i < 0 || i > dim1) { + throw new IllegalArgumentException("Invalid row index, must be [0,"+dim1+"), received " + i); + } return values[i]; } + /** + * Gets a copy of the column. + *

+ * This function is O(dim1 * log(dim2)) as it requires searching each vector for the column index. + * @param i The column index. + * @return A copy of the column as a sparse vector. + */ + @Override + public SparseVector getColumn(int i) { + if (i < 0 || i > dim2) { + throw new IllegalArgumentException("Invalid column index, must be [0,"+dim2+"), received " + i); + } + List indexList = new ArrayList<>(); + List valueList = new ArrayList<>(); + for (int j = 0; j < dim1; j++) { + double tmp = values[j].get(i); + if (tmp != 0) { + indexList.add(j); + valueList.add(tmp); + } + } + + int[] indicesArr = new int[valueList.size()]; + double[] valuesArr = new double[valueList.size()]; + for (int j = 0; j < valueList.size(); j++) { + indicesArr[j] = indexList.get(j); + valuesArr[j] = valueList.get(j); + } + return new SparseVector(dim1, indicesArr, valuesArr); + } + @Override public boolean equals(Object other) { if (other instanceof Matrix) { @@ -524,7 +590,7 @@ private static class DenseSparseMatrixIterator implements MatrixIterator { private Iterator itr; private VectorTuple vecTuple; - public DenseSparseMatrixIterator(DenseSparseMatrix matrix) { + DenseSparseMatrixIterator(DenseSparseMatrix matrix) { this.matrix = matrix; this.tuple = new MatrixTuple(); this.i = 0; diff --git a/Math/src/main/java/org/tribuo/math/la/DenseVector.java b/Math/src/main/java/org/tribuo/math/la/DenseVector.java index 29b4128b2..018e23428 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseVector.java @@ -21,12 +21,14 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.Output; import org.tribuo.math.util.VectorNormalizer; +import org.tribuo.util.MeanVarianceAccumulator; import org.tribuo.util.Util; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.function.BiFunction; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.ToDoubleBiFunction; @@ -182,6 +184,8 @@ public int numActiveElements() { /** * Performs a reduction from left to right of this vector. + *

+ * The first argument to the reducer is the transformed element, the second is the state. * @param initialValue The initial value. * @param op The element wise operation to apply before reducing. * @param reduction The reduction operation (should be commutative). @@ -197,6 +201,24 @@ public double reduce(double initialValue, DoubleUnaryOperator op, DoubleBinaryOp return output; } + /** + * Performs a reduction from left to right of this vector. + *

+ * The first argument to the reducer is the transformed vector element, the second is the state. + * @param initialValue The initial value. + * @param op The element wise operation to apply before reducing. + * @param reduction The reduction operation (should be commutative). + * @return The reduced value. + */ + public T reduce(T initialValue, DoubleUnaryOperator op, BiFunction reduction) { + T output = initialValue; + for (int i = 0; i < elements.length; i++) { + double transformed = op.applyAsDouble(get(i)); + output = reduction.apply(transformed, output); + } + return output; + } + /** * Equals is defined mathematically, that is two SGDVectors are equal iff they have the same indices * and the same values at those indices. @@ -669,6 +691,20 @@ public double l1Distance(SGDVector other) { } } + /** + * Compute the mean and variance of this vector. + * @return The mean and variance. + */ + public MeanVarianceAccumulator meanVariance() { + MeanVarianceAccumulator acc = new MeanVarianceAccumulator(); + + for (int i = 0; i < elements.length; i++) { + acc.observe(get(i)); + } + + return acc; + } + private static class DenseVectorIterator implements VectorIterator { private final DenseVector vector; private final VectorTuple tuple; diff --git a/Math/src/main/java/org/tribuo/math/la/Matrix.java b/Math/src/main/java/org/tribuo/math/la/Matrix.java index 34b4fbd79..c1b4f0f5b 100644 --- a/Math/src/main/java/org/tribuo/math/la/Matrix.java +++ b/Math/src/main/java/org/tribuo/math/la/Matrix.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -135,4 +135,57 @@ public interface Matrix extends Tensor, Iterable { */ public SGDVector getRow(int i); + /** + * Returns a copy of the specified column. + * @param index The column index. + * @return A copy of the column. + */ + public SGDVector getColumn(int index); + + /** + * Interface for matrix factorizations. + */ + public interface Factorization { + /** + * First dimension of the factorized matrix. + * @return First dimension size. + */ + public int dim1(); + + /** + * Second dimension of the factorized matrix. + * @return Second dimension size. + */ + public int dim2(); + + /** + * Compute the matrix determinant of the factorized matrix. + * @return The matrix determinant. + */ + public double determinant(); + + /** + * Solves a system of linear equations A * b = y, where y is the input vector, + * A is the matrix which produced this factorization, and b is the returned value. + * @param vector The input vector y. + * @return The vector b. + */ + public SGDVector solve(SGDVector vector); + + /** + * Solves the system A * X = Y, where Y is the input matrix, and A is the matrix which + * produced this factorization. + * @param matrix The input matrix Y. + * @return The matrix X. + */ + public Matrix solve(Matrix matrix); + + /** + * Generates the inverse of the matrix with this factorization. + * @return The matrix inverse. + */ + default public Matrix inverse() { + return solve(DenseSparseMatrix.createIdentity(dim2())); + } + } } diff --git a/Math/src/main/java/org/tribuo/math/la/package-info.java b/Math/src/main/java/org/tribuo/math/la/package-info.java index d40828aaf..466ac2fe6 100644 --- a/Math/src/main/java/org/tribuo/math/la/package-info.java +++ b/Math/src/main/java/org/tribuo/math/la/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,11 +18,10 @@ * Provides a linear algebra system used for numerical operations in Tribuo. *

* There are Dense and Sparse vectors and Matrices, along with a DenseSparseMatrix which is - * a dense array of sparse row vectors. - *

+ * a dense array of sparse row vectors. The dense matrix provides various factorization methods + * in addition to matrix-vector operations. *

* It's a single threaded implementation in pure Java. We're looking at ways of improving the speed * using new technologies coming in future releases of Java. - *

*/ package org.tribuo.math.la; \ No newline at end of file diff --git a/Math/src/test/java/org/tribuo/math/distributions/MultivariateNormalDistributionTest.java b/Math/src/test/java/org/tribuo/math/distributions/MultivariateNormalDistributionTest.java new file mode 100644 index 000000000..7661ba825 --- /dev/null +++ b/Math/src/test/java/org/tribuo/math/distributions/MultivariateNormalDistributionTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.distributions; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class MultivariateNormalDistributionTest { + + @Test + public void testMeanAndVarChol() { + double[] mean = new double[]{2.5, 2.5}; + + double[][] covariance = new double[][]{{1.0, 0.5}, {0.5, 1.0}}; + MultivariateNormalDistribution rng = new MultivariateNormalDistribution(mean, covariance, 12345); + + meanVarComp(rng,mean,covariance,1e-2); + } + + @Test + public void testMeanAndVarEigen() { + double[] mean = new double[]{2.5, 2.5}; + + double[][] covariance = new double[][]{{1.0, 0.5}, {0.5, 1.0}}; + MultivariateNormalDistribution rng = new MultivariateNormalDistribution(mean, covariance, 12345, true); + + meanVarComp(rng,mean,covariance,1e-2); + } + + private static void meanVarComp(MultivariateNormalDistribution rng, double[] mean, double[][]covariance, double tolerance) { + double[][] samples = new double[100000][]; + double[] computedMean = new double[2]; + for (int i = 0; i < samples.length; i++) { + samples[i] = rng.sampleArray(); + computedMean[0] += samples[i][0]; + computedMean[1] += samples[i][1]; + } + computedMean[0] /= samples.length; + computedMean[1] /= samples.length; + + double[][] computedCovariance = new double[2][2]; + for (int i = 0; i < samples.length; i++) { + computedCovariance[0][0] += (samples[i][0] - computedMean[0]) * (samples[i][0] - computedMean[0]); + computedCovariance[0][1] += (samples[i][0] - computedMean[0]) * (samples[i][1] - computedMean[1]); + computedCovariance[1][0] += (samples[i][1] - computedMean[1]) * (samples[i][0] - computedMean[0]); + computedCovariance[1][1] += (samples[i][1] - computedMean[1]) * (samples[i][1] - computedMean[1]); + } + computedCovariance[0][0] /= samples.length-1; + computedCovariance[0][1] /= samples.length-1; + computedCovariance[1][0] /= samples.length-1; + computedCovariance[1][1] /= samples.length-1; + + assertArrayEquals(mean,computedMean,1e-2); + assertArrayEquals(covariance[0],computedCovariance[0],tolerance); + assertArrayEquals(covariance[1],computedCovariance[1],tolerance); + } + +} diff --git a/Math/src/test/java/org/tribuo/math/la/DenseMatrixTest.java b/Math/src/test/java/org/tribuo/math/la/DenseMatrixTest.java index 88e5107f9..396e87b6c 100644 --- a/Math/src/test/java/org/tribuo/math/la/DenseMatrixTest.java +++ b/Math/src/test/java/org/tribuo/math/la/DenseMatrixTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,15 @@ package org.tribuo.math.la; -import org.junit.jupiter.api.Test; - +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Optional; +import java.util.Random; + +import org.junit.jupiter.api.Test; /** * Matrices used - @@ -35,14 +41,18 @@ */ public class DenseMatrixTest { - public static DenseMatrix identity(int size) { + public static double[][] identityArr(int size) { double[][] values = new double[size][size]; for (int i = 0; i < size; i++) { values[i][i] = 1.0; } - return new DenseMatrix(values); + return values; + } + + public static DenseMatrix identity(int size) { + return new DenseMatrix(identityArr(size)); } // a 4x4 matrix @@ -202,6 +212,74 @@ public static DenseMatrix generateF() { return new DenseMatrix(values); } + public static DenseMatrix generateSymmetric() { + double[][] values = new double[3][3]; + + values[0][0] = 4; + values[0][1] = 12; + values[0][2] = -16; + values[1][0] = 12; + values[1][1] = 37; + values[1][2] = -43; + values[2][0] = -16; + values[2][1] = -43; + values[2][2] = 98; + + return new DenseMatrix(values); + } + + public static DenseMatrix generateCholOutput() { + double[][] values = new double[3][3]; + + values[0][0] = 2; + values[0][1] = 0; + values[0][2] = 0; + values[1][0] = 6; + values[1][1] = 1; + values[1][2] = 0; + values[2][0] = -8; + values[2][1] = 5; + values[2][2] = 3; + + return new DenseMatrix(values); + } + + public static DenseMatrix.LUFactorization generateLUOutput() { + double[][] lValues = new double[3][3]; + + lValues[0][0] = 1; + lValues[0][1] = 0; + lValues[0][2] = 0; + lValues[1][0] = -0.75; + lValues[1][1] = 1; + lValues[1][2] = 0; + lValues[2][0] = -0.25; + lValues[2][1] = 0.263157894736842; + lValues[2][2] = 1; + + DenseMatrix l = new DenseMatrix(lValues); + + double[][] uValues = new double[3][3]; + + uValues[0][0] = -16; + uValues[0][1] = -43; + uValues[0][2] = 98; + uValues[1][0] = 0; + uValues[1][1] = 4.75; + uValues[1][2] = 30.5; + uValues[2][0] = 0; + uValues[2][1] = 0; + uValues[2][2] = 0.473684210526317; + + DenseMatrix u = new DenseMatrix(uValues); + + int[] permutation = new int[]{2,1,0}; + + DenseMatrix.LUFactorization lu = new DenseMatrix.LUFactorization(l,u,permutation,true); + + return lu; + } + public static DenseVector generateVector() { double[] values = new double[4]; @@ -1363,4 +1441,334 @@ public void matrixVectorTest() { assertEquals(matrixMatrixOutput,matrixVectorOutput); } + @Test + public void symmetricTest() { + assertFalse(generateA().isSymmetric()); + assertFalse(generateB().isSymmetric()); + assertTrue(generateSymmetric().isSymmetric()); + } + + @Test + public void choleskyTest() { + DenseMatrix symmetric = generateSymmetric(); + assertTrue(symmetric.isSymmetric()); + Optional cholOpt = symmetric.choleskyFactorization(); + assertTrue(cholOpt.isPresent()); + + // check pre-computed output + DenseMatrix.CholeskyFactorization chol = cholOpt.get(); + assertEquals(generateCholOutput(),chol.lMatrix()); + + // check factorization + Matrix computedSymmetric = chol.lMatrix().matrixMultiply(chol.lMatrix(),false,true); + assertEquals(symmetric,computedSymmetric); + + // test factorization + testFactorization(symmetric, chol, 1e-13); + } + + + @Test + public void choleskyTest2() { + DenseMatrix a = new DenseMatrix(new double[][] {new double[] {8,2,3}, new double[] {2,9,3}, new double[] {3,3,6}}); + assertTrue(a.isSymmetric()); + Optional cholOpt = a.choleskyFactorization(); + assertTrue(cholOpt.isPresent()); + DenseMatrix c = cholOpt.get().lMatrix(); + assertEquals(new DenseMatrix(new double[][]{new double[]{2.8284271247461903, 0.0, 0.0}, new double[]{0.7071067811865475, 2.9154759474226504, 0.0}, new double[]{1.0606601717798212, 0.7717436331412897, 2.0686739145418453}}),c); + // check factorization + assertEquals(a,c.matrixMultiply(c,false,true)); + + a = generateSymmetricPositiveDefinite(5, new Random(1234)); + assertTrue(a.isSymmetric()); + assertTrue(a.isSquare()); + c = a.choleskyFactorization().get().lMatrix(); + assertEquals(a, c.matrixMultiply(c, false, true)); + assertEquals(new DenseMatrix(new double[][]{new double[]{14.94547893340321, 0.0, 0.0, 0.0, 0.0}, new double[]{1.0804097707164124, 8.928639781892967, 0.0, 0.0, 0.0}, new double[]{0.9664077730739391, 0.6309279673272034, 3.0262982901105855, 0.0, 0.0}, new double[]{0.7832294202894065, 1.4377257847560159, 2.2605017656943716, 15.960534656196803, 0.0}, new double[]{0.7843088590247412, 0.04136969152859865, 3.556219834851627, 0.5877350448006833, 16.75596860731699}}), c); + + a = generateSymmetricPositiveDefinite(20, new Random(42)); + assertTrue(a.isSymmetric()); + c = a.choleskyFactorization().get().lMatrix(); + assertEquals(a, c.matrixMultiply(c, false, true)); + assertEquals(new DenseMatrix(new double[][]{new double[]{16.75440746002779, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.6018808759506862, 13.63602457544494, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.7516483220309336, 1.190157675401124, 5.4036024131544655, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.368025817917537, 2.270165339191156, 1.764426591618153, 18.147655254377153, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.8923058726188353, 1.0123840323556053, 3.872869835424954, 0.3826508633553371, 17.135975549880694, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.4334670987046003, 1.831123954664325, 3.4125124344865716, 0.41829925313682753, -0.47377519666787493, 17.834272696939415, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.9955273667573433, 1.5888606812196417, 1.2931406114626933, 0.7507404663781049, 1.1230328218627708, 0.6982985586204641, 11.729209946273539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.0566045539820537, 1.323704443126565, 2.3405076341478823, 0.8909192341770037, 0.5588640878069887, 0.42167743837570937, 0.6918655860556959, 12.604895816109266, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.390901857759653, 1.5244450976350756, 2.1296282171996017, 0.5175951429570781, 0.40949856268727963, 0.29710908111391915, 1.0551485878373852, 0.21347459054055823, 12.185375044419702, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.568660236011373, 1.6539549798204274, 2.7054801249873144, 0.5875503505223718, 0.36058080347341825, 0.9631435157481918, 1.8664848238242542, 0.3528938775896363, 1.0221291272724007, 18.468466444374396, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.6747699323818752, 0.9348877369145657, 2.26838953173407, 0.9708056620594314, 0.4648508465963835, 0.4502237005191294, 0.8494944520263819, 1.0245845314750173, 0.9046030178733844, -0.25418850237135154, 16.042987693353492, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.8869180378613991, 0.7737834087685098, 1.4084786652563897, 0.12734000384887173, 0.502736132499496, 0.07316975918621323, 0.971795330127886, 0.22343852351493915, 0.895107379408903, 0.7852625675942757, 0.5856430554056011, 5.226898701059288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.566488974573314, 2.295179090448591, 1.7394746979213342, 0.40853611940626017, 0.582994247658075, 0.7148554103841811, 1.474587799832604, 0.3318139748198812, 1.24859872007004, 0.19733009851870595, 0.5557654573528521, -0.19658626542639288, 16.240209978936896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.1818432493780875, 1.5652452724346404, 3.2857719301330177, 0.9635319388163569, 0.7563151193955898, 0.9287100018154593, 1.1770478441746075, 0.02087442066230491, 0.6706469709304964, 0.2620284517297045, 0.3304383096053042, 2.185137328699002, 0.9479299608640565, 13.840735361297632, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5789967411201067, 0.9988471923692965, 1.183196829816785, 0.7521697338655636, 0.018668631393969338, 0.5116517301295441, 0.8139711973181936, -0.2094555759887365, 0.39243306130843325, 0.07411895260819369, -0.06757487751692495, 0.29838420849840613, 0.8241581418664375, -0.01620837071104333, 3.2260845557553037, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{1.0104885322567119, 0.9325402930773361, 1.301762888169249, 0.6129186638967622, 0.8976470672020842, 0.3659205085404095, 0.34379283852757003, -0.04079047659367302, 0.43869995114135396, 0.943916377651184, 0.8766579013250082, 1.3128333521551454, 0.37860345884081986, 0.20891699349297296, 1.3466852460402794, 4.431218108245882, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5492881425825883, 1.7200875435056948, 2.6137768716353555, 1.5272404478083912, -0.35532261666960424, 0.7701932271918118, 0.9672343431260562, -0.24698254339887663, 0.6459562358661791, 0.4510727698685872, 0.9702553782551973, 1.424541341254133, 0.4390031008245334, -0.12801559327632064, 2.832720413681546, 0.30850316580459985, 15.2646653384555, 0.0, 0.0, 0.0}, new double[]{1.657019097342754, 1.5610725306412903, 2.363618503717435, 0.8735325067275853, -0.3123830847695933, 0.5853338418588566, 1.4777076819228194, 0.31199933818155606, 0.808967278455822, 0.3956563813584885, 0.5128113149029574, 1.6343353507832152, 0.5845773670802912, -0.44100182190227216, 0.9841283372747659, 0.679495505925546, -0.07558768037348314, 13.611207913547869, 0.0, 0.0}, new double[]{0.8955649049493016, 0.887165103491017, 0.9108045971202429, 0.7785116617751562, 0.2633420108871542, 0.2362331319670587, 0.6655798737488293, 0.8378945325807816, 0.5670482817021022, 0.7767713529738539, 0.7850772851108367, 1.7543617362601482, 0.7582218118032898, 0.22930366774009395, -0.0926515221296133, 0.9599536741417869, 0.1721095656171293, 0.42092281314426966, 11.464109239515773, 0.0}, new double[]{1.2970751226801458, 0.7674164191430094, 1.005264076968263, 1.6432539998603195, 0.7770414629193831, 0.9415726468857499, 0.711083760232971, 0.9179014895793605, 0.32172296524409544, 0.006783534749445878, 1.058649212086019, 2.2596039666799954, 0.09793459293296733, 0.16032662213243493, 2.9599257579515035, 1.9884554445584544, 0.45819043454085545, 0.0034148678716191716, 0.4436061788222598, 16.119007983152684}}), c); + } + + //another library used this trick to make sure the matrix is positive definite + public static DenseMatrix generateSymmetricPositiveDefinite(int n, Random rng) { + double[][] values = new double[n][n]; + for(int i=0; i luOpt = symmetric.luFactorization(); + assertTrue(luOpt.isPresent()); + DenseMatrix.LUFactorization lu = luOpt.get(); + + // check pre-computed output + DenseMatrix.LUFactorization output = generateLUOutput(); + assertEquals(output.lower(),luOpt.get().lower()); + assertEquals(output.upper(),luOpt.get().upper()); + assertArrayEquals(output.permutationArr(),luOpt.get().permutationArr()); + assertEquals(output.oddSwaps(),luOpt.get().oddSwaps()); + + // check factorization + Matrix computedSymmetric = lu.lower().matrixMultiply(lu.upper()); + assertEquals(lu.permutationMatrix().matrixMultiply(symmetric),computedSymmetric); + + // test factorization + testFactorization(symmetric, lu, 1e-13); + + //lets try a couple of non-symmetrical matrices + DenseMatrix a = generateSquareRandom(10, new Random(42)); + assertFalse(a.isSymmetric()); + luOpt = a.luFactorization(); + assertTrue(luOpt.isPresent()); + lu = luOpt.get(); + lu = a.luFactorization().get(); + Matrix computed = lu.lower().matrixMultiply(lu.upper()); + assertEquals(lu.permutationMatrix().matrixMultiply(a),computed); + + a = generateSquareRandom(20, new Random(42)); + assertFalse(a.isSymmetric()); + luOpt = a.luFactorization(); + assertTrue(luOpt.isPresent()); + lu = luOpt.get(); + lu = a.luFactorization().get(); + computed = lu.lower().matrixMultiply(lu.upper()); + assertEquals(lu.permutationMatrix().matrixMultiply(a),computed); + + //an example computed with another library + a = new DenseMatrix(new double[][] {new double[] {0.44670904, 0.44742455, 0.45204733}, + new double[] {0.71710816, 0.14136726, 0.18301841}, + new double[] {0.40983909, 0.07235836, 0.95855327}}); + DenseVector b = new DenseVector(new double[] {0.63392567, 0.93362273, 0.86074978}); + DenseMatrix.LUFactorization lu_a = a.luFactorization().get(); + DenseVector x2 = lu_a.solve(b); + assertEquals(new DenseVector(new double[] {1.2466263014829564,-0.2127572718468386,0.38102040828578143}), x2); + } + + @Test + public void luTest10() { + DenseMatrix a = generateSquareRandom(10, new Random(42)); + assertFalse(a.isSymmetric()); + Optional luOpt = a.luFactorization(); + luOpt = a.luFactorization(); + assertTrue(luOpt.isPresent()); + DenseMatrix.LUFactorization lu = luOpt.get(); + lu = luOpt.get(); + lu = a.luFactorization().get(); + Matrix computed = lu.lower().matrixMultiply(lu.upper()); + assertEquals(lu.permutationMatrix().matrixMultiply(a),computed); + + DenseMatrix upper = new DenseMatrix(new double[][]{new double[]{0.919327782868717, 0.436490974423287, 0.749906181255448, 0.386566874359349, 0.177378477909378, 0.594349910889684, 0.209767568866332, 0.825965871887821, 0.172217937687852, 0.587427381786296}, new double[]{0.0, 0.6046335192903662, -0.2315554768998931, 0.6366476011582105, 0.36981144543030414, -0.02265336333957818, 0.08894501826509618, -0.2421177129596924, 0.7778246727901739, 0.1868970417505812}, new double[]{0.0, 0.0, 0.582161900514335, 0.5515716470908145, 0.5050928223956577, 0.06080618000984278, 0.24126113352929207, 0.028311954869486233, 0.2856241540086502, 0.14599033226603875}, new double[]{0.0, 0.0, 0.0, 0.9201513995278638, 1.0765116824099081, 0.11333278778720035, 1.0435545524490883, -0.12398318208866327, 1.0420632762843565, 0.6298691400964507}, new double[]{0.0, 0.0, 0.0, 0.0, 1.1359028355590344, -0.2293887981353377, 1.423647757155546, -0.4798724675894213, 0.7684694590742605, 0.1621140233457365}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.6387337970258129, -0.429595684654273, 0.041738035239056726, -0.2567574094585644, 0.31066510702634487}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8464215324300727, -0.3385041472325138, 0.5745459425968332, 0.22845107403060855}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0382120340233407, -0.2084728664560298, -0.04915633819475923}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.6513748657999127, -0.28770234443194553}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.4330588997846473}}); + assertEquals(upper, lu.upper()); + DenseMatrix lower = new DenseMatrix(new double[][]{new double[]{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5101798611502932, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5227455598663895, 0.10499381292122886, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5409679787159667, -0.3453554379860352, -0.7715373629846236, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.9724962222164031, 0.28835849803934743, 0.5385521322941219, -0.8471242288141386, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.7914083459574576, 0.5586538967647579, -0.2669415253271348, -0.2578728317292549, 0.6435470362577285, 1.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.8172062465284433, 0.35449109373629945, 0.08465209264493595, 0.17847876514502334, -0.42215148181041795, -0.3788241883348112, 1.0, 0.0, 0.0, 0.0}, new double[]{0.45791252529909143, 0.716128188680375, 0.8971569861596805, -0.8829557983548465, 0.6364353827507656, 0.4799824825668223, 0.8472910553690725, 1.0, 0.0, 0.0}, new double[]{0.3428020747767634, 0.48476626101806697, -0.08308375409609205, 0.38019693902353446, -0.38431061971684766, 0.7615607650134028, 0.6171134330987622, -0.029517947598312787, 1.0, 0.0}, new double[]{0.7584332185567152, 0.9552320656695076, -0.26009512153465214, 0.05570230015791643, 0.1862446242150791, 0.13540648537444386, -0.16638626226876788, 0.17955644250561797, 0.30334486487915147, 1.0}}); + assertEquals(lower, lu.lower()); + } + + @Test + public void luTest20() { + DenseMatrix a = generateSquareRandom(20, new Random(42)); + DenseMatrix.LUFactorization lu = a.luFactorization().get(); + DenseMatrix upper = new DenseMatrix(new double[][]{new double[]{0.91629377516759, 0.202182492969044, 0.56299024004504, 0.561700943667366, 0.080282170345142, 0.416882590906358, 0.560143976505208, 0.100264341467102, 0.610836098745395, 0.920378070753754, 0.03370946135387, 0.179426442633273, 0.997460814518753, 0.741524133735247, 0.063185128546488, 0.318886141572087, 0.631989300813935, 0.727637943878689, 0.028750514440684, 0.812558114652981}, new double[]{0.0, 0.9052275706085295, 0.7311124942979998, 0.17095063660702303, -0.010669320462571885, 0.6812034755884623, 0.3489823054201352, 0.8042763478267299, 0.3953375940092203, 0.4047899102009168, 0.5460004302339985, 0.06049073424230716, 0.2813126916984501, 0.6667691965772725, 0.6070155383157478, 0.28437108835221503, 0.7089863201825783, 0.6096031527180594, 0.25133910642700047, 0.7119862310531005}, new double[]{0.0, 0.0, -0.8418476964092433, 0.2391655310590597, 0.5757389140414536, -0.42180230458380397, -0.4116354501152575, -0.20697949944360516, -0.15932781555894426, -0.9116064716607113, 0.01642904477034135, -0.15965565256018974, -0.957082051438826, -0.6363568431353184, 0.10085272088659036, -0.0839926518392104, -0.1318782939936175, -0.677359059991486, 0.41479172496712496, -0.44148925159663444}, new double[]{0.0, 0.0, 0.0, 0.679310460740777, 0.30536174388152537, 0.09289890466190279, -0.10901446406792045, 0.6577949827161325, 0.06645844700549122, -0.0020572824488075403, 0.7739619189359576, 0.46160778934827934, -0.20418276695810306, 0.47354251316234486, 0.5206282501079079, 0.8807407912012298, 0.6373368056049682, -0.04212319842740797, 0.5317367928049365, 0.5154946648904738}, new double[]{0.0, 0.0, 0.0, 0.0, 1.004270276574797, -0.41062361258634067, 0.06470483365678648, 0.395375206053751, -0.30879568797903967, -0.8577130542425583, 0.6319041693434939, 0.5732188590079953, -1.2735986438387392, 0.009356547995667985, 0.24898513869846228, 1.0291215968438152, -0.32167040809750125, -1.0019321717583807, 0.34318287299865397, -0.22115588557937782}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.6885304612825144, -0.09878565829551755, 0.11888667536611253, 0.026275391516468516, 0.7797339853049778, 0.7905988710596875, 0.4176897754831432, 0.8325929643882475, 0.13011386047293466, -0.06655370067747732, 0.35944616851482125, -0.07971923964435537, 0.7362704520604607, -0.08009767172937574, 0.2407018974393958}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5635665117604685, -0.6249197850949855, 0.23031320412001716, 0.11247725802992603, -1.2988300632624379, 0.11026088327531813, 0.14024854093374872, -0.3875080862753112, -0.5058502641980529, -0.33893480877287796, -0.23032588232457918, 0.12379631467206165, -0.007968186522634757, -0.972676934413537}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.4070639645891962, -0.12418858944571715, 0.6910032948976171, -1.4404891543646792, 0.2622833083003355, 0.8587138026959753, -0.21585181843661005, -0.3106664976261241, -1.495423265502768, -0.5942855879266988, 0.09640037702399482, -0.10120550590895772, -1.213780372497036}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.6802013374354048, 0.1058075299127755, 1.2379538951016724, 0.8253372563137479, 0.28039553297872594, -0.4691976313711233, -0.14590375747978418, 0.2553780343530524, -0.751028592571138, 0.08371269681886573, -0.44019720447366306, 0.03428601335235071}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8019709962475664, 0.9482737524483995, -0.12829271956674393, 0.19730015198640594, -0.7755174705264541, -0.29453892749883254, -0.6304960591161828, -0.9226950539319441, -0.23156393988772728, -0.7638380971401721, -0.0526118887077868}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6349095289112598, 0.9880461582144417, 0.8770199686641961, -0.732470256163468, -0.30405294149385953, 0.3323507498525623, -0.8535567865401182, 0.4991743682669355, -0.6767957549007376, 0.015790757312618864}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.8234358031976985, 0.33820564979310785, 0.8532516474938868, 0.7042310546951711, -0.06423357030263827, 0.28005420631706696, 0.2396254014654069, 0.43221953335735164, 0.528719611205809}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7145335303026192, -0.27218542943627877, 0.3313284160982715, -1.2236380259408746, -0.3517536922205687, 0.45984472966121825, 0.314122632468761, -0.6695645773752903}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.3445205726164846, -1.7642384709526822, -0.11149215277908996, -0.7150759280650528, -0.3805758372617428, -0.5346470409282629, -0.2360072475107876}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.603793754175576, -0.7227592385899313, 0.08933690369120117, -0.15647605160230427, -0.6287687377835115, 0.4898330713252278}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9136570013250238, -0.14654528737983064, -0.044723655495255255, -0.5334655967012899, -0.037340679446577874}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.711745128807139, -0.5079629246932265, 0.8235155496618926, -1.0723815155444778}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.40246596787399197, 0.5083203142718309, -0.24587939289488858}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1995894850960833, -0.09077276170080262}, new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.19718923178294}}); + assertEquals(upper, lu.upper()); + DenseMatrix lower = new DenseMatrix(new double[][]{new double[]{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.3079779583562422, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.7609445225601399, 0.8337847234567108, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.4473141936816384, -0.09494029705026269, 0.1637102253065407, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.975716325972954, 0.44360588537892465, -0.12162520386418879, -0.7621284376905496, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.7940288363301351, 0.5774070069363221, 0.6657504787296222, -0.6283738341691794, 0.4147748029319753, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.47430585732298797, 0.5007279777683581, 0.06697485794746916, 0.5292140082116717, 0.36161236775140626, 0.5908255058790153, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.4197464440148081, 0.844017790432058, 0.4062134998961642, 0.4941441960718356, 0.5766279197366718, 0.2355875305580612, -0.5508060434484056, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.6794156391168887, 0.819068864976729, 0.5181261419263038, 0.4877148304343442, -0.35273920138322346, -0.701843245096469, 0.7092975753114856, -0.2773619520959566, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5754309984653508, 0.33770322283741516, -0.46620789973611665, 0.993377069976633, 0.7480965060063309, 0.5362313482958221, 0.6321596725234732, 0.4967542962016245, -0.543153895575124, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5118691552060198, 0.7996126782295147, 0.8573434443825495, 0.3011988820979851, -0.15716572853459956, -0.3026791986783505, -0.006724822479335297, 0.3140663634975763, -0.5196348554321051, -0.5161094309260015, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.5550385681628209, 0.9241643397204394, 0.8565946072603277, -0.04554456415982458, 0.4289197325443804, 0.3930243241875422, 0.8354331063835333, 0.10937713856509156, -0.10748457309978182, 0.07292951498907436, 0.4648891734874218, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.924877489856665, 0.16516986494967228, 0.25925481629627706, -0.8204474315440505, 0.6289897594567927, 0.16257602317783265, -0.17892272534318476, -0.398486864932312, -0.24193475415361745, -0.9495701198542961, 0.6442702974197089, -0.2821881020966079, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.791721578036539, 0.7012628674246133, 0.16059038996068944, 0.3640215888193717, 0.7221996807408587, -0.505654501540419, -0.11534483823016972, 0.4294290147986364, 0.42714956440474094, -0.44132530245664997, 0.09598532383257474, 0.9776071987597791, 0.7023757563795182, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.43385229365210265, 0.5366798318339504, 0.44956224031251935, 0.557601842211255, 0.20012354981960018, 0.417780155890305, 0.5891940018045582, 0.10518293529451214, -0.3286571126891387, -0.488165630534509, 0.5645239802111713, 0.3264587527258916, 0.0049180337098236335, 0.26081377697536984, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.7164868938471124, 0.057444333550601474, 0.5215596397477621, -0.5119247376365421, 0.2409973255461667, 0.3429726755092168, 0.5648373399165897, -0.4890784878527986, 0.46028006537941496, -0.9483895811414086, 0.4263033727785258, 0.46854884411888414, -0.22411197565282925, -0.22181876013963414, -0.7237871594156134, 1.0, 0.0, 0.0, 0.0, 0.0}, new double[]{0.8199121582268207, 0.44769787996680405, 0.24813891263154456, 0.22976829035756546, -0.24162333204388178, -0.4426186743049232, 0.5347906413658639, -0.25441672637636326, -0.15102889633903993, -0.34133850661373516, 0.8367235419100765, 0.4416918512441593, -0.3901336283827199, 0.022682755644211405, 0.6304792606838273, 0.22895645987172347, 1.0, 0.0, 0.0, 0.0}, new double[]{0.541059443066477, 0.0734295519671491, -0.16515613607865418, 0.6066543955194328, 0.5177988165076132, 0.24672481378747482, 0.3747518123138858, -0.15685824721396155, -0.2163979041035694, 0.19841180164613537, 0.0034976161062877254, 0.07306529002913684, 0.5540708129952783, 0.21256664205840442, -0.23302696186542915, -0.02299628311671909, -0.6312978639861092, 1.0, 0.0, 0.0}, new double[]{0.563819407121102, 0.4653910509604774, 0.46763282136793294, -0.5144092914239631, 0.5443932918982667, 0.4497527019423619, 0.21657484340462652, -0.09875033890514759, -0.23064080504375747, -0.39510415944934346, 0.4928250345738006, 0.6090843004243195, -0.2972883410070384, 0.24627922745765815, 0.3764349274440289, 0.18953240519184184, -0.032537487902659624, -0.6731924063648309, 1.0, 0.0}, new double[]{0.037674681576550156, 0.9632634503348367, 0.23042129274228612, 0.9604648378866348, -0.3934188736330582, -0.5349584958362864, 0.22053698758116233, 0.6042169929804793, -0.16605117995562096, 0.11964967085605736, 0.8115608242264145, 0.34289275272594144, -0.961334398130144, -0.05509181439223857, 0.32182043765652824, 0.050943936755478154, 0.23846240734990845, 0.2641474881442394, 0.08178900340243297, 1.0}}); + assertEquals(lower, lu.lower()); + } + + + @Test + public void eigenTest() { + DenseMatrix symmetric = generateSymmetric(); + + Optional eigOpt = symmetric.eigenDecomposition(); + assertTrue(eigOpt.isPresent()); + DenseMatrix.EigenDecomposition eig = eigOpt.get(); + + // check factorization + Matrix computed = eig.eigenvectors().matrixMultiply(DenseSparseMatrix.createDiagonal(eig.eigenvalues())).matrixMultiply(eig.eigenvectors(),false,true); + for (int i =0 ; i < symmetric.dim1; i++) { + assertArrayEquals(symmetric.getRow(i).toArray(), computed.getRow(i).toArray(), 1e-12); + } + + // check decomposition + for (int i = 0; i < symmetric.dim1; i++) { + // assert A.x = \lambda.x + double eigenValue = eig.eigenvalues().get(i); + DenseVector eigenVector = eig.getEigenVector(i); + double[] output = symmetric.leftMultiply(eigenVector).toArray(); + double[] expected = eigenVector.scale(eigenValue).toArray(); + assertArrayEquals(expected,output,1e-12); + } + + // test factorization + testFactorization(symmetric, eig, 1e-10); + } + + private static void testFactorization(DenseMatrix matrix, Matrix.Factorization factorization, double tolerance) { + // check solve vector method + DenseVector y = new DenseVector(new double[]{5, 6, 34}); + SGDVector b = factorization.solve(y); + DenseVector output = matrix.rightMultiply(b); + assertArrayEquals(y.toArray(), output.toArray(), tolerance); + + // check inverse method + Matrix inv = factorization.inverse(); + double[][] identityArr = identityArr(3); + DenseMatrix outputMatrix = matrix.matrixMultiply(inv); + for (int i = 0; i < identityArr.length; i++) { + assertArrayEquals(identityArr[i], outputMatrix.getRow(i).toArray(), tolerance); + } + } + + @Test + public void eigenTest3() { + DenseMatrix a = new DenseMatrix(new double[][] {new double[] {1.0, 2.0, 3.0}, new double[] {2.0, 1.0, 4.0}, new double[] {3.0, 4.0, 5.0}}); + DenseMatrix.EigenDecomposition eig = a.eigenDecomposition().get(); + assertArrayEquals(new double[] {9.079525367450728,-0.5925456084234676,-1.4869797590272613}, eig.eigenvalues().toArray()); + assertEigenvectorsEquals(new DenseMatrix(new double[][]{ new double[]{0.407376105324850,-0.905523895664965, 0.118622018150518}, + new double[]{0.484198771392754, 0.104025841064094,-0.868752078658035}, + new double[]{0.774336011426632, 0.411345473745188, 0.480831199733589}}), eig.eigenvectors()); + } + + @Test + public void eigenTest8() { + DenseMatrix a = generateSymmetricPositiveDefinite(8, new Random(123)); + DenseMatrix.EigenDecomposition eig = a.eigenDecomposition().get(); + assertArrayEquals(new double[]{459.76211240290263, 293.2666844177924, 216.59270408767136, 170.21352490625335, 166.78909016040728, 157.9347692793759, 143.28347730878713, 6.388464263025822}, eig.eigenvalues().toArray(), 1e-12); + double[][] values = new double[][]{new double[]{0.13866815658629814, -0.8252358100347458, -0.5322979910418807, -0.06874839874131271, 0.05232354573084099, -0.06491424767108263, 0.06112114654771209, 0.03165594801425184}, new double[]{0.10040552986430423, -0.3376082404349135, 0.4375635736101318, -0.11901054624642458, -0.2335674596850905, 0.780423576489013, -0.08180549203857021, 0.002675631411587839}, new double[]{0.07581067854928174, -0.21073685231366956, 0.4962334686590166, -0.6395694525012203, -0.13654456793617276, -0.5017239717927986, 0.15195512526320729, 0.032928059562896524}, new double[]{0.05952863799372147, -0.053657044549330915, 0.03753211518153525, 0.024900409997130044, 0.011629984966546214, -0.046407158488371125, -0.04938822374115087, -0.9933883909062082}, new double[]{0.10148264893220066, -0.2702761476343119, 0.3993474902510659, 0.7416253629655215, -0.2078352622158741, -0.2721705343847198, 0.29563328477316864, 0.04994140515356563}, new double[]{0.07317536757652769, -0.1638812735247339, 0.19609570362266984, 0.14340131278754917, 0.14775784823370017, -0.22004923286199535, -0.9120630704169633, 0.08159498307657408}, new double[]{0.9707144758110352, 0.21915710575752975, -0.08314349908025911, -0.019147960448430115, -0.018914352152997034, 0.00948566942422636, 0.015953634844472334, 0.041253296979782646}, new double[]{0.058813248463141024, -0.09890060504236657, 0.2696217133439837, 0.02241144059690243, 0.9265760530592435, 0.10127275489512025, 0.2110211084939485, 0.015240404666016118}}; + assertEigenvectorsEquals(new DenseMatrix(values), eig.eigenvectors()); + } + + @Test + public void eigenTest20() { + DenseMatrix a = generateSymmetricPositiveDefinite(20, new Random(42)); + DenseMatrix.EigenDecomposition eig = a.eigenDecomposition().get(); + for (int i = 0; i < a.dim1; i++) { + // assert A.x = \lambda.x + double eigenValue = eig.eigenvalues().get(i); + DenseVector eigenVector = eig.getEigenVector(i); + double[] outpt = a.leftMultiply(eigenVector).toArray(); + double[] expected = eigenVector.scale(eigenValue).toArray(); + assertArrayEquals(expected,outpt,1e-12); + } + assertArrayEquals(new double[]{(607.0122035912394), (342.23158891508194), (319.4166697533473), (314.3320479150682), (277.0466868867828), (272.8066860877825), (258.25981917694133), (247.40723877187338), (229.2357092782707), (198.22007297732208), (194.0112606156516), (168.66008287924282), (154.86848777133343), (142.54341246616536), (130.89687348036423), (125.88587299805424), (30.17405842512715), (21.994327309233963), (18.44990986009159), (8.326134535013239)}, eig.eigenvalues().toArray(), 1e-12); + double[][] values = new double[][]{new double[]{-0.26715610385542954, -0.01390478703198786, 0.1854110021241962, -0.1854662839160773, -0.16229872855019126, -0.15697766285531115, -0.7864295335461396, 0.16376563037553812, -0.32631164748717006, -0.13389230710603137, 0.1805076758436477, 0.05177064282900672, 0.02963293191284903, 0.031168113842195717, -0.01300327521337336, 0.025592540782747024, 0.030883988897009274, 0.0038945894037057784, 0.008438603064424666, -0.0015324664887223986}, new double[]{-0.21646536805125438, -0.013881077062626332, -0.01884624743066648, -0.007719582182124222, -0.15538791543439368, 0.04399647639053526, 0.01326559640729243, 0.04899467891124426, -0.08245118897883152, 0.19811793138915004, -0.1439106414069904, -0.7424814246902505, -0.48003665485603947, 0.2438294393508589, 0.0562373627082841, -0.08755214897319864, 0.0567217067754628, -0.03615656676219461, 0.015398062015907793, -0.01902975054850682}, new double[]{-0.1066921727252573, 0.00752781363333148, 0.01026475624523088, -0.031229671834569855, 0.0012686806850796797, 0.021181797375686102, 0.02021525202531208, -0.02837784114210916, -0.019720486930612822, 0.05227451308196039, -0.012161251995479868, -0.02483240450746341, 0.016905728457830504, -0.004130760055860076, -0.01904130742958451, -0.04416933391011066, -0.4074617231152516, 0.8676002055968268, 0.17512107945483427, -0.17288787786512805}, new double[]{-0.33920371511087805, -0.46691663525020594, -0.09031745999231369, 0.6321417470530342, -0.020930272891643172, 0.43592985933677975, -0.1452977571167301, 0.004938893921918626, 0.12863584565123845, -0.14245594866051303, 0.03359117860557403, 0.0840023224659054, -0.02300917557737052, -0.01623752607387114, -0.011098541422544728, -0.0033447671673163286, 0.007461812930786365, -0.013528489929949457, 0.02808146342169226, -0.024086648088416315}, new double[]{-0.259951086830275, -0.04997448086930181, 0.6491032077238915, -0.32175219248118364, 0.38418523432722484, 0.3963171032227332, 0.18407833127986634, -0.08492585414572532, -0.08656096384235985, -0.14763563022387532, -0.09947539484383917, 0.05987364971422366, -0.05398743934010043, 0.04417092562419047, 0.029588033891042777, -0.065271962860663, 0.07067220805307384, -0.02878732331263646, 0.01801166565448024, 0.019677109132111653}, new double[]{-0.3160522302492655, 0.1425457927728826, -0.6945965698158629, -0.4412619380518735, 0.25596051294629746, 0.30979369884113817, -0.08332532009062198, -0.049365521208594, 0.0653886986505201, -0.14860384242615965, -0.014372554107587941, 0.03861282802085877, -0.014269848155164286, -0.015802221885064142, 0.029659945096216545, -0.004361717142820119, 0.022282695747436652, -0.03838055234237704, 0.012884385334153743, -0.014256056678976633}, new double[]{-0.18106927259731728, 0.028748825007578232, 0.04337755020686359, -0.021238387457528943, -0.03847239877274511, 0.011325817965529623, 0.06589137341041316, -0.00047616346151320524, -0.00965129771873849, 0.14532583048801603, -0.10248137823354767, -0.036481870869998274, -8.539265679888359e-05, -0.11504227993486503, -0.17514409047608884, 0.9351610289393385, 0.049619156156582724, 0.021736153494524058, -0.018600022686212497, -0.05780589091592112}, new double[]{-0.15591504015606542, -0.03953307882230713, 0.03301774964306644, -0.021098871447422937, 0.01782536479691528, -0.0035645450588703362, -0.02719654267535835, -0.036528215244615365, 0.10382343154094024, 0.08032026638691166, -0.12974686778592456, -0.44442381390623015, 0.8179812555384617, 0.0383787019925684, -0.22912576306606652, -0.1088944425937944, 0.019102109645195625, -0.045001745479443675, -0.03140350863718813, 0.017645434601391112}, new double[]{-0.1754400162655663, 0.01711452604601059, 0.03965138273897645, -0.030375332105871844, -0.09746918771877507, -0.029798619073213253, 0.018230685423737642, -0.023368576984577207, -0.026283884585206595, 0.1526486186907142, -0.06565525432286132, -0.08417847146265778, -0.10532587646752077, -0.9236993369000231, -0.06003477384997813, -0.20743602856009283, 0.08907302697233448, -0.008322587362255062, -0.022097477444714853, -0.01838009751552626}, new double[]{-0.3429134201976233, 0.795859467843955, 0.0903851269467892, 0.42246126697206393, 0.11051462059291979, -0.08292327038947497, 0.008279250728023886, -0.013891522199128592, 0.10806934045589865, -0.11719353164345055, 0.05599440099466047, 0.03489598840817147, -0.01960623645570973, 0.05539124877786237, -0.040244824792618886, -0.05714197672159361, 0.06762349645206606, 0.0057802638559934115, 0.013432301682123959, 0.01154118392104854}, new double[]{-0.2410770916621225, -0.18877168116635618, 0.05365413952935207, -0.1097097232975626, -0.08063525762356431, -0.40326918223039254, -0.04471858180417749, -0.7253429627986044, 0.37956281631354716, -0.1374996707499872, 0.07298202701198002, 0.026706571009866727, -0.12001278348116623, 0.06892953229099162, -0.04307738212372077, -0.008074525719773884, 0.05648994551633868, 0.0029953794744891665, 0.014857944097380246, 0.023918496469108614}, new double[]{-0.10327213057795478, 0.018867378998055854, 0.031802369587523194, -0.004460665569963404, 0.011462906328509775, -0.04813267037004528, 0.01314423728214291, -0.014666799062726915, -0.01235191188621735, 0.08459453112224846, -0.006622633724362204, 0.009127643356117878, 0.005184779688786686, -0.04825899579638822, 0.03829236978987259, 0.026064083810448932, -0.674375822967597, -0.4376357897852843, 0.569335977575442, 0.05218083896969409}, new double[]{-0.2606735378193528, -0.003758233708587899, 0.05419125719059475, -0.2248451448955618, -0.65236316674233, 0.008383753171361743, 0.35640003262099895, 0.3618580018584963, 0.2158028483568182, -0.3457362410337, 0.05596705513192054, 0.11503286856883942, 0.04174780431192078, 0.05793052973521669, -0.04438856158062291, -0.046584819472861555, 0.0130568010141935, -0.019017023547427202, 0.032082592808699316, -0.04077916554753696}, new double[]{-0.22805448182235571, -0.018505801239088612, 0.02003690469103533, -0.08667775157176884, -0.046736601319101165, 0.10017515395417752, 0.10835233294222103, 0.02107044575192801, 0.06062673398714634, 0.6992915201740049, 0.5898469641747583, 0.18523582853613746, 0.0381212371745189, 0.13432133571942126, -0.05566996531799591, -0.09171998076376336, 0.09231646788289002, -0.019494880786832258, -0.021149822314972977, 0.007601678513472485}, new double[]{-0.08432944212460246, -0.018377177274517965, -0.020684304470545103, 0.0005125472758766888, -0.024502940800675967, -0.0013914454767143362, 0.03059283209237869, 0.034702131584703415, -0.022772161689467423, 0.0011657672355060958, -0.0148812638752305, -0.002598381845480636, -0.030052756685455772, -0.017280138843647857, -0.01591330228238729, 0.03539313031899835, -0.1933885285046404, 0.13359082373558595, -0.2278019764038912, 0.9374353567274204}, new double[]{-0.1164379288702311, 0.004490960089059218, 0.03185588853917711, -0.0021963858491475747, 0.0213529609396527, -0.03376520623626567, 0.012872677649063656, -0.007495895385014294, 0.004820917932928765, 0.011213906726055664, -0.002696263721788023, 0.013711583567772936, -0.030878577335419606, 0.0026984846273555066, 0.036668884296995795, -0.018462398128311096, -0.535096101610618, -0.16706542306989128, -0.7661932201791547, -0.2816083020712854}, new double[]{-0.23436896326134618, -0.1041723056459737, -0.1690436052266547, 0.12406587574627122, -0.02373498749671781, -0.2050166538149868, 0.3889604947272082, -0.21274058157455994, -0.777598272090964, -0.13129020674110803, 0.08456306139940428, 0.0518245364849499, 0.10362567787979178, 0.04667351366681789, -0.015198233337077252, -0.035249809931340914, 0.06103039843549032, -0.021591485525223473, 0.006679244668796843, -0.02736066086138905}, new double[]{-0.20359229956041464, -0.009877311296230795, -0.034642523916877166, -0.0046217003018542355, -0.12837293581337775, -0.08612684845090862, -0.07579519209698454, 0.013989794218879463, -0.03824305574811956, 0.38749003306445007, -0.7307276143578229, 0.41715892719054265, -0.011133220543453202, 0.17935057457453865, -0.06921439184757223, -0.1561046472327393, 0.07774146868774824, -0.015492185546862169, -0.0029185387766898505, -0.007448075235953968}, new double[]{-0.13527490691546717, -0.002810832081829183, 0.023038953528946464, 0.018015972940972555, -0.04320812223613813, -0.06861644689187882, 0.01776645097549279, -0.0009270836701030808, 0.05221468669802647, 0.08548070887491509, -0.0479150112971065, -0.02858434807813084, 0.21904454141202281, -0.05448484093769856, 0.942966943751532, 0.11783109434024673, 0.07141093203398967, 0.03875927438266021, -0.004812479447754362, 0.012435093376353554}, new double[]{-0.24147461042496196, -0.2738472465888555, -0.03891671477744992, 0.023818955522287515, 0.50802230435199, -0.5400427606883157, 0.11889559592832631, 0.5037751747256916, 0.17583773856009446, -0.04341988104559377, 0.022034850426857166, -0.009191750056862291, -0.07652751987417403, -0.00019623152598396564, -0.04280071917815721, -0.01851662493389959, 0.06051047939350801, 0.024914400680777896, 0.029690601833618248, -0.019701443530991528}}; + assertEigenvectorsEquals(new DenseMatrix(values), eig.eigenvectors()); + } + + /** + * if you generate the expected values using another library, the eigenvectors (i.e. the columns) + * may have the opposite sign as what is produced by DenseMatrix. So, we will negate a column if + * the first value of the expected eigenvector is the negative of the first value of the actual + * eigenvector. + * @param expectedDM + * @param actualDM + */ + private void assertEigenvectorsEquals(DenseMatrix expectedDM, DenseMatrix actualDM) { + assertEquals(expectedDM.dim1, actualDM.dim1, "dim1 differs"); + assertEquals(expectedDM.dim2, actualDM.dim2, "dim2 differs"); + assertArrayEquals(expectedDM.getShape(), actualDM.getShape(), "shape differs"); + + // loop over column indices (i.e. we are going to test eigenvectors one at a time) + // and see if we need to negate the actual eigenvector if it differs by sign from the actual + for(int j=0; j max = Double::max; + BiFunction min = Double::min; + BiFunction sum = Double::sum; + + assertEquals(a.maxValue(),a.reduce(Double.MIN_VALUE, DoubleUnaryOperator.identity(), max)); + assertEquals(b.maxValue(),b.reduce(Double.MIN_VALUE, DoubleUnaryOperator.identity(), max)); + assertEquals(c.maxValue(),c.reduce(Double.MIN_VALUE, DoubleUnaryOperator.identity(), max)); + + assertEquals(a.minValue(),a.reduce(Double.MAX_VALUE, DoubleUnaryOperator.identity(), min)); + assertEquals(b.minValue(),b.reduce(Double.MAX_VALUE, DoubleUnaryOperator.identity(), min)); + assertEquals(c.minValue(),c.reduce(Double.MAX_VALUE, DoubleUnaryOperator.identity(), min)); + + assertEquals(a.sum(),a.reduce(0.0, DoubleUnaryOperator.identity(), sum)); + assertEquals(b.sum(),b.reduce(0.0, DoubleUnaryOperator.identity(), sum)); + assertEquals(c.sum(),c.reduce(0.0, DoubleUnaryOperator.identity(), sum)); + + assertEquals(a.sum(i -> i * i),a.reduce(0.0, i -> i * i, sum)); + assertEquals(b.sum(Math::abs),b.reduce(0.0, Math::abs, sum)); + assertEquals(c.sum(Math::exp),c.reduce(0.0, Math::exp, sum)); + + + + DenseVector d = new DenseVector(new double[] {-1.0, 1.0, -2.0, 2.0}); + assertFalse(d.reduce(true,DoubleUnaryOperator.identity(),(value, bool) -> bool && value > 0.0)); + DenseVector e = new DenseVector(new double[] {0.0, 1.0, 0.0, 2.0}); + assertFalse(e.reduce(true,DoubleUnaryOperator.identity(),(value, bool) -> bool && value > 0.0)); + DenseVector f = new DenseVector(new double[] {0.1, 1.0, 0.2, 2.0}); + assertTrue(f.reduce(true,DoubleUnaryOperator.identity(),(value, bool) -> bool && value > 0.0)); + + } + + + @Test + public void testMeanVariance() { + DenseVector d = new DenseVector(new double[] {1, -2, 3, -4, 5, -5, 4, -3, 2, -1}); + MeanVarianceAccumulator mva = d.meanVariance(); + Assertions.assertEquals(12.222222222, mva.getVariance(), 0.000001); + Assertions.assertEquals(3.4960294939, mva.getStdDev(), 0.000001); + Assertions.assertEquals(0.0,mva.getMean(), 0.000001); + Assertions.assertEquals(5,mva.getMax(), 0.000001); + Assertions.assertEquals(-5,mva.getMin(), 0.000001); + } + + @Test public void size() { DenseVector s = generateVectorA(); diff --git a/Math/src/test/resources/cholesky-test.py b/Math/src/test/resources/cholesky-test.py new file mode 100644 index 000000000..6c1efab1f --- /dev/null +++ b/Math/src/test/resources/cholesky-test.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from scipy.linalg import cholesky +from numpy import ndarray, transpose + +def printArrayAsJavaDoubles(nda: ndarray): + assert len(nda.shape) == 1 + return "new double[]{"+(", ".join(map(str, nda)))+"}" + +def printMatrixAsJavaDoubles(nda: ndarray): + assert len(nda.shape) == 2 + return "new double[][]{"+(", ".join(map(printArrayAsJavaDoubles, nda))) + "}" + +''' +This code can be useful to generate expected values for generating expected values for lu factorization. Here are the steps: +1) Generate a DenseMatrix using whatever method/mechanism you choose +2) Print out the DenseMatrix using org.tribuo.math.la.DenseMatrixTest.printMatrixPythonFriendly(DenseMatrix) +3) The above will print out a python-friendly matrix defined as arrays which you can paste into the code below. +4) Run the code below. It will print out Java-friendly lower factorization matrix. +''' + + +if __name__ == '__main__': + a = [[ 8.000000000000000, 2.000000000000000, 3.000000000000000], + [ 2.000000000000000, 9.000000000000000, 3.000000000000000], + [ 3.000000000000000, 3.000000000000000, 6.000000000000000]] + c = transpose(cholesky(a)) + print(printMatrixAsJavaDoubles(c)) diff --git a/Math/src/test/resources/eigendecomposition-test.py b/Math/src/test/resources/eigendecomposition-test.py new file mode 100644 index 000000000..f89bd8adc --- /dev/null +++ b/Math/src/test/resources/eigendecomposition-test.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from numpy import ndarray +from numpy.linalg import eig + +def printArrayAsJavaDoubles(nda: ndarray): + assert len(nda.shape) == 1 + return "new double[]{"+(", ".join(map(str, nda)))+"}" + +def printMatrixAsJavaDoubles(nda: ndarray): + assert len(nda.shape) == 2 + return "new double[][]{"+(", ".join(map(printArrayAsJavaDoubles, nda))) + "}" + +''' +This code can be useful to generate expected values for generating expected values for eigendecomposition. Here are the steps: +1) Generate a DenseMatrix using whatever method/mechanism you choose +2) Print out the DenseMatrix using org.tribuo.math.la.DenseMatrixTest.printMatrixPythonFriendly(DenseMatrix) +3) The above will print out a python-friendly matrix defined as arrays which you can paste into the code below. +4) Run the code below. It will print out Java-friendly eigenvalues and eigenvectors which you can use in your unit test. +''' + + +if __name__ == '__main__': + a = [[272.397395096041460, 28.297103495120155, 11.396656147497406, 12.106296359324405, 20.736399065920747, 15.614978237684149, 18.543740510632610, 5.236590376749971], + [28.297103495120155, 188.191184447367480, 26.048222455749110, 5.502057905049764, + 25.354185394601018, 13.096921987061375, 17.340008675427296, 11.519052526148316], + [11.396656147497406, 26.048222455749110, 184.808999941204970, 8.842353725435146, + 15.174247776061687, 12.374742652448266, 13.473116083130513, 10.172399441288023], + [12.106296359324405, 5.502057905049764, 8.842353725435146, 9.900676311693298, + 12.602839054563070, 14.619263012533160, 21.880922342272704, 4.917793385489778], + [20.736399065920747, 25.354185394601018, 15.174247776061687, 12.602839054563070, + 185.761164903314860, 17.196093211897440, 19.248597118145410, 9.204408145367047], + [15.614978237684149, 13.096921987061375, 12.374742652448266, 14.619263012533160, + 17.196093211897440, 152.690212322542980, 15.267218155191648, 10.476862371088615], + [18.543740510632610, 17.340008675427296, 13.473116083130513, 21.880922342272704, + 19.248597118145410, 15.267218155191648, 448.994124937120900, 12.678295358033086], + [5.236590376749971, 11.519052526148316, 10.172399441288023, 4.917793385489778, 9.204408145367047, 10.476862371088615, 12.678295358033086, 171.487068866930030]] + + w, vr = eig(a) + idx = w.argsort()[::-1] + w = w[idx] + vr = vr[:, idx] + + print(printArrayAsJavaDoubles(w)) + print(printMatrixAsJavaDoubles(vr)) + \ No newline at end of file diff --git a/Math/src/test/resources/lufactorization-test.py b/Math/src/test/resources/lufactorization-test.py new file mode 100644 index 000000000..09540ccf9 --- /dev/null +++ b/Math/src/test/resources/lufactorization-test.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +from scipy.linalg import lu_factor +from numpy import ndarray + +def printArrayAsJavaDoubles(nda: ndarray): + assert len(nda.shape) == 1 + return "new double[]{"+(", ".join(map(str, nda)))+"}" + +def printMatrixAsJavaDoubles(nda: ndarray): + assert len(nda.shape) == 2 + return "new double[][]{"+(", ".join(map(printArrayAsJavaDoubles, nda))) + "}" + +''' +This code can be useful to generate expected values for generating expected values for lu factorization. Here are the steps: +1) Generate a DenseMatrix using whatever method/mechanism you choose +2) Print out the DenseMatrix using org.tribuo.math.la.DenseMatrixTest.printMatrixPythonFriendly(DenseMatrix) +3) The above will print out a python-friendly matrix defined as arrays which you can paste into the code below. +4) Run the code below. It will print out Java-friendly Upper and Lower matrices. +''' + + +if __name__ == '__main__': + a = [[ 0.727563680032868, 0.683223471759845, 0.308719455332660, 0.277078490074137, 0.665548951794574, 0.903372264672178, 0.368782913411306, 0.275748069441702, 0.463653575809153, 0.782901778790036, 0.919327782868717, 0.436490974423287, 0.749906181255448, 0.386566874359349, 0.177378477909378, 0.594349910889684, 0.209767568866332, 0.825965871887821, 0.172217937687852, 0.587427381786296], + [ 0.751280406767460, 0.571040348414867, 0.580024884502061, 0.752509948590651, 0.031418238826581, 0.357919919477129, 0.817796930835639, 0.417687546752919, 0.974035681495881, 0.713406257823229, 0.480574516556434, 0.291656497411804, 0.949860134659467, 0.820491823386347, 0.636644547856282, 0.369121493941897, 0.360254875366135, 0.434661085140610, 0.457317094444769, 0.472688420875855], + [ 0.469022520615569, 0.827322424014995, 0.151031554528758, 0.833866235444166, 0.460306372661161, 0.280571991672810, 0.195964207423156, 0.179273440874917, 0.865686796327352, 0.486590661826194, 0.420971706631029, 0.632869791353279, 0.699858645067172, 0.315328450057673, 0.571620305529977, 0.371012289683761, 0.871814595964839, 0.805730942661998, 0.616213685035179, 0.374359356089004], + [ 0.697248729269730, 0.908614580207571, 0.196147071881852, 0.809124816727739, 0.627933275413473, 0.463399271028345, 0.305579155667449, 0.539909434259369, 0.635111014456388, 0.126257823298765, 0.497326892475921, 0.027314166285966, 0.036484516690250, 0.483843854954305, 0.655053381109821, 0.395766280171242, 0.940172465685381, 0.384610843917291, 0.646231978797643, 0.770465637773941], + [ 0.894042795818409, 0.598837037145018, 0.976034471618408, 0.077085112935252, 0.775120695927176, 0.278822302498768, 0.899205329729558, 0.373836143620542, 0.431309555135461, 0.332324278384749, 0.315147471367318, 0.442735942086220, 0.096450915880824, 0.745153306215386, 0.170859737882898, 0.805390719921382, 0.139789595286861, 0.092946816941456, 0.027029866882133, 0.548334691731752], + [ 0.282198286130648, 0.967495322008511, 0.904501079001562, 0.343942146444473, 0.014055818452733, 0.809594124810063, 0.521494303689756, 0.835155555007701, 0.583461648591119, 0.688246069347515, 0.556382201319052, 0.115750123719626, 0.588508636894290, 0.895142285356935, 0.626475165203972, 0.382580991181686, 0.903625094750243, 0.833699601096352, 0.260193631166134, 0.962236220249723], + [ 0.397536156015603, 0.573534718676561, 0.258165001039166, 0.821745557953679, 0.659183474319186, 0.614106754792385, 0.488197374456925, 0.361470612892181, 0.737981632291887, 0.855361490981633, 0.339747108594258, 0.758029864634980, 0.726078204940908, 0.508562234873031, 0.775409918592465, 0.376087287934175, 0.862824158797762, 0.697093313962580, 0.160130754288181, 0.800060876007507], + [ 0.725449553620742, 0.794874724221567, 0.823240915575137, 0.850283450238451, 0.984979205346519, 0.129126686504577, 0.614094824198923, 0.542874681181350, 0.152711743081139, 0.534576009896194, 0.548993732829397, 0.328559964445490, 0.720124675401226, 0.387796158038202, 0.015246621455439, 0.102898367034462, 0.103856897241301, 0.173225648073560, 0.897633823530948, 0.418858419146666], + [ 0.847459486748288, 0.336510352176303, 0.423201682090854, 0.052406745473978, 0.602894469596566, 0.166167160521423, 0.482231721337996, 0.572751749772856, 0.517300271039096, 0.711167603253355, 0.676391229848706, 0.851110551751301, 0.784331937730285, 0.289892235477594, 0.369486380358754, 0.505344107970708, 0.284422759606104, 0.975007583649560, 0.571903178903554, 0.122532257951209], + [ 0.434603504590581, 0.549169111514354, 0.576735419379642, 0.727536503429085, 0.596054173365293, 0.818055527024611, 0.883763059118550, 0.372825589895905, 0.646352552511846, 0.840091222384709, 0.096858386069135, 0.913314494523744, 0.613422506208544, 0.586314598328913, 0.161061316693670, 0.999695346069685, 0.589474508656968, 0.779202691211620, 0.517479473521841, 0.074713999257383], + [ 0.495769399677456, 0.175863201981718, 0.497332761330185, 0.689073450342942, 0.652826173333953, 0.358856625356691, 0.550875603692069, 0.767121644424121, 0.525732565377378, 0.177013501790267, 0.712595067250474, 0.547207516726220, 0.460507804784482, 0.338962509756818, 0.040996406939037, 0.797565866036582, 0.705004072943589, 0.303686967973348, 0.639451838428449, 0.565818461430073], + [ 0.516624213063747, 0.535279223759976, 0.264002422014894, 0.158654828767362, 0.699170735689853, 0.393163530337159, 0.454667182622594, 0.267983090284141, 0.482431089553878, 0.414409509862082, 0.588727323011403, 0.160817772323823, 0.260643008608848, 0.511360607452316, 0.247617535534881, 0.632397980724370, 0.339772817279730, 0.655652397024369, 0.813470907378378, 0.985505725848116], + [ 0.508578384985525, 0.948798121493297, 0.267027195399829, 0.643681174997653, 0.944718115229204, 0.589869906162092, 0.745528957003539, 0.132007862985256, 0.694705669040006, 0.142405761138598, 0.537393927725167, 0.066367400295788, 0.725299650088167, 0.675524011492800, 0.839865909254752, 0.480540012405865, 0.333956520643007, 0.808328728852761, 0.786510616200610, 0.290253712447647], + [ 0.527263841932272, 0.422040341797175, 0.963337127977142, 0.944261290952574, 0.828810195034958, 0.820888510620017, 0.875489401452617, 0.344751271430859, 0.861737213352584, 0.420654751742865, 0.600866554883861, 0.932875128327571, 0.966258030054034, 0.622854420162981, 0.172694559333108, 0.429899839858112, 0.058715605414863, 0.392899165367936, 0.069986240760362, 0.100471268888463], + [ 0.916293775167590, 0.202182492969044, 0.562990240045040, 0.561700943667366, 0.080282170345142, 0.416882590906358, 0.560143976505208, 0.100264341467102, 0.610836098745395, 0.920378070753754, 0.033709461353870, 0.179426442633273, 0.997460814518753, 0.741524133735247, 0.063185128546488, 0.318886141572087, 0.631989300813935, 0.727637943878689, 0.028750514440684, 0.812558114652981], + [ 0.384611053799500, 0.848893556448996, 0.511415204130695, 0.812886941218117, 0.988548992413256, 0.549928653123423, 0.012207078788491, 0.155015083379985, 0.135270634431162, 0.674816309697567, 0.689649109462087, 0.920102844083035, 0.409658210002090, 0.883157153605221, 0.932938256811512, 0.144324939028378, 0.453352172100663, 0.147898709453929, 0.737653484628361, 0.268544896116506], + [ 0.034521076210014, 0.879589794045556, 0.531484787366069, 0.893395046645043, 0.023600080418727, 0.457130087490391, 0.309388398665635, 0.155466514815749, 0.627079557338775, 0.461676268662487, 0.680737820216252, 0.572540137541269, 0.615213870073979, 0.708217434266601, 0.705454587532873, 0.519265491184666, 0.694559974426116, 0.248434532140912, 0.201766115712206, 0.589422611126140], + [ 0.622544320874315, 0.878809666482779, 0.545152555601672, 0.976876449588937, 0.138795391968933, 0.329553074189390, 0.846208536041941, 0.664561531074459, 0.296768764393705, 0.232766896939920, 0.794559626573988, 0.649310518712321, 0.319230116436018, 0.172366521783839, 0.386637431544191, 0.650066810126769, 0.672451256731913, 0.603631652696010, 0.217031435797061, 0.747998180783144], + [ 0.656512480821271, 0.196861300882893, 0.006299617079889, 0.189254768619929, 0.442894631613962, 0.207457828941940, 0.562533957338732, 0.144591463680319, 0.155582437940280, 0.803867184154219, 0.429641056039816, 0.564801166165025, 0.288190176083906, 0.895153914804218, 0.062212252318568, 0.396881986980274, 0.477159937555487, 0.891591046714639, 0.485021313686658, 0.333088241748358], + [ 0.409871211214595, 0.004496524367330, 0.044602411837619, 0.953491004202395, 0.436540393965985, 0.145649394030820, 0.041024471103732, 0.592401749913592, 0.276077042794300, 0.221970716968770, 0.739892799045843, 0.509987512730102, 0.058603584225627, 0.637755395387549, 0.507772241073650, 0.982634356937440, 0.831133393048532, 0.114593073004391, 0.588640743315033, 0.739090953644765]] + + lu, piv = lu_factor(a) + L, U = np.tril(lu, k=-1) + np.eye(20), np.triu(lu) + print(printMatrixAsJavaDoubles(U)) + print(printMatrixAsJavaDoubles(L)) diff --git a/Regression/SLM/pom.xml b/Regression/SLM/pom.xml index b456f015f..b95a8faf6 100644 --- a/Regression/SLM/pom.xml +++ b/Regression/SLM/pom.xml @@ -47,11 +47,6 @@ tribuo-math ${project.version} - - org.apache.commons - commons-math3 - ${commonsmath.version} - ${project.groupId} diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSLassoTrainer.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSLassoTrainer.java index d24f65173..9491d8414 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSLassoTrainer.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSLassoTrainer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.tribuo.regression.slm; -import org.apache.commons.math3.linear.RealVector; +import com.oracle.labs.mlrg.olcut.util.Pair; +import org.tribuo.math.la.DenseMatrix; +import org.tribuo.math.la.DenseVector; import java.util.ArrayList; import java.util.Collections; @@ -53,32 +55,33 @@ public LARSLassoTrainer() { } @Override - protected RealVector newWeights(SLMState state) { + protected DenseVector newWeights(SLMState state) { if (state.last) { return super.newWeights(state); } - RealVector deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi,state.r); + Pair deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi,state.r); if (deltapi == null) { return null; } - RealVector delta = state.unpack(deltapi); + DenseVector delta = state.unpack(deltapi.getA()); + DenseMatrix xpiInv = deltapi.getB(); // Computing gamma List candidates = new ArrayList<>(); - double AA = SLMTrainer.sumInverted(state.xpi); + double AA = xpiInv.rowSum().sum(); double CC = state.C; - RealVector wa = SLMTrainer.getwa(state.xpi,AA); - RealVector ar = SLMTrainer.getA(state.X, state.xpi,wa); + DenseVector wa = SLMTrainer.getWA(xpiInv,AA); + DenseVector ar = SLMTrainer.getA(state.X, state.xpi, wa); for (int i = 0; i < state.numFeatures; ++i) { if (!state.activeSet.contains(i)) { - double c = state.corr.getEntry(i); - double a = ar.getEntry(i); + double c = state.corr.get(i); + double a = ar.get(i); double v1 = (CC - c) / (AA - a); double v2 = (CC + c) / (AA + a); @@ -94,46 +97,22 @@ protected RealVector newWeights(SLMState state) { double gamma = Collections.min(candidates); -// // The lasso modification -// if (active.size() >= 2) { -// int min = active.get(0); -// double min_gamma = - beta.getEntry(min) / (wa.getEntry(active.indexOf(new Integer(min))) * (corr.getEntry(min) >= 0 ? +1 : -1)); -// -// for (int i = 1; i < active.size()-1; ++i) { -// int idx = active.get(i); -// double gamma_i = - beta.getEntry(idx) / (wa.getEntry(active.indexOf(new Integer(idx))) * (corr.getEntry(idx) >= 0 ? +1 : -1)); -// if (gamma_i < 0) continue; -// if (gamma_i < min) { -// min = i; -// min_gamma = gamma_i; -// } -// } -// -// if (min_gamma < gamma) { -// active.remove(new Integer(min)); -// beta.setEntry(min,0.0); -// return beta.add(delta.mapMultiplyToSelf(min_gamma)); -// } -// } -// -// return beta.add(delta.mapMultiplyToSelf(gamma)); - - RealVector other = delta.mapMultiplyToSelf(gamma); + delta.scaleInPlace(gamma); for (int i = 0; i < state.numFeatures; ++i) { - double betaElement = state.beta.getEntry(i); - double otherElement = other.getEntry(i); + double betaElement = state.beta.get(i); + double otherElement = delta.get(i); if ((betaElement > 0 && betaElement + otherElement < 0) || (betaElement < 0 && betaElement + otherElement > 0)) { - state.beta.setEntry(i,0.0); - other.setEntry(i,0.0); + state.beta.set(i,0.0); + delta.set(i,0.0); Integer integer = i; state.active.remove(integer); state.activeSet.remove(integer); } } - return state.beta.add(other); + return state.beta.add(delta); } @Override diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSTrainer.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSTrainer.java index ac93c2fab..e156fe3e3 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSTrainer.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/LARSTrainer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.tribuo.regression.slm; -import org.apache.commons.math3.linear.RealVector; +import com.oracle.labs.mlrg.olcut.util.Pair; +import org.tribuo.math.la.DenseMatrix; +import org.tribuo.math.la.DenseVector; import java.util.ArrayList; import java.util.Collections; @@ -53,32 +55,33 @@ public LARSTrainer() { } @Override - protected RealVector newWeights(SLMState state) { + protected DenseVector newWeights(SLMState state) { if (state.last) { return super.newWeights(state); } - RealVector deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi,state.r); + Pair deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi,state.r); if (deltapi == null) { return null; } - RealVector delta = state.unpack(deltapi); + DenseVector delta = state.unpack(deltapi.getA()); + DenseMatrix xpiInv = deltapi.getB(); // Computing gamma List candidates = new ArrayList<>(); - double AA = SLMTrainer.sumInverted(state.xpi); + double AA = xpiInv.rowSum().sum(); double CC = state.C; - RealVector wa = SLMTrainer.getwa(state.xpi,AA); - RealVector ar = SLMTrainer.getA(state.X, state.xpi,wa); + DenseVector wa = SLMTrainer.getWA(xpiInv,AA); + DenseVector ar = SLMTrainer.getA(state.X, state.xpi, wa); for (int i = 0; i < state.numFeatures; ++i) { if (!state.activeSet.contains(i)) { - double c = state.corr.getEntry(i); - double a = ar.getEntry(i); + double c = state.corr.get(i); + double a = ar.get(i); double v1 = (CC - c) / (AA - a); double v2 = (CC + c) / (AA + a); @@ -94,7 +97,9 @@ protected RealVector newWeights(SLMState state) { double gamma = Collections.min(candidates); - return state.beta.add(delta.mapMultiplyToSelf(gamma)); + delta.scaleInPlace(gamma); + + return state.beta.add(delta); } @Override diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SLMTrainer.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SLMTrainer.java index aad79fa4f..c7cbdad57 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SLMTrainer.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SLMTrainer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,26 +18,20 @@ import com.oracle.labs.mlrg.olcut.config.Config; import com.oracle.labs.mlrg.olcut.provenance.Provenance; +import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Dataset; import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.SparseTrainer; import org.tribuo.WeightedExamples; +import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SparseVector; -import org.tribuo.math.la.VectorTuple; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.provenance.impl.TrainerProvenanceImpl; import org.tribuo.regression.Regressor; -import org.tribuo.util.Util; -import org.apache.commons.math3.linear.Array2DRowRealMatrix; -import org.apache.commons.math3.linear.ArrayRealVector; -import org.apache.commons.math3.linear.LUDecomposition; -import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.linear.RealVector; -import org.apache.commons.math3.linear.SingularMatrixException; import java.time.OffsetDateTime; import java.util.ArrayList; @@ -46,6 +40,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -94,13 +89,20 @@ public SLMTrainer(boolean normalize) { */ protected SLMTrainer() {} - protected RealVector newWeights(SLMState state) { - RealVector result = SLMTrainer.ordinaryLeastSquares(state.xpi,state.y); + /** + * Computes the new feature weights. + *

+ * In this version it returns the ordinary least squares solution for the current state. + * @param state The SLM state to operate on. + * @return The new feature weights. + */ + protected DenseVector newWeights(SLMState state) { + Pair result = SLMTrainer.ordinaryLeastSquares(state.xpi,state.y); if (result == null) { return null; } else { - return state.unpack(result); + return state.unpack(result.getA()); } } @@ -141,7 +143,7 @@ public SparseLinearModel train(Dataset examples, Map e : examples) { @@ -150,72 +152,44 @@ public SparseLinearModel train(Dataset examples, Map a - colMean, (a,b) -> b + a*a)); + col.foreachInPlace(a -> (a - colMean) / colNorm); + featureMatrix.setColumn(i,col); + featureMeans[i] = colMean; + featureNorms[i] = colNorm; } for (int i = 0; i < numOutputs; i++) { - outputMeans[i] = Util.mean(outputs[i]); - // Remove mean and aggregate variance - double sum = 0.0; - for (int j = 0; j < numExamples; j++) { - outputs[i][j] -= outputMeans[i]; - sum += outputs[i][j] * outputs[i][j]; - } - outputVariances[i] = Math.sqrt(sum); - // Remove variance - for (int j = 0; j < numExamples; j++) { - outputs[i][j] /= outputVariances[i]; - } + DenseVector row = outputMatrix.getRow(i); + double rowMean = row.meanVariance().getMean(); + double rowNorm = Math.sqrt(row.reduce(0.0, a -> a - rowMean, (a,b) -> b + a*a)); + row.foreachInPlace(a -> (a - rowMean) / rowNorm); + outputMeans[i] = rowMean; + outputNorms[i] = rowNorm; } } else { Arrays.fill(featureMeans,0.0); - Arrays.fill(featureVariances,1.0); + Arrays.fill(featureNorms,1.0); Arrays.fill(outputMeans,0.0); - Arrays.fill(outputVariances,1.0); - } - - // Construct the output matrix from the double[][] after scaling - RealMatrix outputMatrix = new Array2DRowRealMatrix(outputs); - - // Array example is useful to compute a submatrix - int[] exampleRows = new int[numExamples]; - for (int i = 0; i < numExamples; ++i) { - exampleRows[i] = i; + Arrays.fill(outputNorms,1.0); } - RealVector one = new ArrayRealVector(numExamples,1.0); - int numToSelect; if ((maxNumFeatures < 1) || (maxNumFeatures > featureIDMap.size())) { numToSelect = featureIDMap.size(); @@ -228,14 +202,14 @@ public SparseLinearModel train(Dataset examples, Map max) { max = absCorr; @@ -307,14 +279,13 @@ private SparseVector trainSingleDimension(SLMState state, int[] exampleRows, int } // Compute the active matrix - int[] activeFeatures = Util.toPrimitiveInt(state.active); - state.xpi = state.X.getSubMatrix(exampleRows, activeFeatures); + state.xpi = state.X.selectColumns(state.active); if (state.active.size() == (numToSelect - 1)) { state.last = true; } - RealVector betapi = newWeights(state); + DenseVector betapi = newWeights(state); if (betapi == null) { // Matrix was not invertible @@ -328,8 +299,8 @@ private SparseVector trainSingleDimension(SLMState state, int[] exampleRows, int Map parameters = new HashMap<>(); for (int i = 0; i < state.numFeatures; ++i) { - if (state.beta.getEntry(i) != 0) { - parameters.put(i, state.beta.getEntry(i)); + if (state.beta.get(i) != 0) { + parameters.put(i, state.beta.get(i)); } } @@ -344,45 +315,23 @@ private SparseVector trainSingleDimension(SLMState state, int[] exampleRows, int * @param target The vector of target values. * @return The OLS solution for the supplied features. */ - static RealVector ordinaryLeastSquares(RealMatrix M, RealVector target) { - RealMatrix inv; - try { - inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse(); - } catch (SingularMatrixException s) { + static Pair ordinaryLeastSquares(DenseMatrix M, DenseVector target) { + Optional lu = M.matrixMultiply(M,true,false).luFactorization(); + if (lu.isPresent()) { + DenseMatrix inv = (DenseMatrix) lu.get().inverse(); + return new Pair<>(inv.matrixMultiply(M,false,true).leftMultiply(target),inv); + } else { // Matrix is not invertible, there is nothing we can do // We will let the caller decide what to do return null; } - - return inv.multiply(M.transpose()).operate(target); - } - - /** - * Sums inverted matrix. - * @param matrix The Matrix to operate on. - * @return The sum of the inverted matrix. - */ - static double sumInverted(RealMatrix matrix) { - // Why are we not trying to catch the potential exception? - // Because in the context of LARS, if we call this method, we know the matrix is invertible - RealMatrix inv = new LUDecomposition(matrix.transpose().multiply(matrix)).getSolver().getInverse(); - - RealVector one = new ArrayRealVector(matrix.getColumnDimension(),1.0); - - return one.dotProduct(inv.operate(one)); } - /** - * Inverts the matrix, takes the dot product and scales it by the supplied value. - * @param M The matrix to invert. - * @param AA The value to scale by. - * @return The vector of feature values. - */ - static RealVector getwa(RealMatrix M, double AA) { - RealMatrix inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse(); - RealVector one = new ArrayRealVector(M.getColumnDimension(),1.0); - - return inv.operate(one).mapMultiply(AA); + static DenseVector getWA(DenseMatrix inv, double AA) { + DenseVector ones = new DenseVector(inv.getDimension2Size(),1.0); + DenseVector output = inv.rightMultiply(ones); + output.scaleInPlace(AA); + return output; } /** @@ -393,9 +342,9 @@ static RealVector getwa(RealMatrix M, double AA) { * @param v A vector. * @return (M . v) . D^T */ - static RealVector getA(RealMatrix D, RealMatrix M, RealVector v) { - RealVector u = M.operate(v); - return D.transpose().operate(u); + static DenseVector getA(DenseMatrix D, DenseMatrix M, DenseVector v) { + DenseVector u = M.leftMultiply(v); + return D.rightMultiply(u); } static class SLMState { @@ -407,26 +356,26 @@ static class SLMState { protected final Set activeSet; protected final List active; - protected final RealMatrix X; - protected final RealVector y; + protected final DenseMatrix X; + protected final DenseVector y; - protected RealMatrix xpi; - protected RealVector r; - protected RealVector beta; + protected DenseMatrix xpi; + protected DenseVector r; + protected DenseVector beta; protected double C; - protected RealVector corr; + protected DenseVector corr; - protected Boolean last = false; + protected boolean last = false; - public SLMState(RealMatrix features, RealVector outputs, ImmutableFeatureMap featureIDMap, boolean normalize) { - this.numExamples = features.getRowDimension(); - this.numFeatures = features.getColumnDimension(); + public SLMState(DenseMatrix features, DenseVector outputs, ImmutableFeatureMap featureIDMap, boolean normalize) { + this.numExamples = features.getDimension1Size(); + this.numFeatures = features.getDimension2Size(); this.featureIDMap = featureIDMap; this.normalize = normalize; - this.active = new ArrayList<>(); + this.active = new ArrayList<>(numFeatures); this.activeSet = new HashSet<>(); - this.beta = new ArrayRealVector(numFeatures); + this.beta = new DenseVector(numFeatures); this.X = features; this.y = outputs; } @@ -436,11 +385,11 @@ public SLMState(RealMatrix features, RealVector outputs, ImmutableFeatureMap fea * @param values The values. * @return A dense vector representing the values at the active set indices. */ - public RealVector unpack(RealVector values) { - RealVector u = new ArrayRealVector(numFeatures); + public DenseVector unpack(DenseVector values) { + DenseVector u = new DenseVector(numFeatures); for (int i = 0; i < active.size(); ++i) { - u.setEntry(active.get(i), values.getEntry(i)); + u.set(active.get(i), values.get(i)); } return u; diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java index 4048ca794..015dfc569 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,9 +66,15 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel private SparseVector[] weights; private final DenseVector featureMeans; + /** + * Note this variable is called a variance, but it actually stores the l2 norm of the centered feature column. + */ private final DenseVector featureVariance; private final boolean bias; private double[] yMean; + /** + * Note this variable is called a variance, but it actually stores the l2 norm of the centered output. + */ private double[] yVariance; // Used to signal if the model has been rewritten to fix the issue with ElasticNet models in 4.0 and 4.1.0. @@ -76,13 +82,13 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel SparseLinearModel(String name, String[] dimensionNames, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo labelIDMap, - SparseVector[] weights, DenseVector featureMeans, DenseVector featureVariance, double[] yMean, double[] yVariance, boolean bias) { + SparseVector[] weights, DenseVector featureMeans, DenseVector featureNorms, double[] yMean, double[] yNorms, boolean bias) { super(name, dimensionNames, description, featureIDMap, labelIDMap, generateActiveFeatures(dimensionNames, featureIDMap, weights)); this.weights = weights; this.featureMeans = featureMeans; - this.featureVariance = featureVariance; + this.featureVariance = featureNorms; this.bias = bias; - this.yVariance = yVariance; + this.yVariance = yNorms; this.yMean = yMean; this.enet41MappingFix = true; } diff --git a/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java b/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java index 437e05757..b1b23727a 100644 --- a/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java +++ b/Regression/SLM/src/test/java/org/tribuo/regression/slm/TestSLM.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,21 +18,21 @@ import ai.onnxruntime.OrtException; import com.oracle.labs.mlrg.olcut.util.Pair; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.tribuo.DataSource; import org.tribuo.Dataset; import org.tribuo.Model; import org.tribuo.MutableDataset; import org.tribuo.SparseModel; -import org.tribuo.Trainer; +import org.tribuo.SparseTrainer; import org.tribuo.interop.onnx.OnnxTestUtils; import org.tribuo.regression.Regressor; import org.tribuo.regression.evaluation.RegressionEvaluation; import org.tribuo.regression.evaluation.RegressionEvaluator; import org.tribuo.regression.example.NonlinearGaussianDataSource; import org.tribuo.regression.example.RegressionDataGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.tribuo.test.Helpers; import java.io.IOException; @@ -47,18 +47,17 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.fail; public class TestSLM { private static final Logger logger = Logger.getLogger(TestSLM.class.getName()); private static final SLMTrainer SFS = new SLMTrainer(false,-1); - private static final SLMTrainer SFSN = new SLMTrainer(false,-1); - private static final ElasticNetCDTrainer ELASTIC_NET = new ElasticNetCDTrainer(1.0,0.5,1e-4,500,false,0); + private static final SLMTrainer SFSN = new SLMTrainer(true,-1); private static final LARSTrainer LARS = new LARSTrainer(10); private static final LARSLassoTrainer LARS_LASSO = new LARSLassoTrainer(-1); + private static final ElasticNetCDTrainer ELASTIC_NET = new ElasticNetCDTrainer(1.0,0.5,1e-4,500,false,0); private static final RegressionEvaluator e = new RegressionEvaluator(); private static final URL TEST_REGRESSION_REORDER_ENET_MODEL = TestSLM.class.getResource("enet-4.1.0.model"); @@ -73,10 +72,10 @@ public static void turnDownLogging() { } // This is a bit contrived, but it makes the trainer that failed appear in the stack trace. - public static Model testTrainer(Trainer trainer, - Pair,Dataset> p, - boolean testONNX) { - Model m = trainer.train(p.getA()); + public static SparseModel testTrainer(SparseTrainer trainer, + Pair,Dataset> p, + boolean testONNX) { + SparseModel m = trainer.train(p.getA()); RegressionEvaluation evaluation = e.evaluate(m,p.getB()); Map>> features = m.getTopFeatures(3); Assertions.assertNotNull(features); @@ -102,19 +101,23 @@ public static Model testTrainer(Trainer trainer, } public static Model testSFS(Pair,Dataset> p, boolean testONNX) { - return testTrainer(SFS,p,testONNX); + SparseModel newM = testTrainer(SFS,p,testONNX); + return newM; } public static Model testSFSN(Pair,Dataset> p, boolean testONNX) { - return testTrainer(SFSN,p,testONNX); + SparseModel newM = testTrainer(SFSN,p,testONNX); + return newM; } public static Model testLARS(Pair,Dataset> p, boolean testONNX) { - return testTrainer(LARS,p,testONNX); + SparseModel newM = testTrainer(LARS,p,testONNX); + return newM; } public static Model testLASSO(Pair,Dataset> p, boolean testONNX) { - return testTrainer(LARS_LASSO,p,testONNX); + SparseModel newM = testTrainer(LARS_LASSO,p,testONNX); + return newM; } public static Model testElasticNet(Pair,Dataset> p, boolean testONNX) { diff --git a/THIRD_PARTY_LICENSES.txt b/THIRD_PARTY_LICENSES.txt index d050c1f26..060f9d16c 100644 --- a/THIRD_PARTY_LICENSES.txt +++ b/THIRD_PARTY_LICENSES.txt @@ -728,465 +728,6 @@ OpenCSV 5.4 - Apache 2.0 See the License for the specific language governing permissions and limitations under the License. -Apache Commons Math 3.6.1 - Apache 2.0 - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - -Apache Commons Math includes the following code provided to the ASF under the -Apache License 2.0: - - - The inverse error function implementation in the Erf class is based on CUDA - code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, - and published in GPU Computing Gems, volume 2, 2010 (grant received on - March 23th 2013) - - The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, - RelationShip, SimplexSolver and SimplexTableau classes in package - org.apache.commons.math3.optimization.linear include software developed by - Benjamin McCann (http://www.benmccann.com) and distributed with - the following copyright: Copyright 2009 Google Inc. (grant received on - March 16th 2009) - - The class "org.apache.commons.math3.exception.util.LocalizedFormatsTest" which - is an adapted version of "OrekitMessagesTest" test class for the Orekit library - - The "org.apache.commons.math3.analysis.interpolation.HermiteInterpolator" - has been imported from the Orekit space flight dynamics library. - -=============================================================================== - - - -APACHE COMMONS MATH DERIVATIVE WORKS: - -The Apache commons-math library includes a number of subcomponents -whose implementation is derived from original sources written -in C or Fortran. License terms of the original sources -are reproduced below. - -=============================================================================== -For the lmder, lmpar and qrsolv Fortran routine from minpack and translated in -the LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general -Original source copyright and license statement: - -Minpack Copyright Notice (1999) University of Chicago. All rights reserved - -Redistribution and use in source and binary forms, with or -without modification, are permitted provided that the -following conditions are met: - -1. Redistributions of source code must retain the above -copyright notice, this list of conditions and the following -disclaimer. - -2. Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following -disclaimer in the documentation and/or other materials -provided with the distribution. - -3. The end-user documentation included with the -redistribution, if any, must include the following -acknowledgment: - - "This product includes software developed by the - University of Chicago, as Operator of Argonne National - Laboratory. - -Alternately, this acknowledgment may appear in the software -itself, if and wherever such third-party acknowledgments -normally appear. - -4. WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED "AS IS" -WITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE -UNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND -THEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES -OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE -OR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY -OR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR -USEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF -THE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4) -DO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION -UNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL -BE CORRECTED. - -5. LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT -HOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF -ENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT, -INCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF -ANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF -PROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER -SUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT -(INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE, -EVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE -POSSIBILITY OF SUCH LOSS OR DAMAGES. -=============================================================================== - -Copyright and license statement for the odex Fortran routine developed by -E. Hairer and G. Wanner and translated in GraggBulirschStoerIntegrator class -in package org.apache.commons.math3.ode.nonstiff: - - -Copyright (c) 2004, Ernst Hairer - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - -- Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS -IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED -TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -=============================================================================== - -Copyright and license statement for the original Mersenne twister C -routines translated in MersenneTwister class in package -org.apache.commons.math3.random: - - Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. The names of its contributors may not be used to endorse or promote - products derived from this software without specific prior written - permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR - CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -=============================================================================== - -The initial code for shuffling an array (originally in class -"org.apache.commons.math3.random.RandomDataGenerator", now replaced by -a method in class "org.apache.commons.math3.util.MathArrays") was -inspired from the algorithm description provided in -"Algorithms", by Ian Craw and John Pulham (University of Aberdeen 1999). -The textbook (containing a proof that the shuffle is uniformly random) is -available here: - http://citeseerx.ist.psu.edu/viewdoc/download;?doi=10.1.1.173.1898&rep=rep1&type=pdf - -=============================================================================== -License statement for the direction numbers in the resource files for Sobol sequences. - ------------------------------------------------------------------------------ -Licence pertaining to sobol.cc and the accompanying sets of direction numbers - ------------------------------------------------------------------------------ -Copyright (c) 2008, Frances Y. Kuo and Stephen Joe -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - * Neither the names of the copyright holders nor the names of the - University of New South Wales and the University of Waikato - and its contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY -DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -=============================================================================== - -The initial commit of package "org.apache.commons.math3.ml.neuralnet" is -an adapted version of code developed in the context of the Data Processing -and Analysis Consortium (DPAC) of the "Gaia" project of the European Space -Agency (ESA). -=============================================================================== - -The initial commit of the class "org.apache.commons.math3.special.BesselJ" is -an adapted version of code translated from the netlib Fortran program, rjbesl -http://www.netlib.org/specfun/rjbesl by R.J. Cody at Argonne National -Laboratory (USA). There is no license or copyright statement included with the -original Fortran sources. -=============================================================================== - - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. - -SciPy license -Copyright © 2001, 2002 Enthought, Inc. -All rights reserved. - -Copyright © 2003-2013 SciPy Developers. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - * Neither the name of Enthought nor the names of the SciPy Developers may - be used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY -EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR ANY -DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -=============================================================================== - junit 5.7.1 - Eclipse Public License 2.0 Eclipse Public License - v 2.0 diff --git a/Util/InformationTheory/pom.xml b/Util/InformationTheory/pom.xml index c6507a8c6..8f486ba0c 100644 --- a/Util/InformationTheory/pom.xml +++ b/Util/InformationTheory/pom.xml @@ -37,11 +37,6 @@ com.oracle.labs.olcut olcut-core - - org.apache.commons - commons-math3 - ${commonsmath.version} - org.junit.jupiter junit-jupiter diff --git a/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/Gamma.java b/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/Gamma.java new file mode 100644 index 000000000..4e2ada42b --- /dev/null +++ b/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/Gamma.java @@ -0,0 +1,459 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Port of portions of the "Freely Distributable Math Library", version 5.3, + * from C to Java. This file ports the function e_lgamma_r and its helper + * function sinpi to produce Java functions GammaMath.lgamma and + * GammaMath.sinpi; + */ + +/* @(#)e_lgamma_r.c 1.3 95/01/18 + * ==================================================== + * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. + * + * Developed at SunSoft, a Sun Microsystems, Inc. business. + * Permission to use, copy, modify, and distribute this + * software is freely granted, provided that this notice + * is preserved. + * ==================================================== + */ + +/* __ieee754_lgamma_r(x, signgamp) + * Reentrant version of the logarithm of the Gamma function + * with user provided pointer for the sign of Gamma(x). + * + * Method: + * 1. Argument Reduction for 0 < x <= 8 + * Since gamma(1+s)=s*gamma(s), for x in [0,8], we may + * reduce x to a number in [1.5,2.5] by + * lgamma(1+s) = log(s) + lgamma(s) + * for example, + * lgamma(7.3) = log(6.3) + lgamma(6.3) + * = log(6.3*5.3) + lgamma(5.3) + * = log(6.3*5.3*4.3*3.3*2.3) + lgamma(2.3) + * 2. Polynomial approximation of lgamma around its + * minimum ymin=1.461632144968362245 to maintain monotonicity. + * On [ymin-0.23, ymin+0.27] (i.e., [1.23164,1.73163]), use + * Let z = x-ymin; + * lgamma(x) = -1.214862905358496078218 + z^2*poly(z) + * where + * poly(z) is a 14 degree polynomial. + * 2. Rational approximation in the primary interval [2,3] + * We use the following approximation: + * s = x-2.0; + * lgamma(x) = 0.5*s + s*P(s)/Q(s) + * with accuracy + * |P/Q - (lgamma(x)-0.5s)| < 2**-61.71 + * Our algorithms are based on the following observation + * + * zeta(2)-1 2 zeta(3)-1 3 + * lgamma(2+s) = s*(1-Euler) + --------- * s - --------- * s + ... + * 2 3 + * + * where Euler = 0.5771... is the Euler constant, which is very + * close to 0.5. + * + * 3. For x>=8, we have + * lgamma(x)~(x-0.5)log(x)-x+0.5*log(2pi)+1/(12x)-1/(360x**3)+.... + * (better formula: + * lgamma(x)~(x-0.5)*(log(x)-1)-.5*(log(2pi)-1) + ...) + * Let z = 1/x, then we approximation + * f(z) = lgamma(x) - (x-0.5)(log(x)-1) + * by + * 3 5 11 + * w = w0 + w1*z + w2*z + w3*z + ... + w6*z + * where + * |w - f(z)| < 2**-58.74 + * + * 4. For negative x, since (G is gamma function) + * -x*G(-x)*G(x) = pi/sin(pi*x), + * we have + * G(x) = pi/(sin(pi*x)*(-x)*G(-x)) + * since G(-x) is positive, sign(G(x)) = sign(sin(pi*x)) for x<0 + * Hence, for x<0, signgam = sign(sin(pi*x)) and + * lgamma(x) = log(|Gamma(x)|) + * = log(pi/(|x*sin(pi*x)|)) - lgamma(-x); + * Note: one should avoid compute pi*(-x) directly in the + * computation of sin(pi*(-x)). + * + * 5. Special Cases + * lgamma(2+s) ~ s*(1-Euler) for tiny s + * lgamma(1)=lgamma(2)=0 + * lgamma(x) ~ -log(x) for tiny x + * lgamma(0) = lgamma(inf) = inf + * lgamma(-integer) = +-inf + * + */ + +package org.tribuo.util.infotheory; + +/** + * Static functions for computing the Gamma and log Gamma functions on real valued inputs. + */ +public final class Gamma { + private static final double + two52= 4.50359962737049600000e+15, /* 0x43300000, 0x00000000 */ + half= 5.00000000000000000000e-01, /* 0x3FE00000, 0x00000000 */ + one = 1.00000000000000000000e+00, /* 0x3FF00000, 0x00000000 */ + zero= 0.00000000000000000000e+00, + pi = 3.14159265358979311600e+00, /* 0x400921FB, 0x54442D18 */ + a0 = 7.72156649015328655494e-02, /* 0x3FB3C467, 0xE37DB0C8 */ + a1 = 3.22467033424113591611e-01, /* 0x3FD4A34C, 0xC4A60FAD */ + a2 = 6.73523010531292681824e-02, /* 0x3FB13E00, 0x1A5562A7 */ + a3 = 2.05808084325167332806e-02, /* 0x3F951322, 0xAC92547B */ + a4 = 7.38555086081402883957e-03, /* 0x3F7E404F, 0xB68FEFE8 */ + a5 = 2.89051383673415629091e-03, /* 0x3F67ADD8, 0xCCB7926B */ + a6 = 1.19270763183362067845e-03, /* 0x3F538A94, 0x116F3F5D */ + a7 = 5.10069792153511336608e-04, /* 0x3F40B6C6, 0x89B99C00 */ + a8 = 2.20862790713908385557e-04, /* 0x3F2CF2EC, 0xED10E54D */ + a9 = 1.08011567247583939954e-04, /* 0x3F1C5088, 0x987DFB07 */ + a10 = 2.52144565451257326939e-05, /* 0x3EFA7074, 0x428CFA52 */ + a11 = 4.48640949618915160150e-05, /* 0x3F07858E, 0x90A45837 */ + tc = 1.46163214496836224576e+00, /* 0x3FF762D8, 0x6356BE3F */ + tf = -1.21486290535849611461e-01, /* 0xBFBF19B9, 0xBCC38A42 */ + /* tt = -(tail of tf) */ + tt = -3.63867699703950536541e-18, /* 0xBC50C7CA, 0xA48A971F */ + t0 = 4.83836122723810047042e-01, /* 0x3FDEF72B, 0xC8EE38A2 */ + t1 = -1.47587722994593911752e-01, /* 0xBFC2E427, 0x8DC6C509 */ + t2 = 6.46249402391333854778e-02, /* 0x3FB08B42, 0x94D5419B */ + t3 = -3.27885410759859649565e-02, /* 0xBFA0C9A8, 0xDF35B713 */ + t4 = 1.79706750811820387126e-02, /* 0x3F9266E7, 0x970AF9EC */ + t5 = -1.03142241298341437450e-02, /* 0xBF851F9F, 0xBA91EC6A */ + t6 = 6.10053870246291332635e-03, /* 0x3F78FCE0, 0xE370E344 */ + t7 = -3.68452016781138256760e-03, /* 0xBF6E2EFF, 0xB3E914D7 */ + t8 = 2.25964780900612472250e-03, /* 0x3F6282D3, 0x2E15C915 */ + t9 = -1.40346469989232843813e-03, /* 0xBF56FE8E, 0xBF2D1AF1 */ + t10 = 8.81081882437654011382e-04, /* 0x3F4CDF0C, 0xEF61A8E9 */ + t11 = -5.38595305356740546715e-04, /* 0xBF41A610, 0x9C73E0EC */ + t12 = 3.15632070903625950361e-04, /* 0x3F34AF6D, 0x6C0EBBF7 */ + t13 = -3.12754168375120860518e-04, /* 0xBF347F24, 0xECC38C38 */ + t14 = 3.35529192635519073543e-04, /* 0x3F35FD3E, 0xE8C2D3F4 */ + u0 = -7.72156649015328655494e-02, /* 0xBFB3C467, 0xE37DB0C8 */ + u1 = 6.32827064025093366517e-01, /* 0x3FE4401E, 0x8B005DFF */ + u2 = 1.45492250137234768737e+00, /* 0x3FF7475C, 0xD119BD6F */ + u3 = 9.77717527963372745603e-01, /* 0x3FEF4976, 0x44EA8450 */ + u4 = 2.28963728064692451092e-01, /* 0x3FCD4EAE, 0xF6010924 */ + u5 = 1.33810918536787660377e-02, /* 0x3F8B678B, 0xBF2BAB09 */ + v1 = 2.45597793713041134822e+00, /* 0x4003A5D7, 0xC2BD619C */ + v2 = 2.12848976379893395361e+00, /* 0x40010725, 0xA42B18F5 */ + v3 = 7.69285150456672783825e-01, /* 0x3FE89DFB, 0xE45050AF */ + v4 = 1.04222645593369134254e-01, /* 0x3FBAAE55, 0xD6537C88 */ + v5 = 3.21709242282423911810e-03, /* 0x3F6A5ABB, 0x57D0CF61 */ + s0 = -7.72156649015328655494e-02, /* 0xBFB3C467, 0xE37DB0C8 */ + s1 = 2.14982415960608852501e-01, /* 0x3FCB848B, 0x36E20878 */ + s2 = 3.25778796408930981787e-01, /* 0x3FD4D98F, 0x4F139F59 */ + s3 = 1.46350472652464452805e-01, /* 0x3FC2BB9C, 0xBEE5F2F7 */ + s4 = 2.66422703033638609560e-02, /* 0x3F9B481C, 0x7E939961 */ + s5 = 1.84028451407337715652e-03, /* 0x3F5E26B6, 0x7368F239 */ + s6 = 3.19475326584100867617e-05, /* 0x3F00BFEC, 0xDD17E945 */ + r1 = 1.39200533467621045958e+00, /* 0x3FF645A7, 0x62C4AB74 */ + r2 = 7.21935547567138069525e-01, /* 0x3FE71A18, 0x93D3DCDC */ + r3 = 1.71933865632803078993e-01, /* 0x3FC601ED, 0xCCFBDF27 */ + r4 = 1.86459191715652901344e-02, /* 0x3F9317EA, 0x742ED475 */ + r5 = 7.77942496381893596434e-04, /* 0x3F497DDA, 0xCA41A95B */ + r6 = 7.32668430744625636189e-06, /* 0x3EDEBAF7, 0xA5B38140 */ + w0 = 4.18938533204672725052e-01, /* 0x3FDACFE3, 0x90C97D69 */ + w1 = 8.33333333333329678849e-02, /* 0x3FB55555, 0x5555553B */ + w2 = -2.77777777728775536470e-03, /* 0xBF66C16C, 0x16B02E5C */ + w3 = 7.93650558643019558500e-04, /* 0x3F4A019F, 0x98CF38B6 */ + w4 = -5.95187557450339963135e-04, /* 0xBF4380CB, 0x8C0FE741 */ + w5 = 8.36339918996282139126e-04, /* 0x3F4B67BA, 0x4CDAD5D1 */ + w6 = -1.63092934096575273989e-03; /* 0xBF5AB89D, 0x0B9E43E4 */ + + /** + * Private constructor to ensure that the class is never instantiated. + */ + private Gamma() {} + + /** + * Return the low-order 32 bits of the double argument as an int. + * @param x The input double. + * @return The lower 32-bits as an int. + */ + private static int __LO(double x) { + return (int)Double.doubleToRawLongBits(x); + } + + /** + * Return the high-order 32 bits of the double argument as an int. + * @param x The input double. + * @return The upper 32-bits as an int. + */ + private static int __HI(double x) { + return (int)(Double.doubleToRawLongBits(x) >> 32); + } + + private static double sin_pi(double x) { + double y,z; + int n,ix; + + ix = 0x7fffffff&__HI(x); + + if (ix<0x3fd00000) { + return Math.sin(pi * x); + } + y = -x; /* x is assumed negative */ + + /* + * argument reduction, make sure inexact flag not raised if input + * is an integer + */ + z = Math.floor(y); + if(z!=y) { /* inexact anyway */ + y *= 0.5; + y = 2.0*(y - Math.floor(y)); /* y = |x| mod 2.0 */ + n = (int) (y*4.0); + } else { + if(ix>=0x43400000) { + y = zero; + n = 0; /* y must be even */ + } else { + if(ix<0x43300000) + z = y+two52; /* exact */ + n = __LO(z)&1; /* lower word of z */ + y = n; + n<<= 2; + } + } + switch (n) { + case 0: + y = Math.sin(pi*y); + break; + case 1: + case 2: + y = Math.cos(pi*(0.5-y)); + break; + case 3: + case 4: + y = Math.sin(pi*(one-y)); + break; + case 5: + case 6: + y = -Math.cos(pi*(y-1.5)); + break; + default: + y = Math.sin(pi*(y-2.0)); + break; + } + return -y; + } + + + /** + * Function to calculate the log of a Gamma function. Negative integer values will return NaN. + * @param x The value to calculate for. + * @return The log of the Gamma function applied to x. + */ + public static double logGamma(double x) { + double t,y,z,nadj,p,p1,p2,p3,q,r,w; + int i,hx,lx,ix; + + if((x <= 0) && (Math.floor(x) == x)) { + return Double.NaN; + } + + hx = __HI(x); + lx = __LO(x); + nadj = zero; + + /* purge off +-inf, NaN, +-0, and negative arguments */ + ix = hx&0x7fffffff; + if(ix>=0x7ff00000) { + return x * x; + } + if ((ix|lx)==0) { + return one / zero; + } + if(ix<0x3b900000) { /* |x|<2**-70, return -log(|x|) */ + if(hx<0) { + return -Math.log(-x); + } else { + return -Math.log(x); + } + } + if(hx<0) { + if(ix>=0x43300000) { /* |x|>=2**52, must be -integer */ + return one / zero; + } + t = sin_pi(x); + if(t==zero) { + return one / zero; /* -integer */ + } + nadj = Math.log(pi/Math.abs(t*x)); + x = -x; + } + + /* purge off 1 and 2 */ + if((((ix-0x3ff00000)|lx)==0)||(((ix-0x40000000)|lx)==0)) { + r = 0; + /* for x < 2.0 */ + } else if(ix<0x40000000) { + if(ix<=0x3feccccc) { /* lgamma(x) = lgamma(x+1)-log(x) */ + r = -Math.log(x); + if(ix>=0x3FE76944) { + y = one-x; + i= 0; + } else if(ix>=0x3FCDA661) { + y= x-(tc-one); + i=1; + } else { + y = x; + i=2; + } + } else { + r = zero; + if(ix>=0x3FFBB4C3) { + y=2.0-x; + i=0; + } /* [1.7316,2] */ + else if(ix>=0x3FF3B4C4) { + y=x-tc; + i=1; + } /* [1.23,1.73] */ + else { + y=x-one; + i=2; + } + } + switch(i) { + case 0: + z = y*y; + p1 = a0+z*(a2+z*(a4+z*(a6+z*(a8+z*a10)))); + p2 = z*(a1+z*(a3+z*(a5+z*(a7+z*(a9+z*a11))))); + p = y*p1+p2; + r += (p-0.5*y); + break; + case 1: + z = y*y; + w = z*y; + p1 = t0+w*(t3+w*(t6+w*(t9 +w*t12))); /* parallel comp */ + p2 = t1+w*(t4+w*(t7+w*(t10+w*t13))); + p3 = t2+w*(t5+w*(t8+w*(t11+w*t14))); + p = z*p1-(tt-w*(p2+y*p3)); + r += (tf + p); + break; + case 2: + p1 = y*(u0+y*(u1+y*(u2+y*(u3+y*(u4+y*u5))))); + p2 = one+y*(v1+y*(v2+y*(v3+y*(v4+y*v5)))); + r += (-0.5*y + p1/p2); + } + } else if(ix<0x40200000) { /* x < 8.0 */ + i = (int)x; + t = zero; + y = x-i; + p = y*(s0+y*(s1+y*(s2+y*(s3+y*(s4+y*(s5+y*s6)))))); + q = one+y*(r1+y*(r2+y*(r3+y*(r4+y*(r5+y*r6))))); + r = half*y+p/q; + z = one; /* lgamma(1+s) = log(s) + lgamma(s) */ + switch(i) { + case 7: + z *= (y+6.0); + //$FALL-THROUGH$ + case 6: + z *= (y+5.0); + //$FALL-THROUGH$ + case 5: + z *= (y+4.0); + //$FALL-THROUGH$ + case 4: + z *= (y+3.0); + //$FALL-THROUGH$ + case 3: + z *= (y+2.0); + //$FALL-THROUGH$ + r += Math.log(z); + break; + } + /* 8.0 <= x < 2**58 */ + } else if (ix < 0x43900000) { + t = Math.log(x); + z = one/x; + y = z*z; + w = w0+z*(w1+y*(w2+y*(w3+y*(w4+y*(w5+y*w6))))); + r = (x-half)*(t-one)+w; + } else { + /* 2**58 <= x <= inf */ + r = x * (Math.log(x) - one); + } + if(hx<0) { + r = nadj - r; + } + return r; + } + + /** + * Function to calculate the sign to be used for the result of a gamma function. + * @param x The value the gamma function is being constructed on. + * @return 1 or -1 depending on the sign that should be applied. The sign is then applied by multiplying by this value. + */ + private static int getSign(double x) { + if(x<0 && ((int)x)%2 == 0) { + return -1; + } else { + return 1; + } + } + + /** + * Function to calculate the value of a Gamma function. Negative integer values will return NaN. + * @param x The value to calculate for. + * @return The value of the Gamma function applied to x. + */ + public static double gamma(double x) { + double r = logGamma(x); + int sign = getSign(x); + return sign * Math.exp(r); + } + + /** + * Computes the regularised partial gamma function P. + *

+ * See RegularisedGammaFunction. + * Throws {@link IllegalStateException} if the iterations don't converge. + * @param a shape (when used as a CDF) + * @param x value / scale (when used as a CDF) + * @param epsilon Tolerance. + * @param maxIterations The maximum number of iterations. + * @return P(a,x). + */ + public static double regularizedGammaP(int a, + double x, + double epsilon, + int maxIterations) { + if (Double.isNaN(x) || (a <= 0) || (x < 0.0)) { + return Double.NaN; + } else if (x == 0.0) { + return 0.0; + } else { + int i; + double ithElement = 1.0 / a; + double accumulator = ithElement; + for (i = 1; i < maxIterations && Math.abs(ithElement/accumulator) > epsilon; i++) { + ithElement *= x / (a + i); + accumulator += ithElement; + if (Double.isInfinite(accumulator)) { + return 1.0; + } + } + if (i >= maxIterations) { + throw new IllegalStateException("Exceeded maximum number of iterations " + maxIterations); + } else { + return Math.exp(-x + (a * Math.log(x)) - logGamma(a)) * accumulator; + } + } + } +} \ No newline at end of file diff --git a/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/InformationTheory.java b/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/InformationTheory.java index 787e40622..081785093 100644 --- a/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/InformationTheory.java +++ b/Util/InformationTheory/src/main/java/org/tribuo/util/infotheory/InformationTheory.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import org.tribuo.util.infotheory.impl.Row; import org.tribuo.util.infotheory.impl.RowList; import org.tribuo.util.infotheory.impl.TripleDistribution; -import org.apache.commons.math3.distribution.ChiSquaredDistribution; import java.util.HashMap; import java.util.List; @@ -136,10 +135,25 @@ public static GTestStatistics gTest(List first, List second, tuple = innerConditionalMI(first,second,conditionList); } double gMetric = 2 * second.size() * tuple.score; - ChiSquaredDistribution dist = new ChiSquaredDistribution(tuple.stateCount); - double prob = dist.cumulativeProbability(gMetric); - GTestStatistics test = new GTestStatistics(gMetric,tuple.stateCount,prob); - return test; + double prob = computeChiSquaredProbability(tuple.stateCount, gMetric); + return new GTestStatistics(gMetric,tuple.stateCount,prob); + } + + /** + * Computes the cumulative probability of the input value under a Chi-Squared distribution + * with the specified degrees of Freedom. + * @param degreesOfFreedom The degrees of freedom in the distribution. + * @param value The observed value. + * @return The cumulative probability of the observed value. + */ + private static double computeChiSquaredProbability(int degreesOfFreedom, double value) { + if (value <= 0) { + return 0.0; + } else { + int shape = degreesOfFreedom / 2; + int scale = 2; + return Gamma.regularizedGammaP(shape, value / scale, 1e-14, Integer.MAX_VALUE); + } } /** @@ -538,6 +552,53 @@ public static double calculateEntropy(DoubleStream vector) { return vector.map((p) -> (- p * Math.log(p) / LOG_BASE)).sum(); } + /** + * Compute the expected mutual information assuming randomized inputs. + * + * @param first The first vector. + * @param second The second vector. + * @return The expected mutual information under a hypergeometric distribution. + */ + public static double expectedMI(List first, List second) { + PairDistribution pd = PairDistribution.constructFromLists(first,second); + + Map firstCount = pd.firstCount; + Map secondCount = pd.secondCount; + long count = pd.count; + + double output = 0.0; + + for (Entry f : firstCount.entrySet()) { + for (Entry s : secondCount.entrySet()) { + long fVal = f.getValue().longValue(); + long sVal = s.getValue().longValue(); + long minCount = Math.min(fVal, sVal); + + long threshold = fVal + sVal - count; + long start = threshold > 1 ? threshold : 1; + + for (long nij = start; nij <= minCount; nij++) { + double acc = ((double) nij) / count; + acc *= Math.log(((double) (count * nij)) / (fVal * sVal)); + //numerator + double logSpace = Gamma.logGamma(fVal + 1); + logSpace += Gamma.logGamma(sVal + 1); + logSpace += Gamma.logGamma(count - fVal + 1); + logSpace += Gamma.logGamma(count - sVal + 1); + //denominator + logSpace -= Gamma.logGamma(count + 1); + logSpace -= Gamma.logGamma(nij + 1); + logSpace -= Gamma.logGamma(fVal - nij + 1); + logSpace -= Gamma.logGamma(sVal - nij + 1); + logSpace -= Gamma.logGamma(count - fVal - sVal + nij + 1); + acc *= Math.exp(logSpace); + output += acc; + } + } + } + return output; + } + /** * A tuple of the information theoretic value, along with the number of * states in the random variable. Will be a record one day. diff --git a/Util/InformationTheory/src/test/java/org/tribuo/util/infotheory/GammaTest.java b/Util/InformationTheory/src/test/java/org/tribuo/util/infotheory/GammaTest.java new file mode 100644 index 000000000..1bc1844d3 --- /dev/null +++ b/Util/InformationTheory/src/test/java/org/tribuo/util/infotheory/GammaTest.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.util.infotheory; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import static org.tribuo.util.infotheory.Gamma.gamma; + +public class GammaTest { + + @Test + public void testExamples() { + assertEquals(Double.NaN, gamma(0.0)); + assertEquals(1.77245385, gamma(0.5), 1e-8); + assertEquals(1.0, gamma(1.0)); + assertEquals(24.0, gamma(5.0), 1e-8); + + //some random examples betwixt -100 and 100 + assertEquals(8.06474995572965e+79, gamma(59.86728989339031), 1e+67); + assertEquals(0.0005019871198070064, gamma(-7.260823951121694), 1e-18); + assertEquals(1.5401131084717308e-110, gamma(-75.48705446197417), 1e-124); + assertEquals(95932082427.69138, gamma(15.035762406520718), 1e-3); + assertEquals(4.2868413548339677e+154, gamma(99.32984689647557), 1e+140); + assertEquals(-4.971777508910858e-48, gamma(-40.14784332381653), 1e-60); + assertEquals(5.3603547985340755e-96, gamma(-67.85881128534656), 1e-108); + assertEquals(-1.887428186224555e-151, gamma(-96.63801919072759), 1e-163); + assertEquals(6.0472720813564265e+125, gamma(84.61636884564746), 1e+113); + assertEquals(-7.495823228458869e-128, gamma(-84.57833815656579), 1e-140); + assertEquals(-2.834337137147687e-14, gamma(-16.831988025996992), 1e-26); + assertEquals(8.990293245462624e+78, gamma(59.32945503543496), 1e+66); + assertEquals(3.604695169965482e-83, gamma(-61.045472852581774), 1e-95); + assertEquals(0.00020572694516842935, gamma(-7.545439745563854), 1e-16); + assertEquals(-7.906506608405116e-105, gamma(-72.4403778408159), 1e-117); + assertEquals(780133888.913568, gamma(13.192513244283958), 1e-4); + assertEquals(-3.0601588660760365e-130, gamma(-86.09108451479372), 1e-142); + assertEquals(2.310606358803366e+90, gamma(65.69557419730668), 1e+78); + assertEquals(4.574728496203664e+16, gamma(19.669827320262186), 1e+4); + assertEquals(1.5276823676246256e+74, gamma(56.618507066510915), 1e+62); + + assertEquals(0.0, gamma(-199.55885272585897), 1e-8); + assertEquals(Double.POSITIVE_INFINITY, gamma(404.5418705074535)); + assertEquals(Double.NaN, gamma(-2)); + } +} diff --git a/Util/InformationTheory/src/test/java/org/tribuo/util/infotheory/InformationTheoryTest.java b/Util/InformationTheory/src/test/java/org/tribuo/util/infotheory/InformationTheoryTest.java new file mode 100644 index 000000000..fa544e804 --- /dev/null +++ b/Util/InformationTheory/src/test/java/org/tribuo/util/infotheory/InformationTheoryTest.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.util.infotheory; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; + +public class InformationTheoryTest { + + /* + * import numpy as np + * from sklearn.metrics import mutual_info_score + * a = np.random.randint(0,5,100) + * #print(printArrayAsJavaDoubles(a)) + * b = np.random.randint(0,5,100) + * #print(printArrayAsJavaDoubles(b)) + * mi = mutual_info_score(a, b) + * print(f"mi.ln={mi}") + * mi /= np.log(2.0) + * print(f"mi.log2={mi}") + */ + @Test + public void testMi() { + List a = Arrays.asList(0, 3, 2, 3, 4, 4, 4, 1, 3, 3, 4, 3, 2, 3, 2, 4, 2, 2, 1, 4, 1, 2, 0, 4, 4, 4, 3, 3, 2, 2, 0, 4, 0, 1, 3, 0, 4, 0, 0, 4, 0, 0, 2, 2, 2, 2, 0, 3, 0, 2, 2, 3, 1, 0, 1, 0, 3, 4, 4, 4, 0, 1, 1, 3, 3, 1, 3, 4, 0, 3, 4, 1, 0, 3, 2, 2, 2, 1, 1, 2, 3, 2, 1, 3, 0, 4, 4, 0, 4, 0, 2, 1, 4, 0, 3, 0, 1, 1, 1, 0); + List b = Arrays.asList(4, 2, 4, 0, 4, 4, 3, 3, 3, 2, 2, 0, 1, 3, 2, 1, 2, 0, 0, 4, 3, 3, 0, 1, 1, 1, 1, 4, 4, 4, 3, 1, 0, 0, 0, 1, 4, 1, 1, 1, 3, 3, 1, 2, 3, 0, 4, 0, 2, 3, 4, 2, 3, 2, 1, 0, 2, 4, 2, 2, 4, 1, 2, 4, 3, 1, 1, 1, 3, 0, 2, 3, 2, 0, 1, 0, 0, 4, 0, 3, 0, 0, 0, 1, 3, 2, 3, 4, 2, 4, 1, 0, 3, 3, 0, 2, 1, 0, 4, 1); + assertEquals(0.15688780624148022, InformationTheory.mi(a,b),1e-13); + } + + /* + * import numpy as np + * from scipy.stats import entropy + * a = np.random.randint(0,5,100) + * #print(printArrayAsJavaDoubles(a)) + * hist = np.histogram(a, bins=5, density=False)[0] + * a_probs = hist / len(a) + * print(f"a entropy={entropy(a_probs, base=2)}") + */ + @Test + void testEntropy() { + List a = Arrays.asList(0, 3, 2, 3, 4, 4, 4, 1, 3, 3, 4, 3, 2, 3, 2, 4, 2, 2, 1, 4, 1, 2, 0, 4, 4, 4, 3, 3, 2, 2, 0, 4, 0, 1, 3, 0, 4, 0, 0, 4, 0, 0, 2, 2, 2, 2, 0, 3, 0, 2, 2, 3, 1, 0, 1, 0, 3, 4, 4, 4, 0, 1, 1, 3, 3, 1, 3, 4, 0, 3, 4, 1, 0, 3, 2, 2, 2, 1, 1, 2, 3, 2, 1, 3, 0, 4, 4, 0, 4, 0, 2, 1, 4, 0, 3, 0, 1, 1, 1, 0); + List b = Arrays.asList(4, 2, 4, 0, 4, 4, 3, 3, 3, 2, 2, 0, 1, 3, 2, 1, 2, 0, 0, 4, 3, 3, 0, 1, 1, 1, 1, 4, 4, 4, 3, 1, 0, 0, 0, 1, 4, 1, 1, 1, 3, 3, 1, 2, 3, 0, 4, 0, 2, 3, 4, 2, 3, 2, 1, 0, 2, 4, 2, 2, 4, 1, 2, 4, 3, 1, 1, 1, 3, 0, 2, 3, 2, 0, 1, 0, 0, 4, 0, 3, 0, 0, 0, 1, 3, 2, 3, 4, 2, 4, 1, 0, 3, 3, 0, 2, 1, 0, 4, 1); + assertEquals(2.3167546539234776, InformationTheory.entropy(a)); + assertEquals(2.316147658077609, InformationTheory.entropy(b)); + } + +} diff --git a/pom.xml b/pom.xml index b234fb660..649ed249c 100644 --- a/pom.xml +++ b/pom.xml @@ -57,7 +57,6 @@ 5.7.1 5.4 - 3.6.1 3.19.4