Skip to content

Commit

Permalink
fix: function map not always working with matrices (#3242)
Browse files Browse the repository at this point in the history
* Removed maxArgumentCount in favor of applyCallback

* Making a pure _recurse function

* Added cbrt tests, removed unnecesary changes in functions.

* Fixed main bottleneck

* Restored back function before unintended change

* Fix format

---------

Co-authored-by: Jos de Jong <[email protected]>
  • Loading branch information
dvd101x and josdejong authored Aug 1, 2024
1 parent 4e2eeac commit c8e4bbd
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 72 deletions.
4 changes: 2 additions & 2 deletions src/function/matrix/map.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ function _map (array, callback) {
const recurse = function (value, index) {
if (Array.isArray(value)) {
return value.map(function (child, i) {
// we create a copy of the index array and append the new index value
// we create a copy of the index array and append the new index value
return recurse(child, index.concat(i))
})
} else {
// invoke the callback function with the right number of arguments
// invoke the callback function with the right number of arguments
return applyCallback(callback, value, index, array, 'map')
}
}
Expand Down
11 changes: 2 additions & 9 deletions src/type/matrix/DenseMatrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { isInteger } from '../../utils/number.js'
import { clone, deepStrictEqual } from '../../utils/object.js'
import { DimensionError } from '../../error/DimensionError.js'
import { factory } from '../../utils/factory.js'
import { maxArgumentCount } from '../../utils/function.js'
import { applyCallback } from '../../utils/applyCallback.js'

const name = 'DenseMatrix'
const dependencies = [
Expand Down Expand Up @@ -550,21 +550,14 @@ export const createDenseMatrixClass = /* #__PURE__ */ factory(name, dependencies
DenseMatrix.prototype.map = function (callback) {
// matrix instance
const me = this
const args = maxArgumentCount(callback)
const recurse = function (value, index) {
if (isArray(value)) {
return value.map(function (child, i) {
return recurse(child, index.concat(i))
})
} else {
// invoke the callback function with the right number of arguments
if (args === 1) {
return callback(value)
} else if (args === 2) {
return callback(value, index)
} else { // 3 or -1
return callback(value, index, me)
}
return applyCallback(callback, value, index, me, 'map')
}
}

Expand Down
13 changes: 5 additions & 8 deletions src/type/matrix/SparseMatrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { clone, deepStrictEqual } from '../../utils/object.js'
import { arraySize, getArrayDataType, processSizesWildcard, unsqueeze, validateIndex } from '../../utils/array.js'
import { factory } from '../../utils/factory.js'
import { DimensionError } from '../../error/DimensionError.js'
import { maxArgumentCount } from '../../utils/function.js'
import { applyCallback } from '../../utils/applyCallback.js'

const name = 'SparseMatrix'
const dependencies = [
Expand Down Expand Up @@ -854,12 +854,9 @@ export const createSparseMatrixClass = /* #__PURE__ */ factory(name, dependencie
const rows = this._size[0]
const columns = this._size[1]
// invoke callback
const args = maxArgumentCount(callback)
const invoke = function (v, i, j) {
// invoke callback
if (args === 1) return callback(v)
if (args === 2) return callback(v, [i, j])
return callback(v, [i, j], me)
return applyCallback(callback, v, [i, j], me, 'map')
}
// invoke _map
return _map(this, 0, rows - 1, 0, columns - 1, invoke, skipZeros)
Expand Down Expand Up @@ -890,11 +887,11 @@ export const createSparseMatrixClass = /* #__PURE__ */ factory(name, dependencie
// invoke callback
const invoke = function (v, x, y) {
// invoke callback
v = callback(v, x, y)
const value = callback(v, x, y)
// check value != 0
if (!eq(v, zero)) {
if (!eq(value, zero)) {
// store value
values.push(v)
values.push(value)
// index
index.push(x)
}
Expand Down
14 changes: 0 additions & 14 deletions src/utils/function.js
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,3 @@ export function memoizeCompare (fn, isEqual) {

return memoize
}

/**
* Find the maximum number of arguments expected by a typed function.
* @param {function} fn A typed function
* @return {number} Returns the maximum number of expected arguments.
* Returns -1 when no signatures where found on the function.
*/
export function maxArgumentCount (fn) {
return Object.keys(fn.signatures || {})
.reduce(function (args, signature) {
const count = (signature.match(/,/g) || []).length + 1
return Math.max(args, count)
}, -1)
}
6 changes: 5 additions & 1 deletion test/unit-tests/function/matrix/map.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,15 @@ describe('map', function () {
it('should invoke a typed function with correct number of arguments (4)', function () {
// cbrt has a syntax cbrt(x, allRoots), but it should invoke cbrt(x) here
assert.deepStrictEqual(math.map([1, 8, 27], math.cbrt), [1, 2, 3])
assert.deepStrictEqual(math.map(math.matrix([1, 8, 27]), math.cbrt), math.matrix([1, 2, 3]))
assert.deepStrictEqual(math.map(math.matrix([1, 8, 27], 'sparse'), math.cbrt), math.matrix([1, 2, 3], 'sparse'))
})

it('should invoke a typed function with correct number of arguments (5)', function () {
// cbrt has a syntax cbrt(x, allRoots), but it should invoke cbrt(x) here
// format has a syntax format(x, options), but it should invoke format(x) here
assert.deepStrictEqual(math.map([1, 8, 27], math.format), ['1', '8', '27'])
assert.deepStrictEqual(math.map(math.matrix([1, 8, 27]), math.format), math.matrix(['1', '8', '27']))
assert.deepStrictEqual(math.map(math.matrix([1, 8, 27], 'sparse'), math.format), math.matrix(['1', '8', '27'], 'sparse'))
})

it('should throw an error if called with unsupported type', function () {
Expand Down
39 changes: 1 addition & 38 deletions test/unit-tests/utils/function.test.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import assert from 'assert'
import { maxArgumentCount, memoize, memoizeCompare } from '../../../src/utils/function.js'
import { memoize, memoizeCompare } from '../../../src/utils/function.js'
import { deepStrictEqual } from '../../../src/utils/object.js'

describe('util.function', function () {
Expand Down Expand Up @@ -108,41 +108,4 @@ describe('util.function', function () {
assert.strictEqual(execCount, 5)
})
})

describe('maxArgumentCount', function () {
it('should calculate the max argument count of a typed function', function () {
const a = function () {}
a.signatures = {
'number, number': function () {},
number: function () {}
}
assert.strictEqual(maxArgumentCount(a), 2)

const b = function () {}
b.signatures = {
number: function () {},
'number, number': function () {}
}
assert.strictEqual(maxArgumentCount(b), 2)

const c = function () {}
c.signatures = {
number: function () {},
BigNumber: function () {}
}
assert.strictEqual(maxArgumentCount(c), 1)

const d = function () {}
d.signatures = {
'number,number': function () {},
number: function () {},
'number,any,number': function () {}
}
assert.strictEqual(maxArgumentCount(d), 3)
})

it('should return -1 for regular functions', function () {
assert.strictEqual(maxArgumentCount(function () {}), -1)
})
})
})

0 comments on commit c8e4bbd

Please sign in to comment.