Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed matrix issue in correlation function + error handling #3030

Merged
merged 10 commits into from
Sep 20, 2023
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -233,5 +233,6 @@ BuildTools <[email protected]>
Anik Patel <[email protected]>
Vrushaket Chaudhari <[email protected]>
Praise Nnamonu <[email protected]>
vrushaket <[email protected]>

# Generated by tools/update-authors.js
21 changes: 16 additions & 5 deletions src/function/statistics/corr.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { factory } from '../../utils/factory.js'

const name = 'corr'
const dependencies = ['typed', 'matrix', 'mean', 'sqrt', 'sum', 'add', 'subtract', 'multiply', 'pow', 'divide']

Expand All @@ -13,8 +14,8 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed,
* Examples:
*
* math.corr([1, 2, 3, 4, 5], [4, 5, 6, 7, 8]) // returns 1
* math.corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]) // returns 0.9569941688503644
* math.corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]])) // returns DenseMatrix [0.9569941688503644, 1]
* math.corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]) //returns 0.9569941688503644
* math.corr([[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]],[[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]]) // returns [1,1]
*
* See also:
*
Expand All @@ -28,8 +29,9 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed,
'Array, Array': function (A, B) {
return _corr(A, B)
},
'Matrix, Matrix': function (xMatrix, yMatrix) {
return matrix(_corr(xMatrix.toArray(), yMatrix.toArray()))
'Matrix, Matrix': function (A, B) {
const res = _corr(A.toArray(), B.toArray())
return Array.isArray(res) ? matrix(res) : res
}
})
/**
Expand All @@ -40,13 +42,22 @@ export const createCorr = /* #__PURE__ */ factory(name, dependencies, ({ typed,
* @private
*/
function _corr (A, B) {
const correlations = []
vrushaket marked this conversation as resolved.
Show resolved Hide resolved
if (Array.isArray(A[0]) && Array.isArray(B[0])) {
const correlations = []
if (A.length !== B.length) {
throw new SyntaxError('Dimension mismatch. Array A and B must have the same length.')
}
for (let i = 0; i < A.length; i++) {
if (A[i].length !== B[i].length) {
throw new SyntaxError('Dimension mismatch. Array A and B must have the same number of elements.')
}
correlations.push(correlation(A[i], B[i]))
}
return correlations
} else {
if (A.length !== B.length) {
throw new SyntaxError('Dimension mismatch. Array A and B must have the same number of elements.')
}
return correlation(A, B)
}
}
Expand Down
30 changes: 27 additions & 3 deletions test/unit-tests/function/statistics/corr.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,42 @@ const corr = math.corr
const BigNumber = math.BigNumber

describe('correlation', function () {
it('should return the correlation coefficient from an array', function () {
it('should return the correlation coefficient from array', function () {
assert.strictEqual(corr([new BigNumber(1), new BigNumber(2.2), new BigNumber(3), new BigNumber(4.8), new BigNumber(5)], [new BigNumber(4), new BigNumber(5.3), new BigNumber(6.6), new BigNumber(7), new BigNumber(8)]).toNumber(), 0.9569941688503653)
assert.strictEqual(corr([1, 2, 3, 4, 5], [4, 5, 6, 7, 8]), 1)
assert.strictEqual(corr([1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]), 0.9569941688503644)
assert.deepStrictEqual(corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]]))._data, [0.9569941688503644, 1])
assert.deepStrictEqual(corr([[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]], [[1, 2.2, 3, 4.8, 5], [4, 5.3, 6.6, 7, 8]]), [1, 1])
})

it('should throw an error if called with invalid number of arguments', function () {
it('should return the correlation coefficient from matrix', function () {
assert.strictEqual((corr(math.matrix([2, 4, 6, 8]), math.matrix([1, 2, 3, 6]))), 0.9561828874675149)
assert.deepStrictEqual(corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8], [1, 2, 3, 4, 5]])).toArray(), [0.9569941688503644, 1])
})

it('should throw an error if called with zero arguments', function () {
assert.throws(function () { corr() })
})

it('should throw an error if called with an empty array', function () {
assert.throws(function () { corr([]) })
})

it('should throw an error if called with different number of arguments', function () {
assert.throws(function () { corr(math.matrix([2, 4, 6, 8]), math.matrix([1, 2, 3])) })
})

it('should throw an error if called with number of arguments do not have same size', function () {
assert.throws(function () { corr(math.matrix([[1, 2.2, 3, 4.8, 5], [1, 2, 3, 4, 5]]), math.matrix([[4, 5.3, 6.6, 7, 8]])) })
})

it('should throw an error if called with different number of arguments', function () {
assert.throws(function () { corr([[1, 2, 3, 4, 5], [4, 5, 6, 7, 8], [9, 10, 11, 12]], [[1, 2, 3, 4, 5], [4, 5, 6, 7, 8]]) })
})

it('should throw an error if called with number of arguments do not have same size', function () {
assert.throws(function () { corr([[1, 2, 3, 4, 5], [4, 5, 6, 7]], [[1, 2, 3, 4, 5], []]) })
})
it('should throw an error if called with number of arguments do not have same size', function () {
assert.throws(function () { corr([1, 2, 3, 4, 5], [1, 2, 3, 4]) })
})
})