diff --git a/.eslintrc.js b/.eslintrc.js new file mode 100644 index 0000000..24f3b46 --- /dev/null +++ b/.eslintrc.js @@ -0,0 +1,23 @@ +module.exports = { + root: true, + ignorePatterns: ['.eslintrc.js'], + env: { 'es6': true, 'browser': true, 'node': true, 'mocha': true }, + parserOptions: { ecmaVersion: 2020, sourceType: 'module'}, + globals: { + 'chai': 'readonly', + 'BigInt': 'readonly', + 'BigInt64Array': 'readonly', + }, + rules: { + 'semi': 'error', + 'no-multi-spaces': ['error', { 'exceptions': { 'ArrayExpression': true } }], + 'indent': 2, + 'require-jsdoc': 'off', + 'max-len': ['error', {'code': 100}], + 'prefer-rest-params': 'off' + }, + extends: [ + 'eslint:recommended', + 'google', + ], +} diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5737055 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +# Set update schedule for GitHub Actions + +version: 2 +updates: + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every weekday + interval: "daily" diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml new file mode 100644 index 0000000..576fb45 --- /dev/null +++ b/.github/workflows/build_test.yml @@ -0,0 +1,36 @@ +name: build and test + +on: [push, pull_request] + +jobs: + + job: + + strategy: + matrix: + platform: [ubuntu-latest, macos-latest, windows-latest] + + runs-on: ${{ matrix.platform }} + + steps: + - name: Git config + run: | + git config --global core.autocrlf false + git config --global core.eol lf + + - name: Checkout repository and submodules + uses: actions/checkout@v2.4.0 + with: + submodules: recursive + + - uses: actions/setup-node@v2.5.1 + with: + node-version: '14.x' + + - run: npm install + + - run: npm run lint + + - run: npm test + env: + CI: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fc2beb3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +node_modules/ +.DS_Store +package-lock.json +debug.log \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index d5c926b..93286fa 100644 --- a/README.md +++ b/README.md @@ -1 +1,20 @@ -# webnn-baseline +# WebNN Baseline + +The double-precision baseline implementation of WebNN operations for testing purpose. + +### Install + +```sh +> npm install +``` + +### Test +```sh +> npm test +``` + +## Lint + +```sh +> npm run lint +``` \ No newline at end of file diff --git a/node_setup.js b/node_setup.js new file mode 100644 index 0000000..dbbb663 --- /dev/null +++ b/node_setup.js @@ -0,0 +1,2 @@ +global.chai = require('chai'); +global.fs = require('fs'); diff --git a/package.json b/package.json new file mode 100644 index 0000000..d526960 --- /dev/null +++ b/package.json @@ -0,0 +1,27 @@ +{ + "name": "webnn-baseline", + "version": "0.1.0", + "description": "WebNN API double-precision baseline implementation for testing", + "directories": { + "src": "src", + "test": "test" + }, + "scripts": { + "lint": "eslint . --config .eslintrc.js --ext .js", + "test": "cross-env NODE_ENV=test mocha --require ./node_setup.js --exit test/*.js" + }, + "authors": [ + "Ningxin Hu " + ], + "license": "Apache-2.0", + "devDependencies": { + "chai": "^4.2.0", + "cross-env": "^7.0.2", + "eslint": "^7.18.0", + "eslint-config-google": "^0.14.0", + "eslint-plugin-import": "^2.22.0", + "eslint-plugin-jsdoc": "^30.3.1", + "eslint-plugin-prefer-arrow": "^1.2.2", + "mocha": "^9.0.3" + } +} diff --git a/src/batch_normalization.js b/src/batch_normalization.js new file mode 100644 index 0000000..2e2e206 --- /dev/null +++ b/src/batch_normalization.js @@ -0,0 +1,35 @@ +'use strict'; + +import {add, sub, mul, div, pow} from './binary.js'; +import {reshape} from './reshape.js'; +import {Tensor, Scalar} from './lib/tensor.js'; +import {validateBatchNormalizationParams} from './lib/validate-input.js'; + +/** + * Normalize the tensor values of input features across the batch dimension using + * [Batch-Normalization](http://arxiv.org/abs/1502.03167). + * @param {Tensor} input + * @param {Tensor} mean + * @param {Tensor} variance + * @param {MLBatchNormalizationOptions} [options] + * @return {Tensor} + */ +export function batchNormalization(input, mean, variance, {axis=1, scale, bias, epsilon=1e-5, + activation = (x) => x} = {}) { + validateBatchNormalizationParams(...arguments); + // The output tensor has the same shape as the input tensor. + let output = new Tensor(input.shape); + const shape = new Array(input.rank).fill(1); + shape[axis] = -1; + output = sub(input, reshape(mean, shape)); + output = div(output, + pow(add(reshape(variance, shape), new Scalar(epsilon)), new Scalar(0.5))); + if (scale) { + output = mul(output, reshape(scale, shape)); + } + if (bias) { + output = add(output, reshape(bias, shape)); + } + output = activation(output); + return output; +} diff --git a/src/binary.js b/src/binary.js new file mode 100644 index 0000000..8f4db79 --- /dev/null +++ b/src/binary.js @@ -0,0 +1,34 @@ +'use strict'; + +import {broadcast, getBroadcastShape} from './lib/broadcast.js'; +import {Tensor, sizeOfShape} from './lib/tensor.js'; + +/** + * Compute the element-wise binary operation of two input tensors. + * @param {Tensor} inputA + * @param {Tensor} inputB + * @param {Function} binaryFunc + * @return {Tensor} + */ +function binary(inputA, inputB, binaryFunc) { + const outputShape = getBroadcastShape(inputA.shape, inputB.shape); + const inputABroadcast = broadcast(inputA, outputShape); + const inputBBroadcast = broadcast(inputB, outputShape); + const outputSize = sizeOfShape(outputShape); + const output = new Tensor(outputShape); + for (let i = 0; i < outputSize; ++i) { + const a = inputABroadcast.getValueByIndex(i); + const b = inputBBroadcast.getValueByIndex(i); + const c = binaryFunc(a, b); + output.setValueByIndex(i, c); + } + return output; +} + +export const add = (inputA, inputB) => binary(inputA, inputB, (a, b) => a + b); +export const sub = (inputA, inputB) => binary(inputA, inputB, (a, b) => a - b); +export const mul = (inputA, inputB) => binary(inputA, inputB, (a, b) => a * b); +export const div = (inputA, inputB) => binary(inputA, inputB, (a, b) => a / b); +export const max = (inputA, inputB) => binary(inputA, inputB, (a, b) => Math.max(a, b)); +export const min = (inputA, inputB) => binary(inputA, inputB, (a, b) => Math.min(a, b)); +export const pow = (inputA, inputB) => binary(inputA, inputB, (a, b) => Math.pow(a, b)); diff --git a/src/clamp.js b/src/clamp.js new file mode 100644 index 0000000..5866b75 --- /dev/null +++ b/src/clamp.js @@ -0,0 +1,19 @@ +'use strict'; + +import {Tensor} from './lib/tensor.js'; + +/** + * Clamp the input tensor element-wise within a range specified by the minimum and maximum values. + * @param {Tensor} input + * @param {MLClampOptions} [options] + * @return {Tensor} + */ +export function clamp(input, {minValue=-Infinity, maxValue=Infinity} = {}) { + const output = new Tensor(input.shape); + for (let i = 0; i < input.size; ++i) { + const x = input.getValueByIndex(i); + const y = Math.min(Math.max(x, minValue), maxValue); + output.setValueByIndex(i, y); + } + return output; +} diff --git a/src/concat.js b/src/concat.js new file mode 100644 index 0000000..b322430 --- /dev/null +++ b/src/concat.js @@ -0,0 +1,36 @@ +'use strict'; + +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateConcatParams} from './lib/validate-input.js'; + +/** + * Concatenates the input tensors along a given axis. + * @param {Array.} inputs + * @param {Number} axis + * @return {Tensor} + */ +export function concat(inputs, axis) { + validateConcatParams(...arguments); + const inputShape = inputs[0].shape; + const outputShape = inputShape.slice(); + for (let i = 1; i < inputs.length; ++i) { + outputShape[axis] += inputs[i].shape[axis]; + } + const output = new Tensor(outputShape); + for (let i = 0; i < sizeOfShape(outputShape); ++i) { + const location = output.locationFromIndex(i); + let dim = location[axis]; + let k = 0; + // Find out input k and its dim of axis according to output dim of axis + for (; k < inputs.length; ++k) { + if (dim < inputs[k].shape[axis]) { + break; + } + dim -= inputs[k].shape[axis]; + } + location[axis] = dim; + const inputValue = inputs[k].getValueByLocation(location); + output.setValueByIndex(i, inputValue); + } + return output; +} diff --git a/src/conv2d.js b/src/conv2d.js new file mode 100644 index 0000000..25cc0b2 --- /dev/null +++ b/src/conv2d.js @@ -0,0 +1,140 @@ +'use strict'; + +import {Tensor} from './lib/tensor.js'; +import {validateConv2dParams} from './lib/validate-input.js'; +import {computePaddingForAutoPad} from './lib/compute-padding.js'; +import {transpose} from './transpose.js'; + +/** + * Compute a 2-D convolution given 4-D input and filter tensors. + * @param {Tensor} input + * @param {Tensor} filter + * @param {MLConv2dOptions} options + * @return {Tensor} + */ +export function conv2d(input, filter, {padding = [0, 0, 0, 0], + strides = [1, 1], + groups = 1, + dilations = [1, 1], + activation = (x) => x, + inputLayout = 'nchw', + filterLayout = 'oihw', + bias, + autoPad = 'explicit', +} += {}) { + if (inputLayout === 'nhwc') { + // nhwc -> nchw + input = transpose(input, {permutation: [0, 3, 1, 2]}); + } + if (filterLayout === 'hwio') { + // hwio -> oihw + filter = transpose(filter, {permutation: [3, 2, 0, 1]}); + } else if (filterLayout === 'ohwi') { + // ohwi -> oihw + filter = transpose(filter, {permutation: [0, 3, 1, 2]}); + } else if (filterLayout === 'ihwo') { + // ihwo -> oihw + filter = transpose(filter, {permutation: [3, 0, 1, 2]}); + } + validateConv2dParams(input, filter, {groups, bias}); + + const [batchCount, inputChannels, inputHeight, inputWidth] = input.shape; + const [outputChannels, , filterHeight, filterWidth] = filter.shape; + const [strideHeight, strideWidth] = strides; + const [dilationHeight, dilationWidth] = dilations; + const effectiveFilterHeight = filterHeight + (filterHeight - 1) * (dilationHeight - 1); + const effectiveFilterWidth = filterWidth + (filterWidth - 1) * (dilationWidth - 1); + + let beginningPaddingHeight; + let endingPaddingHeight; + let beginningPaddingWidth; + let endingPaddingWidth; + if (autoPad === 'explicit') { + [beginningPaddingHeight, endingPaddingHeight, beginningPaddingWidth, endingPaddingWidth] = + padding; + } else { + [beginningPaddingHeight, endingPaddingHeight] = computePaddingForAutoPad( + autoPad, inputHeight, effectiveFilterHeight, strideHeight); + [beginningPaddingWidth, endingPaddingWidth] = computePaddingForAutoPad( + autoPad, inputWidth, effectiveFilterWidth, strideWidth); + } + + const outputShape = new Array(4); + outputShape[0] = batchCount; + outputShape[1] = outputChannels; + const outputHeight = + 1 + (inputHeight - effectiveFilterHeight + beginningPaddingHeight + endingPaddingHeight) / + strideHeight; + outputShape[2] = outputHeight; + const outputWidth = + 1 + (inputWidth - effectiveFilterWidth + beginningPaddingWidth + endingPaddingWidth) / + strideWidth; + outputShape[3] = outputWidth; + let output = new Tensor(outputShape); + + const outputChannelsPerGroup = outputChannels / groups; + const inputChannelsPerGroup = inputChannels / groups; + + for (let ib = 0; ib < batchCount; ++ib) { + for (let g = 0; g < groups; ++g) { + for (let oc = 0; oc < outputChannelsPerGroup; ++oc) { + for (let ic = 0; ic < inputChannelsPerGroup; ++ic) { + for (let ih = -beginningPaddingHeight, oh = 0; oh < outputHeight; + ih += strideHeight, ++oh) { + for (let iw = -beginningPaddingWidth, ow = 0; ow < outputWidth; + iw += strideWidth, ++ow) { + const effectiveOutputChannel = oc + g * outputChannelsPerGroup; + const outputLocation = [ib, effectiveOutputChannel, oh, ow]; + for (let kh = 0; kh < filterHeight; ++kh) { + for (let kw = 0; kw < filterWidth; ++kw) { + const dkh = kh * dilationHeight; + const dkw = kw * dilationWidth; + if (ih + dkh < 0 || ih + dkh >= inputHeight || + iw + dkw < 0 || iw + dkw >= inputWidth) { + // Skip the padding values. + continue; + } else { + const effectiveInputChannel = ic + g * inputChannelsPerGroup; + const inputValue = input.getValueByLocation( + [ib, effectiveInputChannel, ih + dkh, iw + dkw]); + const filterValue = filter.getValueByLocation( + [effectiveOutputChannel, ic, kh, kw]); + let outputValue = output.getValueByLocation(outputLocation); + outputValue += inputValue * filterValue; + output.setValueByLocation(outputLocation, outputValue); + } + } + } + } + } + } + } + } + } + + if (bias) { + for (let ib = 0; ib < batchCount; ++ib) { + for (let oc = 0; oc < outputChannels; ++oc) { + for (let oh = 0; oh < outputHeight; ++oh) { + for (let ow = 0; ow < outputWidth; ++ow) { + const outputLocation = [ib, oc, oh, ow]; + const biasValue = bias.getValueByLocation([oc]); + let outputValue = output.getValueByLocation(outputLocation); + outputValue += biasValue; + output.setValueByLocation(outputLocation, outputValue); + } + } + } + } + } + + output = activation(output); + + if (inputLayout === 'nhwc') { + // nchw -> nhwc + output = transpose(output, {permutation: [0, 2, 3, 1]}); + } + + return output; +} diff --git a/src/gemm.js b/src/gemm.js new file mode 100644 index 0000000..7e721a2 --- /dev/null +++ b/src/gemm.js @@ -0,0 +1,40 @@ +'use strict'; + +import {add, mul} from './binary.js'; +import {matmul} from './matmul.js'; +import {Scalar} from './lib/tensor.js'; +import {validateGemmParams} from './lib/validate-input.js'; +import {transpose} from './transpose.js'; + +/** + * Calculate the general matrix multiplication of the Basic Linear Algebra Subprograms. + * The calculation follows the expression alpha * A * B + beta * C + * @param {Tensor} a + * @param {Tensor} b + * @param {MLGemmOptions} options + * @return {Tensor} + */ +export function gemm(a, b, {c = new Scalar(0.0), + alpha: fAlpha = 1.0, + beta: fBeta = 1.0, + aTranspose = false, + bTranspose = false, +} = {}) { + validateGemmParams(...arguments); + const alpha = new Scalar(fAlpha); + const beta = new Scalar(fBeta); + if (aTranspose) { + a = transpose(a); + } + + if (bTranspose) { + b = transpose(b); + } + + let output = matmul(mul(a, alpha), b); + if (c) { + output = add(output, mul(c, beta)); + } + + return output; +} diff --git a/src/gru.js b/src/gru.js new file mode 100644 index 0000000..137b407 --- /dev/null +++ b/src/gru.js @@ -0,0 +1,91 @@ +'use strict'; + +import {concat} from './concat.js'; +import {gruCell} from './gru_cell.js'; +import {reshape} from './reshape.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {squeeze} from './squeeze.js'; +import {tanh} from './tanh.js'; +import {validateGruParams} from './lib/validate-input.js'; + +/** + * Gated Recurrent Unit [GRU] recurrent network using an update gate and a reset gate to compute + * the hidden state that rolls into the output across the temporal sequence of the Network + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Number} steps + * @param {Number} hiddenSize + * @param {MLGruOptions} options + * @return {Array.} + */ +export function gru(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, resetAfter = true, + returnSequence = false, direction = 'forward', + layout = 'zrn', activations = [sigmoid, tanh]} = {}) { + validateGruParams(...arguments); + const numDirections = (direction === 'both' ? 2 : 1); + const batchSize = input.shape[1]; + + let hiddenState; + if (initialHiddenState) { + hiddenState = initialHiddenState; + } else { + const initialHiddenStateShape = [numDirections, batchSize, hiddenSize]; + hiddenState = new Tensor( + initialHiddenStateShape, new Array(sizeOfShape(initialHiddenStateShape)).fill(0)); + } + + let sequence; + const cellWeight = []; + const cellRecurrentWeight = []; + const cellBias = []; + const cellRecurrentBias = []; + + for (let slot = 0; slot < numDirections; ++slot) { + cellWeight.push( + squeeze(slice(weight, [slot, 0, 0], [1, -1, -1]), [0])); + cellRecurrentWeight.push(squeeze( + slice(recurrentWeight, [slot, 0, 0], [1, -1, -1]), [0])); + cellBias.push( + bias ? (squeeze(slice(bias, [slot, 0], [1, -1]), [0])) : + undefined); + cellRecurrentBias.push( + recurrentBias ? (squeeze(slice(recurrentBias, [slot, 0], [1, -1]), [0])) : undefined); + } + + for (let step = 0; step < steps; ++step) { + const cellHidden = []; + let cellOutput; + + for (let slot = 0; slot < numDirections; ++slot) { + cellHidden.push(squeeze(slice(hiddenState, [slot, 0, 0], [1, -1, -1]), [0])); + } + + for (let slot = 0; slot < numDirections; ++slot) { + const sliceStart = (slot === 1 || direction === 'backward' ? steps - step - 1 : step); + const cellInput = squeeze(slice(input, [sliceStart, 0, 0], [1, -1, -1]), [0]); + + const result = reshape( + gruCell( + cellInput, cellWeight[slot], cellRecurrentWeight[slot], + cellHidden[slot], hiddenSize, {bias: cellBias[slot], + recurrentBias: cellRecurrentBias[slot], resetAfter, layout, activations}), + [1, -1, hiddenSize]); + + cellOutput = (cellOutput ? concat([cellOutput, result], 0) : result); + } + + hiddenState = cellOutput; + + if (returnSequence) { + cellOutput = reshape(cellOutput, [1, numDirections, -1, hiddenSize]); + sequence = + (sequence ? concat([sequence, cellOutput], 0) : cellOutput); + } + } + + return [hiddenState, sequence]; +} diff --git a/src/gru_cell.js b/src/gru_cell.js new file mode 100644 index 0000000..d47a57f --- /dev/null +++ b/src/gru_cell.js @@ -0,0 +1,90 @@ +'use strict'; + +import {add, mul, sub} from './binary.js'; +import {matmul} from './matmul.js'; +import {Scalar} from './lib/tensor.js'; +import {sigmoid} from './sigmoid.js'; +import {slice} from './slice.js'; +import {tanh} from './tanh.js'; +import {transpose} from './transpose.js'; +import {validateGruCellParams} from './lib/validate-input.js'; + +/** + * A single time step of the Gated Recurrent Unit [GRU] recurrent network using an update gate + * and a reset gate to compute the hidden state that rolls into the output across the temporal + * sequence of a recurrent network. + * @param {Tensor} input + * @param {Tensor} weight + * @param {Tensor} recurrentWeight + * @param {Tensor} hiddenState + * @param {Number} hiddenSize + * @param {MLGruCellOptions} options + * @return {Tensor} + */ +export function gruCell(input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter = true, + layout = 'zrn', activations = [sigmoid, tanh]} = {}) { + validateGruCellParams(...arguments); + + const one = new Scalar(1); + const zero = new Scalar(0); + const starts = layout === 'zrn' ? {z: 0, r: hiddenSize, n: 2 * hiddenSize} : + {r: 0, z: hiddenSize, n: 2 * hiddenSize}; + const activation0 = activations[0]; + const activation1 = activations[1]; + // update gate + const z = activation0( + add( + add( + (bias ? slice(bias, [starts.z], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.z], [hiddenSize]) :zero)), + add( + matmul(input, transpose(slice(weight, [starts.z, 0], [hiddenSize, -1]))), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.z, 0], [hiddenSize, -1])))))); + // reset gate + const r = activation0( + add( + add( + (bias ? slice(bias, [starts.r], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.r], [hiddenSize]) : zero)), + add( + matmul(input, transpose(slice(weight, [starts.r, 0], [hiddenSize, -1]))), + matmul( + hiddenState, + transpose(slice(recurrentWeight, [starts.r, 0], [hiddenSize, -1])))))); + // new gate + let n; + if (resetAfter) { + n = activation1( + add( + (bias ? slice(bias, [starts.n], [hiddenSize]) : zero), + add( + matmul(input, transpose(slice(weight, [starts.n, 0], [hiddenSize, -1]))), + mul( + r, + add( + (recurrentBias ? slice(recurrentBias, [starts.n], [hiddenSize]) : zero), + matmul( + hiddenState, + transpose( + slice(recurrentWeight, [starts.n, 0], [hiddenSize, -1])))))))); + } else { + n = activation1( + add( + add( + (bias ? slice(bias, [starts.n], [hiddenSize]) : zero), + (recurrentBias ? slice(recurrentBias, [starts.n], [hiddenSize]) : zero)), + add( + matmul( + input, + transpose(slice(weight, [starts.n, 0], [hiddenSize, -1]))), + matmul( + mul(r, hiddenState), + transpose(slice(recurrentWeight, [starts.n, 0], [hiddenSize, -1])))))); + } + // compute the new hidden state + return add(mul(z, hiddenState), mul(n, sub(one, z))); +} + diff --git a/src/leaky_relu.js b/src/leaky_relu.js new file mode 100644 index 0000000..55d4f20 --- /dev/null +++ b/src/leaky_relu.js @@ -0,0 +1,14 @@ +'use strict'; + +import {unary} from './unary.js'; + +/** + * Calculate the leaky version of rectified linear function on the input tensor element-wise. + * @param {Tensor} input + * @param {MLLeakyReluOptions} [options] + * @return {Tensor} + */ +export function leakyRelu(input, options = {}) { + const alpha = options.alpha !== undefined ? options.alpha : 0.01; + return unary(input, (x) => Math.max(0, x) + alpha * Math.min(0, x)); +} diff --git a/src/lib/broadcast.js b/src/lib/broadcast.js new file mode 100644 index 0000000..80cc62d --- /dev/null +++ b/src/lib/broadcast.js @@ -0,0 +1,72 @@ +import {Tensor} from './tensor.js'; + +/** + * Broadcast a Tensor to a compatible shape NumPy-style. + * @param {Tensor} input + * @param {Array} newShape + * @return {Tensor} + */ +export function broadcast(input, newShape) { + const newRank = newShape.length; + if (newRank < input.rank) { + throw new Error(`The rank of new shape ${newRank} is invalid.`); + } + const broadcastAxes = new Array(input.rank).fill(false); + for (let i = 0; i < input.rank; ++i) { + const newAxis = newRank - i - 1; + const axis = input.rank - i - 1; + if (input.shape[axis] === 1 && newShape[newAxis] !== 1) { + broadcastAxes[axis] = true; + } else if (input.shape[axis] !== newShape[newAxis]) { + throw new Error(`The size of new shape at axis ${newAxis} is invalid.`); + } + } + const output = new Tensor(newShape); + for (let index = 0; index < output.size; ++index) { + const location = output.locationFromIndex(index); + const inputLocation = location.slice(-input.rank); + for (let axis = 0; axis < input.rank; ++axis) { + if (broadcastAxes[axis] === true) { + inputLocation[axis] = 0; + } + } + const inputValue = input.getValueByLocation(inputLocation); + output.setValueByIndex(index, inputValue); + } + return output; +} + +/** + * Get broadcast shape of given two input shapes, throw error if they're incompatible. + * @param {Array} shapeA + * @param {Array} shapeB + * @return {Array} + */ +export function getBroadcastShape(shapeA, shapeB) { + // According to General Broadcasting Rules on + // https://numpy.org/doc/stable/user/basics.broadcasting.html. + const outShape = []; + const lenA = shapeA.length; + const lenB = shapeB.length; + const outlen = Math.max(lenA, lenB); + for (let i = 0; i < outlen; ++i) { + let a = shapeA[lenA - i - 1]; + if (a === undefined) { + a = 1; + } + let b = shapeB[lenB - i - 1]; + if (b === undefined) { + b = 1; + } + if (a === 1) { + outShape.unshift(b); + } else if (b === 1) { + outShape.unshift(a); + } else if (a !== b) { + throw new Error(`Shapes [${shapeA}] and [${shapeB}] are incompatible.`); + } else { + outShape.unshift(a); + } + } + return outShape; +} diff --git a/src/lib/compute-padding.js b/src/lib/compute-padding.js new file mode 100644 index 0000000..feb91ac --- /dev/null +++ b/src/lib/compute-padding.js @@ -0,0 +1,28 @@ +/** + * Compute the beginning and ending pad given input, filter and stride. + * @param {String} autoPad + * @param {Number} inputSize + * @param {Number} effectiveFilterSize + * @param {Number} stride + * @return {Array} [paddingBegin, paddingEnd] + */ +export function computePaddingForAutoPad(autoPad, inputSize, effectiveFilterSize, stride) { + const outSize = Math.ceil(inputSize / stride); + const neededInput = (outSize - 1) * stride + effectiveFilterSize; + const totalPadding = neededInput > inputSize ? neededInput - inputSize : 0; + let paddingBegin; + let paddingEnd; + switch (autoPad) { + case 'same-upper': + paddingBegin = Math.floor(totalPadding / 2); + paddingEnd = Math.floor((totalPadding + 1) / 2); + break; + case 'same-lower': + paddingBegin = Math.floor((totalPadding + 1) / 2); + paddingEnd = Math.floor(totalPadding / 2); + break; + default: + throw new Error('The autoPad is invalid.'); + } + return [paddingBegin, paddingEnd]; +} diff --git a/src/lib/tensor.js b/src/lib/tensor.js new file mode 100644 index 0000000..4ae3c6b --- /dev/null +++ b/src/lib/tensor.js @@ -0,0 +1,141 @@ +'use strict'; + +/** + * Compute the number of elements given a shape. + * @param {Array} shape + * @return {Number} + */ +export function sizeOfShape(shape) { + return shape.reduce( + (accumulator, currentValue) => accumulator * currentValue, 1); +} + +/** + * Tensor: the multidimensional array. + */ +export class Tensor { + /** + * Construct a Tensor object + * @param {Array} shape + * @param {Array} [data] + */ + constructor(shape, data = undefined) { + const size = sizeOfShape(shape); + if (data !== undefined) { + if (size !== data.length) { + throw new Error(`The length of data ${data.length} is invalid, expected ${size}.`); + } + // Copy the data. + this.data = data.slice(); + } else { + this.data = new Array(size).fill(0); + } + // Copy the shape. + this.shape = shape.slice(); + // Calculate the strides. + this.strides = new Array(this.rank); + this.strides[this.rank - 1] = 1; + for (let i = this.rank - 2; i >= 0; --i) { + this.strides[i] = this.strides[i + 1] * this.shape[i + 1]; + } + } + + get rank() { + return this.shape.length; + } + + get size() { + return this.data.length; + } + + /** + * Get index in the flat array given the location. + * @param {Array} location + * @return {Number} + */ + indexFromLocation(location) { + if (location.length !== this.rank) { + throw new Error(`The location length ${location.length} is not equal to rank ${this.rank}.`); + } + let index = 0; + for (let i = 0; i < this.rank; ++i) { + if (location[i] >= this.shape[i]) { + throw new Error(`The location value ${location[i]} at axis ${i} is invalid.`); + } + index += this.strides[i] * location[i]; + } + return index; + } + + /** + * Get location from the index of the flat array. + * @param {Number} index + * @return {Array} + */ + locationFromIndex(index) { + if (index >= this.size) { + throw new Error('The index is invalid.'); + } + const location = new Array(this.rank); + for (let i = 0; i < location.length; ++i) { + location[i] = Math.floor(index / this.strides[i]); + index -= location[i] * this.strides[i]; + } + return location; + } + + /** + * Set value given the location. + * @param {Array} location + * @param {Number} value + */ + setValueByLocation(location, value) { + this.data[this.indexFromLocation(location)] = value; + } + + /** + * Get value given the location. + * @param {Array} location + * @return {Number} + */ + getValueByLocation(location) { + return this.data[this.indexFromLocation(location)]; + } + + /** + * Set value given the index. + * @param {Number} index + * @param {Number} value + */ + setValueByIndex(index, value) { + if (index >= this.size) { + throw new Error('The index is invalid.'); + } + this.data[index] = value; + } + + /** + * Get value given the index. + * @param {Number} index + * @return {Number} + */ + getValueByIndex(index) { + if (index >= this.size) { + throw new Error('The index is invalid.'); + } + return this.data[index]; + } +} + +/** + * Scalar: a helper class to create a Tensor with a single value. + */ +export class Scalar extends Tensor { + /** + * Construct a Tensor with a single value. + * @param {Number} value + */ + constructor(value) { + super([1], [value]); + } +} diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js new file mode 100644 index 0000000..357083c --- /dev/null +++ b/src/lib/validate-input.js @@ -0,0 +1,376 @@ +'use strict'; + +/** + * Check the tensor whether it is a 1-D tensor and its length is equal to `expectedSize`. + * @param {Tensor} a + * @param {Number} expectedSize + * @param {String} name + */ +function check1DTensorWithSize(a, expectedSize, name) { + if (a) { + if (a.rank !== 1) { + throw new Error(`The parameter ${name} is not a 1-D tensor.`); + } else { + if (a.shape[0] !== expectedSize) { + throw new Error(`The length ${a.shape[0]} of the ${name} values is not equal to the ` + + `size ${expectedSize} of the input dimension denoted by options.axis.`); + } + } + } +} + +export function validateBatchNormalizationParams(input, mean, variance, + {axis=1, scale, bias} = {}) { + if (!Number.isInteger(axis)) { + throw new Error(`Invalid axis ${axis}, axis should be an integer.`); + } + const dim = input.shape[axis]; + check1DTensorWithSize(mean, dim, 'mean'); + check1DTensorWithSize(variance, dim, 'variance'); + check1DTensorWithSize(scale, dim, 'scale'); + check1DTensorWithSize(bias, dim, 'bias'); +} + + +export function validateConcatParams(inputs, axis) { + const rank = inputs[0].rank; + if (!Number.isInteger(axis)) { + throw new Error(`Invalid axis ${axis}, axis should be an integer.`); + } else { + if (axis < 0 || axis >= rank) { + throw new Error(`Invalid axis ${axis}, axis should be in the interval [0, ${rank}).`); + } + } + const inputShape = inputs[0].shape; + for (let i = 1; i < inputs.length; ++i) { + if (inputs[i].rank !== rank) { + throw new Error('All input tensors should have the same rank.'); + } else { + const shape = inputs[i].shape; + for (let j = 0; j < inputShape.length; ++j) { + if (j !== axis) { + if (inputShape[j] !== shape[j]) { + throw new Error('All input tensors should have the same shape, ' + + 'except for the size of the dimension to concatenate on.'); + } + } + } + } + } +} + +export function validateConv2dParams(input, filter, {bias, groups = 1}) { + const inputChannels = input.shape[1]; + const outputChannels = filter.shape[0]; + const filterInputChannels = filter.shape[1]; + if (input.rank !== 4) { + throw new Error('The input should be a 4-D tensor.'); + } + if (filter.rank !== 4) { + throw new Error('The filter should be a 4-D tensor.'); + } + if (inputChannels !== filterInputChannels * groups) { + throw new Error('The input channels of filter is invalid.'); + } + if (bias && (bias.rank !== 1 || bias.shape[0] != outputChannels)) { + throw new Error('the bias should be a 1-D tensor with the shape of [output_channels].'); + } +} + +export function validateGemmParams(a, b) { + if (a.rank !== 2) { + throw new Error('The input a is not a 2-D tensor.'); + } + if (b.rank !== 2) { + throw new Error('The input b is not a 2-D tensor.'); + } +} + +export function validateGruCellParams(input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, layout = 'zrn'} = {}) { + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 2) { + throw new Error(`The input (rank ${input.rank}) is not a 2-D tensor.`); + } + const batchSize = input.shape[0]; + const inputSize = input.shape[1]; + if (weight.rank !== 2) { + throw new Error(`The weight (rank ${weight.rank}) is not a 2-D tensor.`); + } + if (weight.shape[0] !== 3 * hiddenSize || weight.shape[1] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}] is invalid.`); + } + if (recurrentWeight.rank !== 2) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 2-D tensor.`); + } + if (recurrentWeight.shape[0] !== 3 * hiddenSize || recurrentWeight.shape[1] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}] is invalid.`); + } + if (hiddenState.rank !== 2) { + throw new Error(`The hiddenState (rank ${hiddenState.rank}) is not a 2-D tensor.`); + } + if (hiddenState.shape[0] !== batchSize || hiddenState.shape[1] !== hiddenSize) { + throw new Error(`The shape of hiddenState + [${hiddenState.shape[0]}, ${hiddenState.shape[1]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 1) { + throw new Error(`The bias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (bias.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 1) { + throw new Error(`The recurrentBias (rank ${bias.rank}) is not a 1-D tensor.`); + } + if (recurrentBias.shape[0] !== 3 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}] is invalid.`); + } + } + if (layout !== 'zrn' && layout !== 'rzn') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + +export function validateGruParams(input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, + direction = 'forward', layout = 'zrn'}) { + if (!Number.isInteger(steps) || steps <= 0) { + throw new Error(`The steps ${steps} is invalid.`); + } + if (!Number.isInteger(hiddenSize) || hiddenSize <= 0) { + throw new Error(`The hiddenSize ${hiddenSize} is invalid.`); + } + if (input.rank !== 3) { + throw new Error(`The input (rank ${input.rank}) is not a 3-D tensor.`); + } + if (input.shape[0] !== steps) { + throw new Error(`The input.shape[0] ${input.shape[0]} is not equal to steps ${steps}.`); + } + const batchSize = input.shape[1]; + const inputSize = input.shape[2]; + if (direction !== 'forward' && direction !== 'backward' && direction !== 'both') { + throw new Error(`The direction ${direction} is invalid.`); + } + const numDirections = (direction === 'both' ? 2 : 1); + if (weight.rank !== 3) { + throw new Error(`The weight (rank ${weight.rank}) is not a 3-D tensor.`); + } + if (weight.shape[0] !== numDirections || weight.shape[1] !== 3 * hiddenSize || + weight.shape[2] !== inputSize) { + throw new Error(`The shape of weight [${weight.shape[0]}, ${weight.shape[1]}, + ${weight.shape[2]}] is invalid.`); + } + if (recurrentWeight.rank !== 3) { + throw new Error(`The recurrentWeight (rank ${recurrentWeight.rank}) is not a 3-D tensor.`); + } + if (recurrentWeight.shape[0] !== numDirections || + recurrentWeight.shape[1] !== 3 * hiddenSize || + recurrentWeight.shape[2] !== hiddenSize) { + throw new Error(`The shape of recurrentWeight ` + + `[${recurrentWeight.shape[0]}, ${recurrentWeight.shape[1]}, ` + + `${recurrentWeight.shape[2]}] is invalid.`); + } + if (bias) { + if (bias.rank !== 2) { + throw new Error(`The bias (rank ${bias.rank}) is not a 2-D tensor.`); + } + if (bias.shape[0] !== numDirections || bias.shape[1] !== 3 * hiddenSize) { + throw new Error(`The shape of bias [${bias.shape[0]}, ${bias.shape[1]}] is invalid.`); + } + } + if (recurrentBias) { + if (recurrentBias.rank !== 2) { + throw new Error(`The recurrentBias (rank ${recurrentBias.rank}) is not a 2-D tensor.`); + } + if (recurrentBias.shape[0] !== numDirections || recurrentBias.shape[1] !== 3 * hiddenSize) { + throw new Error(`The shape of recurrentBias [${recurrentBias.shape[0]}, + ${recurrentBias.shape[1]}] is invalid.`); + } + } + if (initialHiddenState) { + if (initialHiddenState.rank !== 3) { + throw new Error( + `The initialHiddenState (rank ${initialHiddenState.rank}) is not a 3-D tensor.`); + } + if (initialHiddenState.shape[0] !== numDirections || + initialHiddenState.shape[1] !== batchSize || + initialHiddenState.shape[2] !== hiddenSize) { + throw new Error(`The shape of initialHiddenState [${initialHiddenState.shape[0]}, + ${initialHiddenState.shape[1]}, ${initialHiddenState.shape[2]}] is invalid.`); + } + } + if (layout !== 'zrn' && layout !== 'rzn') { + throw new Error(`The layout ${layout} is invalid.`); + } +} + +export function validateMatmulParams(a, b) { + const aCols = a.shape[a.rank - 1]; + const bRows = b.shape[b.rank - 2]; + if (aCols !== bRows) { + throw new Error( + `The columns (${aCols}) of input a is not equal to rows (${bRows}) of input b.`); + } +} + +export function validatePool2dParams(input, _, {roundingType = 'floor'}) { + if (input.rank !== 4) { + throw new Error('The input should be a 4-D tensor.'); + } + if (roundingType !== 'floor' && roundingType !== 'ceil') { + throw new Error('The rounding type is invalid.'); + } +} + +export function validateReduceParams(input, _, {axes}) { + if (axes.length > input.rank) { + throw new Error(`The length ${axes.length} of axes is bigger` + + `than input rank ${input.rank}.`); + } + for (let i = 0; i < axes.length; ++i) { + if (axes[i] < 0 || axes[i] >= input.rank) { + throw new Error(`The value ${axes[i]} at axis ${i} of axes is invalid.`); + } + } +} + +export function validateSliceParams(input, starts, sizes, {axes} = {}) { + let inpAxes = axes; + const rank = input.rank; + const startsForAllAxes = new Array(rank).fill(0); + if (axes) { + if (axes.length > rank) { + throw new Error(`The length of axes ${axes.length} is greater than rank ${rank}.`); + } else { + for (const axis of axes) { + if (!Number.isInteger(axis)) { + throw new Error(`Invalid axes value ${axis}, it should be an integer.`); + } else { + if (axis >= rank || axis < -rank) { + throw new Error(`Invalid axes value ${axis}, it should be in the interval ` + + `[${-rank}, ${rank}).`); + } + } + } + } + } else { + inpAxes = [...Array(rank).keys()]; + } + const axesLen = inpAxes.length; + if (starts.length !== axesLen) { + throw new Error(`The length ${starts.length} of starts is not equal to the length ` + + `${axesLen} of axes.`); + } + if (sizes.length !== axesLen) { + throw new Error(`The length ${sizes.length} of sizes is not equal` + + ` to the length ${axesLen} of axes.`); + } + for (let i = 0; i < axesLen; ++i) { + const axis = inpAxes[i] >= 0 ? inpAxes[i] : inpAxes[i] + rank; + const size = input.shape[axis]; + const start = starts[i]; + if (!Number.isInteger(start)) { + throw new Error(`Invalid starts value ${start}, it should be an integer.`); + } + startsForAllAxes[axis] = start >= 0 ? start : start + size; + if (start >= size || start < -size) { + throw new Error(`Invalid starts value ${start}, it shoule be in the interval ` + + `[${-size}, ${size}).`); + } else { + const sliceSize = sizes[i]; + if (!Number.isInteger(sliceSize)) { + throw new Error(`Invalid sizes value ${sliceSize}, it should be an integer.`); + } + if (sliceSize >= 0) { + if (start >= 0) { + if (start + sliceSize > size) { + throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` + + `plus the size ${sliceSize} is greater than the dimensional size ${size}`); + } + } else { + if (start + sliceSize > 0) { + throw new Error(`Invalid sizes value ${sliceSize}, the sum of the start ${start} ` + + `plus the size ${sliceSize} is greater than the dimensional size ${size}`); + } + } + } else { + if (sliceSize !== -1) { + throw new Error(`The value ${sliceSize} of sizes is invalid,` + + ` it is required to be -1 when it is negative.`); + } + } + } + } +} + +export function validateSoftmaxParams(x) { + if (x.rank !== 2) { + throw new Error('The input is not a 2-D tensor.'); + } +} + +export function validateSplitParams(input, splits, {axis = 0} = {}) { + let inpAxis; + if (axis !== undefined) { + const rank = input.rank; + if (!Number.isInteger(axis)) { + throw new Error(`The axis ${axis} should be an integer.`); + } + if (axis >= rank || axis < -rank) { + throw new Error(`The axis ${axis} should be in the interval [${-rank}, ${rank}).`); + } + inpAxis = axis >= 0 ? axis : rank + axis; + } + if (typeof splits === 'number') { + if (!Number.isInteger(splits) || splits <= 0) { + throw new Error(`Invalid splits ${splits}, it should be a positive integer.`); + } + if (input.shape[inpAxis] % splits !== 0) { + throw new Error(`The splits ${splits} must evenly divide the dimension size ` + + `${input.shape[inpAxis]} of input along options.axis ${inpAxis}.`); + } + } else if (splits instanceof Array) { + if (!splits.every((v) => Number.isInteger(v) && v > 0)) { + throw new Error(`Invalid splits ${splits}, it should be an Array of positive integers.`); + } + const sum = splits.reduce((a, b) => a + b); + if (sum !== input.shape[inpAxis]) { + throw new Error(`Invalid [${splits}], the sum of sizes ${sum} must equal ` + + `to the dimension size ${input.shape[inpAxis]} of input` + + ` along options.axis ${inpAxis}`); + } + } +} + +export function validateSqueezeParams(input, {axes} = {}) { + if (axes) { + if (axes.length > input.rank) { + throw new Error(`The length of axes ${axes.length} is bigger ` + + `than input rank ${input.rank}.`); + } + + for (const axis of axes) { + if (axis < 0 || axis >= input.rank) { + throw new Error(`The value of axes ${axis} is invalid.`); + } + if (axes && input.shape[axis] !== 1) { + throw new Error(`The value ${input.shape[axis]} ` + + `at axis ${axis} of input shape is not 1.`); + } + } + } +} + +export function validateTranposeParams(input, {permutation}) { + if (permutation.length !== input.rank) { + throw new Error( + `The permutation length ${permutation.length} is not equal to rank ${input.rank}.`); + } +} + diff --git a/src/matmul.js b/src/matmul.js new file mode 100644 index 0000000..262c84c --- /dev/null +++ b/src/matmul.js @@ -0,0 +1,67 @@ + +'use strict'; + +import {broadcast, getBroadcastShape} from './lib/broadcast.js'; +import {reshape} from './reshape.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {validateMatmulParams} from './lib/validate-input.js'; + +/** + * Compute the matrix product of two input tensors. + * @param {Tensor} a + * @param {Tensor} b + * @return {Tensor} + */ +export function matmul(a, b) { + const scalarOutput = a.rank === 1 && b.rank === 1; + if (a.rank === 1) { + a = reshape(a, [1, a.shape[0]]); + } + const aRows = a.shape[a.rank - 2]; + const aCols = a.shape[a.rank - 1]; + + if (b.rank === 1) { + b = reshape(b, [b.shape[0], 1]); + } + const bCols = b.shape[b.rank - 1]; + + validateMatmulParams(a, b); + + let cShape = [aRows, bCols]; + if (a.rank > 2 || b.rank > 2) { + // Broadcast + const aBatchDims = a.shape.slice(0, -2); + const bBatchDims = b.shape.slice(0, -2); + const outputBatchDims = getBroadcastShape(aBatchDims, bBatchDims); + const aShape = outputBatchDims.concat(a.shape.slice(-2)); + a = broadcast(a, aShape); + const bShape = outputBatchDims.concat(b.shape.slice(-2)); + b = broadcast(b, bShape); + cShape = outputBatchDims.concat(cShape); + } + let c = new Tensor(cShape); + + for (let i = 0; i < sizeOfShape(cShape); ++i) { + const cLoc = c.locationFromIndex(i); + const m = cLoc[c.rank - 2]; + const n = cLoc[c.rank - 1]; + let cValue = 0; + for (let k = 0; k < aCols; ++k) { + let aLoc = cLoc.slice(0, -2); + aLoc = aLoc.concat(m, k); + let bLoc = cLoc.slice(0, -2); + bLoc = bLoc.concat(k, n); + const aValue = a.getValueByLocation(aLoc); + const bValue = b.getValueByLocation(bLoc); + cValue += aValue * bValue; + } + c.setValueByLocation(cLoc, cValue); + } + + if (scalarOutput) { + const cValue = c.getValueByIndex(0); + c = new Tensor([], [cValue]); + } + + return c; +} diff --git a/src/package.json b/src/package.json new file mode 100644 index 0000000..aead43d --- /dev/null +++ b/src/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} \ No newline at end of file diff --git a/src/pool2d.js b/src/pool2d.js new file mode 100644 index 0000000..85fd174 --- /dev/null +++ b/src/pool2d.js @@ -0,0 +1,127 @@ +'use strict'; + +import {computePaddingForAutoPad} from './lib/compute-padding.js'; +import {Tensor} from './lib/tensor.js'; +import {transpose} from './transpose.js'; +import {meanReducer, maxReducer} from './reduce.js'; +import {validatePool2dParams} from './lib/validate-input.js'; + +/** + * Compute a reduction operation across all the elements within the + * moving window over the input tensor. + * @param {Tensor} input + * @param {Function} reductionFunc + * @param {MLPool2dOptions} options + * @return {Tensor} + */ +function pool2d(input, reductionFunc, + {padding = [0, 0, 0, 0], + strides = [1, 1], + dilations = [1, 1], + roundingType = 'floor', + layout = 'nchw', + windowDimensions, + autoPad = 'explicit', + outputSizes, + }= {}) { + validatePool2dParams(...arguments); + const roundingFunc = roundingType === 'floor' ? Math.floor : Math.ceil; + + if (layout === 'nhwc') { + // nhwc -> nchw + input = transpose(input, {permutation: [0, 3, 1, 2]}); + } + + const [batchCount, channels, inputHeight, inputWidth] = input.shape; + const [windowHeight, windowWidth] = windowDimensions ?? [inputHeight, inputWidth]; + const [strideHeight, strideWidth] = strides; + const [dilationHeight, dilationWidth] = dilations; + const effectiveWindowHeight = windowHeight + (windowHeight - 1) * (dilationHeight - 1); + const effectiveWindowWidth = windowWidth + (windowWidth - 1) * (dilationWidth - 1); + + let beginningPaddingHeight; + let endingPaddingHeight; + let beginningPaddingWidth; + let endingPaddingWidth; + if (autoPad === 'explicit') { + [beginningPaddingHeight, endingPaddingHeight, beginningPaddingWidth, endingPaddingWidth] = + padding; + } else { + [beginningPaddingHeight, endingPaddingHeight] = computePaddingForAutoPad( + autoPad, inputHeight, effectiveWindowHeight, strideHeight); + [beginningPaddingWidth, endingPaddingWidth] = computePaddingForAutoPad( + autoPad, inputWidth, effectiveWindowWidth, strideWidth); + } + + const outputShape = new Array(4); + outputShape[0] = batchCount; + outputShape[1] = channels; + const outputHeight = outputSizes ? outputSizes[0] : + roundingFunc( + 1 + (inputHeight - effectiveWindowHeight + beginningPaddingHeight + endingPaddingHeight) / + strideHeight); + outputShape[2] = outputHeight; + const outputWidth = outputSizes ? outputSizes[1] : + roundingFunc( + 1 + (inputWidth - effectiveWindowWidth + beginningPaddingWidth + endingPaddingWidth) / + strideWidth); + outputShape[3] = outputWidth; + let output = new Tensor(outputShape); + + for (let ib = 0; ib < batchCount; ++ib) { + for (let ic = 0; ic < channels; ++ic) { + for (let ih = -beginningPaddingHeight, oh = 0; oh < outputHeight; ih += strideHeight, ++oh) { + for (let iw = -beginningPaddingWidth, ow = 0; ow < outputWidth; iw += strideWidth, ++ow) { + const outputLocation = [ib, ic, oh, ow]; + const valuesInWindow = []; + for (let kh = 0; kh < windowHeight; ++kh) { + for (let kw = 0; kw < windowWidth; ++kw) { + const dkh = kh * dilationHeight; + const dkw = kw * dilationWidth; + if (ih + dkh < 0 || ih + dkh >= inputHeight || + iw + dkw < 0 || iw + dkw >= inputWidth) { + // Skip the padding values. + continue; + } else { + const inputValue = input.getValueByLocation( + [ib, ic, ih + dkh, iw + dkw]); + valuesInWindow.push(inputValue); + } + } + } + const outputValue = valuesInWindow.reduce(reductionFunc); + output.setValueByLocation(outputLocation, outputValue); + } + } + } + } + + if (layout === 'nhwc') { + // nchw -> nhwc + output = transpose(output, {permutation: [0, 2, 3, 1]}); + } + + return output; +} + +/** + * Compute a mean reduction operation across all the elements within the moving window over + * the input tensor. + * @param {Tensor} input + * @param {MLPool2dOptions} options + * @return {Tensor} + */ +export function averagePool2d(input, options = {}) { + return pool2d(input, meanReducer, options); +} + +/** + * Compute a max reduction operation across all the elements within the moving window over + * the input tensor. + * @param {Tensor} input + * @param {MLPool2dOptions} options + * @return {Tensor} + */ +export function maxPool2d(input, options = {}) { + return pool2d(input, maxReducer, options); +} diff --git a/src/reduce.js b/src/reduce.js new file mode 100644 index 0000000..5f8aa27 --- /dev/null +++ b/src/reduce.js @@ -0,0 +1,125 @@ +'use strict'; + +import {squeeze} from './squeeze.js'; +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {validateReduceParams} from './lib/validate-input.js'; + +/** + * Reduce the input along the dimensions given in axes. + * @param {Tensor} input + * @param {Function} reduceFunc + * @param {MLReduceOptions} options + * @return {Tensor} + */ +function reduce(input, reduceFunc, {keepDimensions = false, axes} = {}) { + const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + + const outputShape = input.shape.slice(); + for (let i = 0; i < inpAxes.length; ++i) { + if (inpAxes[i] === -1) { + inpAxes[i] = input.rank - 1; + } + outputShape[inpAxes[i]] = 1; + } + + validateReduceParams(input, reduceFunc, {keepDimensions, axes: inpAxes}); + + // Calculate the "strides" across the reduction dimensions given in axes. + inpAxes.sort((a, b) => a - b); + const reduceDims = inpAxes.map((axis) => input.shape[axis]); + const reduceElements = sizeOfShape(reduceDims); + const reduceStrides = new Array(inpAxes.length); + reduceStrides[reduceStrides.length - 1] = 1; + for (let i = reduceStrides.length - 2; i >= 0; --i) { + reduceStrides[i] = reduceStrides[i + 1] * reduceDims[i + 1]; + } + + let output = new Tensor(outputShape); + for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) { + const valuesToReduce = []; + // Find all values to reduce. + for (let reduceIndex = 0; reduceIndex < reduceElements; ++reduceIndex) { + // Calculate the input location given index of elements to reduce. + const inputLocation = output.locationFromIndex(outputIndex); + let remainingReduceIndex = reduceIndex; + for (let i = 0; i < inpAxes.length; ++i) { + const axis = inpAxes[i]; + inputLocation[axis] = Math.floor(remainingReduceIndex / reduceStrides[i]); + remainingReduceIndex -= inputLocation[axis] * reduceStrides[i]; + } + valuesToReduce.push(input.getValueByLocation(inputLocation)); + } + const outputValue = valuesToReduce.reduce(reduceFunc); + output.setValueByIndex(outputIndex, outputValue); + } + + if (!keepDimensions) { + output = squeeze(output); + } + return output; +} + +/* The max reducer */ +export const maxReducer = (previousValue, currentValue) => Math.max(previousValue, currentValue); + +/** + * Compute the maximum value of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceMax(input, options = {}) { + return reduce(input, maxReducer, options); +} + +/* The mean reducer */ +export function meanReducer(previousValue, currentValue, currentIndex, array) { + if (currentIndex === array.length - 1) { + return (previousValue + currentValue) / array.length; + } else { + return previousValue + currentValue; + } +} + +/** + * Compute the average value of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceMean(input, options = {}) { + return reduce(input, meanReducer, options); +} + +/** + * Compute the minimum value of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceMin(input, options = {}) { + return reduce(input, + (previousValue, currentValue) => Math.min(previousValue, currentValue), options); +} + +/** + * Compute the product of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceProduct(input, options = {}) { + return reduce(input, + (previousValue, currentValue) => previousValue * currentValue, options); +} + +/** + * Compute the sum of all the input values along the axes. + * @param {Tensor} input + * @param {MLReduceOptions} options + * @return {Tensor} + */ +export function reduceSum(input, options = {}) { + return reduce(input, + (previousValue, currentValue) => previousValue + currentValue, options); +} diff --git a/src/relu.js b/src/relu.js new file mode 100644 index 0000000..b66cbed --- /dev/null +++ b/src/relu.js @@ -0,0 +1,12 @@ +'use strict'; + +import {unary} from './unary.js'; + +/** + * Compute the rectified linear function of the input tensor. + * @param {Tensor} input + * @return {Tensor} + */ +export function relu(input) { + return unary(input, (x) => Math.max(0, x)); +} diff --git a/src/reshape.js b/src/reshape.js new file mode 100644 index 0000000..e1a428f --- /dev/null +++ b/src/reshape.js @@ -0,0 +1,33 @@ +'use strict'; + +import {Tensor, sizeOfShape} from './lib/tensor.js'; + +/** + * Alter the shape of a tensor to a new shape. + * @param {Tensor} input + * @param {Array} newShape + * @return {Tensor} + */ +export function reshape(input, newShape) { + let minusOneAxis; + let elements = 1; + for (let i = 0; i < newShape.length; ++i) { + if (newShape[i] === -1) { + minusOneAxis = i; + } else if (newShape[i] > 0) { + elements *= newShape[i]; + } else { + throw new Error(`The value ${newShape[i]} at axis ${i} of new shape is invalid.`); + } + } + const outputShape = newShape.slice(); + if (minusOneAxis !== undefined) { + outputShape[minusOneAxis] = Math.round(sizeOfShape(input.shape) / elements); + } + if (sizeOfShape(input.shape) !== sizeOfShape(outputShape)) { + throw new Error(`The element size of new shape ${sizeOfShape(outputShape)} is not equal to + element size of old shape ${sizeOfShape(input.shape)} invalid.`); + } + const output = new Tensor(outputShape, input.data); + return output; +} diff --git a/src/sigmoid.js b/src/sigmoid.js new file mode 100644 index 0000000..8627914 --- /dev/null +++ b/src/sigmoid.js @@ -0,0 +1,13 @@ +'use strict'; + +import {unary} from './unary.js'; + +/** + * Compute the sigmoid function of the input tensor. + * The calculation follows the expression 1 / (exp(-x) + 1). + * @param {Tensor} input + * @return {Tensor} + */ +export function sigmoid(input) { + return unary(input, (x) => 1 / (Math.exp(-x) + 1)); +} diff --git a/src/slice.js b/src/slice.js new file mode 100644 index 0000000..90c3910 --- /dev/null +++ b/src/slice.js @@ -0,0 +1,45 @@ +'use strict'; + +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateSliceParams} from './lib/validate-input.js'; + +/** + * Produce a slice of the input tensor. + * @param {Tensor} input + * @param {Array} starts + * @param {Array} sizes + * @param {MLSliceOptions} options + * @return {Tensor} + */ +export function slice(input, starts, sizes, {axes} = {}) { + validateSliceParams(...arguments); + const rank = input.rank; + const startsForAllAxes = new Array(rank).fill(0); + + axes = axes ?? [...Array(rank).keys()]; + const axesLen = axes.length; + const outputShape = input.shape.slice(); + for (let i = 0; i < axesLen; ++i) { + const axis = axes[i] >= 0 ? axes[i] : axes[i] + rank; + const size = input.shape[axis]; + const start = starts[i]; + startsForAllAxes[axis] = start >= 0 ? start : start + size; + const sliceSize = sizes[i]; + if (sliceSize >= 0) { + outputShape[axis] = sliceSize; + } else { + outputShape[axis] = start >= 0 ? size - start : -start; + } + } + const output = new Tensor(outputShape); + for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) { + const loc = output.locationFromIndex(outputIndex); + const selectedInputLoc = loc.slice(); + for (let i = 0; i < loc.length; ++i) { + selectedInputLoc[i] = loc[i] + startsForAllAxes[i]; + } + const inputValue = input.getValueByLocation(selectedInputLoc); + output.setValueByIndex(outputIndex, inputValue); + } + return output; +} diff --git a/src/softmax.js b/src/softmax.js new file mode 100644 index 0000000..356dddb --- /dev/null +++ b/src/softmax.js @@ -0,0 +1,18 @@ +'use strict'; + +import {div, sub} from './binary.js'; +import {exp} from './unary.js'; +import {reduceMax, reduceSum} from './reduce.js'; +import {validateSoftmaxParams} from './lib/validate-input.js'; + +/** + * Compute the softmax values of the 2-D input tensor along axis 1. + * @param {Tensor} x + * @return {Tensor} + */ +export function softmax(x) { + validateSoftmaxParams(...arguments); + const maxX = reduceMax(x, {axes: [1], keepDimensions: true}); + const expX = exp(sub(x, maxX)); + return div(expX, reduceSum(expX, {axes: [1], keepDimensions: true})); +} diff --git a/src/split.js b/src/split.js new file mode 100644 index 0000000..f2c9685 --- /dev/null +++ b/src/split.js @@ -0,0 +1,30 @@ +'use strict'; + +import {slice} from './slice.js'; +import {validateSplitParams} from './lib/validate-input.js'; + +/** + * Split the input tensor into a number of sub tensors along the given axis. + * @param {Tensor} input + * @param {Array|Number} splits + * @param {MLSplitOptions} options + * @return {Array.} + */ +export function split(input, splits, {axis = 0} = {}) { + validateSplitParams(...arguments); + const outputs = []; + let sliceSizes = []; + const rank = input.rank; + const inpAxis = axis >=0 ? axis : rank + axis; + if (typeof splits === 'number') { + sliceSizes = new Array(splits).fill(input.shape[inpAxis] / splits); + } else if (splits instanceof Array) { + sliceSizes = splits.slice(); + } + let start = 0; + for (const size of sliceSizes) { + outputs.push(slice(input, [start], [size], {axes: [inpAxis]})); + start += size; + } + return outputs; +} diff --git a/src/squeeze.js b/src/squeeze.js new file mode 100644 index 0000000..a463f79 --- /dev/null +++ b/src/squeeze.js @@ -0,0 +1,20 @@ +'use strict'; + +import {reshape} from './reshape.js'; +import {validateSqueezeParams} from './lib/validate-input.js'; + +/** + * Reduce the rank of a tensor by eliminating dimensions with size 1 of the tensor shape. + * @param {Tensor} input + * @param {MLSqueezeOptions} options + * @return {Tensor} + */ +export function squeeze(input, {axes} = {}) { + validateSqueezeParams(...arguments); + const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + + const outputShape = input.shape.filter((dim, axis) => + !(dim === 1 && inpAxes.indexOf(axis) !== -1)); + const output = reshape(input, outputShape); + return output; +} diff --git a/src/tanh.js b/src/tanh.js new file mode 100644 index 0000000..4364374 --- /dev/null +++ b/src/tanh.js @@ -0,0 +1,13 @@ +'use strict'; + +import {unary} from './unary.js'; + +/** + * Compute the hyperbolic tangent function of the input tensor. + * The calculation follows the expression (exp(2 * x) - 1) / (exp(2 * x) + 1). + * @param {Tensor} input + * @return {Tensor} + */ +export function tanh(input) { + return unary(input, (x) => (Math.exp(2 * x) - 1) / (Math.exp(2 * x) + 1)); +} diff --git a/src/transpose.js b/src/transpose.js new file mode 100644 index 0000000..646c084 --- /dev/null +++ b/src/transpose.js @@ -0,0 +1,30 @@ +'use strict'; + +import {Tensor} from './lib/tensor.js'; +import {validateTranposeParams} from './lib/validate-input.js'; + +/** + * Permute the dimensions of the input tensor according to the permutation argument. + * @param {Tensor} input + * @param {MLTransposeOptions} [options] + * @return {Tensor} + */ +export function transpose(input, {permutation} = {}) { + const inpPermutation = permutation ?? + new Array(input.rank).fill(0).map((e, i, a) => a.length - i - 1); + validateTranposeParams(input, {permutation: inpPermutation}); + + const outputShape = new Array(input.rank).fill(0).map( + (e, i, a) => input.shape[inpPermutation[i]]); + const output = new Tensor(outputShape); + for (let inputIndex = 0; inputIndex < input.size; ++inputIndex) { + const inputValue = input.getValueByIndex(inputIndex); + const inputLocation = input.locationFromIndex(inputIndex); + const outputLocation = new Array(output.rank); + for (let i = 0; i < inpPermutation.length; ++i) { + outputLocation[i] = inputLocation[inpPermutation[i]]; + } + output.setValueByLocation(outputLocation, inputValue); + } + return output; +} diff --git a/src/unary.js b/src/unary.js new file mode 100644 index 0000000..1d45243 --- /dev/null +++ b/src/unary.js @@ -0,0 +1,29 @@ +'use strict'; + +import {Tensor} from './lib/tensor.js'; + +/** + * Compute the element-wise unary operation for input tensor. + * @param {Tensor} input + * @param {Function} unaryFunc + * @return {Tensor} + */ +export function unary(input, unaryFunc) { + const output = new Tensor(input.shape); + for (let i = 0; i < input.size; ++i) { + const x = input.getValueByIndex(i); + const y = unaryFunc(x); + output.setValueByIndex(i, y); + } + return output; +} + +export const abs = (input) => unary(input, Math.abs); +export const ceil = (input) => unary(input, Math.ceil); +export const cos = (input) => unary(input, Math.cos); +export const exp = (input) => unary(input, Math.exp); +export const floor = (input) => unary(input, Math.floor); +export const log = (input) => unary(input, Math.log); +export const neg = (input) => unary(input, (x) => -1 * x); +export const sin = (input) => unary(input, Math.sin); +export const tan = (input) => unary(input, Math.tan); diff --git a/test/batch_normalization_test.js b/test/batch_normalization_test.js new file mode 100644 index 0000000..c2ae7e5 --- /dev/null +++ b/test/batch_normalization_test.js @@ -0,0 +1,471 @@ +'use strict'; + +import {batchNormalization} from '../src/batch_normalization.js'; +import {clamp} from '../src/clamp.js'; +import {leakyRelu} from '../src/leaky_relu.js'; +import {relu} from '../src/relu.js'; +import {sigmoid} from '../src/sigmoid.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test batchNormalization', function() { + function testBatchNorm( + input, mean, variance, expected, scale = undefined, bias = undefined, + options = {}, activation = undefined, activationOptions = {}) { + const inputTensor = new Tensor(input.shape, input.data); + const meanTensor = new Tensor(mean.shape, mean.data); + const varianceTensor = new Tensor(variance.shape, variance.data); + if (scale) { + options.scale = new Tensor(scale.shape, scale.data); + } + if (bias) { + options.bias = new Tensor(bias.shape, bias.data); + } + if (activation === 'relu') { + options.activation = relu; + } else if (activation === 'relu6') { + options.activation = utils.bindTrailingArgs(clamp, {minValue: 0, maxValue: 6}); + } else if (activation === 'sigmoid') { + options.activation = sigmoid; + } else if (activation === 'leakyRelu') { + options.activation = utils.bindTrailingArgs(leakyRelu, activationOptions); + } + const outputTensor = batchNormalization(inputTensor, meanTensor, varianceTensor, options); + utils.checkShape(outputTensor, input.shape); + utils.checkValue(outputTensor, expected); + } + + it('batchNormalization 3D input axis=0', function() { + const input = { + shape: [3, 1, 2], + data: [-1, 0, 1, 2, 3, 4], + }; + const mean = { + shape: [3], + data: [0, 3, 6], + }; + const variance = { + shape: [3], + data: [1.0, 1.5, 2.0], + }; + const scale = { + shape: [3], + data: [1.0, 1.5, 2.0], + }; + const bias = { + shape: [3], + data: [0, 1, 2], + }; + const expected = [ + -0.9995003746877732, + 0, + -1.4486736542238683, + -0.22433682711193415, + -2.241580424529414, + -0.8277202830196093, + ]; + testBatchNorm(input, mean, variance, expected, scale, bias, {epsilon: 1e-3, axis: 0}); + }); + + it('batchNormalization 3D input axis=2', function() { + const input = { + shape: [2, 1, 3], + data: [-1, 0, 1, 2, 3, 4], + }; + const mean = { + shape: [3], + data: [0, 3, 6], + }; + const variance = { + shape: [3], + data: [1.0, 1.5, 2.0], + }; + const scale = { + shape: [3], + data: [1.0, 1.5, 2.0], + }; + const bias = { + shape: [3], + data: [0, 1, 2], + }; + const expected = [ + -0.9995003746877732, + -2.6730104813358024, + -5.069300707549023, + 1.9990007493755464, + 1, + -0.8277202830196093, + ]; + testBatchNorm(input, mean, variance, expected, scale, bias, {epsilon: 1e-3, axis: 2}); + }); + + it('batchNormalization nchw', function() { + const input = { + shape: [1, 2, 1, 3], + data: [-1, 0, 1, 2, 3, 4], + }; + const mean = { + shape: [2], + data: [0, 3], + }; + const variance = { + shape: [2], + data: [1.0, 1.5], + }; + const scale = { + shape: [2], + data: [1.0, 1.5], + }; + const bias = { + shape: [2], + data: [0, 1], + }; + let expected = [ + -0.9999950000374997, + 0, + 0.9999950000374997, + -0.22474078892909666, + 1, + 2.224740788929097, + ]; + testBatchNorm(input, mean, variance, expected, scale, bias); + + expected = [ + 0, + 0, + 0.9999950000374997, + 0, + 1, + 2.224740788929097, + ]; + testBatchNorm(input, mean, variance, expected, scale, bias, {}, 'relu'); + + let expectedScale = [ + -0.9999950000374997, + 0, + 0.9999950000374997, + -1.2247407889290967, + 0, + 1.2247407889290967, + ]; + testBatchNorm(input, mean, variance, expectedScale, scale); + + expectedScale = [ + 0, + 0, + 0.9999950000374997, + 0, + 0, + 1.2247407889290967, + ]; + testBatchNorm( + input, mean, variance, expectedScale, scale, undefined, {}, 'relu'); + + let expectedBias = [ + -0.9999950000374997, + 0, + 0.9999950000374997, + 0.18350614071393556, + 1, + 1.8164938592860644, + ]; + testBatchNorm(input, mean, variance, expectedBias, undefined, bias); + + expectedBias = [ + 0, + 0, + 0.9999950000374997, + 0.18350614071393556, + 1, + 1.8164938592860644, + ]; + testBatchNorm( + input, mean, variance, expectedBias, undefined, bias, {}, 'relu'); + }); + + it('batchNormalization nhwc', function() { + const input = { + shape: [1, 1, 3, 2], + data: [-1, 2, 0, 3, 1, 4], + }; + const mean = { + shape: [2], + data: [0, 3], + }; + const variance = { + shape: [2], + data: [1.0, 1.5], + }; + const scale = { + shape: [2], + data: [1.0, 1.5], + }; + const bias = { + shape: [2], + data: [0, 1], + }; + let expected = [ + -0.9999950000374997, + -0.22474078892909666, + 0, + 1, + 0.9999950000374997, + 2.224740788929097, + ]; + testBatchNorm(input, mean, variance, expected, scale, bias, {axis: 3}); + + expected = [ + 0, + 0, + 0, + 1, + 0.9999950000374997, + 2.224740788929097, + ]; + testBatchNorm( + input, mean, variance, expected, scale, bias, {axis: 3}, 'relu'); + + let expectedScale = [ + -0.9999950000374997, + -1.2247407889290967, + 0, + 0, + 0.9999950000374997, + 1.2247407889290967, + ]; + testBatchNorm( + input, mean, variance, expectedScale, scale, undefined, {axis: 3}); + + expectedScale = [ + 0, + 0, + 0, + 0, + 0.9999950000374997, + 1.2247407889290967, + ]; + testBatchNorm( + input, mean, variance, expectedScale, scale, undefined, {axis: 3}, + 'relu'); + + let expectedBias = [ + -0.9999950000374997, + 0.18350614071393556, + 0, + 1, + 0.9999950000374997, + 1.8164938592860644, + ]; + testBatchNorm( + input, mean, variance, expectedBias, undefined, bias, {axis: 3}); + + expectedBias = [ + 0, + 0.18350614071393556, + 0, + 1, + 0.9999950000374997, + 1.8164938592860644, + ]; + testBatchNorm( + input, mean, variance, expectedBias, undefined, bias, {axis: 3}, + 'relu'); + }); + + it('batchNormalization without options', function() { + const input = { + shape: [1, 2, 1, 3], + data: [-1, 0, 1, 2, 3, 4], + }; + const mean = { + shape: [2], + data: [0, 3], + }; + const variance = { + shape: [2], + data: [1.0, 1.5], + }; + + const expected = [ + -0.9999950000374997, + 0, + 0.9999950000374997, + -0.8164938592860644, + 0, + 0.8164938592860644, + ]; + testBatchNorm(input, mean, variance, expected); + }); + + it('batchNormalization with epsilon', function() { + const input = { + shape: [2, 3, 4, 5], + data: [ + 2.6973534, -1.1874187, -0.18637535, -1.7081367, 0.03293341, + 1.4802791, -0.68332213, 1.618039, -1.6412221, -0.52998835, + 1.5229957, -0.92798537, -0.35554567, 0.717948, 0.50108916, + 1.0521007, -0.68065745, 1.3121722, 0.50907123, 1.5093223, + -0.540522, -0.80794656, -0.17974755, -1.8922086, 2.0955374, + 0.46592507, -0.2936382, -0.43420887, -0.11036888, -1.2171484, + -1.9003569, 0.32063156, 0.38756344, 0.4720109, -0.4177193, + -0.7655141, -1.2207903, 0.52860916, 0.22583283, 1.2220219, + -0.0248001, 0.6148501, 1.0967597, 0.8798244, -0.6854243, + -0.8442876, 1.6188551, -0.6460473, 0.76349306, 2.630077, + -0.85050315, 0.37401453, 0.08842833, -0.5043717, -0.7495827, + -0.98900026, 0.79681706, -0.3573076, 0.8644746, 1.196009, + 0.35148722, 0.39926755, -0.21630785, 1.731195, 1.8644739, + -0.60227305, -1.0833911, -0.6197943, -0.05721893, -0.23889631, + -0.24901256, 1.3885167, -0.67789817, -0.3381054, 0.33224156, + 0.79065573, 1.1667213, -0.47722074, 0.4234017, 0.2317288, + -0.18525974, -0.17303231, 0.41841915, 0.13230574, 0.1261528, + 1.253214, 1.9984859, -1.7275336, 0.6593169, -1.3704892, + 0.63530993, -0.33128706, -1.2268444, 0.87340677, 1.4801403, + 0.09598545, 0.30467814, -0.15848571, -0.16779709, 1.1372787, + 0.3292992, -0.2240395, 0.88280654, 1.3370756, 0.2533313, + 0.84305125, -1.6560661, -0.09365056, -1.301057, -0.1476929, + -1.2850751, -1.286735, -1.9894414, -0.5574838, -0.392564, + -0.92764777, -0.79910755, 0.9099533, 0.9825949, -0.8327678, + ], + }; + const mean = { + shape: [3], + data: [0.3432895, 1.0855169, 1.8725895], + }; + const variance = { + shape: [3], + data: [0.601868, 0.86580527, 0.38809904], + }; + const scale = { + shape: [3], + data: [0.17215693, -0.7909758, 0.12456307], + }; + const bias = { + shape: [3], + data: [0.5280557, -1.4475446, 0.1760742], + }; + const expected = [ + 1.0461560510143535, + 0.1911657558822769, + 0.411483017030364, + 0.07656216394941046, + 0.4597501628527534, + 0.7782930494872715, + 0.3021111766340805, + 0.8086122997615323, + 0.09128923985758142, + 0.33585804528975194, + 0.7876944448589689, + 0.24826382332973346, + 0.37425072177567253, + 0.6105134023421119, + 0.5627854536128674, + 0.6840562790509362, + 0.3026976397472274, + 0.7412947998229634, + 0.5645422085033447, + 0.7846850986217833, + -0.07321674022185398, + 0.15281046195449766, + -0.3781432568821488, + 1.069228592462868, + -2.3012137908767087, + -0.9238656685427483, + -0.2818828948064065, + -0.16307258341183895, + -0.4367820994432674, + 0.4986678021352766, + 1.076115534527926, + -0.8010636134148209, + -0.857634429396731, + -0.9290094112409543, + -0.1770095657600239, + 0.11694655246443442, + 0.50174593552589, + -0.9768462529874629, + -0.7209397395575678, + -1.5629186076568993, + -0.1985106893011581, + -0.07223019353030338, + 0.022908967224644028, + -0.01991865517616051, + -0.32893189685213553, + -0.36029487662419535, + 0.12598165594057448, + -0.32115804313963925, + -0.04288492531707169, + 0.32561827430846624, + -0.361521957824965, + -0.1197762353807654, + -0.17615699930660242, + -0.2931882793427495, + -0.34159812373655585, + -0.38886422038289215, + -0.03630606199291739, + -0.2641547115300692, + -0.022949030768919576, + 0.042502880912016094, + 0.5298599167884727, + 0.5403757765085855, + 0.4048952439640776, + 0.8335165359291996, + 0.8628495735212578, + 0.31994907678513007, + 0.21406094410345505, + 0.31609286635039624, + 0.439908747758301, + 0.39992380327596644, + 0.3976973417614982, + 0.7580972800988877, + 0.30330492315042945, + 0.37808910951390895, + 0.5256241850390062, + 0.6265154745180018, + 0.7092828555654802, + 0.34747154365875904, + 0.5456874044497102, + 0.5035025696328053, + -0.3734843546348592, + -0.38381897682777666, + -0.8837136713424404, + -0.6418906556574361, + -0.6366901915962482, + -1.5892821663854055, + -2.2191858761181757, + 0.9300453045913626, + -1.087320417271379, + 0.6282714256898352, + -1.0670297294533582, + -0.25006208339353875, + 0.5068628600323595, + -1.268269146626674, + -1.7810802446517826, + -0.6111927514300393, + -0.7875797849745898, + -0.39611376119305164, + -0.38824378406828464, + -1.4912936664044165, + -0.12860398848856922, + -0.2378447662843086, + -0.019329917585430262, + 0.07035241521171431, + -0.14360166077049882, + -0.027178453755221377, + -0.520557144088695, + -0.2121032281964033, + -0.45047082948867634, + -0.22277233060240123, + -0.44731566396952194, + -0.44764336338231037, + -0.5863724884155747, + -0.30367373267244635, + -0.2711150715379228, + -0.37675193955506403, + -0.3513753779467118, + -0.013970572256729707, + 0.0003704179619681558, + -0.35802062414185254, + ]; + testBatchNorm( + input, mean, variance, expected, scale, bias, {epsilon: 1e-2}); + }); +}); diff --git a/test/binary_test.js b/test/binary_test.js new file mode 100644 index 0000000..3a0fc53 --- /dev/null +++ b/test/binary_test.js @@ -0,0 +1,1188 @@ +'use strict'; + +import {add, sub, mul, div, max, min, pow} from '../src/binary.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test binary', function() { + function testBinary(inputA, inputB, expected, func) { + const tensorA = new Tensor(inputA.shape, inputA.data); + const tensorB = new Tensor(inputB.shape, inputB.data); + const outputTensor = func(tensorA, tensorB); + utils.checkShape(outputTensor, expected.shape); + utils.checkValue(outputTensor, expected.data); + } + + it('add', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.08939514, -1.5887482, 0.8545348, 0.20523034, -0.41728342, + 1.01752, 0.19677015, 0.5398451, 0.56893295, 1.2511084, + 2.0092728, 1.0606714, 0.4893267, 0.09536829, -2.3467007, + 2.4527607, 0.61307395, -1.0799897, -0.15071101, -0.48422927, + -0.20479254, 0.32798728, -0.37435308, -1.7116562, 1.6952512, + -0.7479369, -0.09019202, 0.14343949, 1.6754607, 1.6427531, + 0.9470988, 0.20872667, -1.9530525, -0.21783416, 0.0309498, + 0.3008434, 1.1686599, 1.4920886, 0.06633294, 0.6674667, + 0.60627925, 0.04302086, -0.03482966, -0.7343786, -0.76851964, + 0.9446942, -0.35489243, 0.44452578, 0.00648887, -0.55656946, + -0.735903, 0.22050636, -0.5008282, -1.3132697, 1.6642882, + -0.48397836, 0.20099205, -0.28786168, 1.3315053, -0.41619393, + ], + }; + const inputB = { + shape: [3, 4, 5], + data: [ + -0.5781865, -0.49248728, -0.2162451, -0.13176449, -0.52118045, + 1.9125274, 0.6508799, 0.71873736, -2.3154447, 0.8080079, + 0.3022368, 0.21394566, -0.6511544, 0.20001237, -0.08041809, + 1.1127822, -1.521739, 0.7249548, -0.91961324, -0.83175105, + -1.4569077, -0.5417681, -1.6476909, 0.1223801, 2.220618, + -0.14914903, 0.7790501, -0.18711103, -0.9941537, -1.828552, + -1.36035, 0.5727087, 2.5213664, -0.3267195, 0.8431539, + 0.12337407, 1.0018097, -0.23469485, -0.4530751, 0.09238022, + 0.7888511, 0.11107288, 0.48171726, 0.34308678, -0.90550417, + 0.203841, 0.02521433, -1.7966009, -1.4287543, 0.3222213, + 1.0590587, -1.7948701, -1.7195907, -0.9120889, -0.9391962, + -0.2566791, -0.5464537, 1.4351872, 0.5705938, -0.30327085, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + -0.48879136000000006, + -2.08123548, + 0.6382897000000001, + 0.07346585, + -0.9384638700000001, + 2.9300474000000003, + 0.8476500499999999, + 1.25858246, + -1.74651175, + 2.0591163, + 2.3115096000000004, + 1.27461706, + -0.16182770000000002, + 0.29538065999999996, + -2.4271187899999997, + 3.5655428999999996, + -0.90866505, + -0.35503490000000004, + -1.07032425, + -1.31598032, + -1.6617002399999998, + -0.21378081999999998, + -2.02204398, + -1.5892761, + 3.9158692, + -0.89708593, + 0.6888580799999999, + -0.04367154000000001, + 0.6813069999999999, + -0.1857989, + -0.41325119999999993, + 0.78143537, + 0.5683138999999997, + -0.54455366, + 0.8741037, + 0.42421747, + 2.1704695999999997, + 1.2573937499999999, + -0.38674216, + 0.75984692, + 1.39513035, + 0.15409374, + 0.4468876, + -0.39129182, + -1.67402381, + 1.1485352, + -0.3296781, + -1.35207512, + -1.42226543, + -0.23434815999999997, + 0.32315570000000005, + -1.57436374, + -2.2204189, + -2.2253586, + 0.7250920000000001, + -0.74065746, + -0.34546165000000006, + 1.14732552, + 1.9020991, + -0.71946478, + ], + }; + testBinary(inputA, inputB, expected, add); + }); + + it('add broadcast', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + -0.08539673, 0.11800674, -1.2358714, 0.30089188, -0.73443925, + 1.4894297, 0.16823359, -2.2034893, 1.0740992, -0.35457978, + 0.61524934, 0.462153, 0.5992003, -0.81047946, -2.2757835, + -0.21841764, 1.1650828, -0.56927145, 1.9960726, 0.62048405, + 0.10586528, -1.0386543, -1.9402571, -2.0906122, -0.4305259, + -1.2730165, 1.5639576, 0.53357494, -0.8079486, -0.06450062, + -0.7841324, -0.24135855, 1.9275267, 0.4476717, 0.15467685, + -1.2363592, -0.50745815, 0.03250425, 0.86344534, -0.7938714, + 1.1835734, 1.515135, 0.3092435, -1.311751, -0.6659017, + 0.8815683, -0.31157655, 0.57511795, -1.1924151, -1.8408557, + -0.85080767, -1.3341717, 0.54687303, -0.14426671, -0.15728855, + 0.323939, 1.167636, 0.03020451, 0.91373825, 1.0675793, + ], + }; + const inputB = { + shape: [5], + data: [ + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.5484204699999999, + 1.74854074, + -2.6178581, + -0.7418642200000001, + 0.32369675, + 2.1232469, + 1.79876759, + -3.585476, + 0.03134309999999996, + 0.70355622, + 1.2490665399999998, + 2.0926869999999997, + -0.7827863999999999, + -1.8532355600000001, + -1.2176475000000002, + 0.41539956, + 2.7956168, + -1.95125815, + 0.9533164999999999, + 1.67862005, + 0.7396824799999999, + 0.5918797, + -3.3222438, + -3.1333683, + 0.6276101, + -0.6391993, + 3.1944916, + -0.8484117599999998, + -1.8507047, + 0.99363538, + -0.15031519999999998, + 1.38917545, + 0.5455400000000001, + -0.5950844000000001, + 1.21281285, + -0.6025420000000001, + 1.1230758499999998, + -1.34948245, + -0.17931076000000012, + 0.26426459999999996, + 1.8173906, + 3.145669, + -1.0727432, + -2.3545071, + 0.3922342999999999, + 1.5153854999999998, + 1.3189574499999999, + -0.8068687499999999, + -2.2351712, + -0.7827197000000001, + -0.21699047000000005, + 0.29636229999999997, + -0.8351136699999999, + -1.18702281, + 0.90084745, + 0.9577562, + 2.79817, + -1.35178219, + -0.1290178500000001, + 2.1257153, + ], + }; + testBinary(inputA, inputB, expected, add); + }); + + it('sub', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 1.7640524, 0.4001572, 0.978738, 2.2408931, 1.867558, + -0.9772779, 0.95008844, -0.1513572, -0.10321885, 0.41059852, + 0.14404356, 1.4542735, 0.7610377, 0.12167501, 0.44386324, + 0.33367434, 1.4940791, -0.20515826, 0.3130677, -0.85409576, + -2.5529897, 0.6536186, 0.8644362, -0.742165, 2.2697546, + -1.4543657, 0.04575852, -0.18718386, 1.5327792, 1.4693588, + 0.15494743, 0.37816253, -0.88778573, -1.9807965, -0.34791216, + 0.15634897, 1.2302907, 1.2023798, -0.3873268, -0.30230275, + -1.048553, -1.420018, -1.7062702, 1.9507754, -0.5096522, + -0.4380743, -1.2527953, 0.7774904, -1.6138978, -0.21274029, + -0.89546657, 0.3869025, -0.51080513, -1.1806322, -0.02818223, + 0.42833188, 0.06651722, 0.3024719, -0.6343221, -0.36274117, + ], + }; + const inputB = { + shape: [3, 4, 5], + data: [ + -0.67246044, -0.35955316, -0.8131463, -1.7262826, 0.17742614, + -0.40178093, -1.6301984, 0.46278226, -0.9072984, 0.0519454, + 0.7290906, 0.12898292, 1.1394007, -1.2348258, 0.40234163, + -0.6848101, -0.87079716, -0.5788497, -0.31155252, 0.05616534, + -1.1651498, 0.9008265, 0.46566245, -1.5362437, 1.4882522, + 1.8958892, 1.1787796, -0.17992483, -1.0707526, 1.0544517, + -0.40317693, 1.222445, 0.20827498, 0.97663903, 0.3563664, + 0.7065732, 0.01050002, 1.7858706, 0.12691209, 0.40198937, + 1.8831507, -1.347759, -1.270485, 0.9693967, -1.1731234, + 1.9436212, -0.41361898, -0.7474548, 1.922942, 1.4805148, + 1.867559, 0.90604466, -0.86122566, 1.9100649, -0.26800337, + 0.8024564, 0.947252, -0.15501009, 0.61407936, 0.9222067, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 2.43651284, + 0.7597103599999999, + 1.7918843, + 3.9671757000000003, + 1.6901318600000002, + -0.5754969700000001, + 2.5802868400000003, + -0.61413946, + 0.80407955, + 0.35865312000000005, + -0.5850470400000001, + 1.3252905799999999, + -0.3783629999999999, + 1.35650081, + 0.04152160999999999, + 1.01848444, + 2.36487626, + 0.37369144, + 0.62462022, + -0.9102610999999999, + -1.3878399, + -0.24720789999999992, + 0.39877375, + 0.7940787, + 0.7815024000000002, + -3.3502549000000004, + -1.13302108, + -0.00725903, + 2.6035318, + 0.4149071, + 0.55812436, + -0.84428247, + -1.0960607100000002, + -2.95743553, + -0.7042785600000001, + -0.55022423, + 1.21979068, + -0.5834907999999999, + -0.51423889, + -0.7042921200000001, + -2.9317037, + -0.07225900000000007, + -0.4357852, + 0.9813786999999999, + 0.6634711999999999, + -2.3816955, + -0.83917632, + 1.5249452, + -3.5368398, + -1.6932550899999999, + -2.76302557, + -0.5191421599999999, + 0.35042052999999995, + -3.0906971, + 0.23982114000000002, + -0.37412451999999996, + -0.8807347799999999, + 0.45748199, + -1.24840146, + -1.28494787, + ], + }; + testBinary(inputA, inputB, expected, sub); + }); + + it('sub broadcast', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.37642553, -1.0994008, 0.2982382, 1.3263859, -0.69456786, + -0.14963454, -0.43515354, 1.8492638, 0.67229474, 0.40746182, + -0.76991606, 0.5392492, -0.6743327, 0.03183056, -0.6358461, + 0.67643327, 0.57659084, -0.20829876, 0.3960067, -1.0930616, + -1.4912575, 0.4393917, 0.1666735, 0.63503146, 2.3831449, + 0.94447947, -0.91282225, 1.1170163, -1.3159074, -0.4615846, + -0.0682416, 1.7133427, -0.74475485, -0.82643855, -0.09845252, + -0.6634783, 1.1266359, -1.0799315, -1.1474687, -0.43782005, + -0.49803245, 1.929532, 0.9494208, 0.08755124, -1.2254355, + 0.844363, -1.0002153, -1.5447711, 1.1880298, 0.3169426, + 0.9208588, 0.31872764, 0.8568306, -0.6510256, -1.0342429, + 0.6815945, -0.80340964, -0.6895498, -0.4555325, 0.01747916, + ], + }; + const inputB = { + shape: [5], + data: [ + -0.35399392, + -1.3749512, + -0.6436184, + -2.2234032, + 0.62523144, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.7304194500000001, + 0.2755504, + 0.9418566, + 3.5497891, + -1.3197993000000001, + 0.20435938, + 0.93979766, + 2.4928822, + 2.89569794, + -0.21776962, + -0.41592214, + 1.9142004, + -0.030714299999999972, + 2.25523376, + -1.26107754, + 1.03042719, + 1.95154204, + 0.43531964, + 2.6194099, + -1.7182930399999998, + -1.13726358, + 1.8143429, + 0.8102919000000001, + 2.85843466, + 1.7579134600000001, + 1.29847339, + 0.46212895, + 1.7606347, + 0.9074958, + -1.08681604, + 0.28575232, + 3.0882939, + -0.10113644999999993, + 1.39696465, + -0.72368396, + -0.30948437999999995, + 2.5015871, + -0.4363131, + 1.0759345, + -1.0630514899999999, + -0.14403853, + 3.3044832, + 1.5930392, + 2.3109544399999997, + -1.85066694, + 1.19835692, + 0.3747358999999999, + -0.9011526999999999, + 3.4114329999999997, + -0.30828883999999995, + 1.27485272, + 1.69367884, + 1.5004490000000001, + 1.5723775999999998, + -1.65947434, + 1.03558842, + 0.5715415599999999, + -0.045931399999999956, + 1.7678707, + -0.6077522799999999, + ], + }; + testBinary(inputA, inputB, expected, sub); + }); + + it('mul', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 5.6232101e-01, 1.3117781e-01, -1.4161869e+00, 2.0386910e-02, + 9.1077393e-01, 7.4952751e-01, -2.8509337e-01, -1.6272701e+00, + 1.0271618e+00, 4.2815253e-01, -7.7895027e-01, 9.7542489e-01, + 3.9352554e-01, 9.7878903e-01, -6.0965502e-01, 6.6299748e-01, + -1.1980454e+00, -7.7857232e-01, -9.8175555e-01, -2.8763762e-01, + -3.2260692e-01, -7.4259090e-01, -1.0055183e+00, -1.4305019e+00, + 6.0624069e-01, -1.5911928e-01, 4.5598033e-01, 1.0880016e-01, + 1.4949993e+00, 6.6210419e-01, -5.6889033e-01, -2.0945708e-01, + -7.1049523e-01, -2.8507587e-01, 1.1723405e+00, -6.3937567e-02, + -5.4250038e-01, -1.2398884e+00, -1.0347517e+00, 1.2763804e+00, + -1.5979607e+00, -5.8152825e-01, -5.0100851e-01, -1.0742084e+00, + -1.1273566e+00, 3.4815140e-04, -5.6024802e-01, 1.0848801e+00, + -5.1780093e-01, -3.8996863e-01, 5.3133094e-01, 2.3897937e-01, + -1.3832775e+00, 6.3414145e-01, 1.0691971e+00, 5.7040757e-01, + 3.0711100e-01, 8.8405716e-01, -2.1583509e+00, 4.3243581e-01, + ], + }; + const inputB = { + shape: [3, 4, 5], + data: [ + 2.0435283, 0.07213961, -1.1644137, -1.2209045, 0.8982674, + 0.21796915, 0.27658972, 0.7744382, -0.52159035, -0.969913, + 0.6081186, -0.04225572, 0.3275312, -0.06443629, -2.257355, + 1.7802691, -1.279233, -3.1389477, -1.1663845, -0.79485595, + 0.679013, 1.0919031, 0.51905185, 1.3186365, 0.6612518, + 0.40741763, 0.05208012, 0.16548257, -0.4570541, 0.10149371, + 0.08249464, 0.3992067, -0.3945879, -0.37389037, 1.4760005, + -0.781274, -0.49022308, 0.27020553, -0.2356837, 0.13846985, + 0.9767852, -1.3560135, 0.78826934, -0.18788454, 0.38178417, + 0.9748209, 1.0242884, 0.7939937, 0.24449475, -1.3840157, + 1.9665064, 0.35833818, -0.87076694, -0.76727265, 0.6157508, + -0.5558823, 0.18417479, -0.93904793, -0.00859687, 0.5034271, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 1.1491188976195832, + 0.009463116054054102, + 1.6490274281205302, + -0.024890470160095, + 0.818118530088882, + 0.1633738742563165, + -0.0788538953821564, + -1.26022012715782, + -0.53575768276863, + -0.41527070482989, + -0.47369414766202195, + -0.0412172810328708, + 0.128891892346848, + -0.06306953378589869, + 1.3762078076721, + 1.180313927021868, + 1.5325792111782002, + 2.443897793147664, + 1.145104456308975, + 0.22863047370083897, + -0.21905429256996, + -0.81083730574179, + -0.521916133823855, + -1.88631201865935, + 0.400877747495742, + -0.0648279999449064, + 0.0237475103040396, + 0.0180045300932112, + -0.68329555956213, + 0.0671994106496449, + -0.046930402972831194, + -0.08361666969843601, + 0.280352820765717, + 0.10658712251237192, + 1.73037516417025, + 0.049952758720358, + 0.2659462071847704, + -0.335024702262852, + 0.24387410923728997, + 0.17674020253094003, + -1.56086436194164, + 0.788560157631375, + -0.3949296475120834, + 0.201827151098136, + -0.43040690382502195, + 0.00033938526108426, + -0.573855548008968, + 0.8613879646553699, + -0.1265996089301175, + 0.539722706427491, + 1.044865694028016, + 0.0856354325033466, + 1.20451231584585, + -0.4865593908163425, + 0.65835896968268, + -0.31707947194901104, + 0.056562103931690005, + -0.8301720460996788, + 0.018555062101682996, + 0.217699905764451, + ], + }; + testBinary(inputA, inputB, expected, mul); + }); + + it('mul broadcast', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + -0.08539673, 0.11800674, -1.2358714, 0.30089188, -0.73443925, + 1.4894297, 0.16823359, -2.2034893, 1.0740992, -0.35457978, + 0.61524934, 0.462153, 0.5992003, -0.81047946, -2.2757835, + -0.21841764, 1.1650828, -0.56927145, 1.9960726, 0.62048405, + 0.10586528, -1.0386543, -1.9402571, -2.0906122, -0.4305259, + -1.2730165, 1.5639576, 0.53357494, -0.8079486, -0.06450062, + -0.7841324, -0.24135855, 1.9275267, 0.4476717, 0.15467685, + -1.2363592, -0.50745815, 0.03250425, 0.86344534, -0.7938714, + 1.1835734, 1.515135, 0.3092435, -1.311751, -0.6659017, + 0.8815683, -0.31157655, 0.57511795, -1.1924151, -1.8408557, + -0.85080767, -1.3341717, 0.54687303, -0.14426671, -0.15728855, + 0.323939, 1.167636, 0.03020451, 0.91373825, 1.0675793, + ], + }; + const inputB = { + shape: [5], + data: [ + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + -0.054125916297756, + 0.19241400179916, + 1.70795783771038, + -0.31375684331046805, + -0.777136610238, + 0.94402616205084, + 0.27431058843705997, + 3.04519290619231, + -1.12002349280512, + -0.37519363009008, + 0.389955613980648, + 0.753556179702, + -0.82808684523601, + 0.8451324008397061, + -2.408088449556, + -0.138436857015408, + 1.8997071182151999, + 0.786725572589715, + -2.0814168796928603, + 0.6565565107307999, + 0.067099235346816, + -1.6935611503961998, + 2.68140950678057, + 2.17999862428442, + -0.4555549537224, + -0.8068597535837999, + 2.5500860413583997, + -0.7373934705332981, + 0.8424933311364601, + -0.06825042804432, + -0.4969966021972799, + -0.39354332196569997, + -2.66381626329489, + -0.46681239597237006, + 0.1636691433516, + -0.78362572633824, + -0.8274277671521001, + -0.04492044119347499, + -0.900362895301574, + -0.8400239077104, + 0.75016917838248, + 2.47047913209, + -0.42737040406145, + 1.3678363569311, + -0.7046145612312, + 0.5587531515147599, + -0.5080361583777, + -0.7948053578312649, + 1.2433981192571102, + -1.9478756869752, + -0.5392565351379239, + -2.1754123186878, + -0.7557712540487009, + 0.150434991879431, + -0.1664326771428, + 0.20531810995079997, + 1.9038701976239998, + -0.041742231100017, + -0.952806133990825, + 1.1296440901848, + ], + }; + testBinary(inputA, inputB, expected, mul); + }); + + it('div', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.5270042, 0.4537819, -1.8297404, 0.03700572, 0.76790243, + 0.5898798, -0.36385882, -0.8056265, -1.1183119, -0.13105401, + 1.1330799, -1.9518042, -0.6598917, -1.1398025, 0.7849575, + -0.5543096, -0.47063765, -0.21694957, 0.44539326, -0.392389, + -3.046143, 0.5433119, 0.43904296, -0.21954103, -1.0840366, + 0.35178012, 0.37923554, -0.47003287, -0.21673147, -0.9301565, + -0.17858909, -1.5504293, 0.41731882, -0.9443685, 0.23810315, + -1.405963, -0.5900577, -0.11048941, -1.6606998, 0.11514787, + -0.37914756, -1.7423562, -1.3032428, 0.60512006, 0.895556, + -0.13190864, 0.40476182, 0.22384356, 0.32962298, 1.285984, + -1.5069984, 0.67646074, -0.38200897, -0.22425893, -0.30224973, + -0.3751471, -1.2261962, 0.1833392, 1.670943, -0.05613302, + ], + }; + const inputB = { + shape: [3, 4, 5], + data: [ + 0.99861497, 0.312701, 0.88252544, 1.4661665, 0.6297575, + 0.546196, 1.4032645, 0.08199525, 1.2524966, 1.8203218, + 2.3599486, 0.909618, 2.367597, 2.03441, 0.00378734, + -0.21793854, 0.69503635, 2.0289354, 0.927713, 0.39934242, + 2.5522432, 1.2869045, -1.3205943, 1.3171606, 1.5200406, + 1.2256087, 1.449712, 0.9327244, -0.31839585, 0.629296, + 0.05438423, 0.06725907, -0.26306832, 1.4524891, 1.0978961, + 0.55183464, 0.35066205, 0.9765769, 2.0791948, -1.0042157, + 1.3768766, 0.454288, -0.88458586, -0.945703, 0.0872165, + 1.2195096, 1.393063, 0.06101841, 2.017021, 2.4229836, + 1.3960866, 0.40859735, 2.1244192, 1.7553957, 1.8674074, + 0.34353632, -1.8345544, 3.116791, -0.61087835, 0.9642319, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.5277351289856991, + 1.4511686882996857, + -2.0733004591912954, + 0.025239780065906568, + 1.2193621036668878, + 1.0799782495660897, + -0.25929453784372086, + -9.825282562099634, + -0.8928662161637804, + -0.07199496814244602, + 0.48012905874305906, + -2.1457405196467088, + -0.2787179152533138, + -0.5602619432661067, + 207.2582604149614, + 2.543421645386814, + -0.6771410588813089, + -0.10692778587233483, + 0.48009811223945337, + -0.9825878252553285, + -1.1935159627421086, + 0.4221850960968743, + -0.33245862109203406, + -0.16667749551573285, + -0.7131629247271422, + 0.2870248228492503, + 0.2615937096471575, + -0.5039354283001495, + 0.6806981623661239, + -1.4780905964760622, + -3.2838396351295223, + -23.05160181370334, + -1.5863514846637556, + -0.6501725210881101, + 0.21687220675982, + -2.5477976518473, + -1.6826962027969665, + -0.11313948753037267, + -0.7987225631768606, + -0.11466447895606491, + -0.2753678579474733, + -3.8353559856302604, + 1.4732801629906227, + -0.6398626841619409, + 10.268194665000316, + -0.10816531497579025, + 0.29055528716217427, + 3.668459404301095, + 0.16342069814840796, + 0.5307439967814888, + -1.079444785158743, + 1.6555681039047365, + -0.17981807451184775, + -0.1277540613777281, + -0.16185527057459448, + -1.0920158311063004, + 0.668389119450478, + 0.05882306513333746, + -2.7353121943182304, + -0.058215269584007745, + ], + }; + testBinary(inputA, inputB, expected, div); + }); + + it('div broadcast', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 2.3807454, 0.33057675, 0.94924647, -1.5023966, -1.7776669, + -0.5327028, 1.0907497, -0.34624946, -0.7946363, 0.19796729, + 1.0819352, -1.4449402, -1.210543, -0.7886692, 1.0946383, + 0.23482153, 2.1321535, 0.9364457, -0.03509518, 1.2650778, + 0.21149701, -0.70492136, 0.67997485, -0.6963267, -0.2903971, + 1.3277828, -0.10128149, -0.8031414, -0.46433768, 1.0217906, + -0.55254066, -0.38687086, -0.51029277, 0.1839255, -0.38548976, + -1.6018361, -0.8871809, -0.932789, 1.2433194, 0.81267405, + 0.58725935, -0.50535834, -0.81579155, -0.5075176, -1.0518801, + 2.4972005, -2.2453218, 0.56400853, -1.2845523, -0.10434349, + -0.98800194, -1.177629, -1.1401963, 1.7549862, -0.13298842, + -0.7657022, 0.55578697, 0.01034931, 0.72003376, -1.8242567, + ], + }; + const inputB = { + shape: [5], + data: [ + 1.3041736, + 1.5910654, + 1.9217191, + 1.8052639, + 1.7239413, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 1.8254819757124359, + 0.2077706862332623, + 0.49395693158276877, + -0.832231010657223, + -1.0311644021754105, + -0.408460039368992, + 0.6855467411961821, + -0.18017693636910825, + -0.4401773613265074, + 0.114834124572571, + 0.8295944650313425, + -0.9081588978052065, + -0.629927131389806, + -0.4368719720147287, + 0.6349626289479811, + 0.1800538900649423, + 1.3400791067419353, + 0.4872958279906777, + -0.0194404707256374, + 0.733828814240949, + 0.16216936916987126, + -0.4430498960005039, + 0.35383675480979504, + -0.3857201708847111, + -0.16844952899498378, + 1.0181028047186358, + -0.06365639652524654, + -0.41792861402064435, + -0.2572131863934132, + 0.5927061437648719, + -0.4236710971606848, + -0.24315207910372508, + -0.2655397295057326, + 0.10188288814726755, + -0.2236095625761736, + -1.228238403230981, + -0.5576017805427734, + -0.48539300046505235, + 0.6887189180484915, + 0.4714047108216504, + 0.4502923153788729, + -0.3176226068394172, + -0.42451133987272127, + -0.28113208268331297, + -0.6101600443124137, + 1.9147761463657906, + -1.4112064783760618, + 0.2934916606698658, + -0.7115592905834986, + -0.060526126962675585, + -0.7575693450626512, + -0.7401512219422282, + -0.5933210009725146, + 0.9721493904575392, + -0.07714208134580916, + -0.5871167764782235, + 0.34931748877198887, + 0.0053854436894549265, + 0.3988523561569032, + -1.058189568287505, + ], + }; + testBinary(inputA, inputB, expected, div); + }); + + it('max', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.54270846, 0.3356357, 0.04034169, 1.6710619, -1.0029255, + 1.4024457, -0.5183214, -1.5897884, 0.16786452, -0.92690915, + -0.8761584, 1.8612522, 0.2960607, 0.11604685, 0.2686291, + -0.5718065, 0.4856556, -1.2307562, -1.7977105, -1.1370704, + 1.0383102, -1.0015849, -1.367141, 0.32427165, 1.2968429, + 1.3039074, -0.6295407, 1.1250858, 1.0206878, -0.769062, + 0.96548617, 1.9100864, 2.1261373, 0.8835118, -0.66880584, + 0.9088927, 1.8120629, -0.25648043, 0.15793198, -1.5175776, + 0.08734574, 0.9441932, -1.0558261, 0.1276651, -2.9616504, + 2.1102998, 0.58067006, -0.7349921, -0.28586444, -0.92654175, + -0.507083, -1.8776977, 0.57921827, 1.460351, 1.4930215, + -0.757663, 1.0773797, -1.1858964, -0.5337765, 0.27636543, + ], + }; + const inputB = { + shape: [3, 4, 5], + data: [ + -0.00724315, -1.4088361, 0.17466596, 1.1395162, 1.3720452, + -0.35610083, -0.5597993, -0.26632488, -0.31922337, -0.2980101, + 0.12268824, -1.1521344, -1.0502838, 0.85281086, -0.83374727, + 0.00551354, 0.08081324, -0.13748081, 0.59067047, -0.20894054, + -0.9008378, -0.06121079, -1.8927814, -0.5113896, 2.0618987, + -0.09704968, 1.9003097, -0.27883208, -0.9971944, -1.0472671, + 0.995112, 0.83037376, 1.5058613, 0.51366556, 0.4476341, + 1.0389726, -0.04508441, -0.2180115, 0.3973936, 0.58917326, + 2.3834932, 0.71679467, 0.06214673, -0.09415992, 0.9173279, + 0.55409455, 0.6537859, -1.1739589, 1.1591603, 0.5907742, + -1.0454807, -0.8065648, 2.0162134, -0.30215183, 0.67375183, + 1.6682644, -2.916385, 0.43166366, -0.7290503, 0.11509943, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.54270846, 0.3356357, 0.17466596, 1.6710619, 1.3720452, 1.4024457, + -0.5183214, -0.26632488, 0.16786452, -0.2980101, 0.12268824, 1.8612522, + 0.2960607, 0.85281086, 0.2686291, 0.00551354, 0.4856556, -0.13748081, + 0.59067047, -0.20894054, 1.0383102, -0.06121079, -1.367141, 0.32427165, + 2.0618987, 1.3039074, 1.9003097, 1.1250858, 1.0206878, -0.769062, + 0.995112, 1.9100864, 2.1261373, 0.8835118, 0.4476341, 1.0389726, + 1.8120629, -0.2180115, 0.3973936, 0.58917326, 2.3834932, 0.9441932, + 0.06214673, 0.1276651, 0.9173279, 2.1102998, 0.6537859, -0.7349921, + 1.1591603, 0.5907742, -0.507083, -0.8065648, 2.0162134, 1.460351, + 1.4930215, 1.6682644, 1.0773797, 0.43166366, -0.5337765, 0.27636543, + ], + }; + testBinary(inputA, inputB, expected, max); + }); + + it('max broadcast', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + -0.78042406, -0.18523395, -0.12612817, -0.24858657, 0.36215156, + -0.41349608, 1.540389, 1.9143543, 0.4806893, 0.0123093, + 1.2142435, -0.57421523, -2.1229508, 1.1247561, 0.11206079, + 0.5191412, -0.2109448, -0.97485703, 0.6992101, 1.0161952, + -0.19765139, 0.34198883, -0.24741505, 1.5920583, 0.56292, + 0.09105966, 0.82438636, -0.2996084, -0.97498095, 1.9305013, + 1.4938543, 0.01099077, 0.7837045, 0.6621192, 0.9520401, + -0.63094735, -1.4202772, 2.6008792, -0.3047365, -0.58313465, + -0.37956452, -0.14322324, -1.2261407, -1.1514657, -0.28318587, + -0.06985976, 0.48337674, 0.99673945, -0.54980195, -1.7497128, + 0.62820524, 1.0456259, 0.16508068, 0.5966878, 0.7607826, + 0.9664813, -0.13389224, -0.5757679, 0.38655168, -0.39935285, + ], + }; + const inputB = { + shape: [5], + data: [ + 0.67538136, + 0.3535401, + 1.0303422, + -0.50294054, + -0.25600532, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.67538136, 0.3535401, 1.0303422, -0.24858657, 0.36215156, + 0.67538136, 1.540389, 1.9143543, 0.4806893, 0.0123093, + 1.2142435, 0.3535401, 1.0303422, 1.1247561, 0.11206079, + 0.67538136, 0.3535401, 1.0303422, 0.6992101, 1.0161952, + 0.67538136, 0.3535401, 1.0303422, 1.5920583, 0.56292, + 0.67538136, 0.82438636, 1.0303422, -0.50294054, 1.9305013, + 1.4938543, 0.3535401, 1.0303422, 0.6621192, 0.9520401, + 0.67538136, 0.3535401, 2.6008792, -0.3047365, -0.25600532, + 0.67538136, 0.3535401, 1.0303422, -0.50294054, -0.25600532, + 0.67538136, 0.48337674, 1.0303422, -0.50294054, -0.25600532, + 0.67538136, 1.0456259, 1.0303422, 0.5966878, 0.7607826, + 0.9664813, 0.3535401, 1.0303422, 0.38655168, -0.25600532, + ], + }; + testBinary(inputA, inputB, expected, max); + }); + + it('min', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.30360392, 0.79021126, 0.11072686, 1.0779074, -0.02202512, + -0.4660466, 0.5439212, -1.1046865, -0.7237214, 1.7275667, + 0.05005725, 0.03450501, -0.93030375, 0.8889801, 1.6954619, + -0.01362751, -0.276192, 0.05534686, 1.046008, 0.10164198, + 0.5601633, -0.32077986, -0.59266484, -0.39202943, -0.03543149, + -0.311161, -2.6089416, 0.5112193, -1.4783202, -0.8066068, + 0.77324635, 1.5120724, 1.3049824, -0.03303701, 1.201271, + -0.08360443, -1.0856549, -1.268517, -0.77472717, -0.6026987, + -0.37952536, -1.1476341, 0.08269309, 1.0225683, -1.4790517, + 1.9010514, -0.8733177, -0.08186013, 1.1718949, -0.01093488, + -0.3274254, 0.73195547, -0.5514492, -0.7521337, -1.0613606, + 0.6751333, 0.9138903, 1.7775172, 0.5034791, 0.00691956, + ], + }; + const inputB = { + shape: [3, 4, 5], + data: [ + -0.3013072, -0.09710764, 0.19347863, 0.57673335, -0.9459303, + -0.311303, -0.51731133, 0.05566696, 0.1896354, -2.4551184, + 0.49731326, -0.505013, 0.38610065, -0.46502006, 0.11969721, + 0.52275103, 0.25405633, -2.177016, 0.36703554, 0.33286744, + -0.49586803, 0.09411436, 0.38203833, -1.8008012, 0.4627897, + -0.14300857, 0.26225486, 0.10055642, 1.5006567, -0.04743041, + -0.7460712, -1.3833494, -0.2873905, -1.8731467, -1.006253, + -0.21216351, -1.2171068, 0.1594863, -1.7146875, 0.21852039, + 1.3147641, 0.28219756, -0.84008366, -0.979971, 0.2722022, + 1.1494406, -1.4083267, 0.09631079, -0.04712944, -0.8959271, + 1.2020742, -0.24440259, 0.18198308, -1.3384086, -0.5169678, + -0.6608337, 0.30539933, -1.529869, -0.70533603, -2.1911235, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + -0.3013072, -0.09710764, 0.11072686, 0.57673335, -0.9459303, + -0.4660466, -0.51731133, -1.1046865, -0.7237214, -2.4551184, + 0.05005725, -0.505013, -0.93030375, -0.46502006, 0.11969721, + -0.01362751, -0.276192, -2.177016, 0.36703554, 0.10164198, + -0.49586803, -0.32077986, -0.59266484, -1.8008012, -0.03543149, + -0.311161, -2.6089416, 0.10055642, -1.4783202, -0.8066068, + -0.7460712, -1.3833494, -0.2873905, -1.8731467, -1.006253, + -0.21216351, -1.2171068, -1.268517, -1.7146875, -0.6026987, + -0.37952536, -1.1476341, -0.84008366, -0.979971, -1.4790517, + 1.1494406, -1.4083267, -0.08186013, -0.04712944, -0.8959271, + -0.3274254, -0.24440259, -0.5514492, -1.3384086, -1.0613606, + -0.6608337, 0.30539933, -1.529869, -0.70533603, -2.1911235, + ], + }; + testBinary(inputA, inputB, expected, min); + }); + + it('min broadcast', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.09259097, -1.2761278, 0.63461846, 0.83395857, -0.6424096, + -0.10002025, 0.2483844, 1.324728, 0.7070375, -0.24927127, + -1.1588863, 0.05159701, -0.27449006, 1.3718864, -0.2961051, + -0.21801688, 0.4596571, -0.2982913, -2.4248464, 0.25273538, + 0.04604488, -0.87013924, 1.554572, 0.41449285, -0.68581927, + 0.21872331, 0.5650471, -1.3366132, -0.34167227, 1.4196033, + -0.9094157, 0.5909053, 0.20646141, 0.23326884, 0.27068487, + -0.2444074, 0.44961262, -1.3790505, -1.4981223, 1.9089019, + 0.6859794, -1.6197531, -0.85252583, 0.3867299, 0.9107394, + 0.63347656, -2.0192556, 0.49276412, 0.5069547, 0.14318226, + -0.5055633, -1.2882828, 0.00957129, 0.41766334, -0.53743577, + 0.3123349, 0.04377401, -0.26201916, -1.6016098, -0.74272215, + ], + }; + const inputB = { + shape: [5], + data: [ + 0.6450575, + -1.302236, + 0.27485028, + 1.8353013, + -0.83993983, + ], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.09259097, -1.302236, 0.27485028, 0.83395857, -0.83993983, + -0.10002025, -1.302236, 0.27485028, 0.7070375, -0.83993983, + -1.1588863, -1.302236, -0.27449006, 1.3718864, -0.83993983, + -0.21801688, -1.302236, -0.2982913, -2.4248464, -0.83993983, + 0.04604488, -1.302236, 0.27485028, 0.41449285, -0.83993983, + 0.21872331, -1.302236, -1.3366132, -0.34167227, -0.83993983, + -0.9094157, -1.302236, 0.20646141, 0.23326884, -0.83993983, + -0.2444074, -1.302236, -1.3790505, -1.4981223, -0.83993983, + 0.6450575, -1.6197531, -0.85252583, 0.3867299, -0.83993983, + 0.63347656, -2.0192556, 0.27485028, 0.5069547, -0.83993983, + -0.5055633, -1.302236, 0.00957129, 0.41766334, -0.83993983, + 0.3123349, -1.302236, -0.26201916, -1.6016098, -0.83993983, + ], + }; + testBinary(inputA, inputB, expected, min); + }); + + it('sqrt 1dx1d', function() { + const inputA = { + shape: [3], + data: [1, 4, 9], + }; + const inputB = { + shape: [1], + data: [0.5], + }; + const expected = { + shape: [3], + data: [1, 2, 3], + }; + testBinary(inputA, inputB, expected, pow); + }); + + it('pow 3dx1d', function() { + const inputA = { + shape: [3, 4, 5], + data: [ + 0.33435354, 0.57139647, 0.03689031, 0.7820907, 0.7718887, + 0.17709309, 1.05624, 2.2693596, 1.0328789, 1.6043026, + 2.0692635, 1.7839943, 1.4888871, 0.57544494, 0.2760935, + 0.25916228, 0.24607088, 0.75507194, 0.9365655, 0.66641825, + 0.1919839, 0.42336762, 1.1776822, 1.8486708, 0.7361624, + 0.28052628, 0.261271, 1.0593715, 0.54762685, 0.61064255, + 0.6917134, 0.3692974, 0.01287235, 0.6559981, 0.32968605, + 1.9361054, 1.5982035, 0.49353063, 0.28142217, 0.55740887, + 0.43017766, 2.6145968, 0.4801058, 0.7487864, 1.0473998, + 0.11505236, 0.24899477, 0.21978393, 0.21973193, 0.6550839, + 0.7919175, 0.21990986, 0.2881369, 0.5660939, 0.54675615, + 0.70638055, 0.82219034, 0.6266006, 0.89149487, 0.36557788, + ], + }; + const inputB = { + shape: [1], + data: [0.5], + }; + const expected = { + shape: [3, 4, 5], + data: [ + 0.5782331190791479, + 0.755907712621058, + 0.19206850340438436, + 0.88435892034852, + 0.87857196631807, + 0.4208242982528457, + 1.0277353745006543, + 1.5064393781364054, + 1.0163064990444566, + 1.2666106742010348, + 1.4384934827798144, + 1.3356624947942501, + 1.2201996148171823, + 0.7585808724190191, + 0.5254460010315046, + 0.5090798365679002, + 0.4960553194957191, + 0.8689487556812542, + 0.967763142509571, + 0.8163444432345944, + 0.43815967409153483, + 0.6506670577184617, + 1.085210670791621, + 1.3596583394367867, + 0.8579990675985609, + 0.5296473166173884, + 0.511146749965213, + 1.0292577422589542, + 0.7400181416695134, + 0.7814362098085806, + 0.8316930924325391, + 0.6076984449544034, + 0.11345637928296495, + 0.8099370963229182, + 0.5741829412304061, + 1.3914400454205706, + 1.2642007356428804, + 0.7025173520988646, + 0.5304923844882224, + 0.7465981984976926, + 0.6558793029208957, + 1.6169714901630146, + 0.6928966733936598, + 0.8653244478228961, + 1.0234255224489959, + 0.3391936909790629, + 0.49899375747598285, + 0.46881118800642974, + 0.468755725298369, + 0.8093725347452804, + 0.8898974660038088, + 0.46894547657483593, + 0.5367838484902465, + 0.7523921185126808, + 0.7394296112545129, + 0.8404644846749921, + 0.9067471202049665, + 0.7915810760749653, + 0.944190060316248, + 0.6046303664223291, + ], + }; + testBinary(inputA, inputB, expected, pow); + }); + + it('pow 1dx1d', function() { + const inputA = { + shape: [3], + data: [1, 2, 3], + }; + const inputB = { + shape: [3], + data: [4, 5, 6], + }; + const expected = { + shape: [3], + data: [1., 32., 729.], + }; + testBinary(inputA, inputB, expected, pow); + }); + + it('pow 1dx1d', function() { + const inputA = { + shape: [3], + data: [1, 2, 3], + }; + const inputB = { + shape: [3], + data: [4, 5, 6], + }; + const expected = { + shape: [3], + data: [1., 32., 729.], + }; + testBinary(inputA, inputB, expected, pow); + }); + + it('pow 1dx1d broadcast', function() { + const inputA = { + shape: [3], + data: [1, 2, 3], + }; + const inputB = { + shape: [1], + data: [2], + }; + const expected = { + shape: [3], + data: [1., 4., 9.], + }; + testBinary(inputA, inputB, expected, pow); + }); + + it('pow 2dx1d broadcast', function() { + const inputA = { + shape: [2, 3], + data: [1, 2, 3, 4, 5, 6], + }; + const inputB = { + shape: [3], + data: [1, 2, 3], + }; + const expected = { + shape: [2, 3], + data: [1., 4., 27., 4., 25., 216.], + }; + testBinary(inputA, inputB, expected, pow); + }); +}); diff --git a/test/brodcast_test.js b/test/brodcast_test.js new file mode 100644 index 0000000..8aabc92 --- /dev/null +++ b/test/brodcast_test.js @@ -0,0 +1,105 @@ +'use strict'; + +import {broadcast} from '../src/lib/broadcast.js'; +import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test broadcast', function() { + it('broadcast [1] to [3, 4, 5]', function() { + const inputShape = [1]; + const inputData = [0.6338172]; + const inputTensor = new Tensor(inputShape, inputData); + const expectedShape = [3, 4, 5]; + const expectedData = new Array(sizeOfShape(expectedShape)).fill(0.6338172); + const outputTensor = broadcast(inputTensor, [3, 4, 5]); + utils.checkShape(outputTensor, expectedShape); + utils.checkValue(outputTensor, expectedData); + }); + + it('broadcast [5] to [3, 4, 5]', function() { + const inputShape = [5]; + const inputData = [0.6338172, 1.630534, -1.3819867, -1.0427561, 1.058136]; + const inputTensor = new Tensor(inputShape, inputData); + const expectedShape = [3, 4, 5]; + const expectedData = [ + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + 0.6338172, + 1.630534, + -1.3819867, + -1.0427561, + 1.058136, + ]; + const outputTensor = broadcast(inputTensor, [3, 4, 5]); + utils.checkShape(outputTensor, expectedShape); + utils.checkValue(outputTensor, expectedData); + }); + + it('broadcast [2, 1, 2] to [2, 2, 2]', function() { + const inputShape = [2, 1, 2]; + const inputData = [ + 0.8189771771430969, 0.9455667734146118, 0.8828932046890259, 0.3519825041294098]; + const inputTensor = new Tensor(inputShape, inputData); + const expectedShape = [2, 2, 2]; + const expectedData = [ + 0.8189771771430969, 0.9455667734146118, 0.8189771771430969, 0.9455667734146118, + 0.8828932046890259, 0.3519825041294098, 0.8828932046890259, 0.3519825041294098, + ]; + const outputTensor = broadcast(inputTensor, [2, 2, 2]); + utils.checkShape(outputTensor, expectedShape); + utils.checkValue(outputTensor, expectedData); + }); +}); diff --git a/test/clamp_test.js b/test/clamp_test.js new file mode 100644 index 0000000..7e1779b --- /dev/null +++ b/test/clamp_test.js @@ -0,0 +1,117 @@ +'use strict'; + +import {clamp} from '../src/clamp.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test clamp', function() { + function testClamp(inputShape, inputValue, expected, options = {}) { + const inputTensor = new Tensor(inputShape, inputValue); + const outputTensor = clamp(inputTensor, options); + utils.checkValue(outputTensor, expected); + } + + it('clamp', function() { + testClamp([3], [-2, 0, 2], [-1, 0, 1], {minValue: -1, maxValue: 1}); + testClamp([3], [-1, 0, 1], [-1, 0, 1], {minValue: -5, maxValue: 5}); + testClamp([3], [-6, 0, 6], [-5, 0, 5], {minValue: -5, maxValue: 5}); + testClamp([3], [-1, 0, 6], [-1, 0, 5], {minValue: -5, maxValue: 5}); + testClamp( + [3, 4, 5], + [ + 0.58585083, 1.1363881, 0.67161655, -0.9741674, -1.6196846, + 0.572627, 1.9026182, -0.7756641, -0.18808974, -1.0357478, + 1.1778295, -2.305167, -2.2636602, 0.3750199, -0.08234365, + -0.47962302, -0.3010948, 0.5369879, -0.413804, -1.096925, + -0.9273629, 0.88833886, -0.52474195, -1.3852776, 0.10217833, + 0.50499475, 1.3289608, 0.21790339, -0.65971124, 0.47400787, + 0.7271749, -0.03890531, -0.04459939, 0.2601329, -0.06985649, + 0.2501139, -1.0219133, -1.1504377, -0.83611137, 0.64221096, + 0.25879756, 1.040239, -0.18669093, -1.1436414, 1.1445535, + -0.01876706, 1.283455, 0.59794647, 2.1886187, -0.21977298, + 0.90072393, 0.8913641, -0.55512637, -0.17248231, -1.4617383, + -1.5487962, 0.1265688, 0.7930071, 0.63802403, 0.3400246, + ], + [ + 0.58585083, 1., 0.67161655, -0.9741674, -1., + 0.572627, 1., -0.7756641, -0.18808974, -1., + 1., -1., -1., 0.3750199, -0.08234365, + -0.47962302, -0.3010948, 0.5369879, -0.413804, -1., + -0.9273629, 0.88833886, -0.52474195, -1., 0.10217833, + 0.50499475, 1., 0.21790339, -0.65971124, 0.47400787, + 0.7271749, -0.03890531, -0.04459939, 0.2601329, -0.06985649, + 0.2501139, -1., -1., -0.83611137, 0.64221096, + 0.25879756, 1., -0.18669093, -1., 1., + -0.01876706, 1., 0.59794647, 1., -0.21977298, + 0.90072393, 0.8913641, -0.55512637, -0.17248231, -1., + -1., 0.1265688, 0.7930071, 0.63802403, 0.3400246, + ], + {minValue: -1, maxValue: 1}); + }); + + it('clamp with defaults', function() { + testClamp([3], [-1, 0, 1], [-1, 0, 1]); + testClamp( + [3, 4, 5], + [ + 0.86301714, -0.5896978, -0.27253276, 0.7375215, 0.43311873, + -0.21018882, 1.3207943, -1.2920012, -0.51867867, -0.28339776, + 0.8165349, 0.0023852, -1.2614918, 0.5140042, 1.0875463, + 0.73930454, 0.61915493, -1.8743135, -0.8998865, 0.4820806, + -0.05488819, 0.5225576, -1.2663426, -0.06149476, -1.389781, + -1.9536786, 0.29577908, 0.8425888, 0.24561642, -0.03299648, + -1.5620143, 1.0061071, -0.0440449, 1.9595621, 0.9423143, + -2.0051255, 0.7550497, -1.3965353, -0.7594955, -0.25075668, + -0.09406245, 0.39756522, -1.022855, -1.150692, 0.6006052, + -0.01325027, 0.17437305, -2.1936834, -0.17713739, -0.8907292, + -0.9206264, 0.9219348, -1.0956712, -1.0928966, -0.3310106, + 0.45028883, -0.8840147, 1.2341441, 1.4498476, -0.8814471, + ], + [ + 0.86301714, 0., 0., 0.7375215, 0.43311873, + 0., 1.3207943, 0., 0., 0., + 0.8165349, 0.0023852, 0., 0.5140042, 1.0875463, + 0.73930454, 0.61915493, 0., 0., 0.4820806, + 0., 0.5225576, 0., 0., 0., + 0., 0.29577908, 0.8425888, 0.24561642, 0., + 0., 1.0061071, 0., 1.9595621, 0.9423143, + 0., 0.7550497, 0., 0., 0., + 0., 0.39756522, 0., 0., 0.6006052, + 0., 0.17437305, 0., 0., 0., + 0., 0.9219348, 0., 0., 0., + 0.45028883, 0., 1.2341441, 1.4498476, 0., + ], + {minValue: 0}); + testClamp( + [3, 4, 5], + [ + -0.24508175, -0.7786755, -1.6853821, 0.30301106, 0.7335949, + 2.0118642, -0.8974095, 1.336235, 1.3423537, 0.19785331, + 0.6021635, 0.8732731, 1.9741, 0.47780856, -0.06013789, + -0.8661688, 0.30532077, 1.0241649, 0.24461035, -0.77992326, + 0.08907621, -0.12915348, 0.26473877, -1.6618484, 0.55078864, + 0.59542316, 0.44485343, -0.00376282, -1.8059362, -0.01932279, + 1.060715, -0.8601289, -1.9892695, -1.540558, 0.3140257, + 0.37287602, 0.8862932, -0.055259, -1.5003284, -0.81850415, + 0.8188394, 0.14049591, 0.6498296, 0.4347888, -0.20496055, + -0.17400683, 1.8571023, 0.41467425, -0.12858754, 0.45542, + 0.22290581, -2.1573563, 0.6500845, 1.8209393, -0.7802799, + 1.4540358, -0.2568697, 0.2934714, 1.0703601, -0.72000146, + ], + [ + -0.24508175, -0.7786755, -1.6853821, 0., 0., + 0., -0.8974095, 0., 0., 0., + 0., 0., 0., 0., -0.06013789, + -0.8661688, 0., 0., 0., -0.77992326, + 0., -0.12915348, 0., -1.6618484, 0., + 0., 0., -0.00376282, -1.8059362, -0.01932279, + 0., -0.8601289, -1.9892695, -1.540558, 0., + 0., 0., -0.055259, -1.5003284, -0.81850415, + 0., 0., 0., 0., -0.20496055, + -0.17400683, 0., 0., -0.12858754, 0., + 0., -2.1573563, 0., 0., -0.7802799, + 0., -0.2568697, 0., 0., -0.72000146, + ], + {maxValue: 0}); + }); +}); diff --git a/test/concat_test.js b/test/concat_test.js new file mode 100644 index 0000000..4ac44e7 --- /dev/null +++ b/test/concat_test.js @@ -0,0 +1,91 @@ +'use strict'; + +import {concat} from '../src/concat.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test concat', function() { + function testConcat(tensors, expected) { + const inputs = []; + for (let i = 0; i < tensors.length; i++) { + inputs.push(new Tensor(tensors[i].shape, tensors[i].value)); + } + const output = concat(inputs, expected.axis); + utils.checkShape(output, expected.shape); + utils.checkValue(output, expected.value); + } + + it('concat 1d', function() { + const tensors = [ + {shape: [2], value: [1, 2]}, + {shape: [2], value: [3, 4]}, + ]; + const expected = {axis: 0, shape: [4], value: [1, 2, 3, 4]}; + testConcat(tensors, expected); + }); + + it('concat 2d axis=0', function() { + const tensors = [ + {shape: [1, 2], value: [1, 2]}, + {shape: [2, 2], value: [3, 4, 5, 6]}, + ]; + const expected = {axis: 0, shape: [3, 2], value: [1, 2, 3, 4, 5, 6]}; + testConcat(tensors, expected); + }); + + it('concat 2d axis=1', function() { + const tensors = [ + {shape: [2, 1], value: [1, 2]}, + {shape: [2, 2], value: [3, 4, 5, 6]}, + ]; + const expected = {axis: 1, shape: [2, 3], value: [1, 3, 4, 2, 5, 6]}; + testConcat(tensors, expected); + }); + + it('concat 2d', function() { + const tensors = [ + {shape: [2, 2], value: [1, 2, 3, 4]}, + {shape: [2, 2], value: [5, 6, 7, 8]}, + ]; + const expected = [ + {axis: 0, shape: [4, 2], value: [1, 2, 3, 4, 5, 6, 7, 8]}, + {axis: 1, shape: [2, 4], value: [1, 2, 5, 6, 3, 4, 7, 8]}, + ]; + for (const test of expected) { + testConcat(tensors, test); + } + }); + + it('concat 3d', function() { + const tensors = [ + { + shape: [2, 2, 2], + value: [1, 2, 3, 4, 5, 6, 7, 8], + }, + { + shape: [2, 2, 2], + value: [9, 10, 11, 12, 13, 14, 15, 16], + }, + ]; + const expected = [ + { + axis: 0, + shape: [4, 2, 2], + value: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + }, + { + axis: 1, + shape: [2, 4, 2], + value: [1, 2, 3, 4, 9, 10, 11, 12, 5, 6, 7, 8, 13, 14, 15, 16], + }, + { + axis: 2, + shape: [2, 2, 4], + value: [1, 2, 9, 10, 3, 4, 11, 12, 5, 6, 13, 14, 7, 8, 15, 16], + }, + ]; + for (const test of expected) { + testConcat(tensors, test); + } + }); +}); diff --git a/test/conv2d_test.js b/test/conv2d_test.js new file mode 100644 index 0000000..5c06248 --- /dev/null +++ b/test/conv2d_test.js @@ -0,0 +1,3249 @@ +'use strict'; + +import {conv2d} from '../src/conv2d.js'; +import {clamp} from '../src/clamp.js'; +import {leakyRelu} from '../src/leaky_relu.js'; +import {relu} from '../src/relu.js'; +import {sigmoid} from '../src/sigmoid.js'; +import {Tensor} from '../src/lib/tensor.js'; + +import * as utils from './utils.js'; + +describe('test conv2d', function() { + function testConv2d( + input, filter, expected, options = {}, bias = undefined, + activation = undefined, fusion = false, activationOptions = {}) { + const inputTensor = new Tensor(input.shape, input.data); + const filterTensor = new Tensor(filter.shape, filter.data); + if (bias) { + options.bias = new Tensor(bias.shape, bias.data); + } + if (activation === 'relu') { + options.activation = relu; + } else if (activation === 'relu6') { + options.activation = utils.bindTrailingArgs(clamp, {minValue: 0, maxValue: 6}); + } else if (activation === 'sigmoid') { + options.activation = sigmoid; + } else if (activation === 'leakyRelu') { + options.activation = utils.bindTrailingArgs(leakyRelu, activationOptions); + } + + const outputTensor = conv2d(inputTensor, filterTensor, options); + utils.checkShape(outputTensor, expected.shape); + utils.checkValue(outputTensor, expected.data); + } + + it('conv2d with padding default', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const options = {padding: [1, 1, 1, 1]}; + const expected = { + shape: [1, 1, 5, 5], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d input channels with padding default', function() { + const input = { + shape: [1, 3, 2, 3], + data: [ + 1.0, 4.0, 7.0, 10.0, 13.0, 16.0, + 2.0, 5.0, 8.0, 11.0, 14.0, 17.0, + 3.0, 6.0, 9.0, 12.0, 15.0, 18.0, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: [1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0], + }; + const options = {}; + const expected = { + shape: [1, 3, 2, 3], + data: [ + 30.0, 66.0, 102.0, 138.0, 174.0, 210.0, + 36.0, 81.0, 126.0, 171.0, 216.0, 261.0, + 42.0, 96.0, 150.0, 204.0, 258.0, 312.0, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding explicit autoPad', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + autoPad: 'explicit', + }; + const expected = { + shape: [1, 1, 5, 5], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nchw oihw', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'oihw', + }; + const expected = { + shape: [1, 1, 5, 5], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nchw hwio', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + const expected = { + shape: [1, 1, 5, 5], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nchw ohwi', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + const expected = { + shape: [1, 1, 5, 5], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nchw ihwo', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + const expected = { + shape: [1, 1, 5, 5], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nhwc oihw', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + const expected = { + shape: [1, 5, 5, 1], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nhwc hwio', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + const expected = { + shape: [1, 5, 5, 1], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nhwc ohwi', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + const expected = { + shape: [1, 5, 5, 1], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with padding nhwc ihwo', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + const expected = { + shape: [1, 5, 5, 1], + data: [ + 12, 21, 27, 33, 24, 33, 54, 63, 72, 51, 63, 99, 108, + 117, 81, 93, 144, 153, 162, 111, 72, 111, 117, 123, 84, + ], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding default', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + testConv2d(input, filter, expected); + }); + + it('conv2d without padding nchw hwio', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + const options = { + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding nchw ohwi', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + const options = { + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding nchw ihwo', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + const options = { + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding nhwc oihw', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const options = { + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + const expected = { + shape: [1, 3, 3, 1], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding nhwc hwio', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const options = { + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + const expected = { + shape: [1, 3, 3, 1], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding nhwc ohwi', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const options = { + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + const expected = { + shape: [1, 3, 3, 1], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d without padding nhwc ihwo', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const options = { + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + const expected = { + shape: [1, 3, 3, 1], + data: [54., 63., 72., 99., 108., 117., 144., 153., 162.], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding default', function() { + const input = { + shape: [1, 1, 7, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 4, 3], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nchw hwio', function() { + const input = { + shape: [1, 1, 7, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 4, 3], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nchw ohwi', function() { + const input = { + shape: [1, 1, 7, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 4, 3], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nchw ihwo', function() { + const input = { + shape: [1, 1, 7, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 4, 3], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nhwc oihw', function() { + const input = { + shape: [1, 7, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 4, 3, 1], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nhwc hwio', function() { + const input = { + shape: [1, 7, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 4, 3, 1], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nhwc ohwi', function() { + const input = { + shape: [1, 7, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 4, 3, 1], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and padding nhwc ihwo', function() { + const input = { + shape: [1, 7, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 4, 3, 1], + data: [12., 27., 24., 63., 108., 81., 123., 198., 141., 112., 177., 124.], + }; + const options = { + padding: [1, 1, 1, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding default', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 4, 2], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nchw hwio', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [4, 2, 1, 1], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nchw ohwi', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 4, 2, 1], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nchw ihwo', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 4, 2, 1], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nhwc oihw', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 4, 2], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nhwc hwio', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [4, 2, 1, 1], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nhwc ohwi', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 4, 2, 1], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with strides=2 and asymetric padding nhwc ihwo', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 4, 2, 1], + data: new Array(8).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [33, 45, 27, 104, 120, 66, 72, 80, 43], + }; + const options = { + padding: [1, 2, 0, 1], + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower default', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nchw hwio', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nchw ohwi', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nchw ihwo', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nhwc oihw', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nhwc hwio', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nhwc ohwi', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-lower nhwc ihwo', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 3, 3, 1], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-lower', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper default', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 3, 3], + data: [12., 27., 24., 63., 108., 81., 72., 117., 84.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nchw hwio', function() { + const input = { + shape: [1, 1, 4, 4], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 2, 2], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nchw ohwi', function() { + const input = { + shape: [1, 1, 4, 4], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 2, 2], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nchw ihwo', function() { + const input = { + shape: [1, 1, 4, 4], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 2, 2], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nhwc oihw', function() { + const input = { + shape: [1, 4, 4, 1], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 2, 2, 1], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nhwc hwio', function() { + const input = { + shape: [1, 4, 4, 1], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 2, 2, 1], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nhwc ohwi', function() { + const input = { + shape: [1, 4, 4, 1], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 2, 2, 1], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d with autopad same-upper nhwc ihwo', function() { + const input = { + shape: [1, 4, 4, 1], + data: [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 2, 2, 1], + data: [45., 39., 66., 50.], + }; + const options = { + autoPad: 'same-upper', + strides: [2, 2], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); + + it('fused depthwise conv2d default', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 4, 2, 2], + data: [ + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 10, + 20, + 30, + 40, + 0, + 0, + 0, + 0, + ], + }; + const filter = { + shape: [4, 1, 2, 2], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + 10.0, + 20.0, + 30.0, + 40.0, + 50.0, + 50.0, + 50.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 4, 1, 1], + data: [6010, 7046, 11000, 9000], + }; + const options = {groups: 4}; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 4, 1, 1], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nchw hwio', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 4, 2, 2], + data: [ + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 10, + 20, + 30, + 40, + 0, + 0, + 0, + 0, + ], + }; + const filter = { + shape: [2, 2, 1, 4], + data: [ + 0.25, + 0.0, + 10.0, + 50.0, + 0.25, + 1.0, + 20.0, + 50.0, + 0.25, + 0.0, + 30.0, + 50.0, + 0.25, + 1.0, + 40.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 4, 1, 1], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 4, 1, 1], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nchw ohwi', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 4, 2, 2], + data: [ + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 10, + 20, + 30, + 40, + 0, + 0, + 0, + 0, + ], + }; + const filter = { + shape: [4, 2, 2, 1], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + 10.0, + 20.0, + 30.0, + 40.0, + 50.0, + 50.0, + 50.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 4, 1, 1], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 4, 1, 1], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nchw ihwo', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 4, 2, 2], + data: [ + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 10, + 20, + 30, + 40, + 0, + 0, + 0, + 0, + ], + }; + const filter = { + shape: [1, 2, 2, 4], + data: [ + 0.25, + 0.0, + 10.0, + 50.0, + 0.25, + 1.0, + 20.0, + 50.0, + 0.25, + 0.0, + 30.0, + 50.0, + 0.25, + 1.0, + 40.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 4, 1, 1], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 4, 1, 1], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nhwc oihw', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 2, 2, 4], + data: [ + 10, + 21, + 10, + 0, + 10, + 22, + 20, + 0, + 10, + 23, + 30, + 0, + 10, + 24, + 40, + 0, + ], + }; + const filter = { + shape: [4, 1, 2, 2], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + 10.0, + 20.0, + 30.0, + 40.0, + 50.0, + 50.0, + 50.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 1, 1, 4], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 1, 4], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nhwc hwio', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 2, 2, 4], + data: [ + 10, + 21, + 10, + 0, + 10, + 22, + 20, + 0, + 10, + 23, + 30, + 0, + 10, + 24, + 40, + 0, + ], + }; + const filter = { + shape: [2, 2, 1, 4], + data: [ + 0.25, + 0.0, + 10.0, + 50.0, + 0.25, + 1.0, + 20.0, + 50.0, + 0.25, + 0.0, + 30.0, + 50.0, + 0.25, + 1.0, + 40.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 1, 1, 4], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 1, 4], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nhwc ohwi', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 2, 2, 4], + data: [ + 10, + 21, + 10, + 0, + 10, + 22, + 20, + 0, + 10, + 23, + 30, + 0, + 10, + 24, + 40, + 0, + ], + }; + const filter = { + shape: [4, 2, 2, 1], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + 10.0, + 20.0, + 30.0, + 40.0, + 50.0, + 50.0, + 50.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 1, 1, 4], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 1, 4], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused depthwise conv2d nhwc ihwo', function() { + // It is based on Android NNAPI CTS: V1_2/depthwise_conv2d_v1_2.mod.py + const input = { + shape: [1, 2, 2, 4], + data: [ + 10, + 21, + 10, + 0, + 10, + 22, + 20, + 0, + 10, + 23, + 30, + 0, + 10, + 24, + 40, + 0, + ], + }; + const filter = { + shape: [1, 2, 2, 4], + data: [ + 0.25, + 0.0, + 10.0, + 50.0, + 0.25, + 1.0, + 20.0, + 50.0, + 0.25, + 0.0, + 30.0, + 50.0, + 0.25, + 1.0, + 40.0, + 50.0, + ], + }; + const bias = { + shape: [4], + data: [6000, 7000, 8000, 9000], + }; + let expected = { + shape: [1, 1, 1, 4], + data: [6010, 7046, 11000, 9000], + }; + const options = { + groups: 4, + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options, bias); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 1, 4], + data: [6, 6, 6, 6], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('depthwise conv2d nchw oihw', function() { + const input = { + shape: [1, 4, 2, 2], + data: [ + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 10, + 20, + 30, + 40, + 0, + 0, + 0, + 0, + ], + }; + const filter = { + shape: [4, 1, 2, 2], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + 10.0, + 20.0, + 30.0, + 40.0, + 50.0, + 50.0, + 50.0, + 50.0, + ], + }; + let expected = { + shape: [1, 4, 1, 1], + data: [10, 46, 3000, 0], + }; + const options = { + groups: 4, + inputLayout: 'nchw', + filterLayout: 'oihw', + }; + testConv2d(input, filter, expected, options); + testConv2d(input, filter, expected, options, undefined, 'relu', true); + expected = { + shape: [1, 4, 1, 1], + data: [6, 6, 6, 0], + }; + testConv2d(input, filter, expected, options, undefined, 'relu6', true); + }); + + it('fused depthwise conv2d explicit autoPad', function() { + const input = { + shape: [1, 2, 3, 3], + data: [ + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + ], + }; + const filter = { + shape: [2, 1, 2, 2], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + ], + }; + let expected = { + shape: [1, 2, 3, 3], + data: [ + 10, + 10, + 5, + 10, + 10, + 5, + 5, + 5, + 2.5, + 47, + 49, + 0, + 53, + 55, + 0, + 28, + 29, + 0, + ], + }; + const options = { + groups: 2, + padding: [0, 1, 0, 1], + autoPad: 'explicit', + }; + testConv2d(input, filter, expected, options); + testConv2d(input, filter, expected, options, undefined, 'relu', true); + expected = { + shape: [1, 2, 3, 3], + data: [ + 6, + 6, + 5, + 6, + 6, + 5, + 5, + 5, + 2.5, + 6, + 6, + 0, + 6, + 6, + 0, + 6, + 6, + 0, + ], + }; + testConv2d(input, filter, expected, options, undefined, 'relu6', true); + }); + + it('fused depthwise conv2d same-upper autoPad', function() { + const input = { + shape: [1, 2, 3, 3], + data: [ + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + ], + }; + const filter = { + shape: [2, 1, 2, 2], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + ], + }; + let expected = { + shape: [1, 2, 3, 3], + data: [ + 10, + 10, + 5, + 10, + 10, + 5, + 5, + 5, + 2.5, + 47, + 49, + 0, + 53, + 55, + 0, + 28, + 29, + 0, + ], + }; + const options = { + groups: 2, + autoPad: 'same-upper', + }; + testConv2d(input, filter, expected, options); + testConv2d(input, filter, expected, options, undefined, 'relu', true); + expected = { + shape: [1, 2, 3, 3], + data: [ + 6, + 6, + 5, + 6, + 6, + 5, + 5, + 5, + 2.5, + 6, + 6, + 0, + 6, + 6, + 0, + 6, + 6, + 0, + ], + }; + testConv2d(input, filter, expected, options, undefined, 'relu6', true); + }); + + it('fused depthwise conv2d same-lower autoPad', function() { + const input = { + shape: [1, 2, 3, 3], + data: [ + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + ], + }; + const filter = { + shape: [2, 1, 2, 2], + data: [ + 0.25, + 0.25, + 0.25, + 0.25, + 0.0, + 1.0, + 0.0, + 1.0, + ], + }; + let expected = { + shape: [1, 2, 3, 3], + data: [ + 2.5, + 5, + 5, + 5, + 10, + 10, + 5, + 10, + 10, + 21, + 22, + 23, + 45, + 47, + 49, + 51, + 53, + 55, + ], + }; + const options = { + groups: 2, + autoPad: 'same-lower', + }; + testConv2d(input, filter, expected, options); + testConv2d(input, filter, expected, options, undefined, 'relu', true); + expected = { + shape: [1, 2, 3, 3], + data: [ + 2.5, + 5, + 5, + 5, + 6, + 6, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + ], + }; + testConv2d(input, filter, expected, options, undefined, 'relu6', true); + }); + + it('fused conv2d with padding default', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + }; + let expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + expected = { + shape: [1, 1, 5, 5], + data: [ + -8.800000131130219, + -7.900000117719173, + -7.300000108778477, + -6.70000009983778, + -7.600000113248825, + -6.70000009983778, + -4.6000000685453415, + -3.7000000551342964, + -2.8000000417232513, + -4.90000007301569, + -3.7000000551342964, + -0.10000000149011612, + 8, + 17, + -1.9000000283122063, + -0.7000000104308128, + 44, + 53, + 62, + 11, + -2.8000000417232513, + 11, + 17, + 23, + -1.600000023841858, + ], + }; + testConv2d( + input, filter, expected, options, bias, 'leakyRelu', true, + {alpha: 0.10000000149011612}); + expected = { + shape: [1, 1, 5, 5], + data: [ + 6.054601895401186e-39, + 4.906094730649281e-35, + 1.9792598779469048e-32, + 7.984904245686979e-30, + 9.854154686111257e-34, + 7.984904245686979e-30, + 1.0530617357553813e-20, + 8.533047625744066e-17, + 6.914400106935423e-13, + 5.242885663363465e-22, + 8.533047625744066e-17, + 0.2689414213699951, + 0.9996646498695336, + 0.9999999586006244, + 5.602796406145939e-9, + 0.0009110511944006454, + 1, + 1, + 1, + 0.999983298578152, + 6.914400106935423e-13, + 0.999983298578152, + 0.9999999586006244, + 0.9999999998973812, + 1.12535162055095e-7, + ], + }; + testConv2d(input, filter, expected, options, bias, 'sigmoid', true); + }); + + it('fused conv2d with padding nchw hwio', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'hwio', + }; + let expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused conv2d with padding nchw ohwi', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'ohwi', + }; + let expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused conv2d with padding nchw ihwo', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nchw', + filterLayout: 'ihwo', + }; + let expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 1, 5, 5], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused conv2d with padding nhwc oihw', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'oihw', + }; + let expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + expected = { + shape: [1, 5, 5, 1], + data: [ + -8.800000131130219, + -7.900000117719173, + -7.300000108778477, + -6.70000009983778, + -7.600000113248825, + -6.70000009983778, + -4.6000000685453415, + -3.7000000551342964, + -2.8000000417232513, + -4.90000007301569, + -3.7000000551342964, + -0.10000000149011612, + 8, + 17, + -1.9000000283122063, + -0.7000000104308128, + 44, + 53, + 62, + 11, + -2.8000000417232513, + 11, + 17, + 23, + -1.600000023841858, + ], + }; + testConv2d( + input, filter, expected, options, bias, 'leakyRelu', true, + {alpha: 0.10000000149011612}); + expected = { + shape: [1, 5, 5, 1], + data: [ + 6.054601895401186e-39, + 4.906094730649281e-35, + 1.9792598779469048e-32, + 7.984904245686979e-30, + 9.854154686111257e-34, + 7.984904245686979e-30, + 1.0530617357553813e-20, + 8.533047625744066e-17, + 6.914400106935423e-13, + 5.242885663363465e-22, + 8.533047625744066e-17, + 0.2689414213699951, + 0.9996646498695336, + 0.9999999586006244, + 5.602796406145939e-9, + 0.0009110511944006454, + 1, + 1, + 1, + 0.999983298578152, + 6.914400106935423e-13, + 0.999983298578152, + 0.9999999586006244, + 0.9999999998973812, + 1.12535162055095e-7, + ], + }; + testConv2d(input, filter, expected, options, bias, 'sigmoid', true); + }); + + it('fused conv2d with padding nhwc hwio', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [3, 3, 1, 1], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'hwio', + }; + let expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused conv2d with padding nhwc ohwi', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'ohwi', + }; + let expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('fused conv2d with padding nhwc ihwo', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: new Array(9).fill(1), + }; + const bias = { + shape: [1], + data: [-100], + }; + const options = { + padding: [1, 1, 1, 1], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + let expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., + 17., 0., 0., 44., 53., 62., 11., 0., 11., 17., 23., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu'); + testConv2d(input, filter, expected, options, bias, 'relu', true); + expected = { + shape: [1, 5, 5, 1], + data: [ + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., + 6., 0., 0., 6., 6., 6., 6., 0., 6., 6., 6., 0., + ], + }; + testConv2d(input, filter, expected, options, bias, 'relu6', true); + }); + + it('conv2d input=1x1x5x5 dilations=2', function() { + const input = { + shape: [1, 1, 5, 5], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 3, 3], + data: new Array(9).fill(1), + }; + const expected = { + shape: [1, 1, 1, 1], + data: [108], + }; + const options = { + dilations: [2, 2], + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d input=1x5x5x1 dilations=4 nhwc', function() { + const input = { + shape: [1, 5, 5, 1], + data: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + ], + }; + const filter = { + shape: [1, 1, 2, 2], + data: new Array(4).fill(1), + }; + const expected = { + shape: [1, 1, 1, 1], + data: [48], + }; + const options = { + dilations: [4, 4], + inputLayout: 'nhwc', + }; + testConv2d(input, filter, expected, options); + }); + + it('conv2d input=1x65x65x1 dilations=4 nhwc', function() { + const input = { + shape: [1, 65, 65, 1], + data: [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, + 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, + 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, + 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, + 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, + 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, + 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, + 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, + 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, + 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, + 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, + 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, + 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, + 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, + 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, + 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, + 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, + 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, + 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 57, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 57, 48, 49, 50, 51, 52, + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, + ], + }; + const filter = { + shape: [1, 3, 3, 1], + data: [ + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + ], + }; + const expected = { + shape: [1, 57, 57, 1], + data: [ + 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, + 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, + 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, + 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, + 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, + 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, + 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, + 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, + 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, + 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, + 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, + 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, + 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, + 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, + 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, + 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, + 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, + 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, + 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, + 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, + 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, + 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, + 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, + 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, + 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, + 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, + 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, + 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, + 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, + 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, + 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, + 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, + 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, + 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, + 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, + 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, + 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, + 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, + 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, + 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, + 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, + 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, + 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, + 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, + 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, + 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, + 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, + 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, + 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, + 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, + 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, + 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, + 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, + 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, + 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, + 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, + 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, + 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, + 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, + 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, + 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, + 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, + 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, + 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, + 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, + 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, + 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, + 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, + 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, + 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, + 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, + 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, + 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, + 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, + 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, + 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, + 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, + 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, + 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, + 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, + 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, + 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, + 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, + 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, + 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, + 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, + 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, + 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, + 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, + 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, + 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, + 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, + 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, + 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, + 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, + 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, + 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, + 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, + 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, + 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, + 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, + 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, + 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, + 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, + 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, + 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, + 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, + 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, + 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, + 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, + 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, + 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, + 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, + 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, + 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, + 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, + 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, + 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, + 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, + 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, + 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, + 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, + 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, + 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, + 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, + 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, + 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, + 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, + 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, + 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, + 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, + 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, + 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, + 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, + 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, + 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, + 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, + 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, + 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, + 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, + 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, + 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, + 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, + 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, + 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, + 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, + 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, + 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, + 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, + 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, + 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, + 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, + 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, + 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, + 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, + 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, + 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, + 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, + 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, + 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, + 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, + 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, + 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, + 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, + 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, + 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, + 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, + 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, + 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, + 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, + 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, + 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, + 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, + 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, + 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, + 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, + 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, + 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, + 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, + 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, + 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, + 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, + 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, + 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, + 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, + 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, + 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, + 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, + 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, + 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, + 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, + 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, 39, + 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, + 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, + 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, 165, + 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, 36, + 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, + 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, + 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, 162, + 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, 33, + 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, + 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, + 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, 159, + 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, 30, + 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, + 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, + 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, 156, + 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, 27, + 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, + 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, + 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, 163, + 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, 24, + 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, + 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, + 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, 150, + 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, 21, + 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, + 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, + 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, 147, + 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, 18, + 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, + 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, + 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, 144, + 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, 15, + 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, + 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, + 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, 151, + 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, 183, + 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, + 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, + 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 139, 132, 135, 138, + 151, 144, 147, 150, 163, 156, 159, 162, 165, 168, 171, 174, 177, 180, + 183, + ], + }; + const options = { + dilations: [4, 4], + inputLayout: 'nhwc', + filterLayout: 'ihwo', + }; + testConv2d(input, filter, expected, options); + }); +}); diff --git a/test/gemm_test.js b/test/gemm_test.js new file mode 100644 index 0000000..3aa9dea --- /dev/null +++ b/test/gemm_test.js @@ -0,0 +1,487 @@ +'use strict'; + +import {gemm} from '../src/gemm.js'; +import {Tensor, Scalar} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test gemm', function() { + function testGemm(A, B, expected, C = undefined, options = {}) { + const a = new Tensor(A.shape, A.value); + const b = new Tensor(B.shape, B.value); + if (C !== undefined) { + if (typeof C === 'number') { + options.c = new Scalar(C); + } else { + options.c = new Tensor(C.shape, C.value); + } + } + const output = gemm(a, b, options); + utils.checkShape(output, expected.shape); + utils.checkValue(output, expected.value); + } + + it('gemm all attributes', function() { + testGemm( + { + shape: [4, 3], + value: [ + 0.16255096, + 0.36969426, + 0.86049074, + 0.52583617, + 0.40931734, + 0.5978712, + 0.02344302, + 0.22055508, + 0.45481342, + 0.70314133, + 0.7205724, + 0.3864092, + ], + }, + { + shape: [5, 4], + value: [ + 0.51755, 0.9102891, 0.28687012, 0.88691926, 0.8710911, + 0.70967263, 0.7630267, 0.5303827, 0.19783545, 0.03771181, + 0.9599521, 0.9786202, 0.64332396, 0.05995055, 0.28637242, + 0.27146435, 0.89837885, 0.25250614, 0.71064854, 0.6231573, + ], + }, + { + shape: [3, 5], + value: [ + 0.5243318683310813, + 0.557397912385277, + 0.528117080331477, + 0.27080579384162223, + 0.4888170726819199, + 0.5426185448391568, + 0.6217472531769015, + 0.5888327516148626, + 0.31766935679671005, + 0.565719856634743, + 0.5917375902862806, + 0.762459989982706, + 0.5893491460226568, + 0.3935235504287967, + 0.677412680257146, + ], + }, + { + shape: [1, 5], + value: [ + 0.645844, + 0.94571555, + 0.9641909, + 0.53538203, + 0.87259406, + ], + }, + {alpha: 0.25, beta: 0.35, aTranspose: true, bTranspose: true}); + }); + + it('gemm alpha', function() { + testGemm( + { + shape: [3, 5], + value: [ + 0.6454119, + 0.3664521, + 0.9339664, + 0.80538565, + 0.95667505, + 0.88069373, + 0.9214904, + 0.36922687, + 0.14497812, + 0.9269842, + 0.35244322, + 0.8430814, + 0.8758174, + 0.72640723, + 0.54108346, + ], + }, + { + shape: [5, 4], + value: [ + 0.63330984, 0.35902178, 0.66004074, 0.55209464, 0.61517054, + 0.9752662, 0.7728582, 0.6539411, 0.7348231, 0.16676156, + 0.02067508, 0.8232157, 0.8965447, 0.13730305, 0.84798694, + 0.5877349, 0.59452033, 0.77880573, 0.74310994, 0.30480564, + ], + }, + { + shape: [3, 4], + value: [ + 1.3056516655501957, + 0.8002504434655275, + 1.061197370189373, + 1.0648877745287364, + 1.038516251507783, + 1.0091530323988844, + 1.0564498368939703, + 0.880269401543591, + 1.1791785627830518, + 0.8079765443192164, + 0.9601925967309116, + 1.029377197893551, + ], + }, + undefined, {alpha: 0.5}); + }); + + it('gemm beta', function() { + testGemm( + { + shape: [2, 7], + value: [ + 0.49310637, + 0.90337706, + 0.48412615, + 0.12574232, + 0.46357006, + 0.19181924, + 0.35951078, + 0.6291139, + 0.6532259, + 0.5328194, + 0.34455, + 0.30500737, + 0.42374912, + 0.5333988, + ], + }, + { + shape: [7, 4], + value: [ + 0.09745993, 0.15800808, 0.8296135, 0.32482183, 0.17396742, + 0.8128901, 0.57767284, 0.5426502, 0.3885252, 0.14400595, + 0.4086032, 0.07829365, 0.02829663, 0.68851817, 0.08823209, + 0.99632514, 0.5235559, 0.01747696, 0.13607074, 0.93223673, + 0.6470155, 0.9115019, 0.8971734, 0.8677906, 0.31736255, + 0.3187993, 0.5536664, 0.7935204, + ], + }, + { + shape: [2, 4], + value: [ + 1.0496741612243639, + 1.36938856972795, + 1.7876274752838266, + 2.092117621875701, + 1.1667527091648682, + 1.6092673496321952, + 2.0779392344028675, + 2.4137995216249193, + ], + }, + { + shape: [1, 4], + value: [0.34378892, 0.20655482, 0.4271018, 0.78929764], + }, + {beta: 0.5}); + }); + + it('gemm bias', function() { + testGemm( + { + shape: [3, 6], + value: [ + 0.26634723, + 0.6047771, + 0.55068576, + 0.9991724, + 0.67497027, + 0.92930806, + 0.2860512, + 0.6988867, + 0.89526093, + 0.3717174, + 0.19019075, + 0.4499846, + 0.51885146, + 0.4031741, + 0.6269008, + 0.27176815, + 0.8668971, + 0.5799844, + ], + }, + { + shape: [6, 4], + value: [ + 0.9380275, 0.12814452, 0.6522671, 0.6646124, 0.17909846, + 0.901151, 0.3000952, 0.326868, 0.3915249, 0.45130828, + 0.581687, 0.9738232, 0.38067418, 0.9359868, 0.4762245, + 0.07865977, 0.9511043, 0.7688367, 0.88563424, 0.7623637, + 0.04992711, 0.2871258, 0.09989859, 0.78412026, + ], + }, + { + shape: [3, 4], + value: [ + 2.1861426866108147, + 2.8305439820330496, + 1.9655125113296532, + 2.7772129031129267, + 1.739566756989102, + 1.8645036093741494, + 1.835017045627564, + 2.620343856710825, + 2.358861093402937, + 2.009745597962233, + 2.069979204417261, + 3.0899011883030036, + ], + }, + { + shape: [3, 4], + value: [ + 0.5436559, + 0.2819061, + 0.1235218, + 0.5443856, + 0.6506955, + 0.1706562, + 0.52752787, + 0.80288535, + 0.5975874, + 0.20960917, + 0.29078278, + 0.8657452, + ], + }); + }); + + it('gemm no bias', function() { + testGemm( + { + shape: [2, 10], + value: [ + 0.97596496, 0.47531518, 0.7147315, 0.14236908, 0.06151228, + 0.05889508, 0.3534669, 0.31915423, 0.61336106, 0.5946216, + 0.21969128, 0.7347848, 0.4087221, 0.00412959, 0.77303815, + 0.6495765, 0.3174799, 0.62841094, 0.7002717, 0.63384914, + ], + }, + { + shape: [10, 3], + value: [ + 0.51739925, 0.25108355, 0.31373033, 0.6488124, 0.9777175, + 0.13308926, 0.47903556, 0.23692878, 0.0822504, 0.3080891, + 0.51966125, 0.969734, 0.6691261, 0.59346807, 0.7651862, + 0.48655444, 0.48373327, 0.2799068, 0.35760838, 0.19906454, + 0.3612888, 0.11448191, 0.19188708, 0.00769753, 0.3161914, + 0.323555, 0.17573832, 0.79587144, 0.91238266, 0.5517277, + ], + }, + { + shape: [2, 3], + value: [ + 2.099535174427932, + 1.8906747304468192, + 1.1958703062765146, + 2.5321421018503787, + 2.6342243688257088, + 1.5699927914492209, + ], + }); + }); + + it('gemm scalar bias', function() { + testGemm( + { + shape: [2, 3], + value: [ + 0.41595492, + 0.7063231, + 0.3784654, + 0.3524597, + 0.41936764, + 0.08190536, + ], + }, + { + shape: [3, 4], + value: [ + 0.38356313, + 0.92939967, + 0.06164686, + 0.09034675, + 0.34704673, + 0.9492532, + 0.7738587, + 0.93576515, + 0.49937814, + 0.38543963, + 0.02364575, + 0.80216527, + ], + }, + { + shape: [2, 4], + value: [ + 3.7336694407389186, + 4.342943392035599, + 3.721185688897571, + 4.142124516565133, + 3.461632460193509, + 3.897231574768164, + 3.48819604416023, + 3.6299748461695684, + ], + }, + 3.14); + }); + + it('gemm broadcasting bias', function() { + testGemm( + { + shape: [3, 7], + value: [ + 0.96122783, 0.7414551, 0.22178489, 0.23116009, 0.19249596, + 0.860125, 0.24145897, 0.43657154, 0.20278022, 0.01261093, + 0.526355, 0.94473153, 0.59416693, 0.5121616, 0.93981737, + 0.9942615, 0.46400633, 0.40644044, 0.43731472, 0.22579351, + 0.6787937, + ], + }, + { + shape: [7, 3], + value: [ + 0.1004637, 0.31921694, 0.7323029, 0.05150159, 0.9162225, + 0.89180815, 0.00931315, 0.3568885, 0.9506084, 0.04976705, + 0.6065987, 0.99300903, 0.29279497, 0.29296732, 0.38377914, + 0.80959237, 0.5812153, 0.34052548, 0.2931774, 0.12963536, + 0.58294684, + ], + }, + { + shape: [3, 3], + value: [ + 1.7498733917245672, + 2.5712126946215443, + 3.091094720699889, + 1.7664618383456954, + 2.115494714166447, + 2.676713501071726, + 1.4580698060088526, + 2.748510677662516, + 3.838076489432427, + ], + }, + {shape: [1], value: [0.7780463]}); + }); + + it('gemm aTranpose', function() { + testGemm( + { + shape: [6, 3], + value: [ + 0.2714853, + 0.0877158, + 0.31404206, + 0.2387523, + 0.25758955, + 0.37354097, + 0.6452827, + 0.8840964, + 0.14744024, + 0.65488476, + 0.35878596, + 0.3690042, + 0.4229308, + 0.40953776, + 0.85461134, + 0.7056307, + 0.17941293, + 0.4431382, + ], + }, + { + shape: [6, 4], + value: [ + 0.09183464, 0.8638833, 0.9302645, 0.06964016, 0.43232033, + 0.7631357, 0.5690705, 0.57837325, 0.17691018, 0.77424425, + 0.8207884, 0.429646, 0.31379876, 0.29592493, 0.10828935, + 0.00203794, 0.9140249, 0.31878716, 0.11819135, 0.6350557, + 0.7605846, 0.40116578, 0.32365775, 0.07651691, + ], + }, + { + shape: [3, 4], + value: [ + 1.3710694580379748, + 1.5280349660791157, + 1.2673472343664909, + 0.7581492662083144, + 0.8991952125197271, + 1.2655619679904546, + 1.0991664828167445, + 0.809478525550666, + 1.4503861699104, + 1.2299214444777435, + 0.9101225708947001, + 0.8786485304251702, + ], + }, + 0.0, {aTranspose: true}); + }); + + it('gemm bTranspose', function() { + testGemm( + { + shape: [3, 6], + value: [ + 0.4520783, + 0.25709572, + 0.28996432, + 0.03766193, + 0.0546827, + 0.46305302, + 0.91171485, + 0.48380807, + 0.09058774, + 0.6646215, + 0.35773644, + 0.03604647, + 0.21229707, + 0.18758385, + 0.01589681, + 0.9606218, + 0.08803706, + 0.18099776, + ], + }, + { + shape: [4, 6], + value: [ + 0.1482661, 0.27676222, 0.10893039, 0.8347901, 0.7146212, + 0.7316929, 0.97991717, 0.97123116, 0.69798464, 0.8436566, + 0.9630883, 0.23252074, 0.09898344, 0.08882044, 0.90780985, + 0.7116153, 0.5819304, 0.6742051, 0.5233705, 0.5594687, + 0.963364, 0.1351259, 0.8119938, 0.13756031, + ], + }, + { + shape: [3, 4], + value: [ + 0.5790980251807042, + 1.0871967394943738, + 0.7016311248544318, + 0.7729714738584772, + 1.11578439712886, + 2.340149006498089, + 0.9208884121099968, + 1.2203550096802216, + 1.0823897209392859, + 1.3386246706912768, + 0.9089606799007333, + 0.45756027099352364, + ], + }, + 0.0, {bTranspose: true}); + }); +}); diff --git a/test/gru_cell_test.js b/test/gru_cell_test.js new file mode 100644 index 0000000..3939d34 --- /dev/null +++ b/test/gru_cell_test.js @@ -0,0 +1,345 @@ +'use strict'; + +import {gruCell} from '../src/gru_cell.js'; +import {Tensor} from '../src/lib/tensor.js'; +import {sigmoid} from '../src/sigmoid.js'; +import {tanh} from '../src/tanh.js'; +import * as utils from './utils.js'; + +describe('test gruCell', function() { + it('gruCell defaults', function() { + const batchSize = 3; + const inputSize = 2; + const hiddenSize = 5; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(0)); + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 0.12397026217591961, + 0.12397026217591961, + 0.12397026217591961, + 0.12397026217591961, + 0.12397026217591961, + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.19991654116571125, + 0.19991654116571125, + 0.19991654116571125, + 0.19991654116571125, + 0.19991654116571125, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with bias', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(0)); + const bias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(0.1)); + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, {bias}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.15482337214048048, + 0.15482337214048048, + 0.15482337214048048, + 0.07484276504070396, + 0.07484276504070396, + 0.07484276504070396, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with recurrentBias', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(0)); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + {recurrentBias}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 0.14985295238282167, + 0.14985295238282167, + 0.14985295238282167, + 0.07467770390292117, + 0.07467770390292117, + 0.07467770390292117, + 0.032218815985522856, + 0.032218815985522856, + 0.032218815985522856, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with explict resetAfter true', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(2)); + const bias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const resetAfter = true; + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 1.9064574801795497, + 1.9064574801795497, + 1.9064574801795497, + 1.9606870240735346, + 1.9606870240735346, + 1.9606870240735346, + 1.9836880687096186, + 1.9836880687096186, + 1.9836880687096186, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with resetAfter false', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(2)); + const bias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const resetAfter = false; + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 1.906856117423314, + 1.906856117423314, + 1.906856117423314, + 1.9606980991458889, + 1.9606980991458889, + 1.9606980991458889, + 1.983688371193181, + 1.983688371193181, + 1.983688371193181, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with default zrn layout', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(2)); + const bias = new Tensor([3 * hiddenSize], + [ + 1.9853785, + 2.2497437, + 0.6179927, + 0.3148022, + -0.4366297, + -0.9718124, + -1.257099, + -1.5698853, + -0.39671835, + ]); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const resetAfter = true; + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 1.9801673183552388, + 1.9812534682811542, + 1.9376592706336329, + 1.9935192730591977, + 1.9947569570033654, + 1.9759958501762682, + 1.997469445392646, + 1.9980404252433588, + 1.9902071255213296, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with explict zrn layout', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(2)); + const bias = new Tensor([3 * hiddenSize], + [ + 1.9853785, + 2.2497437, + 0.6179927, + 0.3148022, + -0.4366297, + -0.9718124, + -1.257099, + -1.5698853, + -0.39671835, + ]); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const resetAfter = true; + const layout = 'zrn'; + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter, layout}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 1.9801673183552388, + 1.9812534682811542, + 1.9376592706336329, + 1.9935192730591977, + 1.9947569570033654, + 1.9759958501762682, + 1.997469445392646, + 1.9980404252433588, + 1.9902071255213296, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with rzn layout', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(2)); + const bias = new Tensor([3 * hiddenSize], + [ + 0.3148022, + -0.4366297, + -0.9718124, + 1.9853785, + 2.2497437, + 0.6179927, + -1.257099, + -1.5698853, + -0.39671835, + ]); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const resetAfter = true; + const layout = 'rzn'; + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + {bias, recurrentBias, resetAfter, layout}); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 1.9801673183552388, + 1.9812534682811542, + 1.9376592706336329, + 1.9935192730591977, + 1.9947569570033654, + 1.9759958501762682, + 1.997469445392646, + 1.9980404252433588, + 1.9902071255213296, + ]; + utils.checkValue(output, expected); + }); + + it('gruCell with [tanh, sigmoid] activations', function() { + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([3 * hiddenSize, inputSize], + new Array(3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([3 * hiddenSize, hiddenSize], + new Array(3 * hiddenSize * hiddenSize).fill(0.1)); + const hiddenState = new Tensor([batchSize, hiddenSize], + new Array(batchSize * hiddenSize).fill(2)); + const bias = new Tensor([3 * hiddenSize], + [ + 1.9853785, + 2.2497437, + 0.6179927, + 0.3148022, + -0.4366297, + -0.9718124, + -1.257099, + -1.5698853, + -0.39671835, + ]); + const recurrentBias = new Tensor([3 * hiddenSize], new Array(3 * hiddenSize).fill(1)); + const resetAfter = true; + const output = gruCell( + input, weight, recurrentWeight, hiddenState, hiddenSize, + { + bias, + recurrentBias, + resetAfter, + activations: [tanh, sigmoid], + }); + utils.checkShape(output, [batchSize, hiddenSize]); + const expected = [ + 1.9994052973405467, + 1.9996265670444457, + 1.9916469375315222, + 1.9999129608020485, + 1.9999467564798181, + 1.9987442445027492, + 1.9999865812225888, + 1.99999193815786, + 1.9997998773572325, + ]; + utils.checkValue(output, expected); + }); +}); diff --git a/test/gru_test.js b/test/gru_test.js new file mode 100644 index 0000000..43c4bb0 --- /dev/null +++ b/test/gru_test.js @@ -0,0 +1,407 @@ +'use strict'; + +import {gru} from '../src/gru.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test gru', function() { + it('gru with 1 step', function() { + const steps = 1; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 3; + const input = new Tensor([steps, batchSize, inputSize], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + [ + 0.3148022, + -0.4366297, + -0.9718124, + 1.9853785, + 2.2497437, + 0.6179927, + -1.257099, + -1.5698853, + -0.39671835, + ]); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(2)); + const resetAfter = true; + const layout = 'rzn'; + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, resetAfter, layout}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 1.9801673183552388, + 1.9812534682811542, + 1.9376592706336329, + 1.9935192730591977, + 1.9947569570033654, + 1.9759958501762682, + 1.997469445392646, + 1.9980404252433588, + 1.9902071255213296, + ]; + utils.checkValue(outputs[0], expected); + }); + + it('gru with 2 steps', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize) + .fill(0.1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(0)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + ]; + utils.checkValue(outputs[0], expected); + }); + + it('gru with explict returnSequence false', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize) + .fill(0.1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(0)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const returnSequence = false; + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, returnSequence}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + ]; + utils.checkValue(outputs[0], expected); + }); + + it('gru with returnSequence true', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize) + .fill(0.1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(0)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const returnSequence = true; + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, returnSequence}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + utils.checkShape(outputs[1], [steps, numDirections, batchSize, hiddenSize]); + const expected = [ + [ + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + ], + [ + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.20053661855501925, + 0.15482337214048048, + 0.15482337214048048, + 0.15482337214048048, + 0.15482337214048048, + 0.15482337214048048, + 0.07484276504070396, + 0.07484276504070396, + 0.07484276504070396, + 0.07484276504070396, + 0.07484276504070396, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + ], + ]; + for (let i = 0; i < expected.length; ++i) { + utils.checkValue(outputs[i], expected[i]); + } + }); + + it('gru with explict forward direction', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize) + .fill(0.1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(0)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const direction = 'forward'; + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, direction}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + ]; + utils.checkValue(outputs[0], expected); + }); + + it('gru with backward direction', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize).fill(0.1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(0)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const direction = 'backward'; + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, direction}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 0.22227008136062426, + 0.22227008136062426, + 0.22227008136062426, + 0.22227008136062426, + 0.22227008136062426, + 0.1652493513699554, + 0.1652493513699554, + 0.1652493513699554, + 0.1652493513699554, + 0.1652493513699554, + 0.07972921857068853, + 0.07972921857068853, + 0.07972921857068853, + 0.07972921857068853, + 0.07972921857068853, + ]; + utils.checkValue(outputs[0], expected); + }); + + it('gru with both direction', function() { + const steps = 2; + const numDirections = 2; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize).fill(0.1)); + const initialHiddenState = new Tensor([numDirections, batchSize, hiddenSize], + new Array(numDirections * batchSize * hiddenSize).fill(0)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const direction = 'both'; + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias, initialHiddenState, direction}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.22227008136062426, + 0.22227008136062426, + 0.22227008136062426, + 0.22227008136062426, + 0.22227008136062426, + 0.1652493513699554, + 0.1652493513699554, + 0.1652493513699554, + 0.1652493513699554, + 0.1652493513699554, + 0.07972921857068853, + 0.07972921857068853, + 0.07972921857068853, + 0.07972921857068853, + 0.07972921857068853, + ]; + utils.checkValue(outputs[0], expected); + }); + + it('gru without initialHiddenState', function() { + const steps = 2; + const numDirections = 1; + const batchSize = 3; + const inputSize = 3; + const hiddenSize = 5; + const input = new Tensor([steps, batchSize, inputSize], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]); + const weight = new Tensor([numDirections, 3 * hiddenSize, inputSize], + new Array(numDirections * 3 * hiddenSize * inputSize).fill(0.1)); + const recurrentWeight = new Tensor([numDirections, 3 * hiddenSize, hiddenSize], + new Array(numDirections * 3 * hiddenSize * hiddenSize).fill(0.1)); + const bias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0.1)); + const recurrentBias = new Tensor([numDirections, 3 * hiddenSize], + new Array(numDirections * 3 * hiddenSize).fill(0)); + const outputs = gru( + input, weight, recurrentWeight, steps, hiddenSize, + {bias, recurrentBias}); + utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]); + const expected = [ + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.22391088955449673, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.16530139937319663, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + 0.0797327116380732, + ]; + utils.checkValue(outputs[0], expected); + }); +}); diff --git a/test/leaky_relu_test.js b/test/leaky_relu_test.js new file mode 100644 index 0000000..c46c78f --- /dev/null +++ b/test/leaky_relu_test.js @@ -0,0 +1,182 @@ +'use strict'; + +import {leakyRelu} from '../src/leaky_relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test leakyRelu', function() { + function testLeakyRelu(input, expected, options = {}) { + const inputTensor = new Tensor(input.shape, input.value); + const outputTensor = leakyRelu(inputTensor, options); + utils.checkValue(outputTensor, expected); + } + + it('leakyRelu', function() { + testLeakyRelu( + {shape: [3], value: [-1, 0, 1]}, [-0.1, 0., 1.], {alpha: 0.1}); + testLeakyRelu( + { + shape: [3, 4, 5], + value: [ + 0.5945598, -0.735546, 0.9624621, 0.7178781, -2.2841945, + 1.4461595, 0.13227068, -0.05931347, 0.25514695, 0.83969593, + 3.4556108, 1.6048287, 0.30937293, -0.11302311, -0.55214405, + 0.15766327, 0.40505877, 0.7130178, -0.53093743, 0.77193236, + -1.6821449, -0.8352944, 0.08011059, 0.53667474, 0.11023884, + -0.61316216, 0.53726774, -0.7437747, -0.5286507, 1.2811732, + -0.19160618, -0.5079444, 0.33344734, 1.4179748, -0.09760198, + 1.0317479, 0.7191149, 0.9713708, -0.32984316, 0.15518457, + 0.16741018, -0.8231882, 0.24937603, -1.1336567, 2.3608718, + 1.2201307, -0.09541762, -0.61066127, 0.91480494, 0.9309983, + -0.08354045, -0.44542325, 3.088639, -0.90056187, 0.25742382, + 1.3762826, 0.39736032, 0.49137968, -0.5622506, 1.1100211, + ], + }, + [ + 0.5945598, + -0.07355460000000001, + 0.9624621, + 0.7178781, + -0.22841945, + 1.4461595, + 0.13227068, + -0.005931347, + 0.25514695, + 0.83969593, + 3.4556108, + 1.6048287, + 0.30937293, + -0.011302311, + -0.055214405, + 0.15766327, + 0.40505877, + 0.7130178, + -0.053093743000000006, + 0.77193236, + -0.16821449, + -0.08352944000000001, + 0.08011059, + 0.53667474, + 0.11023884, + -0.06131621600000001, + 0.53726774, + -0.07437747, + -0.05286507000000001, + 1.2811732, + -0.019160618, + -0.050794439999999996, + 0.33344734, + 1.4179748, + -0.009760198000000001, + 1.0317479, + 0.7191149, + 0.9713708, + -0.03298431600000001, + 0.15518457, + 0.16741018, + -0.08231882000000001, + 0.24937603, + -0.11336567, + 2.3608718, + 1.2201307, + -0.009541762, + -0.061066127000000005, + 0.91480494, + 0.9309983, + -0.008354045000000001, + -0.044542325, + 3.088639, + -0.09005618700000001, + 0.25742382, + 1.3762826, + 0.39736032, + 0.49137968, + -0.05622506000000001, + 1.1100211, + ], + {alpha: 0.1}); + }); + + it('leakyRelu default', function() { + testLeakyRelu( + { + shape: [3, 4, 5], + value: [ + 1.2178663, 0.08626969, -0.25983566, 0.03568677, -1.5386598, + 0.2786136, 0.1057941, -0.5374242, -0.11235637, 0.07136911, + 1.1007954, -0.3993358, -1.5691061, 0.7312798, 0.7960611, + 0.6767248, -0.30511293, 0.85154665, -0.97270423, 0.33083355, + -0.96259284, 1.0446007, 1.2399997, -0.4430618, -0.88743573, + -1.1777387, 0.4861841, 1.0564232, -0.92164683, -1.7308608, + 0.08230155, -0.7713891, -0.77213866, -1.0124619, -1.2846667, + 1.0307417, 0.9004573, -0.593318, 0.29095086, -0.50655633, + -0.6983193, 0.69927245, -1.1014417, -0.36207023, 1.1648387, + 0.0049276, -0.12467039, 2.7892349, 0.8076212, 2.2155113, + 1.5295383, -2.2338881, -1.7535976, -1.1389159, -0.16080397, + 0.4859151, 0.34155434, 0.91066486, 0.65148973, 0.13155791, + ], + }, + [ + 1.2178663, + 0.08626969, + -0.0025983566000000002, + 0.03568677, + -0.015386598000000001, + 0.2786136, + 0.1057941, + -0.005374242, + -0.0011235637, + 0.07136911, + 1.1007954, + -0.003993358000000001, + -0.015691061, + 0.7312798, + 0.7960611, + 0.6767248, + -0.0030511293, + 0.85154665, + -0.0097270423, + 0.33083355, + -0.0096259284, + 1.0446007, + 1.2399997, + -0.004430618, + -0.0088743573, + -0.011777387, + 0.4861841, + 1.0564232, + -0.0092164683, + -0.017308608, + 0.08230155, + -0.0077138910000000005, + -0.0077213865999999996, + -0.010124619, + -0.012846667, + 1.0307417, + 0.9004573, + -0.005933180000000001, + 0.29095086, + -0.0050655633, + -0.0069831929999999995, + 0.69927245, + -0.011014417, + -0.0036207023, + 1.1648387, + 0.0049276, + -0.0012467039000000001, + 2.7892349, + 0.8076212, + 2.2155113, + 1.5295383, + -0.022338881, + -0.017535976, + -0.011389159, + -0.0016080397, + 0.4859151, + 0.34155434, + 0.91066486, + 0.65148973, + 0.13155791, + ]); + }); +}); diff --git a/test/matmul_test.js b/test/matmul_test.js new file mode 100644 index 0000000..1f11f2f --- /dev/null +++ b/test/matmul_test.js @@ -0,0 +1,382 @@ +'use strict'; + +import {matmul} from '../src/matmul.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test matmul', function() { + function testMatmul(A, B, expected) { + const a = new Tensor(A.shape, A.value); + const b = new Tensor(B.shape, B.value); + const c = matmul(a, b); + utils.checkShape(c, expected.shape); + utils.checkValue(c, expected.value); + } + + it('matmul 1d', function() { + testMatmul( + {shape: [4], value: [0.9025404, 0.89538723, 0.16789329, 0.7440875]}, + {shape: [4], value: [0.8782074, 0.22533207, 0.7134056, 0.04190519]}, + {shape: [], value: [1.145334257418975]}); + }); + + it('matmul 1dx2d', function() { + testMatmul( + {shape: [4], value: [0.1309212, 0.9090703, 0.62183434, 0.9195683]}, { + shape: [4, 3], + value: [ + 0.3093976, + -1.2924036, + -0.64339244, + 1.1423386, + 1.5052135, + 1.8182521, + -1.825652, + -0.39694095, + -0.90111053, + 0.7807154, + -1.9163561, + -0.13988003, + ], + }, + {shape: [1, 3], + value: [ + 0.6616408255448397, + -0.8099099769211229, + 0.8797145586262527, + ]}); + }); + + it('matmul 2dx1d', function() { + testMatmul( + { + shape: [3, 4], + value: [ + 0.3582649, + 0.83665735, + 0.30253866, + 0.6446781, + 0.4684662, + 0.94761264, + 0.4122941, + 0.6787481, + 0.15072346, + 0.2820577, + 0.67296237, + 0.3856028, + ], + }, + {shape: [4], value: [0.25528687, 0.2126722, 0.26320502, 0.8297401]}, + {shape: [3, 1], + value: [0.8839390494404162, + 0.992826528000594, + 0.5955407318122876, + ]}); + }); + + it('matmul 2d', function() { + testMatmul( + { + shape: [3, 4], + value: [ + 0.9602246, + 0.97682184, + -0.33201018, + 0.8248904, + 0.40872088, + 0.18995902, + 0.69355214, + -0.37210146, + 0.18104352, + 3.270753, + -0.803097, + -0.7268995, + ], + }, + { + shape: [4, 3], + value: [ + 0.17467105, + -1.2045133, + -0.02621938, + 0.6096196, + 1.4499376, + 1.3465316, + 0.03289436, + 1.0754977, + -0.61485314, + 0.94857556, + -0.36462623, + 1.402278, + ], + }, + { + shape: [3, 3], + value: [ + 1.5347627892239333, + -0.39812544905177394, + 2.651008143473561, + -0.14295794996907119, + 0.6647106735468218, + -0.703152987090222, + 1.3096017351837559, + 3.9256360751909685, + 3.8738969565609622, + ], + }); + }); + + it('matmul 3d', function() { + testMatmul( + { + shape: [2, 3, 4], + value: [ + 0.19521078, 0.11637875, 0.54684865, 0.13257395, -0.05654722, + -0.64351636, -1.0019655, -1.6156989, 0.01625126, 1.2386297, + -0.1242797, 0.40350053, -0.5883816, 0.93452644, -0.01409106, + -0.7825521, -1.2281458, -1.2388189, 0.7644939, -0.8567167, + 0.3942727, -0.772506, -0.06412488, -0.9848109, + ], + }, + { + shape: [2, 4, 3], + value: [ + -2.7142005, 0.41909233, 0.80572236, 0.19983047, -1.9361104, + 1.1919757, 0.61684674, 0.23732206, 0.74679494, 0.4595843, + -0.90667343, 0.7676448, 0.48643762, 0.41120672, 1.1319419, + 1.9692143, -0.44463134, 0.17005378, 1.1589569, -0.4333597, + -0.47976026, 0.01067371, -0.79455626, -1.4024538, + ], + }, + { + shape: [2, 3, 3], + value: [ + -0.10833446333599143, + -0.13393279743161207, + 0.8061598404542069, + -1.3357226841086192, + 2.4493429579826187, + -2.801162847627581, + 0.31218775792583, + -2.7866505894621443, + 1.7064441398054897, + 1.5293882188296952, + -0.029578043761553485, + 0.5971594643551588, + -2.1600450781503135, + 0.39520467130026193, + -0.7661237278034159, + -1.414270346305937, + 1.3158847659611541, + 1.7268425818711388, + ], + }); + }); + + it('matmul 3dx2d', function() { + testMatmul( + { + shape: [2, 3, 4], + value: [ + -0.57675153, -0.40231872, 0.10705414, -0.66516143, 0.3206562, + 0.43695804, -1.8614748, 0.77510875, -1.2424866, -0.58930343, + 0.40949076, 0.5517746, 0.09809388, 0.5084747, 0.76594603, + 0.8050488, -0.03979152, 2.4019558, -0.54937273, -0.1696853, + -1.223669, 1.0791223, -0.61921734, 2.1074235, + ], + }, + { + shape: [4, 3], + value: [ + -0.38534147, + -0.18395364, + -2.548874, + 0.4525641, + -0.41875792, + 0.57480955, + -0.41603103, + 0.6973883, + 0.9531734, + 1.3292471, + -1.003955, + -0.7639869, + ], + }, + { + shape: [2, 3, 3], + value: [ + -0.8885304730243201, + 1.0170201418415437, + 1.849026114390587, + 1.878931727122919, + -2.3183105665073347, + -2.9326257919832135, + 0.7751679607198163, + 0.20695260775599766, + 2.7969716845016883, + 0.9437692220342353, + -0.5050435022828981, + 0.15727981777314692, + 1.1053746974552965, + -1.211287928674362, + 1.0880805766675583, + 4.01880262734377, + -2.774385931084078, + 1.539002458305059, + ], + }); + }); + + it('matmul 3dx2d should be 3d', function() { + testMatmul( + { + shape: [1, 3, 4], + value: [ + 0.25500464, + -1.105212, + -0.5368534, + -0.01583702, + 0.9875369, + 1.3744136, + 0.61079186, + 0.74018836, + -0.56111795, + -0.16432828, + 1.3176169, + -0.249416, + ], + }, + { + shape: [4, 3], + value: [ + 0.2545374, + -1.6150205, + -0.64508885, + -0.3454305, + 0.38700557, + 1.3147515, + -0.3379386, + 1.1804152, + 1.9414345, + -1.5912915, + 0.40443325, + -0.23596671, + ], + }, + { + shape: [1, 3, 3], + value: [ + 0.653306953532106, + -1.4796758522265552, + -2.65610848747696, + -1.607664893851476, + -0.04264183969545593, + 2.181115876300509, + -0.13444155392012996, + 2.2970839255543356, + 2.7628408608418473, + ], + }); + }); + + it('matmul 4d', function() { + testMatmul( + { + shape: [1, 2, 3, 4], + value: [ + -0.8074054, -0.72524256, 0.4510249, 1.6203358, 1.9851393, + 0.501528, 1.3975041, -2.3231244, 0.70866925, 0.24667543, + -0.6271161, -0.9634111, -0.5911732, -0.09888726, -1.0926677, + 0.47262478, 0.6141726, -0.634484, -0.07425678, -1.2638812, + -1.1002079, -1.5324054, -1.1643038, -0.05644368, + ], + }, + { + shape: [1, 2, 4, 3], + value: [ + -0.45605758, -0.43318668, 0.61509126, -2.2228749, 0.50257015, + -0.29311436, -0.64561933, -0.6439757, 1.6211574, -0.28852704, + -0.46247238, 0.5082442, 1.2357981, -0.82043344, -0.926581, + -0.8955289, 0.74586314, -0.8022598, -0.5360306, -0.08719682, + 0.72717273, 1.1277325, 2.0261378, -1.4311641, + ], + }, + { + shape: [1, 2, 3, 3], + value: [ + 1.221645749907327, + -1.0545376270456461, + 1.2706596306239777, + -2.252143281998571, + -0.4334607112616218, + 2.1588963856372976, + -0.188674175668765, + 0.6663841820450125, + -1.1427098588511797, + 0.47668332938386415, + 1.4641419805936096, + -0.8438586440388712, + -0.05832390830187206, + -3.531448486441424, + 1.6947642397339107, + 0.57312758193175, + -0.25315643602996796, + 1.4829491816053335, + ], + }); + }); + + it('matmul 4dx2d', function() { + testMatmul( + { + shape: [1, 2, 3, 4], + value: [ + -0.40162078, -0.5607968, -1.4350457, -0.22855183, -0.1357853, + -1.3434876, 1.0602195, -0.17137937, 0.44751146, 0.78427273, + -0.49435133, -0.9062699, -0.6109297, 0.645001, 0.6632162, + 0.903104, 2.4085212, 0.7805757, -0.9099179, -0.6195976, + 0.38710263, 0.5102191, -0.03610202, 1.2280966, + ], + }, + { + shape: [4, 3], + value: [ + 0.01829041, + -0.73948264, + -0.95898634, + -0.5105271, + 2.1705306, + 1.2495605, + -1.9865801, + -0.58367056, + -0.80371356, + -0.583849, + -1.2323712, + 1.3314632, + ], + }, + { + shape: [1, 2, 3, 3], + value: [ + 3.2632291428668005, + 0.1990196002350671, + 0.5334566494457813, + -1.3227480270318333, + -3.2232859838291446, + -2.6288509024476023, + 1.1189859750587483, + 2.7767602451442084, + -0.2585093062909366, + -2.185273116934897, + 0.35170718388113653, + 2.061254897681926, + 1.814924023726412, + 1.2078703550705958, + -1.4280204169121542, + -0.8987038289844296, + -0.6712086999562717, + 1.9305046113199869, + ], + }); + }); +}); diff --git a/test/package.json b/test/package.json new file mode 100644 index 0000000..aead43d --- /dev/null +++ b/test/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} \ No newline at end of file diff --git a/test/pool2d_test.js b/test/pool2d_test.js new file mode 100644 index 0000000..39b0996 --- /dev/null +++ b/test/pool2d_test.js @@ -0,0 +1,512 @@ +'use strict'; +import {averagePool2d, maxPool2d} from '../src/pool2d.js'; +import {Tensor} from '../src/lib/tensor.js'; + +import * as utils from './utils.js'; + +describe('test pool2d', function() { + it('maxPool2d default', function() { + const x = new Tensor([1, 1, 4, 4], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const windowDimensions = [3, 3]; + const y = maxPool2d(x, {windowDimensions}); + utils.checkShape(y, [1, 1, 2, 2]); + utils.checkValue(y, [11, 12, 15, 16]); + }); + + it('maxPool2d nhwc', function() { + const x = new Tensor([1, 4, 4, 1], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const windowDimensions = [3, 3]; + const layout = 'nhwc'; + const y = maxPool2d(x, {windowDimensions, layout}); + utils.checkShape(y, [1, 2, 2, 1]); + utils.checkValue(y, [11, 12, 15, 16]); + }); + + it('maxPool2d dilations default', function() { + const x = new Tensor([1, 1, 4, 4], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const windowDimensions = [2, 2]; + const dilations = [2, 2]; + const y = maxPool2d(x, {windowDimensions, dilations}); + utils.checkShape(y, [1, 1, 2, 2]); + utils.checkValue(y, [11, 12, 15, 16]); + }); + + it('maxPool2d dilations nhwc', function() { + const x = new Tensor([1, 4, 4, 1], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const windowDimensions = [2, 2]; + const dilations = [2, 2]; + const layout = 'nhwc'; + const y = maxPool2d(x, {windowDimensions, dilations, layout}); + utils.checkShape(y, [1, 2, 2, 1]); + utils.checkValue(y, [11, 12, 15, 16]); + }); + + it('maxPool2d pads default', function() { + const x = new Tensor([1, 1, 5, 5], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const padding = [2, 2, 2, 2]; + const y = maxPool2d(x, {windowDimensions, padding}); + utils.checkShape(y, [1, 1, 5, 5]); + const expected = [ + 13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, + 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, + ]; + utils.checkValue(y, expected); + }); + + it('maxPool2d pads nhwc', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const padding = [2, 2, 2, 2]; + const layout = 'nhwc'; + const y = maxPool2d(x, {windowDimensions, padding, layout}); + utils.checkShape(y, [1, 5, 5, 1]); + const expected = [ + 13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, + 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, + ]; + utils.checkValue(y, expected); + }); + + it('maxPool2d autoPad same-upper default', function() { + const x = new Tensor([1, 1, 5, 5], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const autoPad = 'same-upper'; + const y = maxPool2d(x, {windowDimensions, autoPad}); + utils.checkShape(y, [1, 1, 5, 5]); + const expected = [ + 13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, + 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, + ]; + utils.checkValue(y, expected); + }); + + it('maxPool2d autoPad explicit nhwc', function() { + const x = new Tensor([1, 7, 7, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const padding = [2, 1, 2, 1]; + const strides = [2, 2]; + const autoPad = 'explicit'; + const layout = 'nhwc'; + const y = maxPool2d( + x, {windowDimensions, autoPad, padding, strides, layout}); + utils.checkShape(y, [1, 4, 4, 1]); + const expected = [ + 9, + 11, + 13, + 14, + 23, + 25, + 27, + 28, + 37, + 39, + 41, + 42, + 44, + 46, + 48, + 49, + ]; + utils.checkValue(y, expected); + }); + + it('maxPool2d autoPad same-lower nhwc', function() { + const x = new Tensor([1, 7, 7, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const strides = [2, 2]; + const autoPad = 'same-lower'; + const layout = 'nhwc'; + const y = + maxPool2d(x, {windowDimensions, autoPad, strides, layout}); + utils.checkShape(y, [1, 4, 4, 1]); + const expected = [ + 9, + 11, + 13, + 14, + 23, + 25, + 27, + 28, + 37, + 39, + 41, + 42, + 44, + 46, + 48, + 49, + ]; + utils.checkValue(y, expected); + }); + + it('maxPool2d autoPad same-upper nhwc', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const autoPad = 'same-upper'; + const layout = 'nhwc'; + const y = maxPool2d(x, {windowDimensions, autoPad, layout}); + utils.checkShape(y, [1, 5, 5, 1]); + const expected = [ + 13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, + 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, + ]; + utils.checkValue(y, expected); + }); + + it('maxPool2d strides default', function() { + const x = new Tensor([1, 1, 5, 5], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [2, 2]; + const strides = [2, 2]; + const y = maxPool2d(x, {windowDimensions, strides}); + utils.checkShape(y, [1, 1, 2, 2]); + const expected = [7, 9, 17, 19]; + utils.checkValue(y, expected); + }); + + it('maxPool2d strides nhwc', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [2, 2]; + const strides = [2, 2]; + const layout = 'nhwc'; + const y = maxPool2d(x, {windowDimensions, strides, layout}); + utils.checkShape(y, [1, 2, 2, 1]); + const expected = [7, 9, 17, 19]; + utils.checkValue(y, expected); + }); + + it('averagePool2d default', function() { + const x = new Tensor([1, 1, 4, 4], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const windowDimensions = [3, 3]; + const y = averagePool2d(x, {windowDimensions}); + utils.checkShape(y, [1, 1, 2, 2]); + const expected = [6, 7, 10, 11]; + utils.checkValue(y, expected); + }); + + it('averagePool2d nhwc', function() { + const x = new Tensor([1, 4, 4, 1], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + const windowDimensions = [3, 3]; + const layout = 'nhwc'; + const y = averagePool2d(x, {windowDimensions, layout}); + utils.checkShape(y, [1, 2, 2, 1]); + const expected = [6, 7, 10, 11]; + utils.checkValue(y, expected); + }); + + it('averagePool2d pads default', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const padding = [2, 2, 2, 2]; + const layout = 'nhwc'; + const y = averagePool2d(x, {windowDimensions, padding, layout}); + utils.checkShape(y, [1, 5, 5, 1]); + const expected = [ + 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, + 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19, + ]; + utils.checkValue(y, expected); + }); + + it('averagePool2d pads nhwc', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const padding = [2, 2, 2, 2]; + const layout = 'nhwc'; + const y = averagePool2d(x, {windowDimensions, padding, layout}); + utils.checkShape(y, [1, 5, 5, 1]); + const expected = [ + 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, + 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19, + ]; + utils.checkValue(y, expected); + }); + + it('averagePool2d autoPad same-upper default', function() { + const x = new Tensor([1, 1, 5, 5], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const autoPad = 'same-upper'; + const y = averagePool2d(x, {windowDimensions, autoPad}); + utils.checkShape(y, [1, 1, 5, 5]); + const expected = [ + 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, + 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19, + ]; + utils.checkValue(y, expected); + }); + + it('averagePool2d autoPad same-upper nhwc', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [5, 5]; + const autoPad = 'same-upper'; + const layout = 'nhwc'; + const y = averagePool2d(x, {windowDimensions, autoPad, layout}); + utils.checkShape(y, [1, 5, 5, 1]); + const expected = [ + 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, + 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19, + ]; + utils.checkValue(y, expected); + }); + + it('averagePool2d autoPad explicit nhwc', function() { + const x = new Tensor([1, 7, 7, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const padding = [2, 1, 2, 1]; + const strides = [2, 2]; + const autoPad = 'explicit'; + const layout = 'nhwc'; + const y = averagePool2d( + x, {windowDimensions, autoPad, padding, strides, layout}); + utils.checkShape(y, [1, 4, 4, 1]); + const expected = [ + 5, + 6, + 8, + 9.5, + 12, + 13, + 15, + 16.5, + 26, + 27, + 29, + 30.5, + 36.5, + 37.5, + 39.5, + 41, + ]; + utils.checkValue(y, expected); + }); + + it('averagePool2d autoPad same-lower nhwc', function() { + const x = new Tensor([1, 7, 7, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const strides = [2, 2]; + const autoPad = 'same-lower'; + const layout = 'nhwc'; + const y = + averagePool2d(x, {windowDimensions, autoPad, strides, layout}); + utils.checkShape(y, [1, 4, 4, 1]); + const expected = [ + 5, + 6, + 8, + 9.5, + 12, + 13, + 15, + 16.5, + 26, + 27, + 29, + 30.5, + 36.5, + 37.5, + 39.5, + 41, + ]; + utils.checkValue(y, expected); + }); + + it('averagePool2d strides default', function() { + const x = new Tensor([1, 1, 5, 5], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [2, 2]; + const strides = [2, 2]; + const y = averagePool2d(x, {windowDimensions, strides}); + utils.checkShape(y, [1, 1, 2, 2]); + const expected = [4, 6, 14, 16]; + utils.checkValue(y, expected); + }); + + it('averagePool2d strides nhwc', function() { + const x = new Tensor([1, 5, 5, 1], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + ]); + const windowDimensions = [2, 2]; + const strides = [2, 2]; + const layout = 'nhwc'; + const y = averagePool2d(x, {windowDimensions, strides, layout}); + utils.checkShape(y, [1, 2, 2, 1]); + const expected = [4, 6, 14, 16]; + utils.checkValue(y, expected); + }); + + it('averagePool2d pads outputSizes=[3,3]', function() { + const x = new Tensor([1, 1, 7, 7], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const strides = [2, 2]; + const padding = [1, 1, 1, 1]; + const outputSizes = [3, 3]; + const y = + averagePool2d(x, {windowDimensions, strides, padding, outputSizes}); + utils.checkShape(y, [1, 1, 3, 3]); + const expected = [9, 10.5, 12.5, 19.5, 21, 23, 33.5, 35, 37]; + utils.checkValue(y, expected); + }); + + it('averagePool2d pads outputSizes=[4,4]', function() { + const x = new Tensor([1, 1, 7, 7], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const strides = [2, 2]; + const padding = [1, 1, 1, 1]; + const outputSizes = [4, 4]; + const y = + averagePool2d(x, {windowDimensions, strides, padding, outputSizes}); + utils.checkShape(y, [1, 1, 4, 4]); + const expected = [9, 10.5, 12.5, 13.5, 19.5, 21, 23, 24, 33.5, 35, 37, 38, 40.5, 42, 44, 45]; + utils.checkValue(y, expected); + }); + + it('averagePool2d pads roundingType=floor', function() { + const x = new Tensor([1, 1, 7, 7], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const strides = [2, 2]; + const padding = [1, 1, 1, 1]; + const roundingType = 'floor'; + const y = + averagePool2d(x, {windowDimensions, strides, padding, roundingType}); + utils.checkShape(y, [1, 1, 3, 3]); + const expected = [9, 10.5, 12.5, 19.5, 21, 23, 33.5, 35, 37]; + utils.checkValue(y, expected); + }); + + it('averagePool2d pads roundingType=ceil', function() { + const x = new Tensor([1, 1, 7, 7], [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + ]); + const windowDimensions = [4, 4]; + const strides = [2, 2]; + const padding = [1, 1, 1, 1]; + const roundingType = 'ceil'; + const y = + averagePool2d(x, {windowDimensions, strides, padding, roundingType}); + utils.checkShape(y, [1, 1, 4, 4]); + const expected = [9, 10.5, 12.5, 13.5, 19.5, 21, 23, 24, 33.5, 35, 37, 38, 40.5, 42, 44, 45]; + utils.checkValue(y, expected); + }); + + it('global averagePool2d default', function() { + const x = new Tensor([1, 3, 5, 5], [ + -1.1289884, 0.34016284, 0.497431, 2.1915932, 0.42038894, + -0.18261199, -0.15769927, -0.26465914, 0.03877424, 0.39492005, + -0.33410737, 0.74918455, -1.3542547, -0.0222946, 0.7094626, + -0.09399617, 0.790736, -0.75826526, 0.27656242, 0.46543223, + -1.2342638, 1.1549494, 0.24823844, 0.75670505, -1.7108902, + -1.4767597, -1.4969662, -0.31936142, 0.5327554, -0.06070877, + 0.31212643, 2.2274113, 1.2775147, 0.59886885, -1.5765078, + 0.18522178, 0.22655599, 0.88869494, 0.38609484, -0.05860576, + -0.72732115, -0.0046324, -1.3593693, -0.6295078, 1.384531, + 0.06825881, 0.19907428, 0.20298219, -0.8399954, 1.3583295, + 0.02117888, -1.0636739, -0.30460566, -0.92678875, -0.09120782, + -0.88333017, -0.9641269, 0.6065926, -0.5830042, -0.81138134, + 1.3569402, 1.2891295, 0.2508177, 0.20211531, 0.8832168, + -0.19886094, -0.61088, 0.682026, -0.5253442, 1.5022339, + 1.0256356, 1.0642492, -0.4169051, -0.8740329, 1.1494869, + ]); + const y = averagePool2d(x); + utils.checkShape(y, [1, 3, 1, 1]); + const expected = [ + 0.07170040239999997, + 0.05194737240000002, + 0.07117922839999995, + ]; + utils.checkValue(y, expected); + }); + + it('global averagePool2d nhwc', function() { + const x = new Tensor([1, 5, 5, 3], [ + -1.1289884, -1.4767597, 0.02117888, 0.34016284, -1.4969662, + -1.0636739, 0.497431, -0.31936142, -0.30460566, 2.1915932, + 0.5327554, -0.92678875, 0.42038894, -0.06070877, -0.09120782, + -0.18261199, 0.31212643, -0.88333017, -0.15769927, 2.2274113, + -0.9641269, -0.26465914, 1.2775147, 0.6065926, 0.03877424, + 0.59886885, -0.5830042, 0.39492005, -1.5765078, -0.81138134, + -0.33410737, 0.18522178, 1.3569402, 0.74918455, 0.22655599, + 1.2891295, -1.3542547, 0.88869494, 0.2508177, -0.0222946, + 0.38609484, 0.20211531, 0.7094626, -0.05860576, 0.8832168, + -0.09399617, -0.72732115, -0.19886094, 0.790736, -0.0046324, + -0.61088, -0.75826526, -1.3593693, 0.682026, 0.27656242, + -0.6295078, -0.5253442, 0.46543223, 1.384531, 1.5022339, + -1.2342638, 0.06825881, 1.0256356, 1.1549494, 0.19907428, + 1.0642492, 0.24823844, 0.20298219, -0.4169051, 0.75670505, + -0.8399954, -0.8740329, -1.7108902, 1.3583295, 1.1494869, + ]); + const layout = 'nhwc'; + const y = averagePool2d(x, {layout}); + utils.checkShape(y, [1, 1, 1, 3]); + const expected = [ + 0.07170040239999997, + 0.05194737240000002, + 0.07117922839999995, + ]; + utils.checkValue(y, expected); + }); +}); diff --git a/test/reduce_test.js b/test/reduce_test.js new file mode 100644 index 0000000..1e9a734 --- /dev/null +++ b/test/reduce_test.js @@ -0,0 +1,513 @@ +'use strict'; + +import * as reducers from '../src/reduce.js'; +import {Tensor} from '../src/lib/tensor.js'; + +import * as utils from './utils.js'; + +describe('test reduce', function() { + function testReduce(op, options, input, expected) { + const x = new Tensor(input.shape, input.values); + const y = reducers['reduce' + op](x, options); + utils.checkShape(y, expected.shape); + utils.checkValue(y, expected.values); + } + + it('reduceMax default', function() { + testReduce( + 'Max', {}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [], values: [600.]}); + }); + + it('reduceMax default axes keep dims', function() { + testReduce( + 'Max', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [1, 1, 1], values: [600.]}); + }); + + it('reduceMax axes0 do not keep dims', function() { + testReduce( + 'Max', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [2, 2], values: [500., 100., 600., 400.]}); + }); + + it('reduceMax axes1 do not keep dims', function() { + testReduce( + 'Max', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2], values: [200., 100., 300., 400., 600., 6.]}); + }); + + it('reduceMax axes2 do not keep dims', function() { + testReduce( + 'Max', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2], values: [100., 200., 300., 400., 500., 600.]}); + }); + + it('reduceMax negative axes do not keep dims', function() { + testReduce( + 'Max', {axes: [-1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2], values: [100., 200., 300., 400., 500., 600.]}); + }); + + it('reduceMax axes0 keep dims', function() { + testReduce( + 'Max', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [1, 2, 2], values: [500., 100., 600., 400.]}); + }); + + it('reduceMax axes1 keep dims', function() { + testReduce( + 'Max', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 1, 2], values: [200., 100., 300., 400., 600., 6.]}); + }); + + it('reduceMax axes2 keep dims', function() { + testReduce( + 'Max', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2, 1], values: [100., 200., 300., 400., 500., 600.]}); + }); + + it('reduceMax negative axes keep dims', function() { + testReduce( + 'Max', {axes: [-1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2, 1], values: [100., 200., 300., 400., 500., 600.]}); + }); + + it('reduceMean default', function() { + testReduce( + 'Mean', {}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [], values: [18.25]}); + }); + + it('reduceMean default axes keep dims', function() { + testReduce( + 'Mean', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [1, 1, 1], values: [18.25]}); + }); + + it('reduceMean axes0 do not keep dims', function() { + testReduce( + 'Mean', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [2, 2], values: [30., 1., 40., 2.]}); + }); + + it('reduceMean axes1 do not keep dims', function() { + testReduce( + 'Mean', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [3, 2], values: [12.5, 1.5, 35., 1.5, 57.5, 1.5]}); + }); + + it('reduceMean axes2 do not keep dims', function() { + testReduce( + 'Mean', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [3, 2], values: [3., 11., 15.5, 21., 28., 31.]}); + }); + + it('reduceMean negative axes do not keep dims', function() { + testReduce( + 'Mean', {axes: [-1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [3, 2], values: [3., 11., 15.5, 21., 28., 31.]}); + }); + + it('reduceMean axes0 keep dims', function() { + testReduce( + 'Mean', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [1, 2, 2], values: [30., 1., 40., 2.]}); + }); + + it('reduceMean axes1 keep dims', function() { + testReduce( + 'Mean', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [3, 1, 2], values: [12.5, 1.5, 35., 1.5, 57.5, 1.5]}); + }); + + it('reduceMean axes2 keep dims', function() { + testReduce( + 'Mean', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [3, 2, 1], values: [3., 11., 15.5, 21., 28., 31.]}); + }); + + it('reduceMean negative axes keep dims', function() { + testReduce( + 'Mean', {axes: [-1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [5., 1., 20., 2., 30., 1., 40., 2., 55., 1., 60., 2.], + }, + {shape: [3, 2, 1], values: [3., 11., 15.5, 21., 28., 31.]}); + }); + + it('reduceMin default', function() { + testReduce( + 'Min', {}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [], values: [1.]}); + }); + + it('reduceMin default axes keep dims', function() { + testReduce( + 'Min', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [1, 1, 1], values: [1.]}); + }); + + it('reduceMin axes0 do not keep dims', function() { + testReduce( + 'Min', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [2, 2], values: [1., 3., 4., 2.]}); + }); + + it('reduceMin axes1 do not keep dims', function() { + testReduce( + 'Min', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2], values: [1., 2., 4., 3., 500., 5.]}); + }); + + it('reduceMin axes2 do not keep dims', function() { + testReduce( + 'Min', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2], values: [1., 2., 3., 4., 5., 6.]}); + }); + + it('reduceMin negative axes do not keep dims', function() { + testReduce( + 'Min', {axes: [-1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2], values: [1., 2., 3., 4., 5., 6.]}); + }); + + it('reduceMin axes0 keep dims', function() { + testReduce( + 'Min', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [1, 2, 2], values: [1., 3., 4., 2.]}); + }); + + it('reduceMin axes1 keep dims', function() { + testReduce( + 'Min', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 1, 2], values: [1., 2., 4., 3., 500., 5.]}); + }); + + it('reduceMin axes2 keep dims', function() { + testReduce( + 'Min', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2, 1], values: [1., 2., 3., 4., 5., 6.]}); + }); + + it('reduceMin negative axes keep dims', function() { + testReduce( + 'Min', {axes: [-1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [1., 100., 200., 2., 300., 3., 4., 400., 500., 5., 600., 6.], + }, + {shape: [3, 2, 1], values: [1., 2., 3., 4., 5., 6.]}); + }); + + it('reduceProduct default', function() { + testReduce( + 'Product', {}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + {shape: [], values: [0.]}); + }); + + it('reduceProduct default axes keep dims', function() { + testReduce( + 'Product', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + {shape: [1, 1, 1], values: [0.]}); + }); + + it('reduceProduct axes0 do not keep dims', function() { + testReduce( + 'Product', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [2, 2], + values: [0., 45., 120., 231.], + }); + }); + + it('reduceProduct axes1 do not keep dims', function() { + testReduce( + 'Product', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2], + values: [0., 3., 24., 35., 80., 99.], + }); + }); + + it('reduceProduct axes2 do not keep dims', function() { + testReduce( + 'Product', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2], + values: [0., 6., 20., 42., 72., 110.], + }); + }); + + it('reduceProduct negative axes do not keep dims', function() { + testReduce( + 'Product', {axes: [-1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2], + values: [0., 6., 20., 42., 72., 110.], + }); + }); + + it('reduceProduct axes0 keep dims', function() { + testReduce( + 'Product', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [1, 2, 2], + values: [0., 45., 120., 231.], + }); + }); + + it('reduceProduct axes1 keep dims', function() { + testReduce( + 'Product', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 1, 2], + values: [0., 3., 24., 35., 80., 99.], + }); + }); + + it('reduceProduct axes2 keep dims', function() { + testReduce( + 'Product', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2, 1], + values: [0., 6., 20., 42., 72., 110.], + }); + }); + + it('reduceProduct negative axes keep dims', function() { + testReduce( + 'Product', {axes: [-1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2, 1], + values: [0., 6., 20., 42., 72., 110.], + }); + }); + + it('reduceSum default', function() { + testReduce( + 'Sum', {}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + {shape: [], values: [66.]}); + }); + + it('reduceSum default axes keep dims', function() { + testReduce( + 'Sum', {keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + {shape: [1, 1, 1], values: [66.]}); + }); + + it('reduceSum axes0 do not keep dims', function() { + testReduce( + 'Sum', {axes: [0], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [2, 2], + values: [12., 15., 18., 21.], + }); + }); + + it('reduceSum axes1 do not keep dims', function() { + testReduce( + 'Sum', {axes: [1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2], + values: [2., 4., 10., 12., 18., 20.], + }); + }); + + it('reduceSum axes2 do not keep dims', function() { + testReduce( + 'Sum', {axes: [2], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2], + values: [1., 5., 9., 13., 17., 21.], + }); + }); + + it('reduceSum negative axes do not keep dims', function() { + testReduce( + 'Sum', {axes: [-1], keepDimensions: false}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2], + values: [1., 5., 9., 13., 17., 21.], + }); + }); + + it('reduceSum axes0 keep dims', function() { + testReduce( + 'Sum', {axes: [0], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [1, 2, 2], + values: [12., 15., 18., 21.], + }); + }); + + it('reduceSum axes1 keep dims', function() { + testReduce( + 'Sum', {axes: [1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 1, 2], + values: [2., 4., 10., 12., 18., 20.], + }); + }); + + it('reduceSum axes2 keep dims', function() { + testReduce( + 'Sum', {axes: [2], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2, 1], + values: [1., 5., 9., 13., 17., 21.], + }); + }); + + it('reduceSum negative axes keep dims', function() { + testReduce( + 'Sum', {axes: [-1], keepDimensions: true}, { + shape: [3, 2, 2], + values: [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.], + }, + { + shape: [3, 2, 1], + values: [1., 5., 9., 13., 17., 21.], + }); + }); +}); diff --git a/test/relu_test.js b/test/relu_test.js new file mode 100644 index 0000000..f849ea1 --- /dev/null +++ b/test/relu_test.js @@ -0,0 +1,42 @@ +'use strict'; + +import {relu} from '../src/relu.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test relu', function() { + it('relu', function() { + const inputShape = [3, 4, 5]; + const inputData = [ + -1.483762, 0.6447428, -1.2266507, -1.7132527, 0.9777725, + -0.34438756, -0.99921757, -1.2882805, 1.3725083, -0.06386258, + -0.44738683, -0.6776338, 0.5027815, -1.0428967, -1.4220539, + 0.00880813, -1.2053454, 1.1644533, -1.6577007, -0.33448243, + 0.69386536, 0.06171616, -0.20644434, 1.0620342, -0.8824057, + -0.7676657, 0.7517342, 1.4035656, -0.29105335, 0.18367627, + 1.3628657, -0.39770076, -0.1550809, -1.2575449, 0.5797014, + -0.02414344, 0.9181723, -1.1963434, 0.56652546, -0.25052008, + -0.02097719, -2.6274924, 0.7993208, -0.31359985, 0.9019325, + -0.02042965, 0.5222995, 1.3394557, -1.0482218, 1.1774449, + 0.8999488, -1.1143959, 1.0122099, -0.48604885, -0.06009902, + -0.1766853, 1.4515465, -0.7182982, 2.0361354, 0.7899623, + ]; + const inputTensor = new Tensor(inputShape, inputData); + const expectedShape = [3, 4, 5]; + const expectedData = [ + 0., 0.6447428, 0., 0., 0.9777725, 0., + 0., 0., 1.3725083, 0., 0., 0., + 0.5027815, 0., 0., 0.00880813, 0., 1.1644533, + 0., 0., 0.69386536, 0.06171616, 0., 1.0620342, + 0., 0., 0.7517342, 1.4035656, 0., 0.18367627, + 1.3628657, 0., 0., 0., 0.5797014, 0., + 0.9181723, 0., 0.56652546, 0., 0., 0., + 0.7993208, 0., 0.9019325, 0., 0.5222995, 1.3394557, + 0., 1.1774449, 0.8999488, 0., 1.0122099, 0., + 0., 0., 1.4515465, 0., 2.0361354, 0.7899623, + ]; + const outputTensor = relu(inputTensor); + utils.checkShape(outputTensor, expectedShape); + utils.checkValue(outputTensor, expectedData); + }); +}); diff --git a/test/reshape_test.js b/test/reshape_test.js new file mode 100644 index 0000000..640aea2 --- /dev/null +++ b/test/reshape_test.js @@ -0,0 +1,47 @@ +'use strict'; + +import {reshape} from '../src/reshape.js'; +import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test reshape', function() { + function testReshape(oldShape, newShape, expectedShape) { + const bufferSize = sizeOfShape(oldShape); + const inputBuffer = new Array(bufferSize); + for (let i = 0; i < inputBuffer.length; ++i) { + inputBuffer[i] = Math.random(); + } + const x = new Tensor(oldShape, inputBuffer); + const y = reshape(x, newShape); + utils.checkShape(y, expectedShape ? expectedShape : newShape); + utils.checkValue(y, inputBuffer); + } + + it('reshape reordered_all_dims', function() { + testReshape([2, 3, 4], [4, 2, 3]); + }); + + it('reshape reordered_last_dims', function() { + testReshape([2, 3, 4], [2, 4, 3]); + }); + + it('reshape reduced_dims', function() { + testReshape([2, 3, 4], [2, 12]); + }); + + it('reshape extended_dims', function() { + testReshape([2, 3, 4], [2, 3, 2, 2]); + }); + + it('reshape one_dim', function() { + testReshape([2, 3, 4], [24]); + }); + + it('reshape [2, 3, 4] to negative_dim [2, -1, 2]', function() { + testReshape([2, 3, 4], [2, -1, 2], [2, 6, 2]); + }); + + it('reshape [2, 3, 4] to negative_dim [-1, 2, 3, 4]', function() { + testReshape([2, 3, 4], [-1, 2, 3, 4], [1, 2, 3, 4]); + }); +}); diff --git a/test/sigmoid_test.js b/test/sigmoid_test.js new file mode 100644 index 0000000..d7c0ca5 --- /dev/null +++ b/test/sigmoid_test.js @@ -0,0 +1,49 @@ +'use strict'; + +import {sigmoid} from '../src/sigmoid.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test sigmoid', function() { + async function testSigmoid(input, expected, shape) { + const inputTensor = new Tensor(shape, input); + const outputTensor = sigmoid(inputTensor); + utils.checkValue(outputTensor, expected); + } + it('sigmoid 1d', async function() { + testSigmoid([-1, 0, 1], [0.26894143, 0.5, 0.7310586], [3]); + }); + + it('sigmoid 3d', async function() { + testSigmoid( + [ + -0.18371736, 0.4805392, 2.7183356, 0.03039639, 0.04197176, + -1.1536852, -2.0124357, -0.885673, -0.25776535, 1.0151213, + -0.22013742, 0.13626824, 0.8574488, -0.15987602, 0.7025059, + -0.8209337, 1.2621661, 0.4055987, -0.65470445, 0.14290208, + 1.6874043, -0.7997532, -1.0582826, 1.0813274, -1.9656292, + -0.13285251, 0.87344545, -0.07760263, 1.0503976, -0.23713546, + 0.21536243, 0.59599924, -0.8221842, 0.10256762, -0.67856175, + 1.1891315, -0.6567207, -0.2958169, -1.9581499, -0.9223802, + -0.32011083, -0.31802705, 0.7264381, 1.0234208, 0.673269, + 0.96394795, 0.6152301, -0.4362364, -1.2325221, -0.11140272, + -0.43866253, 0.5770897, 0.42372307, -0.33066413, -0.46210232, + -0.6456375, 2.0984166, -1.2020895, 1.5637838, -0.7114222, + ], + [ + 0.4541994, 0.61787516, 0.9381, 0.50759846, 0.5104914, + 0.23981662, 0.11790343, 0.29200357, 0.43591312, 0.7340212, + 0.44518682, 0.53401446, 0.7021274, 0.4601159, 0.66874313, + 0.3055655, 0.77939874, 0.6000321, 0.34193018, 0.53566486, + 0.8438825, 0.31007832, 0.2576378, 0.7467451, 0.12285913, + 0.46683565, 0.70546216, 0.48060906, 0.7408512, 0.44099236, + 0.55363345, 0.64474046, 0.3053002, 0.52561945, 0.33658236, + 0.7665857, 0.34147665, 0.4265804, 0.12366741, 0.28447315, + 0.42064875, 0.42115664, 0.67402315, 0.7356384, 0.6622347, + 0.7239115, 0.64913297, 0.39263815, 0.2257403, 0.47217807, + 0.39205968, 0.6403975, 0.6043738, 0.41807905, 0.38648725, + 0.34397328, 0.89074916, 0.2311037, 0.8268956, 0.32928467, + ], + [3, 4, 5]); + }); +}); diff --git a/test/slice_test.js b/test/slice_test.js new file mode 100644 index 0000000..2d7b03f --- /dev/null +++ b/test/slice_test.js @@ -0,0 +1,199 @@ +'use strict'; + +import {slice} from '../src/slice.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test slice', function() { + function testSlice(inputShape, inputData, starts, sizes, axes, expectedShape, expected) { + const input = new Tensor(inputShape, inputData); + const options = {}; + if (axes) { + options.axes = axes; + } + const output = slice(input, starts, sizes, options); + utils.checkShape(output, expectedShape); + utils.checkValue(output, expected); + } + + it('slice default axes', function() { + const inputShape = [3, 4, 5]; + const inputData = [ + 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, + -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, + 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, + -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, + 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, + 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, + -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, + 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, + -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, + -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, + -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, + ]; + const starts = [0, 0, 1]; + const sizes = [2, 3, 4]; + const expectedShape = [2, 3, 4]; + const expected = [ + 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, -3.7128052e-01, + 7.5784922e-01, 3.5759725e-02, 1.9211160e+00, -8.1603736e-01, + -1.8293047e+00, -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, + 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, 3.5038650e-01, + -3.6469769e-01, -8.7751287e-01, 2.7995768e-01, -1.6042528e+00, + -1.7991974e+00, -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, + ]; + testSlice( + inputShape, inputData, starts, sizes, undefined, expectedShape, + expected); + }); + + it('slice with negative starts', function() { + const inputShape = [3, 4, 5]; + const inputData = [ + 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, + -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, + 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, + -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, + 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, + 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, + -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, + 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, + -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, + -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, + -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, + ]; + const starts = [-3, -4, -4]; + const sizes = [2, 3, 4]; + const expectedShape = [2, 3, 4]; + const expected = [ + 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, -3.7128052e-01, + 7.5784922e-01, 3.5759725e-02, 1.9211160e+00, -8.1603736e-01, + -1.8293047e+00, -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, + 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, 3.5038650e-01, + -3.6469769e-01, -8.7751287e-01, 2.7995768e-01, -1.6042528e+00, + -1.7991974e+00, -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, + ]; + testSlice( + inputShape, inputData, starts, sizes, undefined, expectedShape, + expected); + }); + + it('slice with axes', function() { + const inputShape = [3, 4, 5]; + const inputData = [ + 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, + -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, + 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, + -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, + 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, + 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, + -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, + 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, + -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, + -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, + -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, + ]; + const starts = [0, 1]; + const sizes = [2, 4]; + const axes = [0, 2]; + const expectedShape = [2, 4, 4]; + const expected = [ + 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, -3.7128052e-01, + 7.5784922e-01, 3.5759725e-02, 1.9211160e+00, -8.1603736e-01, + -1.8293047e+00, -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, 3.5038650e-01, + -3.6469769e-01, -8.7751287e-01, 2.7995768e-01, -1.6042528e+00, + -1.7991974e+00, -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + ]; + testSlice( + inputShape, inputData, starts, sizes, axes, expectedShape, expected); + }); + + it('slice with negative axes', function() { + const inputShape = [3, 4, 5]; + const inputData = [ + 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, + -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, + 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, + -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, + 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, + 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, + -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, + 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, + -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, + -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, + -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, + ]; + const starts = [0, 1]; + const sizes = [2, 4]; + const axes = [-3, -1]; + const expectedShape = [2, 4, 4]; + const expected = [ + 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, -3.7128052e-01, + 7.5784922e-01, 3.5759725e-02, 1.9211160e+00, -8.1603736e-01, + -1.8293047e+00, -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, 3.5038650e-01, + -3.6469769e-01, -8.7751287e-01, 2.7995768e-01, -1.6042528e+00, + -1.7991974e+00, -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + ]; + testSlice( + inputShape, inputData, starts, sizes, axes, expectedShape, expected); + }); + + it('slice with -1 sizes', function() { + const inputShape = [3, 4, 5]; + const inputData = [ + 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, + -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, + 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, + -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, + 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, + 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, + -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, + 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, + -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, + -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, + -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, + ]; + const starts = [0, -4, 1]; + const sizes = [2, -1, 4]; + const axes = [0, 1, 2]; + const expectedShape = [2, 4, 4]; + const expected = [ + 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, -3.7128052e-01, + 7.5784922e-01, 3.5759725e-02, 1.9211160e+00, -8.1603736e-01, + -1.8293047e+00, -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, 3.5038650e-01, + -3.6469769e-01, -8.7751287e-01, 2.7995768e-01, -1.6042528e+00, + -1.7991974e+00, -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + ]; + testSlice( + inputShape, inputData, starts, sizes, axes, expectedShape, expected); + }); +}); diff --git a/test/softmax_test.js b/test/softmax_test.js new file mode 100644 index 0000000..f081b32 --- /dev/null +++ b/test/softmax_test.js @@ -0,0 +1,41 @@ +'use strict'; + +import {softmax} from '../src/softmax.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test softmax', function() { + it('softmax', function() { + const x = new Tensor([3, 4], [ + 0.4301911, + 0.54719144, + -1.1637765, + 0.18390046, + 0.58390397, + 0.1735679, + 0.539724, + -0.953514, + -0.59202826, + -0.17344485, + 0.14395015, + -0.37920907, + ]); + const y = softmax(x); + utils.checkShape(y, [3, 4]); + const expected = [ + 0.3216537706230936, + 0.36157737612692964, + 0.06533370899961785, + 0.2514351442503589, + 0.35271572559067943, + 0.23400122550752556, + 0.33747196917113453, + 0.07581107973066051, + 0.17110128476868686, + 0.2600409450936177, + 0.35717794384660767, + 0.2116798262910878, + ]; + utils.checkValue(y, expected); + }); +}); diff --git a/test/split_test.js b/test/split_test.js new file mode 100644 index 0000000..738271e --- /dev/null +++ b/test/split_test.js @@ -0,0 +1,79 @@ +'use strict'; + +import {split} from '../src/split.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test split', function() { + function testSplit( + inputShape, inputValue, expectedArray, splits, axis = undefined) { + const x = new Tensor(inputShape, inputValue); + const options = {}; + if (axis !== undefined) { + options.axis = axis; + } + const splittedOutputs = split(x, splits, options); + for (let i = 0; i < splittedOutputs.length; ++i) { + utils.checkShape(splittedOutputs[i], expectedArray[i].shape); + utils.checkValue(splittedOutputs[i], expectedArray[i].value); + } + } + + it('split', function() { + testSplit( + [6], [1, 2, 3, 4, 5, 6], + [ + {shape: [2], value: [1, 2]}, + {shape: [2], value: [3, 4]}, + {shape: [2], value: [5, 6]}, + ], + 3); + testSplit( + [6], [1, 2, 3, 4, 5, 6], + [ + {shape: [2], value: [1, 2]}, + {shape: [2], value: [3, 4]}, + {shape: [2], value: [5, 6]}, + ], + 3, -1); + testSplit( + [6], [1, 2, 3, 4, 5, 6], + [{shape: [2], value: [1, 2]}, {shape: [4], value: [3, 4, 5, 6]}], + [2, 4]); + testSplit( + [6], [1, 2, 3, 4, 5, 6], + [{shape: [2], value: [1, 2]}, {shape: [4], value: [3, 4, 5, 6]}], + [2, 4], -1); + }); + + it('split 2d', function() { + testSplit( + [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + [ + {shape: [2, 3], value: [1, 2, 3, 7, 8, 9]}, + {shape: [2, 3], value: [4, 5, 6, 10, 11, 12]}, + ], + 2, 1); + testSplit( + [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + [ + {shape: [2, 3], value: [1, 2, 3, 7, 8, 9]}, + {shape: [2, 3], value: [4, 5, 6, 10, 11, 12]}, + ], + 2, -1); + testSplit( + [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + [ + {shape: [2, 2], value: [1, 2, 7, 8]}, + {shape: [2, 4], value: [3, 4, 5, 6, 9, 10, 11, 12]}, + ], + [2, 4], 1); + testSplit( + [2, 6], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + [ + {shape: [2, 2], value: [1, 2, 7, 8]}, + {shape: [2, 4], value: [3, 4, 5, 6, 9, 10, 11, 12]}, + ], + [2, 4], -1); + }); +}); diff --git a/test/squeeze_test.js b/test/squeeze_test.js new file mode 100644 index 0000000..e2103be --- /dev/null +++ b/test/squeeze_test.js @@ -0,0 +1,35 @@ +'use strict'; + +import {squeeze} from '../src/squeeze.js'; +import {Tensor, sizeOfShape} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test squeeze', function() { + function testSqueeze(oldShape, axes, expectedShape) { + const bufferSize = sizeOfShape(oldShape); + const inputBuffer = new Array(bufferSize); + for (let i = 0; i < inputBuffer.length; ++i) { + inputBuffer[i] = Math.random(); + } + const x = new Tensor(oldShape, inputBuffer); + const y = squeeze(x, {axes}); + utils.checkShape(y, expectedShape); + utils.checkValue(y, inputBuffer); + } + + it('squeeze one dimension by default', function() { + testSqueeze([1, 3, 4, 5], undefined, [3, 4, 5]); + }); + + it('squeeze one dimension with axes', function() { + testSqueeze([1, 3, 1, 5], [0], [3, 1, 5]); + }); + + it('squeeze two dimensions by default', function() { + testSqueeze([1, 3, 1, 5], undefined, [3, 5]); + }); + + it('squeeze two dimensions with axes', function() { + testSqueeze([1, 3, 1, 5], [0, 2], [3, 5]); + }); +}); diff --git a/test/tanh_test.js b/test/tanh_test.js new file mode 100644 index 0000000..1d44aa2 --- /dev/null +++ b/test/tanh_test.js @@ -0,0 +1,99 @@ +'use strict'; + +import {tanh} from '../src/tanh.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test tanh', function() { + function testTanh(input, expected, shape) { + const x = new Tensor(shape, input); + const y = tanh(x); + utils.checkShape(y, shape); + utils.checkValue(y, expected); + } + + it('tanh 1d', function() { + testTanh([-1, 0, 1], [-0.7615941559557649, 0, 0.7615941559557649], [3]); + }); + + it('tanh 3d', function() { + testTanh( + [ + 0.15102264, -1.1556778, -0.0657572, -0.04362043, 1.13937, + 0.5458485, -1.1451102, 0.3929889, 0.56226826, -0.68606883, + 0.46685237, -0.53841704, 0.7025275, -1.5314125, 0.28699, + 0.84823394, -0.18585628, -0.319641, 0.41442505, 0.88782656, + 1.0844846, -0.56016934, 0.531165, 0.73836696, 1.0364187, + -0.07221687, -0.9580888, 1.8173703, -1.5682113, -1.272829, + 2.331454, 0.2967249, 0.21472701, -0.9332915, 2.3962052, + 0.498327, 0.53040606, 1.6241137, 0.8147571, -0.6471784, + 0.8489049, -0.33946696, -0.67703784, -0.07758674, 0.7667829, + 0.58996105, 0.7728692, -0.47817922, 2.1541011, -1.1611695, + 2.1465113, 0.64678246, 1.239878, -0.10861816, 0.07814338, + -1.026162, -0.8464255, 0.53589034, 0.93667775, 1.2927296, + ], + [ + 0.14988485243804905, + -0.8196262968070427, + -0.06566258539315876, + -0.04359278490025841, + 0.81420184519224, + 0.4974022861162581, + -0.8161277031478561, + 0.37393406888200575, + 0.5096584488637819, + -0.5954506103157979, + 0.43565258419537356, + -0.491788789926786, + 0.6059696311244674, + -0.9106660002679733, + 0.279362062510579, + 0.6901457130337738, + -0.18374545679834775, + -0.30918227539202525, + 0.3922234692633119, + 0.7103185679032363, + 0.7948562387257937, + -0.5081030654144844, + 0.48627111897365854, + 0.628157503380644, + 0.776469901729632, + -0.07209158770240355, + -0.7434231595576823, + 0.9485755721878665, + -0.916740776910177, + -0.854562546243227, + 0.9812985744347655, + 0.2883125958646078, + 0.21148657183129607, + -0.7321248135617321, + 0.9835515078926846, + 0.4608004134276178, + 0.485691423797799, + 0.9252187290148373, + 0.6722061563570123, + -0.569767419642911, + 0.6904969313562962, + -0.3270014357504291, + -0.5895903065238024, + -0.07743143093124916, + 0.6450548687979449, + 0.5298675936947264, + 0.6485947437799567, + -0.4447842241322541, + 0.9734419663341614, + -0.8214206480085401, + 0.973041226248526, + 0.5694999552658558, + 0.8454207971475833, + -0.10819301066972077, + 0.07798470962146853, + -0.7723644708905993, + -0.6891974525265244, + 0.4898708088154638, + 0.7336921191412772, + 0.8598397504927623, + ], + [3, 4, 5]); + }); +}); diff --git a/test/transpose_test.js b/test/transpose_test.js new file mode 100644 index 0000000..cb5ae2a --- /dev/null +++ b/test/transpose_test.js @@ -0,0 +1,89 @@ +'use strict'; + +import {Tensor} from '../src/lib/tensor.js'; +import {transpose} from '../src/transpose.js'; +import * as utils from './utils.js'; + +describe('test transpose', function() { + function checkTranspose( + inputShape, inputData, expectedShape, expected, permutation = undefined) { + const inputTensor = new Tensor(inputShape, inputData); + const outputTensor = transpose(inputTensor, {permutation}); + utils.checkShape(outputTensor, expectedShape); + utils.checkValue(outputTensor, expected); + } + + it('transpose default', function() { + const inputShape = [2, 3, 4]; + const inputData = [ + 0.43376675, 0.264609, 0.26321858, 0.04260185, 0.6862414, 0.26150206, + 0.04169406, 0.24857993, 0.14914423, 0.19905873, 0.33851373, 0.74131566, + 0.91501445, 0.21852633, 0.02267954, 0.22069663, 0.95799077, 0.17188412, + 0.09732241, 0.03296741, 0.04709655, 0.50648814, 0.13075736, 0.82511896, + ]; + const expected = [ + 0.43376675, 0.91501445, 0.6862414, 0.95799077, 0.14914423, 0.04709655, + 0.264609, 0.21852633, 0.26150206, 0.17188412, 0.19905873, 0.50648814, + 0.26321858, 0.02267954, 0.04169406, 0.09732241, 0.33851373, 0.13075736, + 0.04260185, 0.22069663, 0.24857993, 0.03296741, 0.74131566, 0.82511896, + ]; + checkTranspose(inputShape, inputData, [4, 3, 2], expected); + }); + + it('transpose permutations', function() { + const permutations = + [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]]; + const inputShape = [2, 3, 4]; + const inputData = [ + 0.7760998, 0.8363521, 0.10145967, 0.00533229, 0.8190919, 0.83241564, + 0.39479077, 0.5622921, 0.9306249, 0.00480607, 0.39600816, 0.35415828, + 0.43689877, 0.7603583, 0.14368972, 0.11940759, 0.4834097, 0.6982117, + 0.7195266, 0.72893023, 0.896649, 0.13060148, 0.07824122, 0.33766487, + ]; + const expectedShapes = + [[2, 3, 4], [2, 4, 3], [3, 2, 4], [3, 4, 2], [4, 2, 3], [4, 3, 2]]; + const expecteds = [ + [ + 0.7760998, 0.8363521, 0.10145967, 0.00533229, 0.8190919, 0.83241564, + 0.39479077, 0.5622921, 0.9306249, 0.00480607, 0.39600816, 0.35415828, + 0.43689877, 0.7603583, 0.14368972, 0.11940759, 0.4834097, 0.6982117, + 0.7195266, 0.72893023, 0.896649, 0.13060148, 0.07824122, 0.33766487, + ], + [ + 0.7760998, 0.8190919, 0.9306249, 0.8363521, 0.83241564, 0.00480607, + 0.10145967, 0.39479077, 0.39600816, 0.00533229, 0.5622921, 0.35415828, + 0.43689877, 0.4834097, 0.896649, 0.7603583, 0.6982117, 0.13060148, + 0.14368972, 0.7195266, 0.07824122, 0.11940759, 0.72893023, 0.33766487, + ], + [ + 0.7760998, 0.8363521, 0.10145967, 0.00533229, 0.43689877, 0.7603583, + 0.14368972, 0.11940759, 0.8190919, 0.83241564, 0.39479077, 0.5622921, + 0.4834097, 0.6982117, 0.7195266, 0.72893023, 0.9306249, 0.00480607, + 0.39600816, 0.35415828, 0.896649, 0.13060148, 0.07824122, 0.33766487, + ], + [ + 0.7760998, 0.43689877, 0.8363521, 0.7603583, 0.10145967, 0.14368972, + 0.00533229, 0.11940759, 0.8190919, 0.4834097, 0.83241564, 0.6982117, + 0.39479077, 0.7195266, 0.5622921, 0.72893023, 0.9306249, 0.896649, + 0.00480607, 0.13060148, 0.39600816, 0.07824122, 0.35415828, 0.33766487, + ], + [ + 0.7760998, 0.8190919, 0.9306249, 0.43689877, 0.4834097, 0.896649, + 0.8363521, 0.83241564, 0.00480607, 0.7603583, 0.6982117, 0.13060148, + 0.10145967, 0.39479077, 0.39600816, 0.14368972, 0.7195266, 0.07824122, + 0.00533229, 0.5622921, 0.35415828, 0.11940759, 0.72893023, 0.33766487, + ], + [ + 0.7760998, 0.43689877, 0.8190919, 0.4834097, 0.9306249, 0.896649, + 0.8363521, 0.7603583, 0.83241564, 0.6982117, 0.00480607, 0.13060148, + 0.10145967, 0.14368972, 0.39479077, 0.7195266, 0.39600816, 0.07824122, + 0.00533229, 0.11940759, 0.5622921, 0.72893023, 0.35415828, 0.33766487, + ], + ]; + for (let i = 0; i < permutations.length; ++i) { + checkTranspose( + inputShape, inputData, expectedShapes[i], expecteds[i], + permutations[i]); + } + }); +}); diff --git a/test/unary_test.js b/test/unary_test.js new file mode 100644 index 0000000..be8899a --- /dev/null +++ b/test/unary_test.js @@ -0,0 +1,739 @@ +'use strict'; + +import {Tensor} from '../src/lib/tensor.js'; +import * as unaryFunctions from '../src/unary.js'; +import * as utils from './utils.js'; + +describe('test unary', function() { + function testUnary(op, input, expected, shape) { + const x = new Tensor(shape, input); + const y = unaryFunctions[op](x); + utils.checkShape(y, shape); + utils.checkValue(y, expected); + } + + it('abs', function() { + testUnary('abs', [-1, 0, 1], [1, 0, 1], [3]); + testUnary( + 'abs', + [-1.1, 0, 1.1, 2.2, 0, -2.2], + [1.1, 0, 1.1, 2.2, 0, 2.2], + [2, 3]); + testUnary( + 'abs', + [-1.1, 0, 1.1, 2.2, 0, -2.2], + [1.1, 0, 1.1, 2.2, 0, 2.2], + [1, 2, 3]); + testUnary( + 'abs', + [-1.1, 0, 1.1, 2.2, 0, -2.2], + [1.1, 0, 1.1, 2.2, 0, 2.2], + [1, 2, 3, 1]); + }); + + it('ceil', function() { + testUnary('ceil', [-1.1, 0, 1.1], [-1, 0, 2], [3]); + testUnary( + 'ceil', + [-1.1, 0, 1.1, -2.2, 0, 2.2], + [-1, 0, 2, -2, 0, 3], + [2, 3]); + testUnary( + 'ceil', + [-1.1, 0, 1.1, -2.2, 0, 2.2], + [-1, 0, 2, -2, 0, 3], + [1, 2, 3]); + testUnary( + 'ceil', + [-1.1, 0, 1.1, -2.2, 0, 2.2], + [-1, 0, 2, -2, 0, 3], + [1, 2, 3, 1]); + }); + + it('cos', function() { + testUnary( + 'cos', + [1.4124068, 1.9740626, -0.06506752, 0.73539704], + [ + 0.15772809760857773, + -0.39242469654349826, + 0.9978838556864368, + 0.7415644450136674, + ], + [4]); + testUnary( + 'cos', + [ + 1.4124068, 1.9740626, -0.06506752, 0.73539704, + -0.56439203, 0.89806247, 0.12939146, -0.34816208, + -1.0759926, 0.66291636, 0.21504708, -0.71527237, + ], + [ + 0.15772809760857773, + -0.39242469654349826, + 0.9978838556864368, + 0.7415644450136674, + 0.8449139610653698, + 0.6231265199397442, + 0.9916405976730124, + 0.9400013446543031, + 0.4748589278431268, + 0.7882008050685133, + 0.9769663487271947, + 0.7549146426895217, + ], + [3, 4]); + testUnary( + 'cos', + [ + 1.4124068, 1.9740626, + -0.06506752, 0.73539704, + -0.56439203, 0.89806247, + 0.12939146, -0.34816208, + -1.0759926, 0.66291636, + 0.21504708, -0.71527237, + ], + [ + 0.15772809760857773, + -0.39242469654349826, + 0.9978838556864368, + 0.7415644450136674, + 0.8449139610653698, + 0.6231265199397442, + 0.9916405976730124, + 0.9400013446543031, + 0.4748589278431268, + 0.7882008050685133, + 0.9769663487271947, + 0.7549146426895217, + ], + [3, 2, 2]); + testUnary( + 'cos', + [ + 1.4124068, + 1.9740626, + -0.06506752, + 0.73539704, + -0.56439203, + 0.89806247, + 0.12939146, + -0.34816208, + -1.0759926, + 0.66291636, + 0.21504708, + -0.71527237, + ], + [ + 0.15772809760857773, + -0.39242469654349826, + 0.9978838556864368, + 0.7415644450136674, + 0.8449139610653698, + 0.6231265199397442, + 0.9916405976730124, + 0.9400013446543031, + 0.4748589278431268, + 0.7882008050685133, + 0.9769663487271947, + 0.7549146426895217, + ], + [3, 2, 2, 1]); + }); + + it('exp', function() { + testUnary('exp', [-1, 0, 1], [0.36787944117144233, 1, 2.718281828459045], [3]); + testUnary( + 'exp', + [ + 0.3143407, 0.03632548, 0.5354084, -0.5000897, + 1.2028517, -1.2581364, -1.5108215, -1.2340564, + 1.3860914, -0.2944251, -1.5065757, -0.4673513, + ], + [ + 1.3693561967375985, + 1.036993312152022, + 1.708145706528859, + 0.6064762563524844, + 3.3295984129224556, + 0.2841831370156004, + 0.22072857493397907, + 0.29110932356722374, + 3.999188237901296, + 0.7449597417366962, + 0.2216677366530061, + 0.6266599061142515, + ], + [3, 4]); + testUnary( + 'exp', + [ + 0.3143407, 0.03632548, 0.5354084, -0.5000897, 1.2028517, + -1.2581364, -1.5108215, -1.2340564, 1.3860914, -0.2944251, + -1.5065757, -0.4673513, 0.56616277, 0.77866685, -0.01097398, + 1.0758846, 0.6035437, 0.36806744, 0.03906458, -0.54385495, + 0.10609569, -0.40644982, -1.2890846, 1.3825086, 0.51489764, + 1.6407244, -0.67886734, -0.6556329, 1.0399923, 0.1484657, + 1.011217, 0.8451463, 0.75473833, -2.0161264, 1.6406634, + -0.01692923, -0.7986609, 0.97758174, 0.893054, -0.01632686, + -1.9721986, -0.75843745, 0.42327842, -0.08648382, -1.3960054, + 0.7547995, -0.42002508, -1.784105, 1.0171342, 0.3634587, + 0.4158588, -1.0103701, -0.23202766, 0.6390487, -0.22796124, + 0.11259284, 0.3690759, -0.18703128, 0.07711394, 2.9116163, + ], + [ + 1.3693561967375985, + 1.036993312152022, + 1.708145706528859, + 0.6064762563524844, + 3.3295984129224556, + 0.2841831370156004, + 0.22072857493397907, + 0.29110932356722374, + 3.999188237901296, + 0.7449597417366962, + 0.2216677366530061, + 0.6266599061142515, + 1.7614948056979465, + 2.1785659734395244, + 0.9890860144586422, + 2.9325859189764807, + 1.828587297220383, + 1.4449394824067432, + 1.0398376341962599, + 0.580506111445723, + 1.111928271832297, + 0.666010515185213, + 0.27552288133235436, + 3.984885583357506, + 1.6734671969280184, + 5.158905269957428, + 0.5071911422650927, + 0.5191134117181747, + 2.829195229464421, + 1.1600530072738744, + 2.7489444455169916, + 2.328318422639379, + 2.127054864081674, + 0.13317031580877914, + 5.158590586333908, + 0.9832132641755142, + 0.4499310635787923, + 2.6580206785468157, + 2.4425779049391343, + 0.983805700764556, + 0.13915058318029658, + 0.4683977504013186, + 1.526959372817342, + 0.9171503881579351, + 0.2475839902489274, + 2.127184980007265, + 0.6570303412874574, + 0.16794730659465534, + 2.765258719399421, + 1.4382954540763537, + 1.5156718408966148, + 0.36408420706837397, + 0.7929241907202421, + 1.8946776149031732, + 0.7961551182097963, + 1.119176156404052, + 1.446397381069858, + 0.8294177917911054, + 1.0801651433977786, + 18.38649265135158, + ], + [3, 4, 5]); + testUnary( + 'exp', + [ + 0.3143407, 0.03632548, 0.5354084, -0.5000897, 1.2028517, + -1.2581364, -1.5108215, -1.2340564, 1.3860914, -0.2944251, + -1.5065757, -0.4673513, 0.56616277, 0.77866685, -0.01097398, + 1.0758846, 0.6035437, 0.36806744, 0.03906458, -0.54385495, + 0.10609569, -0.40644982, -1.2890846, 1.3825086, 0.51489764, + 1.6407244, -0.67886734, -0.6556329, 1.0399923, 0.1484657, + 1.011217, 0.8451463, 0.75473833, -2.0161264, 1.6406634, + -0.01692923, -0.7986609, 0.97758174, 0.893054, -0.01632686, + -1.9721986, -0.75843745, 0.42327842, -0.08648382, -1.3960054, + 0.7547995, -0.42002508, -1.784105, 1.0171342, 0.3634587, + 0.4158588, -1.0103701, -0.23202766, 0.6390487, -0.22796124, + 0.11259284, 0.3690759, -0.18703128, 0.07711394, 2.9116163, + ], + [ + 1.3693561967375985, + 1.036993312152022, + 1.708145706528859, + 0.6064762563524844, + 3.3295984129224556, + 0.2841831370156004, + 0.22072857493397907, + 0.29110932356722374, + 3.999188237901296, + 0.7449597417366962, + 0.2216677366530061, + 0.6266599061142515, + 1.7614948056979465, + 2.1785659734395244, + 0.9890860144586422, + 2.9325859189764807, + 1.828587297220383, + 1.4449394824067432, + 1.0398376341962599, + 0.580506111445723, + 1.111928271832297, + 0.666010515185213, + 0.27552288133235436, + 3.984885583357506, + 1.6734671969280184, + 5.158905269957428, + 0.5071911422650927, + 0.5191134117181747, + 2.829195229464421, + 1.1600530072738744, + 2.7489444455169916, + 2.328318422639379, + 2.127054864081674, + 0.13317031580877914, + 5.158590586333908, + 0.9832132641755142, + 0.4499310635787923, + 2.6580206785468157, + 2.4425779049391343, + 0.983805700764556, + 0.13915058318029658, + 0.4683977504013186, + 1.526959372817342, + 0.9171503881579351, + 0.2475839902489274, + 2.127184980007265, + 0.6570303412874574, + 0.16794730659465534, + 2.765258719399421, + 1.4382954540763537, + 1.5156718408966148, + 0.36408420706837397, + 0.7929241907202421, + 1.8946776149031732, + 0.7961551182097963, + 1.119176156404052, + 1.446397381069858, + 0.8294177917911054, + 1.0801651433977786, + 18.38649265135158, + ], + [3, 2, 2, 5]); + }); + + it('floor', function() { + testUnary('floor', [-1.1, 0, 1.1], [-2, 0, 1], [3]); + testUnary( + 'floor', + [-1.1, 0, 1.1, -2.2, 0, 2.2], + [-2, 0, 1, -3, 0, 2], + [2, 3]); + testUnary( + 'floor', + [-1.1, 0, 1.1, -2.2, 0, 2.2], + [-2, 0, 1, -3, 0, 2], + [1, 2, 3]); + testUnary( + 'floor', + [-1.1, 0, 1.1, -2.2, 0, 2.2], + [-2, 0, 1, -3, 0, 2], + [1, 2, 3, 1]); + }); + + it('log', function() { + testUnary( + 'log', + [1.4599811, 0.34325936, 1.0420732], + [ + 0.37842349043097573, + -1.0692689659512902, + 0.04121219038394666, + ], + [3]); + testUnary( + 'log', + [ + 1.4599811, 0.34325936, 1.0420732, 0.10867598, + 0.39999306, 0.03704359, 1.5873954, 0.44784936, + 0.69070333, 1.8561625, 1.4088289, 0.06367786, + ], + [ + 0.37842349043097573, + -1.0692689659512902, + 0.04121219038394666, + -2.219384484434575, + -0.9163080820246681, + -3.2956599516545957, + 0.4620945598501042, + -0.8032983531118589, + -0.3700448817029436, + 0.618511184410702, + 0.34275879190200165, + -2.753918343538326, + ], + [3, 4]); + testUnary( + 'log', + [ + 1.4599811, 0.34325936, 1.0420732, 0.10867598, 0.39999306, + 0.03704359, 1.5873954, 0.44784936, 0.69070333, 1.8561625, + 1.4088289, 0.06367786, 0.32938832, 1.2429568, 1.1544572, + 0.47578564, 1.868428, 1.2279319, 1.0712656, 1.17982, + 1.460244, 0.62389, 0.79644215, 0.4196875, 0.372386, + 1.8887448, 1.4791015, 0.98091763, 0.45482925, 0.50871295, + 0.11605832, 0.86883324, 0.6235918, 1.392687, 0.75550365, + 0.35920736, 0.04935746, 0.13449927, 1.3587855, 0.9073937, + 1.0731584, 1.7933426, 1.9806778, 0.43379396, 1.3261564, + 0.52664477, 0.041302, 1.5167572, 0.6400343, 0.7669278, + 1.1766342, 1.6620969, 1.2579637, 1.7453014, 0.5470841, + 1.5960937, 0.37127188, 1.9055833, 1.3749765, 0.43101534, + ], + [ + 0.37842349043097573, + -1.0692689659512902, + 0.04121219038394666, + -2.219384484434575, + -0.9163080820246681, + -3.2956599516545957, + 0.4620945598501042, + -0.8032983531118589, + -0.3700448817029436, + 0.618511184410702, + 0.34275879190200165, + -2.753918343538326, + -1.1105179202764903, + 0.21749305729871296, + 0.1436302767995352, + -0.7427878623169412, + 0.6250974356178759, + 0.2053313721611498, + 0.0688407532508926, + 0.16536188446892103, + 0.3786035450443755, + -0.47178120820349856, + -0.22760078252711477, + -0.8682448922645803, + -0.987824328270859, + 0.6359124814574089, + 0.3914348088248874, + -0.019266788283548882, + -0.7878332051896428, + -0.6758713704300671, + -2.1536624555958537, + -0.14060407086584029, + -0.4722592913397491, + 0.3312349746469, + -0.2803706660434222, + -1.023855452786281, + -3.0086663593800362, + -2.006196507464266, + 0.306591286066901, + -0.09717885469019541, + 0.0706060762388415, + 0.5840812527784782, + 0.6834391093595378, + -0.835185604153331, + 0.2822848335262163, + -0.6412290184429051, + -3.1868443540375373, + 0.41657463482092677, + -0.4462335103145133, + -0.26536261503132674, + 0.162657989828439, + 0.5080799979827856, + 0.22949430253601164, + 0.5569272628026896, + -0.6031527406633165, + 0.467559206577478, + -0.990820654574951, + 0.644788155936501, + 0.31843664006339245, + -0.8416115978644251, + ], + [3, 4, 5]); + testUnary( + 'log', + [ + 1.4599811, 0.34325936, 1.0420732, 0.10867598, 0.39999306, + 0.03704359, 1.5873954, 0.44784936, 0.69070333, 1.8561625, + 1.4088289, 0.06367786, 0.32938832, 1.2429568, 1.1544572, + 0.47578564, 1.868428, 1.2279319, 1.0712656, 1.17982, + 1.460244, 0.62389, 0.79644215, 0.4196875, 0.372386, + 1.8887448, 1.4791015, 0.98091763, 0.45482925, 0.50871295, + 0.11605832, 0.86883324, 0.6235918, 1.392687, 0.75550365, + 0.35920736, 0.04935746, 0.13449927, 1.3587855, 0.9073937, + 1.0731584, 1.7933426, 1.9806778, 0.43379396, 1.3261564, + 0.52664477, 0.041302, 1.5167572, 0.6400343, 0.7669278, + 1.1766342, 1.6620969, 1.2579637, 1.7453014, 0.5470841, + 1.5960937, 0.37127188, 1.9055833, 1.3749765, 0.43101534, + ], + [ + 0.37842349043097573, + -1.0692689659512902, + 0.04121219038394666, + -2.219384484434575, + -0.9163080820246681, + -3.2956599516545957, + 0.4620945598501042, + -0.8032983531118589, + -0.3700448817029436, + 0.618511184410702, + 0.34275879190200165, + -2.753918343538326, + -1.1105179202764903, + 0.21749305729871296, + 0.1436302767995352, + -0.7427878623169412, + 0.6250974356178759, + 0.2053313721611498, + 0.0688407532508926, + 0.16536188446892103, + 0.3786035450443755, + -0.47178120820349856, + -0.22760078252711477, + -0.8682448922645803, + -0.987824328270859, + 0.6359124814574089, + 0.3914348088248874, + -0.019266788283548882, + -0.7878332051896428, + -0.6758713704300671, + -2.1536624555958537, + -0.14060407086584029, + -0.4722592913397491, + 0.3312349746469, + -0.2803706660434222, + -1.023855452786281, + -3.0086663593800362, + -2.006196507464266, + 0.306591286066901, + -0.09717885469019541, + 0.0706060762388415, + 0.5840812527784782, + 0.6834391093595378, + -0.835185604153331, + 0.2822848335262163, + -0.6412290184429051, + -3.1868443540375373, + 0.41657463482092677, + -0.4462335103145133, + -0.26536261503132674, + 0.162657989828439, + 0.5080799979827856, + 0.22949430253601164, + 0.5569272628026896, + -0.6031527406633165, + 0.467559206577478, + -0.990820654574951, + 0.644788155936501, + 0.31843664006339245, + -0.8416115978644251, + ], + [3, 2, 2, 5]); + }); + + it('neg', function() { + testUnary('neg', [-1.1, 0, 1.1], [1.1, -0, -1.1], [3]); + testUnary( + 'neg', + [-1, 0, 1.1, -2.2, 0, 2], + [1, -0, -1.1, 2.2, -0, -2], + [2, 3]); + testUnary( + 'neg', + [-1, 0, 1.1, -2.2, 0, 2], + [1, -0, -1.1, 2.2, -0, -2], + [1, 2, 3]); + testUnary( + 'neg', + [-1, 0, 1.1, -2.2, 0, 2], + [1, -0, -1.1, 2.2, -0, -2], + [1, 2, 3, 1]); + }); + + it('sin', function() { + testUnary( + 'sin', + [1.4124068, 1.9740626, -0.06506752, 0.73539704], + [ + 0.9874825807196697, + 0.9197841363835013, + -0.06502161610088251, + 0.6708816392565617, + ], + [4]); + testUnary( + 'sin', + [ + 1.4124068, 1.9740626, -0.06506752, 0.73539704, + -0.56439203, 0.89806247, 0.12939146, -0.34816208, + -1.0759926, 0.66291636, 0.21504708, -0.71527237, + ], + [ + 0.9874825807196697, + 0.9197841363835013, + -0.06502161610088251, + 0.6708816392565617, + -0.5349022325592097, + 0.7821210521062475, + 0.12903071357901844, + -0.34117073738540654, + -0.8800619288707335, + 0.6154181431265636, + 0.21339342411295945, + -0.6558230571220807, + ], + [3, 4]); + testUnary( + 'sin', + [ + 1.4124068, 1.9740626, + -0.06506752, 0.73539704, + -0.56439203, 0.89806247, + 0.12939146, -0.34816208, + -1.0759926, 0.66291636, + 0.21504708, -0.71527237, + ], + [ + 0.9874825807196697, + 0.9197841363835013, + -0.06502161610088251, + 0.6708816392565617, + -0.5349022325592097, + 0.7821210521062475, + 0.12903071357901844, + -0.34117073738540654, + -0.8800619288707335, + 0.6154181431265636, + 0.21339342411295945, + -0.6558230571220807, + ], + [3, 2, 2]); + testUnary( + 'sin', + [ + 1.4124068, + 1.9740626, + -0.06506752, + 0.73539704, + -0.56439203, + 0.89806247, + 0.12939146, + -0.34816208, + -1.0759926, + 0.66291636, + 0.21504708, + -0.71527237, + ], + [ + 0.9874825807196697, + 0.9197841363835013, + -0.06502161610088251, + 0.6708816392565617, + -0.5349022325592097, + 0.7821210521062475, + 0.12903071357901844, + -0.34117073738540654, + -0.8800619288707335, + 0.6154181431265636, + 0.21339342411295945, + -0.6558230571220807, + ], + [3, 2, 2, 1]); + }); + + it('tan', function() { + testUnary( + 'tan', + [1.4124068, 1.9740626, -0.06506752, 0.73539704], + [ + 6.260663735197218, + -2.3438487548949354, + -0.06515950301265735, + 0.9046842034668978, + ], + [4]); + testUnary( + 'tan', + [ + 1.4124068, 1.9740626, -0.06506752, 0.73539704, + -0.56439203, 0.89806247, 0.12939146, -0.34816208, + -1.0759926, 0.66291636, 0.21504708, -0.71527237, + ], + [ + 6.260663735197218, + -2.3438487548949354, + -0.06515950301265735, + 0.9046842034668978, + -0.633084855036293, + 1.2551560992491184, + 0.1301184258508601, + -0.3629470737734702, + -1.8533123782006025, + 0.7807885239004154, + 0.2184245387683735, + -0.8687380268391548, + ], + [3, 4]); + testUnary( + 'tan', + [ + 1.4124068, 1.9740626, + -0.06506752, 0.73539704, + -0.56439203, 0.89806247, + 0.12939146, -0.34816208, + -1.0759926, 0.66291636, + 0.21504708, -0.71527237, + ], + [ + 6.260663735197218, + -2.3438487548949354, + -0.06515950301265735, + 0.9046842034668978, + -0.633084855036293, + 1.2551560992491184, + 0.1301184258508601, + -0.3629470737734702, + -1.8533123782006025, + 0.7807885239004154, + 0.2184245387683735, + -0.8687380268391548, + ], + [3, 2, 2]); + testUnary( + 'tan', + [ + 1.4124068, + 1.9740626, + -0.06506752, + 0.73539704, + -0.56439203, + 0.89806247, + 0.12939146, + -0.34816208, + -1.0759926, + 0.66291636, + 0.21504708, + -0.71527237, + ], + [ + 6.260663735197218, + -2.3438487548949354, + -0.06515950301265735, + 0.9046842034668978, + -0.633084855036293, + 1.2551560992491184, + 0.1301184258508601, + -0.3629470737734702, + -1.8533123782006025, + 0.7807885239004154, + 0.2184245387683735, + -0.8687380268391548, + ], + [3, 2, 2, 1]); + }); +}); diff --git a/test/utils.js b/test/utils.js new file mode 100644 index 0000000..1e60b0a --- /dev/null +++ b/test/utils.js @@ -0,0 +1,58 @@ +'use strict'; + +const assert = chai.assert; + +/** + * Get bitwise of the given value. + * @param {Number} value + * @return {Number} A 64-bit signed integer. + */ +function getBitwise(value) { + const buffer = new ArrayBuffer(8); + const int64Array = new BigInt64Array(buffer); + int64Array[0] = value < 0 ? ~BigInt(0) : BigInt(0); + const f64Array = new Float64Array(buffer); + f64Array[0] = value; + return int64Array[0]; +} + +/** + * Asserts that the distance between a and b whether is close enough to the given ULP distance. + * @param {Number} a + * @param {Number} b + * @param {Number} nulp A BigInt value. + * @param {String} message A message to report when the assertion fails + * @return {Boolean} A boolean value: + * true: The distance between a and b is close enough to the given ULP distance. + * false: The distance between a and b is far away from the given ULP distance. + */ +assert.isAlmostEqualUlp = function(a, b, nulp, message) { + const aBitwise = getBitwise(a); + const bBitwise = getBitwise(b); + let distance = aBitwise - bBitwise; + distance = distance >= 0 ? distance : -distance; + return assert.isTrue(distance <= nulp, message); +}; + +export function checkValue(tensor, expected, nulp = 0) { + assert.isTrue(tensor.size === expected.length); + for (let i = 0; i < expected.length; ++i) { + assert.isAlmostEqualUlp(tensor.getValueByIndex(i), expected[i], nulp, + `${tensor.getValueByIndex(i)} is almost equal to ${expected[i]}`); + } +} + +export function checkShape(tensor, expected) { + assert.equal(tensor.rank, expected.length, + `Tensor has expected rank ${expected.length}: ${tensor.rank}`); + for (let i = 0; i < expected.length; ++i) { + assert.equal(tensor.shape[i], expected[i], + `Tensor line ${i} has expected length ${expected[i]}: ${tensor.shape[i]}`); + } +} + +export function bindTrailingArgs(fn, ...boundArgs) { + return function(...args) { + return fn(...args, ...boundArgs); + }; +}