From b83bed69a5b1d3745e5ff39dd9c79b80c00f36ff Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 18 Sep 2024 14:03:19 +0800 Subject: [PATCH] keep dimensions --- code/samples/matmul.js | 12 ++++++++-- code/samples/mul_add.js | 2 +- code/samples/simple_graph.js | 2 +- common/utils.js | 3 ++- face_recognition/facenet_nchw.js | 1 + face_recognition/facenet_nhwc.js | 1 + .../face_landmark_nchw.js | 1 + .../face_landmark_nhwc.js | 1 + .../ssd_mobilenetv2_face_nchw.js | 1 + .../ssd_mobilenetv2_face_nhwc.js | 1 + .../efficientnet_fp16_nchw.js | 1 + image_classification/mobilenet_nchw.js | 1 + image_classification/mobilenet_nhwc.js | 1 + image_classification/resnet50v1_fp16_nchw.js | 1 + image_classification/resnet50v2_nchw.js | 1 + image_classification/resnet50v2_nhwc.js | 1 + image_classification/squeezenet_nchw.js | 1 + image_classification/squeezenet_nhwc.js | 1 + lenet/lenet.js | 24 ++++++++++++------- nnotepad/js/index.js | 1 + nnotepad/js/nnotepad.js | 10 +++++--- nsnet2/nsnet2.js | 3 +++ object_detection/ssd_mobilenetv1_nchw.js | 1 + object_detection/ssd_mobilenetv1_nhwc.js | 1 + object_detection/tiny_yolov2_nchw.js | 3 ++- object_detection/tiny_yolov2_nhwc.js | 1 + rnnoise/rnnoise.js | 4 ++++ semantic_segmentation/deeplabv3_mnv2_nchw.js | 1 + semantic_segmentation/deeplabv3_mnv2_nhwc.js | 1 + style_transfer/fast_style_transfer_net.js | 9 +++---- 30 files changed, 71 insertions(+), 21 deletions(-) diff --git a/code/samples/matmul.js b/code/samples/matmul.js index d8519ca8..1b773705 100644 --- a/code/samples/matmul.js +++ b/code/samples/matmul.js @@ -2,8 +2,16 @@ const context = await navigator.ml.createContext({deviceType: 'gpu'}); const builder = new MLGraphBuilder(context); // Step 1: Create a computational graph calculating `c = a * b`. -const a = builder.input('a', {dataType: 'float32', shape: [3, 4]}); -const b = builder.input('b', {dataType: 'float32', shape: [4, 3]}); +const a = builder.input('a', { + dataType: 'float32', + dimensions: [3, 4], + shape: [3, 4], +}); +const b = builder.input('b', { + dataType: 'float32', + dimensions: [4, 3], + shape: [4, 3], +}); const c = builder.matmul(a, b); // Step 2: Compile it into an executable graph. const graph = await builder.build({c}); diff --git a/code/samples/mul_add.js b/code/samples/mul_add.js index 66f93363..ddfdf512 100644 --- a/code/samples/mul_add.js +++ b/code/samples/mul_add.js @@ -1,4 +1,4 @@ -const operandType = {dataType: 'float32', shape: [2, 2]}; +const operandType = {dataType: 'float32', dimensions: [2, 2], shape: [2, 2]}; const context = await navigator.ml.createContext(); const builder = new MLGraphBuilder(context); // 1. Create a computational graph 'C = 0.2 * A + B'. diff --git a/code/samples/simple_graph.js b/code/samples/simple_graph.js index 9fd4e6b0..bd0fcb99 100644 --- a/code/samples/simple_graph.js +++ b/code/samples/simple_graph.js @@ -18,7 +18,7 @@ const TENSOR_SIZE = 8; const builder = new MLGraphBuilder(context); // Create MLOperandDescriptor object. -const desc = {dataType: 'float32', shape: TENSOR_DIMS}; +const desc = {dataType: 'float32', dimensions: TENSOR_DIMS, shape: TENSOR_DIMS}; // constant1 is a constant MLOperand with the value 0.5. const constantBuffer1 = new Float32Array(TENSOR_SIZE).fill(0.5); diff --git a/common/utils.js b/common/utils.js index eab6a4fc..fb2f212b 100644 --- a/common/utils.js +++ b/common/utils.js @@ -121,7 +121,8 @@ export async function buildConstantByNpy(builder, url, targetType = 'float32') { throw new Error(`Conversion from ${npArray.dataType} ` + `to ${targetType} is not supported.`); } - return builder.constant({dataType: type, shape}, typedArray); + return builder.constant( + {dataType: type, dimensions: shape, shape}, typedArray); } // Convert video frame to a canvas element diff --git a/face_recognition/facenet_nchw.js b/face_recognition/facenet_nchw.js index f47ff777..91410492 100644 --- a/face_recognition/facenet_nchw.js +++ b/face_recognition/facenet_nchw.js @@ -140,6 +140,7 @@ export class FaceNetNchw { this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/face_recognition/facenet_nhwc.js b/face_recognition/facenet_nhwc.js index ab19da8f..29214a48 100644 --- a/face_recognition/facenet_nhwc.js +++ b/face_recognition/facenet_nhwc.js @@ -141,6 +141,7 @@ export class FaceNetNhwc { this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/facial_landmark_detection/face_landmark_nchw.js b/facial_landmark_detection/face_landmark_nchw.js index 406d02d5..6fc42ac1 100644 --- a/facial_landmark_detection/face_landmark_nchw.js +++ b/facial_landmark_detection/face_landmark_nchw.js @@ -71,6 +71,7 @@ export class FaceLandmarkNchw { this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/facial_landmark_detection/face_landmark_nhwc.js b/facial_landmark_detection/face_landmark_nhwc.js index 066202c7..4ff29509 100644 --- a/facial_landmark_detection/face_landmark_nhwc.js +++ b/facial_landmark_detection/face_landmark_nhwc.js @@ -72,6 +72,7 @@ export class FaceLandmarkNhwc { this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js b/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js index 9c67eacd..6d335e84 100644 --- a/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js +++ b/facial_landmark_detection/ssd_mobilenetv2_face_nchw.js @@ -115,6 +115,7 @@ ${nameArray[1]}`; this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js b/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js index 9270f428..e02d01b9 100644 --- a/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js +++ b/facial_landmark_detection/ssd_mobilenetv2_face_nhwc.js @@ -127,6 +127,7 @@ ${nameArray[1]}`; this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/image_classification/efficientnet_fp16_nchw.js b/image_classification/efficientnet_fp16_nchw.js index df19b9dd..cec1ccd1 100644 --- a/image_classification/efficientnet_fp16_nchw.js +++ b/image_classification/efficientnet_fp16_nchw.js @@ -77,6 +77,7 @@ export class EfficientNetFP16Nchw { this.builder_ = new MLGraphBuilder(this.context_); let data = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); data = this.builder_.cast(data, 'float16'); diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js index 6d1eed70..014f3e7f 100644 --- a/image_classification/mobilenet_nchw.js +++ b/image_classification/mobilenet_nchw.js @@ -91,6 +91,7 @@ export class MobileNetV2Nchw { this.builder_ = new MLGraphBuilder(this.context_); let data = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); if (this.dataType_ === 'float16') { diff --git a/image_classification/mobilenet_nhwc.js b/image_classification/mobilenet_nhwc.js index fff011d3..441eab85 100644 --- a/image_classification/mobilenet_nhwc.js +++ b/image_classification/mobilenet_nhwc.js @@ -89,6 +89,7 @@ export class MobileNetV2Nhwc { const filterLayout = 'ohwi'; const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv0 = this.buildConv_( diff --git a/image_classification/resnet50v1_fp16_nchw.js b/image_classification/resnet50v1_fp16_nchw.js index d204b6e2..d9cdc843 100644 --- a/image_classification/resnet50v1_fp16_nchw.js +++ b/image_classification/resnet50v1_fp16_nchw.js @@ -78,6 +78,7 @@ export class ResNet50V1FP16Nchw { this.builder_ = new MLGraphBuilder(this.context_); let data = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); data = this.builder_.cast(data, 'float16'); diff --git a/image_classification/resnet50v2_nchw.js b/image_classification/resnet50v2_nchw.js index 8cb359d8..56201dee 100644 --- a/image_classification/resnet50v2_nchw.js +++ b/image_classification/resnet50v2_nchw.js @@ -100,6 +100,7 @@ export class ResNet50V2Nchw { this.builder_ = new MLGraphBuilder(this.context_); const data = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const bn1 = this.buildBatchNorm_(data, '0', '', false); diff --git a/image_classification/resnet50v2_nhwc.js b/image_classification/resnet50v2_nhwc.js index 2e15d4a7..3babf6f3 100644 --- a/image_classification/resnet50v2_nhwc.js +++ b/image_classification/resnet50v2_nhwc.js @@ -122,6 +122,7 @@ export class ResNet50V2Nhwc { this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv1 = await this.buildConv_( diff --git a/image_classification/squeezenet_nchw.js b/image_classification/squeezenet_nchw.js index 16739f74..6fe75e5b 100644 --- a/image_classification/squeezenet_nchw.js +++ b/image_classification/squeezenet_nchw.js @@ -45,6 +45,7 @@ export class SqueezeNetNchw { this.builder_ = new MLGraphBuilder(this.context_); const data = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv0 = this.buildConv_(data, 'conv0', {strides: [2, 2]}); diff --git a/image_classification/squeezenet_nhwc.js b/image_classification/squeezenet_nhwc.js index 91407860..c9d7d2d1 100644 --- a/image_classification/squeezenet_nhwc.js +++ b/image_classification/squeezenet_nhwc.js @@ -56,6 +56,7 @@ export class SqueezeNetNhwc { const layout = 'nhwc'; const placeholder = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv1 = this.buildConv_( diff --git a/lenet/lenet.js b/lenet/lenet.js index e939ed76..3b4d2d03 100644 --- a/lenet/lenet.js +++ b/lenet/lenet.js @@ -27,6 +27,7 @@ export class LeNet { const inputShape = /* nchw */ [1, 1, 28, 28]; let input = this.builder_.input('input', { dataType: 'float32', + dimensions: inputShape, shape: inputShape, }); @@ -50,7 +51,11 @@ export class LeNet { conv1FilterData, conv1FitlerShape, this.oihwToOhwiPermutation_); } const conv1Filter = this.builder_.constant( - {dataType: 'float32', shape: conv1FitlerShape}, + { + dataType: 'float32', + dimensions: conv1FitlerShape, + shape: conv1FitlerShape, + }, conv1FilterData); byteOffset += sizeOfShape(conv1FitlerShape) * Float32Array.BYTES_PER_ELEMENT; @@ -59,7 +64,7 @@ export class LeNet { const add1BiasData = new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add1BiasShape)); const add1Bias = this.builder_.constant( - {dataType: 'float32', shape: add1BiasShape}, + {dataType: 'float32', dimensions: add1BiasShape, shape: add1BiasShape}, add1BiasData, ); byteOffset += sizeOfShape(add1BiasShape) * Float32Array.BYTES_PER_ELEMENT; @@ -87,14 +92,17 @@ export class LeNet { conv2FilterData, conv2FilterShape, this.oihwToOhwiPermutation_); } const conv2Filter = this.builder_.constant( - {dataType: 'float32', shape: conv2FilterShape}, + { + dataType: 'float32', + dimensions: conv2FilterShape, + shape: conv2FilterShape}, conv2FilterData); byteOffset += sizeOfShape(conv2FilterShape) * Float32Array.BYTES_PER_ELEMENT; const add2BiasShape = [50]; const add2Bias = this.builder_.constant( - {dataType: 'float32', shape: add2BiasShape}, + {dataType: 'float32', dimensions: add2BiasShape, shape: add2BiasShape}, new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add2BiasShape))); byteOffset += sizeOfShape(add2BiasShape) * Float32Array.BYTES_PER_ELEMENT; conv2Options.bias = add2Bias; @@ -120,7 +128,7 @@ export class LeNet { const matmul1Shape = [500, 800]; const matmul1Weights = this.builder_.constant( - {dataType: 'float32', shape: matmul1Shape}, + {dataType: 'float32', dimensions: matmul1Shape, shape: matmul1Shape}, new Float32Array(arrayBuffer, byteOffset, sizeOfShape(matmul1Shape))); byteOffset += sizeOfShape(matmul1Shape) * Float32Array.BYTES_PER_ELEMENT; const matmul1WeightsTransposed = this.builder_.transpose(matmul1Weights); @@ -128,7 +136,7 @@ export class LeNet { const add3BiasShape = [1, 500]; const add3Bias = this.builder_.constant( - {dataType: 'float32', shape: add3BiasShape}, + {dataType: 'float32', dimensions: add3BiasShape, shape: add3BiasShape}, new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add3BiasShape))); byteOffset += sizeOfShape(add3BiasShape) * Float32Array.BYTES_PER_ELEMENT; const add3 = this.builder_.add(matmul1, add3Bias); @@ -140,7 +148,7 @@ export class LeNet { const matmul2Shape = [10, 500]; const matmul2Weights = this.builder_.constant( - {dataType: 'float32', shape: matmul2Shape}, + {dataType: 'float32', dimensions: matmul2Shape, shape: matmul2Shape}, new Float32Array(arrayBuffer, byteOffset, sizeOfShape(matmul2Shape))); byteOffset += sizeOfShape(matmul2Shape) * Float32Array.BYTES_PER_ELEMENT; const matmul2WeightsTransposed = this.builder_.transpose(matmul2Weights); @@ -148,7 +156,7 @@ export class LeNet { const add4BiasShape = [1, 10]; const add4Bias = this.builder_.constant( - {dataType: 'float32', shape: add4BiasShape}, + {dataType: 'float32', dimensions: add4BiasShape, shape: add4BiasShape}, new Float32Array(arrayBuffer, byteOffset, sizeOfShape(add4BiasShape))); const add4 = this.builder_.add(matmul2, add4Bias); diff --git a/nnotepad/js/index.js b/nnotepad/js/index.js index fcb4fbb9..766c9bac 100644 --- a/nnotepad/js/index.js +++ b/nnotepad/js/index.js @@ -130,6 +130,7 @@ function explain(outputs) { .map((output) => [ 'dataType: ' + output.dataType, + 'dimensions: ' + Util.stringify(output.shape), 'shape: ' + Util.stringify(output.shape), 'tensor: ' + dumpTensor(output.shape, output.buffer, 8), ].join('\n'), diff --git a/nnotepad/js/nnotepad.js b/nnotepad/js/nnotepad.js index f3bb711d..08a9fde4 100644 --- a/nnotepad/js/nnotepad.js +++ b/nnotepad/js/nnotepad.js @@ -475,7 +475,8 @@ export class NNotepad { }); }(tensor, 0)); const ctor = WebNNUtil.dataTypeToBufferType(dataType); - return `_.constant({dataType: "${dataType}", shape: ${ + return `_.constant({dataType: "${dataType}", dimensions: ${ + Util.stringify(shape)}}, shape: ${ Util.stringify(shape)}}, new ${ctor.name}([${ elements.map((n) => Util.stringifyNumber(n, dataType)).join(',')}]))`; } @@ -500,7 +501,8 @@ export class NNotepad { } const dims = shape.value.map((expr) => expr.value); const ctor = WebNNUtil.dataTypeToBufferType(dataType.value); - return `_.constant({dataType: "${dataType.value}", shape: ${ + return `_.constant({dataType: "${dataType.value}", dimensions: ${ + Util.stringify(dims)}}, shape: ${ Util.stringify(dims)}}, new ${ ctor.name}(await Util.loadBuffer(${Util.stringify(url.value)})))`; } @@ -516,7 +518,8 @@ export class NNotepad { const dims = shape.value.map((expr) => expr.value); const ctor = WebNNUtil.dataTypeToBufferType(dataType.value); const len = dims.reduce((a, b) => a * b, 1); - return `_.constant({dataType: "${dataType.value}", shape: ${ + return `_.constant({dataType: "${dataType.value}", dimensions: ${ + Util.stringify(dims)}}, shape: ${ Util.stringify(dims)}}, new ${ ctor.name}(${len}))`; } @@ -595,6 +598,7 @@ export class NNotepad { return outputOperands.map( (op, index) => ({ dataType: op.dataType(), + dimensions: op.shape(), shape: op.shape(), buffer: maybeProxyForFloat16Array(result.outputs[`output-${index}`]), })); diff --git a/nsnet2/nsnet2.js b/nsnet2/nsnet2.js index 3ff05a5e..e91a8608 100644 --- a/nsnet2/nsnet2.js +++ b/nsnet2/nsnet2.js @@ -38,12 +38,14 @@ export class NSNet2 { // Build up the network. const input = this.builder_.input('input', { dataType: 'float32', + dimensions: [batchSize, frames, this.frameSize], shape: [batchSize, frames, this.frameSize], }); const relu20 = this.builder_.relu(this.builder_.add(this.builder_.matmul(input, weight172), biasFcIn0)); const transpose31 = this.builder_.transpose(relu20, {permutation: [1, 0, 2]}); const initialState92 = this.builder_.input('initialState92', { dataType: 'float32', + dimensions: [1, batchSize, this.hiddenSize], shape: [1, batchSize, this.hiddenSize], }); const [gru94, gru93] = this.builder_.gru(transpose31, weight192, recurrentWeight193, frames, this.hiddenSize, @@ -54,6 +56,7 @@ export class NSNet2 { const squeeze95 = this.builder_.reshape(gru93, squeeze95Shape); const initialState155 = this.builder_.input('initialState155', { dataType: 'float32', + dimensions: [1, batchSize, this.hiddenSize], shape: [1, batchSize, this.hiddenSize], }); const [gru157, gru156] = this.builder_.gru(squeeze95, weight212, recurrentWeight213, frames, this.hiddenSize, diff --git a/object_detection/ssd_mobilenetv1_nchw.js b/object_detection/ssd_mobilenetv1_nchw.js index 2455c747..067dab6f 100644 --- a/object_detection/ssd_mobilenetv1_nchw.js +++ b/object_detection/ssd_mobilenetv1_nchw.js @@ -81,6 +81,7 @@ ${nameArray[1]}_BatchNorm_batchnorm`; this.builder_ = new MLGraphBuilder(this.context_); let input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); if (this.targetDataType_ === 'float16') { diff --git a/object_detection/ssd_mobilenetv1_nhwc.js b/object_detection/ssd_mobilenetv1_nhwc.js index 513bf914..9fe7d316 100644 --- a/object_detection/ssd_mobilenetv1_nhwc.js +++ b/object_detection/ssd_mobilenetv1_nhwc.js @@ -87,6 +87,7 @@ ${nameArray[1]}_BatchNorm_batchnorm`; this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const strides = [2, 2]; diff --git a/object_detection/tiny_yolov2_nchw.js b/object_detection/tiny_yolov2_nchw.js index ecc5300b..d6bd91dc 100644 --- a/object_detection/tiny_yolov2_nchw.js +++ b/object_detection/tiny_yolov2_nchw.js @@ -63,10 +63,11 @@ export class TinyYoloV2Nchw { this.builder_ = new MLGraphBuilder(this.context_); let image = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); let mulScale = this.builder_.constant( - {dataType: 'float32', shape: [1]}, + {dataType: 'float32', dimensions: [1], shape: [1]}, new Float32Array([0.003921568859368563]), ); const poolOptions = { diff --git a/object_detection/tiny_yolov2_nhwc.js b/object_detection/tiny_yolov2_nhwc.js index e6825ef9..1be3cd95 100644 --- a/object_detection/tiny_yolov2_nhwc.js +++ b/object_detection/tiny_yolov2_nhwc.js @@ -57,6 +57,7 @@ export class TinyYoloV2Nhwc { this.builder_ = new MLGraphBuilder(this.context_); const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); diff --git a/rnnoise/rnnoise.js b/rnnoise/rnnoise.js index 67c3cbe0..bde6e0d0 100644 --- a/rnnoise/rnnoise.js +++ b/rnnoise/rnnoise.js @@ -53,6 +53,7 @@ export class RNNoise { // Build up the network. const input = this.builder_.input('input', { dataType: 'float32', + dimensions: [this.batchSize_, this.frames_, this.featureSize], shape: [this.batchSize_, this.frames_, this.featureSize], }); const inputDense0 = this.builder_.matmul(input, inputDenseKernel0); @@ -68,6 +69,7 @@ export class RNNoise { [1, 3 * this.vadGruHiddenSize]); const vadGruInitialH = this.builder_.input('vadGruInitialH', { dataType: 'float32', + dimensions: [1, this.batchSize_, this.vadGruHiddenSize], shape: [1, this.batchSize_, this.vadGruHiddenSize], }); const [vadGruYH, vadGruY] = this.builder_.gru(vadGruX, @@ -95,6 +97,7 @@ export class RNNoise { [1, 3 * this.noiseGruHiddenSize]); const noiseGruInitialH = this.builder_.input('noiseGruInitialH', { dataType: 'float32', + dimensions: [1, this.batchSize_, this.noiseGruHiddenSize], shape: [1, this.batchSize_, this.noiseGruHiddenSize], }); const [noiseGruYH, noiseGruY] = this.builder_.gru(noiseGruX, @@ -122,6 +125,7 @@ export class RNNoise { [1, 3 * this.denoiseGruHiddenSize]); const denoiseGruInitialH = this.builder_.input('denoiseGruInitialH', { dataType: 'float32', + dimensions: [1, this.batchSize_, this.denoiseGruHiddenSize], shape: [1, this.batchSize_, this.denoiseGruHiddenSize], }); const [denoiseGruYH, denoiseGruY] = this.builder_.gru(denoiseGruX, diff --git a/semantic_segmentation/deeplabv3_mnv2_nchw.js b/semantic_segmentation/deeplabv3_mnv2_nchw.js index b749eb5d..a0ce7e78 100644 --- a/semantic_segmentation/deeplabv3_mnv2_nchw.js +++ b/semantic_segmentation/deeplabv3_mnv2_nchw.js @@ -90,6 +90,7 @@ export class DeepLabV3MNV2Nchw { const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv0 = this.buildConv_( diff --git a/semantic_segmentation/deeplabv3_mnv2_nhwc.js b/semantic_segmentation/deeplabv3_mnv2_nhwc.js index f01608e2..69f10eef 100644 --- a/semantic_segmentation/deeplabv3_mnv2_nhwc.js +++ b/semantic_segmentation/deeplabv3_mnv2_nhwc.js @@ -81,6 +81,7 @@ export class DeepLabV3MNV2Nhwc { const strides = [2, 2]; const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv0 = await this.buildConv_( diff --git a/style_transfer/fast_style_transfer_net.js b/style_transfer/fast_style_transfer_net.js index 772b23a9..14c40f7c 100644 --- a/style_transfer/fast_style_transfer_net.js +++ b/style_transfer/fast_style_transfer_net.js @@ -96,24 +96,25 @@ export class FastStyleTransferNet { const padding1 = [0, 0, 1, 1]; const padding4 = [0, 0, 4, 4]; this.constAdd_ = this.builder_.constant( - {dataType: 'float32', shape: [1]}, + {dataType: 'float32', dimensions: [1], shape: [1]}, new Float32Array([9.999999717180685e-10]), ); this.constPow_ = this.builder_.constant( - {dataType: 'float32', shape: [1]}, + {dataType: 'float32', dimensions: [1], shape: [1]}, new Float32Array([0.5]), ); const constMul0 = this.builder_.constant( - {dataType: 'float32', shape: [1]}, + {dataType: 'float32', dimensions: [1], shape: [1]}, new Float32Array([150]), ); const constAdd0 = this.builder_.constant( - {dataType: 'float32', shape: [1]}, + {dataType: 'float32', dimensions: [1], shape: [1]}, new Float32Array([127.5]), ); // Build up the network. const input = this.builder_.input('input', { dataType: 'float32', + dimensions: this.inputOptions.inputShape, shape: this.inputOptions.inputShape, }); const conv2D0 = this.builder_.conv2d(this.builder_.pad(input, padding4, padding4, {mode: 'reflection'}), weightConv0);