Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StyleTransfer infers the number of filters #1414

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 39 additions & 37 deletions src/StyleTransfer/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

/* eslint max-len: "off" */
/* eslint no-trailing-spaces: "off" */
/*
Fast Style Transfer
This implementation is heavily based on github.com/reiinakano/fast-style-transfer-deeplearnjs by Reiichiro Nakano.
Expand All @@ -13,28 +11,19 @@ The original TensorFlow implementation was developed by Logan Engstrom: github.c

import * as tf from '@tensorflow/tfjs';
import handleArguments from "../utils/handleArguments";
import Video from './../utils/Video';
import CheckpointLoader from '../utils/checkpointLoader';
import { array3DToImage } from '../utils/imageUtilities';
import { array3DToImage, mediaReady } from '../utils/imageUtilities';
import callCallback from '../utils/callcallback';

const IMAGE_SIZE = 200;

class StyleTransfer extends Video {
class StyleTransfer {
/**
* Create a new Style Transfer Instance。
* @param {string} model - The path to Style Transfer model.
* @param {HTMLVideoElement || p5.Video} video - Optional. A HTML video element or a p5 video element.
* @param {HTMLVideoElement} video - Optional. A HTML video element or a p5 video element.
* @param {function} callback - Optional. A function to be called once the model is loaded. If no callback is provided, it will return a promise that will be resolved once the model has loaded.
*/
constructor(model, video, callback) {
super(video, IMAGE_SIZE);
/**
* Boolean value that specifies if the model has loaded.
* @type {boolean}
* @public
*/
this.ready = false;
this.video = video;
/**
* @private
* @type {Record<string, tf.Tensor>}
Expand All @@ -43,9 +32,7 @@ class StyleTransfer extends Video {
this.timesScalar = tf.scalar(150);
this.plusScalar = tf.scalar(255.0 / 2);
this.epsilonScalar = tf.scalar(1e-3);
this.video = null;
this.ready = callCallback(this.load(model), callback);
// this.then = this.ready.then;
}

/**
Expand All @@ -54,10 +41,10 @@ class StyleTransfer extends Video {
* @return {Promise<StyleTransfer>}
*/
async load(model) {
if (this.videoElt) {
await this.loadVideo();
}
await this.loadCheckpoints(model);
await Promise.all([
mediaReady(this.video, false),
this.loadCheckpoints(model)
]);
return this;
}

Expand All @@ -83,8 +70,8 @@ class StyleTransfer extends Video {
const moments = tf.moments(input, [0, 1]);
const mu = moments.mean;
const sigmaSq = moments.variance;
const shift = this.variables[StyleTransfer.getVariableName(id)];
const scale = this.variables[StyleTransfer.getVariableName(id + 1)];
const shift = this.getVariable(id);
const scale = this.getVariable(id + 1);
const epsilon = this.epsilonScalar;
const normalized = tf.div(tf.sub(input.asType('float32'), mu), tf.sqrt(tf.add(sigmaSq, epsilon)));
const shifted = tf.add(tf.mul(scale, normalized), shift);
Expand All @@ -102,7 +89,7 @@ class StyleTransfer extends Video {
*/
convLayer(input, strides, relu, id) {
return tf.tidy(() => {
const y = tf.conv2d(input, this.variables[StyleTransfer.getVariableName(id)], [strides, strides], 'same');
const y = tf.conv2d(input, this.getVariable(id), [strides, strides], 'same');
const y2 = this.instanceNorm(y, id + 1);
return relu ? tf.relu(y2) : y2;
});
Expand All @@ -124,18 +111,19 @@ class StyleTransfer extends Video {

/**
* @param {tf.Tensor3D} input
* @param {number} numFilters
* @param {number} strides
* @param {number} id
* @return {tf.Tensor3D}
*/
convTransposeLayer(input, numFilters, strides, id) {
convTransposeLayer(input, strides, id) {
return tf.tidy(() => {
const filter = this.getVariable(id);
const outDepth = filter.shape[2];
const [height, width] = input.shape;
const newRows = height * strides;
const newCols = width * strides;
const newShape = [newRows, newCols, numFilters];
const y = tf.conv2dTranspose(input, this.variables[StyleTransfer.getVariableName(id)], newShape, [strides, strides], 'same');
const newShape = [newRows, newCols, outDepth];
const y = tf.conv2dTranspose(input, filter, newShape, [strides, strides], 'same');
const y2 = this.instanceNorm(y, id + 1);
const y3 = tf.relu(y2);
return y3;
Expand All @@ -154,10 +142,18 @@ class StyleTransfer extends Video {

/**
* @private
*
* Applies each layer of the model in sequence, where the output of one layer
* is used as the input of the next layer.
*
* @param {ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement} input
* @return {Promise<HTMLImageElement>}
*/
async transferInternal(input) {
await Promise.all([
mediaReady(input, true),
this.ready
]);
const image = tf.browser.fromPixels(input);
const result = array3DToImage(tf.tidy(() => {
const conv1 = this.convLayer(image, 1, true, 0);
Expand All @@ -168,8 +164,8 @@ class StyleTransfer extends Video {
const res3 = this.residualBlock(res2, 21);
const res4 = this.residualBlock(res3, 27);
const res5 = this.residualBlock(res4, 33);
const convT1 = this.convTransposeLayer(res5, 64, 2, 39);
const convT2 = this.convTransposeLayer(convT1, 32, 2, 42);
const convT1 = this.convTransposeLayer(res5, 2, 39);
const convT2 = this.convTransposeLayer(convT1, 2, 42);
const convT3 = this.convLayer(convT2, 1, false, 45);
const outTanh = tf.tanh(convT3);
const scaled = tf.mul(this.timesScalar, outTanh);
Expand All @@ -179,7 +175,6 @@ class StyleTransfer extends Video {
return normalized;
}));
image.dispose();
await tf.nextFrame();
return result;
}

Expand All @@ -195,12 +190,19 @@ class StyleTransfer extends Video {
this.epsilonScalar.dispose();
}

// Static Methods
static getVariableName(id) {
if (id === 0) {
return 'Variable';
}
return `Variable_${id}`;
/**
* @private
*
* Access a variable's tensor from its numeric index.
* Model contains variables with ids from 0 to 47.
* The returned tensor will be 4D if `id` is divisible by 3, or 1D otherwise.
*
* @param {number} id
* @returns {tf.Tensor}
*/
getVariable(id) {
const key = id === 0 ? 'Variable' : `Variable_${id}`;
return this.variables[key];
}
}

Expand Down