Skip to content

Commit

Permalink
feat: add matrix datatypes in more cases (#3235)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvd101x authored Jul 30, 2024
1 parent 4f15753 commit cf24943
Show file tree
Hide file tree
Showing 19 changed files with 59 additions and 59 deletions.
4 changes: 2 additions & 2 deletions src/expression/transform/filter.transform.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ export const createFilterTransform = /* #__PURE__ */ factory(name, dependencies,
'Array, function': _filter,

'Matrix, function': function (x, test) {
return x.create(_filter(x.toArray(), test))
return x.create(_filter(x.toArray(), test), x.datatype())
},

'Array, RegExp': filterRegExp,

'Matrix, RegExp': function (x, test) {
return x.create(filterRegExp(x.toArray(), test))
return x.create(filterRegExp(x.toArray(), test), x.datatype())
}
})

Expand Down
2 changes: 1 addition & 1 deletion src/expression/transform/map.transform.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export const createMapTransform = /* #__PURE__ */ factory(name, dependencies, ({
},

'Matrix, function': function (x, callback) {
return x.create(_map(x.valueOf(), callback, x))
return x.create(_map(x.valueOf(), callback, x), x.datatype())
}
})

Expand Down
2 changes: 1 addition & 1 deletion src/function/matrix/apply.js
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export const createApply = /* #__PURE__ */ factory(name, dependencies, ({ typed,
}

if (isMatrix(mat)) {
return mat.create(_apply(mat.valueOf(), dim, callback))
return mat.create(_apply(mat.valueOf(), dim, callback), mat.datatype())
} else {
return _apply(mat, dim, callback)
}
Expand Down
2 changes: 1 addition & 1 deletion src/function/matrix/fft.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export const createFft = /* #__PURE__ */ factory(name, dependencies, ({
return typed(name, {
Array: _ndFft,
Matrix: function (matrix) {
return matrix.create(_ndFft(matrix.toArray()))
return matrix.create(_ndFft(matrix.valueOf()), matrix.datatype())
}
})

Expand Down
4 changes: 2 additions & 2 deletions src/function/matrix/filter.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ export const createFilter = /* #__PURE__ */ factory(name, dependencies, ({ typed
'Array, function': _filterCallback,

'Matrix, function': function (x, test) {
return x.create(_filterCallback(x.toArray(), test))
return x.create(_filterCallback(x.valueOf(), test), x.datatype())
},

'Array, RegExp': filterRegExp,

'Matrix, RegExp': function (x, test) {
return x.create(filterRegExp(x.toArray(), test))
return x.create(filterRegExp(x.valueOf(), test), x.datatype())
}
})
})
Expand Down
14 changes: 7 additions & 7 deletions src/function/matrix/size.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ export const createSize = /* #__PURE__ */ factory(name, dependencies, ({ typed,
*
* Examples:
*
* math.size(2.3) // returns []
* math.size('hello world') // returns [11]
* math.size(2.3) // returns []
* math.size('hello world') // returns [11]
*
* const A = [[1, 2, 3], [4, 5, 6]]
* math.size(A) // returns [2, 3]
* math.size(math.range(1,6)) // returns [5]
* math.size(A) // returns [2, 3]
* math.size(math.range(1,6).toArray()) // returns [5]
*
* See also:
*
Expand All @@ -31,20 +31,20 @@ export const createSize = /* #__PURE__ */ factory(name, dependencies, ({ typed,
*/
return typed(name, {
Matrix: function (x) {
return x.create(x.size())
return x.create(x.size(), 'number')
},

Array: arraySize,

string: function (x) {
return (config.matrix === 'Array') ? [x.length] : matrix([x.length])
return (config.matrix === 'Array') ? [x.length] : matrix([x.length], 'dense', 'number')
},

'number | Complex | BigNumber | Unit | boolean | null': function (x) {
// scalar
return (config.matrix === 'Array')
? []
: matrix ? matrix([]) : noMatrix()
: matrix ? matrix([], 'dense', 'number') : noMatrix()
}
})
})
6 changes: 3 additions & 3 deletions src/function/matrix/squeeze.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import { squeeze as arraySqueeze } from '../../utils/array.js'
import { factory } from '../../utils/factory.js'

const name = 'squeeze'
const dependencies = ['typed', 'matrix']
const dependencies = ['typed']

export const createSqueeze = /* #__PURE__ */ factory(name, dependencies, ({ typed, matrix }) => {
export const createSqueeze = /* #__PURE__ */ factory(name, dependencies, ({ typed }) => {
/**
* Squeeze a matrix, remove inner and outer singleton dimensions from a matrix.
*
Expand Down Expand Up @@ -43,7 +43,7 @@ export const createSqueeze = /* #__PURE__ */ factory(name, dependencies, ({ type
Matrix: function (x) {
const res = arraySqueeze(x.toArray())
// FIXME: return the same type of matrix as the input
return Array.isArray(res) ? matrix(res) : res
return Array.isArray(res) ? x.create(res, x.datatype()) : res
},

any: function (x) {
Expand Down
2 changes: 1 addition & 1 deletion src/function/probability/random.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export const createRandom = /* #__PURE__ */ factory(name, dependencies, ({ typed

function _randomMatrix (size, min, max) {
const res = randomMatrix(size.valueOf(), () => _random(min, max))
return isMatrix(size) ? size.create(res) : res
return isMatrix(size) ? size.create(res, 'number') : res
}

function _random (min, max) {
Expand Down
2 changes: 1 addition & 1 deletion src/function/probability/randomInt.js
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export const createRandomInt = /* #__PURE__ */ factory(name, dependencies, ({ ty

function _randomIntMatrix (size, min, max) {
const res = randomMatrix(size.valueOf(), () => _randomInt(min, max))
return isMatrix(size) ? size.create(res) : res
return isMatrix(size) ? size.create(res, 'number') : res
}

function _randomInt (min, max) {
Expand Down
4 changes: 2 additions & 2 deletions src/function/statistics/cumsum.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ export const createCumSum = /* #__PURE__ */ factory(name, dependencies, ({ typed
// sum([a, b, c, d, ...])
Array: _cumsum,
Matrix: function (matrix) {
return matrix.create(_cumsum(matrix.valueOf()))
return matrix.create(_cumsum(matrix.valueOf(), matrix.datatype()))
},

// sum([a, b, c, d, ...], dim)
'Array, number | BigNumber': _ncumSumDim,
'Matrix, number | BigNumber': function (matrix, dim) {
return matrix.create(_ncumSumDim(matrix.valueOf(), dim))
return matrix.create(_ncumSumDim(matrix.valueOf(), dim), matrix.datatype())
},

// cumsum(a, b, c, d, ...)
Expand Down
2 changes: 1 addition & 1 deletion src/utils/collection.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export function reduce (mat, dim, callback) {
}

if (isMatrix(mat)) {
return mat.create(_reduce(mat.valueOf(), dim, callback))
return mat.create(_reduce(mat.valueOf(), dim, callback), mat.datatype())
} else {
return _reduce(mat, dim, callback)
}
Expand Down
2 changes: 1 addition & 1 deletion test/benchmark/matrix_operations.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ const fiedler = [

// sylvester
(function () {
const A = sylvester.Matrix.create(fiedler)
const A = sylvester.Matrix.create(fiedler, sylvester.Matrix.datatype())

suite.add(pad('matrix operations sylvester A+A'), function () { return A.add(A) })
suite.add(pad('matrix operations sylvester A*A'), function () { return A.multiply(A) })
Expand Down
2 changes: 1 addition & 1 deletion test/node-tests/doc.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function maybeCheckExpectation (name, expected, expectedFrom, got, gotFrom) {
function checkExpectation (want, got) {
if (Array.isArray(want)) {
if (!Array.isArray(got)) {
want = math.matrix(want)
got = want.valueOf()
}
return approxDeepEqual(got, want, 1e-9)
}
Expand Down
10 changes: 5 additions & 5 deletions test/unit-tests/expression/parse.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ describe('parse', function () {
it('should get/set the matrix correctly for 3d matrices', function () {
const scope = {}
assert.deepStrictEqual(parseAndEval('f=[1,2;3,4]', scope), math.matrix([[1, 2], [3, 4]]))
assert.deepStrictEqual(parseAndEval('size(f)', scope), math.matrix([2, 2]))
assert.deepStrictEqual(parseAndEval('size(f)', scope), math.matrix([2, 2], 'dense', 'number'))

parseAndEval('f[:,:,2]=[5,6;7,8]', scope)
assert.deepStrictEqual(scope.f, math.matrix([
Expand All @@ -761,7 +761,7 @@ describe('parse', function () {
]
]))

assert.deepStrictEqual(parseAndEval('size(f)', scope), math.matrix([2, 2, 2]))
assert.deepStrictEqual(parseAndEval('size(f)', scope), math.matrix([2, 2, 2], 'dense', 'number'))
assert.deepStrictEqual(parseAndEval('f[:,:,1]', scope), math.matrix([[[1], [2]], [[3], [4]]]))
assert.deepStrictEqual(parseAndEval('f[:,:,2]', scope), math.matrix([[[5], [6]], [[7], [8]]]))
assert.deepStrictEqual(parseAndEval('f[:,2,:]', scope), math.matrix([[[2, 6]], [[4, 8]]]))
Expand Down Expand Up @@ -808,11 +808,11 @@ describe('parse', function () {
assert.deepStrictEqual(parseAndEval('d=1:3', scope), math.matrix([1, 2, 3]))
assert.deepStrictEqual(parseAndEval('concat(d,d)', scope), math.matrix([1, 2, 3, 1, 2, 3]))
assert.deepStrictEqual(parseAndEval('e=1+d', scope), math.matrix([2, 3, 4]))
assert.deepStrictEqual(parseAndEval('size(e)', scope), math.matrix([3]))
assert.deepStrictEqual(parseAndEval('size(e)', scope), math.matrix([3], 'dense', 'number'))
assert.deepStrictEqual(parseAndEval('concat(e,e)', scope), math.matrix([2, 3, 4, 2, 3, 4]))
assert.deepStrictEqual(parseAndEval('[[],[]]', scope), math.matrix([[], []]))
assert.deepStrictEqual(parseAndEval('[[],[]]', scope).size(), [2, 0])
assert.deepStrictEqual(parseAndEval('size([[],[]])', scope), math.matrix([2, 0]))
assert.deepStrictEqual(parseAndEval('size([[],[]])', scope), math.matrix([2, 0], 'dense', 'number'))
})

it('should disable arrays as range in a matrix index', function () {
Expand Down Expand Up @@ -1831,7 +1831,7 @@ describe('parse', function () {
assert.ok(parseAndEval('[1,2,3;4,5,6]\'') instanceof Matrix)
assert.deepStrictEqual(parseAndEval('[1:5]'), math.matrix([[1, 2, 3, 4, 5]]))
assert.deepStrictEqual(parseAndEval('[1:5]\''), math.matrix([[1], [2], [3], [4], [5]]))
assert.deepStrictEqual(parseAndEval('size([1:5])'), math.matrix([1, 5]))
assert.deepStrictEqual(parseAndEval('size([1:5])'), math.matrix([1, 5], 'dense', 'number'))
assert.deepStrictEqual(parseAndEval('[1,2;3,4]\''), math.matrix([[1, 3], [2, 4]]))
})

Expand Down
8 changes: 4 additions & 4 deletions test/unit-tests/function/matrix/eigs.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ describe('eigs', function () {
const id2 = matrix([[1, 0], [0, 1]])
const realSymMatrix = eigs(id2)
assert(realSymMatrix.values instanceof Matrix)
assert.deepStrictEqual(size(realSymMatrix.values), matrix([2]))
assert.deepStrictEqual(size(realSymMatrix.values), matrix([2], 'dense', 'number'))
testEigenvectors(realSymMatrix, vector => {
assert(vector instanceof Matrix)
assert.deepStrictEqual(size(vector), matrix([2]))
assert.deepStrictEqual(size(vector), matrix([2], 'dense', 'number'))
})
// Check we get exact values in this trivial case with lower precision
const rough = eigs(id2, { precision: 1e-6 })
assert.deepStrictEqual(realSymMatrix, rough)

const genericMatrix = eigs(matrix([[0, 1], [-1, 0]]))
assert(genericMatrix.values instanceof Matrix)
assert.deepStrictEqual(size(genericMatrix.values), matrix([2]))
assert.deepStrictEqual(size(genericMatrix.values), matrix([2], 'dense', 'number'))
testEigenvectors(genericMatrix, vector => {
assert(vector instanceof Matrix)
assert.deepStrictEqual(size(vector), matrix([2]))
assert.deepStrictEqual(size(vector), matrix([2], 'dense', 'number'))
})
})

Expand Down
28 changes: 14 additions & 14 deletions test/unit-tests/function/matrix/size.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@ describe('size', function () {
})

it('should calculate the size of a DenseMatrix', function () {
assert.deepStrictEqual(size(matrix()), matrix([0]))
assert.deepStrictEqual(size(matrix([[1, 2, 3], [4, 5, 6]])), matrix([2, 3]))
assert.deepStrictEqual(size(matrix([[], []])), matrix([2, 0]))
assert.deepStrictEqual(size(matrix()), matrix([0], 'dense', 'number'))
assert.deepStrictEqual(size(matrix([[1, 2, 3], [4, 5, 6]])), matrix([2, 3], 'dense', 'number'))
assert.deepStrictEqual(size(matrix([[], []])), matrix([2, 0], 'dense', 'number'))
})

it('should calculate the size of a SparseMatrix', function () {
assert.deepStrictEqual(size(matrix('sparse')), matrix([0, 0], 'sparse'))
assert.deepStrictEqual(size(matrix([[1, 2, 3], [4, 5, 6]], 'sparse')), matrix([2, 3], 'sparse'))
assert.deepStrictEqual(size(matrix([[], []], 'sparse')), matrix([2, 0], 'sparse'))
assert.deepStrictEqual(size(matrix('sparse')), matrix([0, 0], 'sparse', 'number'))
assert.deepStrictEqual(size(matrix([[1, 2, 3], [4, 5, 6]], 'sparse')), matrix([2, 3], 'sparse', 'number'))
assert.deepStrictEqual(size(matrix([[], []], 'sparse')), matrix([2, 0], 'sparse', 'number'))
})

it('should calculate the size of a range', function () {
assert.deepStrictEqual(size(math.range(2, 6)), matrix([4]))
assert.deepStrictEqual(size(math.range(2, 6)), matrix([4], 'dense', 'number'))
})

it('should calculate the size of a scalar', function () {
assert.deepStrictEqual(size(2), matrix([]))
assert.deepStrictEqual(size(math.bignumber(2)), matrix([]))
assert.deepStrictEqual(size(math.complex(2, 3)), matrix([]))
assert.deepStrictEqual(size(true), matrix([]))
assert.deepStrictEqual(size(null), matrix([]))
assert.deepStrictEqual(size(2), matrix([], 'dense', 'number'))
assert.deepStrictEqual(size(math.bignumber(2)), matrix([], 'dense', 'number'))
assert.deepStrictEqual(size(math.complex(2, 3)), matrix([], 'dense', 'number'))
assert.deepStrictEqual(size(true), matrix([], 'dense', 'number'))
assert.deepStrictEqual(size(null), matrix([], 'dense', 'number'))
})

it('should calculate the size of a scalar with setting matrix=="array"', function () {
Expand All @@ -53,8 +53,8 @@ describe('size', function () {
})

it('should calculate the size of a string', function () {
assert.deepStrictEqual(size('hello'), matrix([5]))
assert.deepStrictEqual(size(''), matrix([0]))
assert.deepStrictEqual(size('hello'), matrix([5], 'dense', 'number'))
assert.deepStrictEqual(size(''), matrix([0], 'dense', 'number'))
})

it('should throw an error if called with an invalid number of arguments', function () {
Expand Down
10 changes: 5 additions & 5 deletions test/unit-tests/function/matrix/squeeze.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ const matrix = math.matrix
describe('squeeze', function () {
it('should squeeze an matrix', function () {
let m = math.ones(matrix([1, 3, 2]))
assert.deepStrictEqual(size(m), matrix([1, 3, 2]))
assert.deepStrictEqual(size(m), matrix([1, 3, 2], 'dense', 'number'))
assert.deepStrictEqual(size(m.valueOf()), [1, 3, 2])
assert.deepStrictEqual(size(squeeze(m)), matrix([3, 2]))
assert.deepStrictEqual(size(squeeze(m)), matrix([3, 2], 'dense', 'number'))

m = math.ones(matrix([1, 1, 3]))
assert.deepStrictEqual(size(m), matrix([1, 1, 3]))
assert.deepStrictEqual(size(squeeze(m)), matrix([3]))
assert.deepStrictEqual(size(squeeze(math.range(1, 6))), matrix([5]))
assert.deepStrictEqual(size(m), matrix([1, 1, 3], 'dense', 'number'))
assert.deepStrictEqual(size(squeeze(m)), matrix([3], 'dense', 'number'))
assert.deepStrictEqual(size(squeeze(math.range(1, 6))), matrix([5], 'dense', 'number'))

assert.deepStrictEqual(squeeze(2.3), 2.3)
assert.deepStrictEqual(squeeze(matrix([[5]])), 5)
Expand Down
2 changes: 1 addition & 1 deletion test/unit-tests/type/chain/Chain.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ describe('Chain', function () {
})

it('should not break with null or true as value', function () {
assert.deepStrictEqual(new Chain(null).size().done(), math.matrix([]))
assert.deepStrictEqual(new Chain(null).size().done(), math.matrix([], 'dense', 'number'))
assert.strictEqual(new Chain(true).add(1).done(), 2)
})

Expand Down
12 changes: 6 additions & 6 deletions test/unit-tests/type/matrix/function/matrix.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ describe('matrix', function () {
it('should create an empty matrix with one dimension if called without argument', function () {
const a = matrix()
assert.ok(a instanceof math.Matrix)
assert.deepStrictEqual(math.size(a), matrix([0])) // TODO: wouldn't it be nicer if an empty matrix has zero dimensions?
assert.deepStrictEqual(math.size(a), matrix([0], 'dense', 'number')) // TODO: wouldn't it be nicer if an empty matrix has zero dimensions?
})

it('should create empty matrix, dense format', function () {
const a = matrix('dense')
assert.ok(a instanceof math.Matrix)
assert.deepStrictEqual(math.size(a), matrix([0]))
assert.deepStrictEqual(math.size(a), matrix([0], 'dense', 'number'))
})

it('should create empty matrix, dense format, number datatype', function () {
const a = matrix('dense', 'number')
assert.ok(a instanceof math.Matrix)
assert.deepStrictEqual(math.size(a), matrix([0]))
assert.deepStrictEqual(math.size(a), matrix([0], 'dense', 'number'))
assert(a.datatype(), 'number')
})

Expand All @@ -33,15 +33,15 @@ describe('matrix', function () {
const b = matrix([[1, 2], [3, 4]])
assert.ok(b instanceof math.Matrix)
assert.deepStrictEqual(b, matrix([[1, 2], [3, 4]]))
assert.deepStrictEqual(math.size(b), matrix([2, 2]))
assert.deepStrictEqual(math.size(b), matrix([2, 2], 'dense', 'number'))
})

it('should be the identity if called with a matrix, dense format', function () {
const b = matrix([[1, 2], [3, 4]], 'dense')
const c = matrix(b, 'dense')
assert.ok(c._data !== b._data) // data should be cloned
assert.deepStrictEqual(c, matrix([[1, 2], [3, 4]], 'dense'))
assert.deepStrictEqual(math.size(c), matrix([2, 2], 'dense'))
assert.deepStrictEqual(math.size(c), matrix([2, 2], 'dense', 'number'))
})

it('should be the identity if called with a matrix, dense format, number datatype', function () {
Expand Down Expand Up @@ -73,7 +73,7 @@ describe('matrix', function () {
const d = matrix(math.range(1, 6))
assert.ok(d instanceof math.Matrix)
assert.deepStrictEqual(d, matrix([1, 2, 3, 4, 5]))
assert.deepStrictEqual(math.size(d), matrix([5]))
assert.deepStrictEqual(math.size(d), matrix([5], 'dense', 'number'))
})

it('should throw an error if called with an invalid argument', function () {
Expand Down

0 comments on commit cf24943

Please sign in to comment.