diff --git a/AUTHORS b/AUTHORS index 6c5e9274d7..5ee87d1522 100644 --- a/AUTHORS +++ b/AUTHORS @@ -233,5 +233,6 @@ BuildTools Anik Patel <74193405+Bobingstern@users.noreply.github.com> Vrushaket Chaudhari <82214275+vrushaket@users.noreply.github.com> Praise Nnamonu <110940850+praisennamonu1@users.noreply.github.com> +vrushaket # Generated by tools/update-authors.js diff --git a/src/function/statistics/corr.js b/src/function/statistics/corr.js index efaa905d69..c16291a935 100644 --- a/src/function/statistics/corr.js +++ b/src/function/statistics/corr.js @@ -1,4 +1,5 @@ import { factory } from '../../utils/factory.js' + const name = 'corr' const dependencies = ['typed', 'matrix', 'mean', 'sqrt', 'sum', 'add', 'subtract', 'multiply', 'pow', 'divide'] @@ -13,8 +14,8 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed, * Examples: * * math.corr([1, 2, 3, 4, 5], [4, 5, 6, 7, 8]) // returns 1 - * math.corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]) // returns 0.9569941688503644 - * math.corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]])) // returns DenseMatrix [0.9569941688503644, 1] + * math.corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]) //returns 0.9569941688503644 + * math.corr([[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]],[[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]]) // returns [1,1] * * See also: * @@ -28,8 +29,9 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed, 'Array, Array': function (A, B) { return _corr(A, B) }, - 'Matrix, Matrix': function (xMatrix, yMatrix) { - return matrix(_corr(xMatrix.toArray(), yMatrix.toArray())) + 'Matrix, Matrix': function (A, B) { + const res = _corr(A.toArray(), B.toArray()) + return Array.isArray(res) ? matrix(res) : res } }) /** @@ -40,13 +42,22 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed, * @private */ function _corr (A, B) { + const correlations = [] if (Array.isArray(A[0]) && Array.isArray(B[0])) { - const correlations = [] + if (A.length !== B.length) { + throw new SyntaxError('Dimension mismatch. Array A and B must have the same length.') + } for (let i = 0; i < A.length; i++) { + if (A[i].length !== B[i].length) { + throw new SyntaxError('Dimension mismatch. Array A and B must have the same number of elements.') + } correlations.push(correlation(A[i], B[i])) } return correlations } else { + if (A.length !== B.length) { + throw new SyntaxError('Dimension mismatch. Array A and B must have the same number of elements.') + } return correlation(A, B) } } diff --git a/test/unit-tests/function/statistics/corr.test.js b/test/unit-tests/function/statistics/corr.test.js index 38071158c2..0ca02691db 100644 --- a/test/unit-tests/function/statistics/corr.test.js +++ b/test/unit-tests/function/statistics/corr.test.js @@ -4,18 +4,42 @@ const corr = math.corr const BigNumber = math.BigNumber describe('correlation', function () { - it('should return the correlation coefficient from an array', function () { + it('should return the correlation coefficient from array', function () { assert.strictEqual(corr([new BigNumber(1), new BigNumber(2.2), new BigNumber(3), new BigNumber(4.8), new BigNumber(5)], [new BigNumber(4), new BigNumber(5.3), new BigNumber(6.6), new BigNumber(7), new BigNumber(8)]).toNumber(), 0.9569941688503653) assert.strictEqual(corr([1, 2, 3, 4, 5], [4, 5, 6, 7, 8]), 1) assert.strictEqual(corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]), 0.9569941688503644) - assert.deepStrictEqual(corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]]))._data, [0.9569941688503644, 1]) + assert.deepStrictEqual(corr([[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]], [[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]]), [1, 1]) }) - it('should throw an error if called with invalid number of arguments', function () { + it('should return the correlation coefficient from matrix', function () { + assert.strictEqual((corr(math.matrix([2, 4, 6, 8]), math.matrix([1, 2, 3, 6]))), 0.9561828874675149) + assert.deepStrictEqual(corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]])).toArray(), [0.9569941688503644, 1]) + }) + + it('should throw an error if called with zero arguments', function () { assert.throws(function () { corr() }) }) it('should throw an error if called with an empty array', function () { assert.throws(function () { corr([]) }) }) + + it('should throw an error if called with different number of arguments', function () { + assert.throws(function () { corr(math.matrix([2, 4, 6, 8]), math.matrix([1, 2, 3])) }) + }) + + it('should throw an error if called with number of arguments do not have same size', function () { + assert.throws(function () { corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8]])) }) + }) + + it('should throw an error if called with different number of arguments', function () { + assert.throws(function () { corr([[1, 2, 3, 4, 5], [4, 5, 6, 7, 8], [9, 10, 11, 12]], [[1, 2, 3, 4, 5], [4, 5, 6, 7, 8]]) }) + }) + + it('should throw an error if called with number of arguments do not have same size', function () { + assert.throws(function () { corr([[1, 2, 3, 4, 5], [4, 5, 6, 7]], [[1, 2, 3, 4, 5], []]) }) + }) + it('should throw an error if called with number of arguments do not have same size', function () { + assert.throws(function () { corr([1, 2, 3, 4, 5], [1, 2, 3, 4]) }) + }) })