diff --git a/src/BodyPix/index.js b/src/BodyPix/index.js index 01ca98e9c..28fef3997 100644 --- a/src/BodyPix/index.js +++ b/src/BodyPix/index.js @@ -14,8 +14,9 @@ // @ts-check import * as tf from '@tensorflow/tfjs'; import * as bp from '@tensorflow-models/body-pix'; -import handleArguments from "../utils/handleArguments"; import callCallback from '../utils/callcallback'; +import generatedImageResult from '../utils/generatedImageResult'; +import handleArguments from '../utils/handleArguments'; import p5Utils from '../utils/p5Utils'; import BODYPIX_PALETTE from './BODYPIX_PALETTE'; @@ -68,7 +69,7 @@ class BodyPix { /** * Load the model and set it to this.model - * @return {Promise} the BodyPix model. + * @return {Promise} */ async loadModel() { this.model = await bp.load(this.config.multiplier); @@ -89,19 +90,6 @@ class BodyPix { return [r, g, b] } - /** - * Returns a p5Image - * @param {Uint8ClampedArray} tfBrowserPixelImage - * @param {number} segmentationWidth - * @param {number} segmentationHeight - * @return {Promise} - */ - async convertToP5Image(tfBrowserPixelImage, segmentationWidth, segmentationHeight) { - const blob1 = await p5Utils.rawToBlob(tfBrowserPixelImage, segmentationWidth, segmentationHeight); - const p5Image1 = await p5Utils.blobToP5Image(blob1); - return p5Image1 - } - /** * Returns a bodyPartsSpec object * @param {Array} colorOptions - an array of [r,g,b] colors @@ -138,7 +126,8 @@ class BodyPix { */ /** - * Segments the image with partSegmentation, return result object* @param {InputImage} [imgToSegment] + * Segments the image with partSegmentation, return result object + * @param {InputImage} [imgToSegment] * @param {BodyPixOptions} [segmentationOptions] - config params for the segmentation * includes outputStride, segmentationThreshold * @return {Promise} a result object with image, raw, bodyParts @@ -212,27 +201,16 @@ class BodyPix { } }) - const personMaskPixels = await tf.browser.toPixels(personMask); - const bgMaskPixels = await tf.browser.toPixels(backgroundMask); - const partMaskPixels = await tf.browser.toPixels(partMask); - - // otherwise, return the pixels - result.personMask = personMaskPixels; - result.backgroundMask = bgMaskPixels; - result.partMask = partMaskPixels; + const personMaskRes = await generatedImageResult(personMask, this.config); + const bgMaskRes = await generatedImageResult(backgroundMask, this.config); + const partMaskRes = await generatedImageResult(partMask, this.config); - // if p5 exists, convert to p5 image - if (p5Utils.checkP5()) { - result.personMask = await this.convertToP5Image(personMaskPixels, segmentation.width, segmentation.height) - result.backgroundMask = await this.convertToP5Image(bgMaskPixels, segmentation.width, segmentation.height) - result.partMask = await this.convertToP5Image(partMaskPixels, segmentation.width, segmentation.height) - } + // if p5 exists, return p5 image. otherwise, return the pixels. + result.personMask = personMaskRes.image || personMaskRes.raw; + result.backgroundMask = bgMaskRes.image || bgMaskRes.raw; + result.partMask = partMaskRes.image || partMaskRes.raw; - if (!this.config.returnTensors) { - personMask.dispose(); - backgroundMask.dispose(); - partMask.dispose(); - } else { + if (this.config.returnTensors) { // return tensors result.tensor.personMask = personMask; result.tensor.backgroundMask = backgroundMask; @@ -342,28 +320,19 @@ class BodyPix { } }) - const personMaskPixels = await tf.browser.toPixels(personMask); - const bgMaskPixels = await tf.browser.toPixels(backgroundMask); - - // if p5 exists, convert to p5 image - if (p5Utils.checkP5()) { - result.personMask = await this.convertToP5Image(personMaskPixels, segmentation.width, segmentation.height) - result.backgroundMask = await this.convertToP5Image(bgMaskPixels, segmentation.width, segmentation.height) - } else { - // otherwise, return the pixels - result.personMask = personMaskPixels; - result.backgroundMask = bgMaskPixels; - } + const personMaskRes = await generatedImageResult(personMask, this.config); + const bgMaskRes = await generatedImageResult(backgroundMask, this.config); + + // if p5 exists, return p5 image. otherwise, return the pixels. + result.personMask = personMaskRes.image || personMaskRes.raw; + result.backgroundMask = bgMaskRes.image || bgMaskRes.raw; - if (!this.config.returnTensors) { - personMask.dispose(); - backgroundMask.dispose(); - } else { + if (this.config.returnTensors) { + // return tensors result.tensor.personMask = personMask; result.tensor.backgroundMask = backgroundMask; } - return result; } diff --git a/src/CVAE/index.js b/src/CVAE/index.js index 13620fd75..800b2f38c 100644 --- a/src/CVAE/index.js +++ b/src/CVAE/index.js @@ -12,7 +12,7 @@ import * as tf from '@tensorflow/tfjs'; import axios from "axios"; import callCallback from '../utils/callcallback'; -import p5Utils from '../utils/p5Utils'; +import generatedImageResult from '../utils/generatedImageResult'; class Cvae { /** @@ -49,31 +49,12 @@ class Cvae { * Generate a random result. * @param {String} label - A label of the feature your want to generate * @param {function} callback - A function to handle the results of ".generate()". Likely a function to do something with the generated image data. - * @return {raw: ImageData, src: Blob, image: p5.Image} + * @return {Promise<{ raws: Uint8ClampedArray, src: Blob, image: p5.Image }>} */ async generate(label, callback) { return callCallback(this.generateInternal(label), callback); } - loadAsync(url){ - return new Promise((resolve, reject) => { - if(!this.ready) reject(); - loadImage(url, (img) => { - resolve(img); - }); - }); - }; - - getBlob(inputCanvas) { - return new Promise((resolve, reject) => { - if (!this.ready) reject(); - - inputCanvas.toBlob((blob) => { - resolve(blob); - }); - }); - } - async generateInternal(label) { const res = tf.tidy(() => { this.latentDim = tf.randomUniform([1, 16]); @@ -91,26 +72,10 @@ class Cvae { const temp = this.model.predict([this.latentDim, input]); return temp.reshape([temp.shape[1], temp.shape[2], temp.shape[3]]); }); - - - const raws = await tf.browser.toPixels(res); - res.dispose(); - - const canvas = document.createElement('canvas'); // consider using offScreneCanvas - const ctx = canvas.getContext('2d'); - const [x, y] = res.shape; - canvas.width = x; - canvas.height = y; - const imgData = ctx.createImageData(x, y); - const data = imgData.data; - for (let i = 0; i < x * y * 4; i += 1) data[i] = raws[i]; - ctx.putImageData(imgData, 0, 0); - const src = URL.createObjectURL(await this.getBlob(canvas)); - let image; - /* global loadImage */ - if (p5Utils.checkP5()) image = await this.loadAsync(src); - return { src, raws, image }; + const { raw, image, blob } = await generatedImageResult(res, { returnTensors: false }); + const src = typeof URL !== 'undefined' ? URL.createObjectURL(blob) : undefined; + return { src, raws: raw, image }; } } diff --git a/src/CartoonGAN/index.js b/src/CartoonGAN/index.js index 6ae6df102..099b6cb2b 100644 --- a/src/CartoonGAN/index.js +++ b/src/CartoonGAN/index.js @@ -8,9 +8,9 @@ */ import * as tf from '@tensorflow/tfjs'; -import handleArguments from "../utils/handleArguments"; import callCallback from '../utils/callcallback'; -import p5Utils from '../utils/p5Utils'; +import generatedImageResult from '../utils/generatedImageResult'; +import handleArguments from '../utils/handleArguments'; const IMAGE_SIZE = 256; @@ -88,45 +88,24 @@ class Cartoon { async generateInternal(src) { await this.ready; await tf.nextFrame(); - // adds resizeBilinear to resize image to 256x256 as required by the model - let img = tf.browser.fromPixels(src).resizeBilinear([IMAGE_SIZE,IMAGE_SIZE]); - if (img.shape[0] !== IMAGE_SIZE || img.shape[1] !== IMAGE_SIZE) { - throw new Error(`Input size should be ${IMAGE_SIZE}*${IMAGE_SIZE} but ${img.shape} is found`); - } else if (img.shape[2] !== 3) { - throw new Error(`Input color channel number should be 3 but ${img.shape[2]} is found`); - } - img = img.sub(127.5).div(127.5).reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]); - - const alpha = tf.ones([IMAGE_SIZE, IMAGE_SIZE, 1]).tile([1, 1, 1]).mul(255) - let res = this.model.predict(img); - res = res.add(1).mul(127.5).reshape([IMAGE_SIZE, IMAGE_SIZE, 3]).floor(); - res = res.concat(alpha, 2) - const result = this.resultFinalize(res); - - if(this.config.returnTensors){ - return result; - } - - img.dispose(); - res.dispose(); - return result; + const result = tf.tidy(() => { + // adds resizeBilinear to resize image to 256x256 as required by the model + let img = tf.browser.fromPixels(src).resizeBilinear([IMAGE_SIZE, IMAGE_SIZE]); + if (img.shape[0] !== IMAGE_SIZE || img.shape[1] !== IMAGE_SIZE) { + throw new Error(`Input size should be ${IMAGE_SIZE}*${IMAGE_SIZE} but ${img.shape} is found`); + } else if (img.shape[2] !== 3) { + throw new Error(`Input color channel number should be 3 but ${img.shape[2]} is found`); + } + img = img.sub(127.5).div(127.5).reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]); + + const alpha = tf.ones([IMAGE_SIZE, IMAGE_SIZE, 1]).tile([1, 1, 1]).mul(255) + let res = this.model.predict(img); + res = res.add(1).mul(127.5).reshape([IMAGE_SIZE, IMAGE_SIZE, 3]).floor(); + return res.concat(alpha, 2).cast('int32'); + }) + return generatedImageResult(result, this.config); } - /** - * @private - * @param {tf.Tensor3D} res - * @return {Promise} - */ - async resultFinalize(res){ - const tensor = res; - const raw = await res.data(); - const blob = await p5Utils.rawToBlob(raw, res.shape[0], res.shape[1]); - const image = await p5Utils.blobToP5Image(blob); - if(this.config.returnTensors){ - return {tensor, raw, blob, image}; - } - return {raw, blob, image}; - } } /** diff --git a/src/DCGAN/index.js b/src/DCGAN/index.js index af84cbc87..164f6e06d 100644 --- a/src/DCGAN/index.js +++ b/src/DCGAN/index.js @@ -11,8 +11,8 @@ This version is based on alantian's TensorFlow.js implementation: https://github import * as tf from '@tensorflow/tfjs'; import axios from 'axios'; import callCallback from '../utils/callcallback'; -import handleArguments from "../utils/handleArguments"; -import p5Utils from '../utils/p5Utils'; +import generatedImageResult from '../utils/generatedImageResult'; +import handleArguments from '../utils/handleArguments'; // Default pre-trained face model @@ -103,43 +103,9 @@ class DCGANBase { * @return {object} includes blob, raw, and tensor. if P5 exists, then a p5Image */ async generateInternal(latentVector) { - - const { - modelLatentDim - } = this.modelInfo; + const { modelLatentDim } = this.modelInfo; const imageTensor = await this.compute(modelLatentDim, latentVector); - - // get the raw data from tensor - const raw = await tf.browser.toPixels(imageTensor); - // get the blob from raw - const [imgHeight, imgWidth] = imageTensor.shape; - const blob = await p5Utils.rawToBlob(raw, imgWidth, imgHeight); - - // get the p5.Image object - let p5Image; - if (p5Utils.checkP5()) { - p5Image = await p5Utils.blobToP5Image(blob); - } - - // wrap up the final js result object - const result = {}; - result.blob = blob; - result.raw = raw; - - - if (p5Utils.checkP5()) { - result.image = p5Image; - } - - if(!this.config.returnTensors){ - result.tensor = null; - imageTensor.dispose(); - } else { - result.tensor = imageTensor; - } - - return result; - + return generatedImageResult(imageTensor, this.config); } diff --git a/src/UNET/index.js b/src/UNET/index.js index d854038e3..484f5100a 100644 --- a/src/UNET/index.js +++ b/src/UNET/index.js @@ -9,11 +9,8 @@ Image Classifier using pre-trained networks import * as tf from '@tensorflow/tfjs'; import callCallback from '../utils/callcallback'; +import generatedImageResult from '../utils/generatedImageResult'; import handleArguments from "../utils/handleArguments"; -import { - array3DToImage -} from '../utils/imageUtilities'; -import p5Utils from '../utils/p5Utils'; const DEFAULTS = { modelPath: 'https://raw.githubusercontent.com/zaidalyafeai/HostedModels/master/unet-128/model.json', @@ -53,28 +50,6 @@ class UNET { return callCallback(this.segmentInternal(image), callback); } - static dataURLtoBlob(dataurl) { - const arr = dataurl.split(','); - const mime = arr[0].match(/:(.*?);/)[1]; - const bstr = atob(arr[1]); - let n = bstr.length; - const u8arr = new Uint8Array(n); - - while (n) { - u8arr[n] = bstr.charCodeAt(n); - n -= 1; - } - return new Blob([u8arr], { - type: mime - }); - } - - async convertToP5Image(tfBrowserPixelImage){ - const blob1 = await p5Utils.rawToBlob(tfBrowserPixelImage, this.config.imageSize, this.config.imageSize); - const p5Image1 = await p5Utils.blobToP5Image(blob1); - return p5Image1 - } - async segmentInternal(imgToPredict) { // Wait for the model to be ready await this.ready; @@ -128,51 +103,28 @@ class UNET { this.isPredicting = false; - // these come first because array3DToImage() will dispose of the input tensor - const maskFeat = await tf.browser.toPixels(featureMask); - const maskBg = await tf.browser.toPixels(backgroundMask); - const mask = await tf.browser.toPixels(segmentation); - - const maskFeatDom = array3DToImage(featureMask); - const maskBgDom = array3DToImage(backgroundMask); - const maskFeatBlob = UNET.dataURLtoBlob(maskFeatDom.src); - const maskBgBlob = UNET.dataURLtoBlob(maskBgDom.src); - - - let pFeatureMask; - let pBgMask; - let pMask; - - if (p5Utils.checkP5()) { - pFeatureMask = await this.convertToP5Image(maskFeat); - pBgMask = await this.convertToP5Image(maskBg) - pMask = await this.convertToP5Image(mask) - } - - if(!this.config.returnTensors){ - featureMask.dispose(); - backgroundMask.dispose(); - segmentation.dispose(); - } + const maskFeat = await generatedImageResult(featureMask, this.config); + const maskBg = await generatedImageResult(backgroundMask, this.config); + const mask = await generatedImageResult(segmentation, this.config); return { - segmentation:mask, + segmentation: mask.raw, blob: { - featureMask: maskFeatBlob, - backgroundMask: maskBgBlob + featureMask: maskFeat.blob, + backgroundMask: maskBg.blob }, tensor: { - featureMask, - backgroundMask, + featureMask: maskFeat.tensor, + backgroundMask: maskBg.tensor, }, raw: { - featureMask: maskFeat, - backgroundMask: maskBg + featureMask: maskFeat.raw, + backgroundMask: maskBg.raw }, // returns if p5 is available - featureMask: pFeatureMask, - backgroundMask: pBgMask, - mask: pMask + featureMask: maskFeat.image, + backgroundMask: maskBg.image, + mask: mask.image }; } } diff --git a/src/utils/generatedImageResult.js b/src/utils/generatedImageResult.js new file mode 100644 index 000000000..53277bc99 --- /dev/null +++ b/src/utils/generatedImageResult.js @@ -0,0 +1,34 @@ +import * as tf from '@tensorflow/tfjs'; +import p5Utils from './p5Utils'; + +/** +* @typedef {Object} GeneratedImageResult +* @property {Uint8ClampedArray} raw - an array of all pixel values +* @property {Blob} blob - image blob +* @property {p5.Image | null} image - p5 Image, if p5 is available. +* @property {tf.Tensor3D} [tensor] - the original tensor, if `returnTensors` is true. +*/ + +/** + * Utility function for returning an image in multiple formats. + * + * Takes a Tensor and returns an object containing `blob`, `raw`, `image`, and optionally `tensor`. + * Will dispose of the Tensor if not returning it. + * + * Accepts options as an object with property `returnTensors` so that models can pass this.config. + * + * @param {tf.Tensor3D} tensor + * @param {{ returnTensors: boolean }} options + * @return {Promise} + */ +export default async function generatedImageResult(tensor, options) { + const raw = await tf.browser.toPixels(tensor); // or tensor.data()?? + const [height, width] = tensor.shape; + const blob = await p5Utils.rawToBlob(raw, width, height); + const image = await p5Utils.blobToP5Image(blob); + if (options.returnTensors) { + return { tensor, raw, blob, image }; + } + tensor.dispose(); + return { raw, blob, image }; +}