diff --git a/src/StyleTransfer/index.js b/src/StyleTransfer/index.js index a38b9bd4a..b3331ac5d 100644 --- a/src/StyleTransfer/index.js +++ b/src/StyleTransfer/index.js @@ -3,8 +3,6 @@ // This software is released under the MIT License. // https://opensource.org/licenses/MIT -/* eslint max-len: "off" */ -/* eslint no-trailing-spaces: "off" */ /* Fast Style Transfer This implementation is heavily based on github.com/reiinakano/fast-style-transfer-deeplearnjs by Reiichiro Nakano. @@ -13,28 +11,19 @@ The original TensorFlow implementation was developed by Logan Engstrom: github.c import * as tf from '@tensorflow/tfjs'; import handleArguments from "../utils/handleArguments"; -import Video from './../utils/Video'; import CheckpointLoader from '../utils/checkpointLoader'; -import { array3DToImage } from '../utils/imageUtilities'; +import { array3DToImage, mediaReady } from '../utils/imageUtilities'; import callCallback from '../utils/callcallback'; -const IMAGE_SIZE = 200; - -class StyleTransfer extends Video { +class StyleTransfer { /** * Create a new Style Transfer Instance。 * @param {string} model - The path to Style Transfer model. - * @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element. + * @param {HTMLVideoElement} video - Optional. A HTML video element or a p5 video element. * @param {function} callback - Optional. A function to be called once the model is loaded. If no callback is provided, it will return a promise that will be resolved once the model has loaded. */ constructor(model, video, callback) { - super(video, IMAGE_SIZE); - /** - * Boolean value that specifies if the model has loaded. - * @type {boolean} - * @public - */ - this.ready = false; + this.video = video; /** * @private * @type {Record} @@ -43,9 +32,7 @@ class StyleTransfer extends Video { this.timesScalar = tf.scalar(150); this.plusScalar = tf.scalar(255.0 / 2); this.epsilonScalar = tf.scalar(1e-3); - this.video = null; this.ready = callCallback(this.load(model), callback); - // this.then = this.ready.then; } /** @@ -54,10 +41,10 @@ class StyleTransfer extends Video { * @return {Promise} */ async load(model) { - if (this.videoElt) { - await this.loadVideo(); - } - await this.loadCheckpoints(model); + await Promise.all([ + mediaReady(this.video, false), + this.loadCheckpoints(model) + ]); return this; } @@ -83,8 +70,8 @@ class StyleTransfer extends Video { const moments = tf.moments(input, [0, 1]); const mu = moments.mean; const sigmaSq = moments.variance; - const shift = this.variables[StyleTransfer.getVariableName(id)]; - const scale = this.variables[StyleTransfer.getVariableName(id + 1)]; + const shift = this.getVariable(id); + const scale = this.getVariable(id + 1); const epsilon = this.epsilonScalar; const normalized = tf.div(tf.sub(input.asType('float32'), mu), tf.sqrt(tf.add(sigmaSq, epsilon))); const shifted = tf.add(tf.mul(scale, normalized), shift); @@ -102,7 +89,7 @@ class StyleTransfer extends Video { */ convLayer(input, strides, relu, id) { return tf.tidy(() => { - const y = tf.conv2d(input, this.variables[StyleTransfer.getVariableName(id)], [strides, strides], 'same'); + const y = tf.conv2d(input, this.getVariable(id), [strides, strides], 'same'); const y2 = this.instanceNorm(y, id + 1); return relu ? tf.relu(y2) : y2; }); @@ -124,18 +111,19 @@ class StyleTransfer extends Video { /** * @param {tf.Tensor3D} input - * @param {number} numFilters * @param {number} strides * @param {number} id * @return {tf.Tensor3D} */ - convTransposeLayer(input, numFilters, strides, id) { + convTransposeLayer(input, strides, id) { return tf.tidy(() => { + const filter = this.getVariable(id); + const outDepth = filter.shape[2]; const [height, width] = input.shape; const newRows = height * strides; const newCols = width * strides; - const newShape = [newRows, newCols, numFilters]; - const y = tf.conv2dTranspose(input, this.variables[StyleTransfer.getVariableName(id)], newShape, [strides, strides], 'same'); + const newShape = [newRows, newCols, outDepth]; + const y = tf.conv2dTranspose(input, filter, newShape, [strides, strides], 'same'); const y2 = this.instanceNorm(y, id + 1); const y3 = tf.relu(y2); return y3; @@ -154,10 +142,18 @@ class StyleTransfer extends Video { /** * @private + * + * Applies each layer of the model in sequence, where the output of one layer + * is used as the input of the next layer. + * * @param {ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement} input * @return {Promise} */ async transferInternal(input) { + await Promise.all([ + mediaReady(input, true), + this.ready + ]); const image = tf.browser.fromPixels(input); const result = array3DToImage(tf.tidy(() => { const conv1 = this.convLayer(image, 1, true, 0); @@ -168,8 +164,8 @@ class StyleTransfer extends Video { const res3 = this.residualBlock(res2, 21); const res4 = this.residualBlock(res3, 27); const res5 = this.residualBlock(res4, 33); - const convT1 = this.convTransposeLayer(res5, 64, 2, 39); - const convT2 = this.convTransposeLayer(convT1, 32, 2, 42); + const convT1 = this.convTransposeLayer(res5, 2, 39); + const convT2 = this.convTransposeLayer(convT1, 2, 42); const convT3 = this.convLayer(convT2, 1, false, 45); const outTanh = tf.tanh(convT3); const scaled = tf.mul(this.timesScalar, outTanh); @@ -179,7 +175,6 @@ class StyleTransfer extends Video { return normalized; })); image.dispose(); - await tf.nextFrame(); return result; } @@ -195,12 +190,19 @@ class StyleTransfer extends Video { this.epsilonScalar.dispose(); } - // Static Methods - static getVariableName(id) { - if (id === 0) { - return 'Variable'; - } - return `Variable_${id}`; + /** + * @private + * + * Access a variable's tensor from its numeric index. + * Model contains variables with ids from 0 to 47. + * The returned tensor will be 4D if `id` is divisible by 3, or 1D otherwise. + * + * @param {number} id + * @returns {tf.Tensor} + */ + getVariable(id) { + const key = id === 0 ? 'Variable' : `Variable_${id}`; + return this.variables[key]; } }