Skip to content

Commit

Permalink
Merge pull request #1348 from lindapaiste/fix/image-conversion
Browse files Browse the repository at this point in the history
Cleanup of tensor to image conversions 🧹
  • Loading branch information
lindapaiste authored May 13, 2022
2 parents b6fd7eb + a40ac46 commit 2c5cd1e
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 231 deletions.
73 changes: 21 additions & 52 deletions src/BodyPix/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
// @ts-check
import * as tf from '@tensorflow/tfjs';
import * as bp from '@tensorflow-models/body-pix';
import handleArguments from "../utils/handleArguments";
import callCallback from '../utils/callcallback';
import generatedImageResult from '../utils/generatedImageResult';
import handleArguments from '../utils/handleArguments';
import p5Utils from '../utils/p5Utils';
import BODYPIX_PALETTE from './BODYPIX_PALETTE';

Expand Down Expand Up @@ -68,7 +69,7 @@ class BodyPix {

/**
* Load the model and set it to this.model
* @return {Promise<BodyPix>} the BodyPix model.
* @return {Promise<BodyPix>}
*/
async loadModel() {
this.model = await bp.load(this.config.multiplier);
Expand All @@ -89,19 +90,6 @@ class BodyPix {
return [r, g, b]
}

/**
* Returns a p5Image
* @param {Uint8ClampedArray} tfBrowserPixelImage
* @param {number} segmentationWidth
* @param {number} segmentationHeight
* @return {Promise<p5.Image>}
*/
async convertToP5Image(tfBrowserPixelImage, segmentationWidth, segmentationHeight) {
const blob1 = await p5Utils.rawToBlob(tfBrowserPixelImage, segmentationWidth, segmentationHeight);
const p5Image1 = await p5Utils.blobToP5Image(blob1);
return p5Image1
}

/**
* Returns a bodyPartsSpec object
* @param {Array} colorOptions - an array of [r,g,b] colors
Expand Down Expand Up @@ -138,7 +126,8 @@ class BodyPix {
*/

/**
* Segments the image with partSegmentation, return result object* @param {InputImage} [imgToSegment]
* Segments the image with partSegmentation, return result object
* @param {InputImage} [imgToSegment]
* @param {BodyPixOptions} [segmentationOptions] - config params for the segmentation
* includes outputStride, segmentationThreshold
* @return {Promise<SegmentationResult>} a result object with image, raw, bodyParts
Expand Down Expand Up @@ -212,27 +201,16 @@ class BodyPix {
}
})

const personMaskPixels = await tf.browser.toPixels(personMask);
const bgMaskPixels = await tf.browser.toPixels(backgroundMask);
const partMaskPixels = await tf.browser.toPixels(partMask);

// otherwise, return the pixels
result.personMask = personMaskPixels;
result.backgroundMask = bgMaskPixels;
result.partMask = partMaskPixels;
const personMaskRes = await generatedImageResult(personMask, this.config);
const bgMaskRes = await generatedImageResult(backgroundMask, this.config);
const partMaskRes = await generatedImageResult(partMask, this.config);

// if p5 exists, convert to p5 image
if (p5Utils.checkP5()) {
result.personMask = await this.convertToP5Image(personMaskPixels, segmentation.width, segmentation.height)
result.backgroundMask = await this.convertToP5Image(bgMaskPixels, segmentation.width, segmentation.height)
result.partMask = await this.convertToP5Image(partMaskPixels, segmentation.width, segmentation.height)
}
// if p5 exists, return p5 image. otherwise, return the pixels.
result.personMask = personMaskRes.image || personMaskRes.raw;
result.backgroundMask = bgMaskRes.image || bgMaskRes.raw;
result.partMask = partMaskRes.image || partMaskRes.raw;

if (!this.config.returnTensors) {
personMask.dispose();
backgroundMask.dispose();
partMask.dispose();
} else {
if (this.config.returnTensors) {
// return tensors
result.tensor.personMask = personMask;
result.tensor.backgroundMask = backgroundMask;
Expand Down Expand Up @@ -342,28 +320,19 @@ class BodyPix {
}
})

const personMaskPixels = await tf.browser.toPixels(personMask);
const bgMaskPixels = await tf.browser.toPixels(backgroundMask);

// if p5 exists, convert to p5 image
if (p5Utils.checkP5()) {
result.personMask = await this.convertToP5Image(personMaskPixels, segmentation.width, segmentation.height)
result.backgroundMask = await this.convertToP5Image(bgMaskPixels, segmentation.width, segmentation.height)
} else {
// otherwise, return the pixels
result.personMask = personMaskPixels;
result.backgroundMask = bgMaskPixels;
}
const personMaskRes = await generatedImageResult(personMask, this.config);
const bgMaskRes = await generatedImageResult(backgroundMask, this.config);

// if p5 exists, return p5 image. otherwise, return the pixels.
result.personMask = personMaskRes.image || personMaskRes.raw;
result.backgroundMask = bgMaskRes.image || bgMaskRes.raw;

if (!this.config.returnTensors) {
personMask.dispose();
backgroundMask.dispose();
} else {
if (this.config.returnTensors) {
// return tensors
result.tensor.personMask = personMask;
result.tensor.backgroundMask = backgroundMask;
}


return result;

}
Expand Down
45 changes: 5 additions & 40 deletions src/CVAE/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import * as tf from '@tensorflow/tfjs';
import axios from "axios";
import callCallback from '../utils/callcallback';
import p5Utils from '../utils/p5Utils';
import generatedImageResult from '../utils/generatedImageResult';

class Cvae {
/**
Expand Down Expand Up @@ -49,31 +49,12 @@ class Cvae {
* Generate a random result.
* @param {String} label - A label of the feature your want to generate
* @param {function} callback - A function to handle the results of ".generate()". Likely a function to do something with the generated image data.
* @return {raw: ImageData, src: Blob, image: p5.Image}
* @return {Promise<{ raws: Uint8ClampedArray, src: Blob, image: p5.Image }>}
*/
async generate(label, callback) {
return callCallback(this.generateInternal(label), callback);
}

loadAsync(url){
return new Promise((resolve, reject) => {
if(!this.ready) reject();
loadImage(url, (img) => {
resolve(img);
});
});
};

getBlob(inputCanvas) {
return new Promise((resolve, reject) => {
if (!this.ready) reject();

inputCanvas.toBlob((blob) => {
resolve(blob);
});
});
}

async generateInternal(label) {
const res = tf.tidy(() => {
this.latentDim = tf.randomUniform([1, 16]);
Expand All @@ -91,26 +72,10 @@ class Cvae {
const temp = this.model.predict([this.latentDim, input]);
return temp.reshape([temp.shape[1], temp.shape[2], temp.shape[3]]);
});


const raws = await tf.browser.toPixels(res);
res.dispose();

const canvas = document.createElement('canvas'); // consider using offScreneCanvas
const ctx = canvas.getContext('2d');
const [x, y] = res.shape;
canvas.width = x;
canvas.height = y;
const imgData = ctx.createImageData(x, y);
const data = imgData.data;
for (let i = 0; i < x * y * 4; i += 1) data[i] = raws[i];
ctx.putImageData(imgData, 0, 0);

const src = URL.createObjectURL(await this.getBlob(canvas));
let image;
/* global loadImage */
if (p5Utils.checkP5()) image = await this.loadAsync(src);
return { src, raws, image };
const { raw, image, blob } = await generatedImageResult(res, { returnTensors: false });
const src = typeof URL !== 'undefined' ? URL.createObjectURL(blob) : undefined;
return { src, raws: raw, image };
}

}
Expand Down
57 changes: 18 additions & 39 deletions src/CartoonGAN/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
*/

import * as tf from '@tensorflow/tfjs';
import handleArguments from "../utils/handleArguments";
import callCallback from '../utils/callcallback';
import p5Utils from '../utils/p5Utils';
import generatedImageResult from '../utils/generatedImageResult';
import handleArguments from '../utils/handleArguments';

const IMAGE_SIZE = 256;

Expand Down Expand Up @@ -88,45 +88,24 @@ class Cartoon {
async generateInternal(src) {
await this.ready;
await tf.nextFrame();
// adds resizeBilinear to resize image to 256x256 as required by the model
let img = tf.browser.fromPixels(src).resizeBilinear([IMAGE_SIZE,IMAGE_SIZE]);
if (img.shape[0] !== IMAGE_SIZE || img.shape[1] !== IMAGE_SIZE) {
throw new Error(`Input size should be ${IMAGE_SIZE}*${IMAGE_SIZE} but ${img.shape} is found`);
} else if (img.shape[2] !== 3) {
throw new Error(`Input color channel number should be 3 but ${img.shape[2]} is found`);
}
img = img.sub(127.5).div(127.5).reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);

const alpha = tf.ones([IMAGE_SIZE, IMAGE_SIZE, 1]).tile([1, 1, 1]).mul(255)
let res = this.model.predict(img);
res = res.add(1).mul(127.5).reshape([IMAGE_SIZE, IMAGE_SIZE, 3]).floor();
res = res.concat(alpha, 2)
const result = this.resultFinalize(res);

if(this.config.returnTensors){
return result;
}

img.dispose();
res.dispose();
return result;
const result = tf.tidy(() => {
// adds resizeBilinear to resize image to 256x256 as required by the model
let img = tf.browser.fromPixels(src).resizeBilinear([IMAGE_SIZE, IMAGE_SIZE]);
if (img.shape[0] !== IMAGE_SIZE || img.shape[1] !== IMAGE_SIZE) {
throw new Error(`Input size should be ${IMAGE_SIZE}*${IMAGE_SIZE} but ${img.shape} is found`);
} else if (img.shape[2] !== 3) {
throw new Error(`Input color channel number should be 3 but ${img.shape[2]} is found`);
}
img = img.sub(127.5).div(127.5).reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);

const alpha = tf.ones([IMAGE_SIZE, IMAGE_SIZE, 1]).tile([1, 1, 1]).mul(255)
let res = this.model.predict(img);
res = res.add(1).mul(127.5).reshape([IMAGE_SIZE, IMAGE_SIZE, 3]).floor();
return res.concat(alpha, 2).cast('int32');
})
return generatedImageResult(result, this.config);
}

/**
* @private
* @param {tf.Tensor3D} res
* @return {Promise<CartoonResult>}
*/
async resultFinalize(res){
const tensor = res;
const raw = await res.data();
const blob = await p5Utils.rawToBlob(raw, res.shape[0], res.shape[1]);
const image = await p5Utils.blobToP5Image(blob);
if(this.config.returnTensors){
return {tensor, raw, blob, image};
}
return {raw, blob, image};
}
}

/**
Expand Down
42 changes: 4 additions & 38 deletions src/DCGAN/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ This version is based on alantian's TensorFlow.js implementation: https://github
import * as tf from '@tensorflow/tfjs';
import axios from 'axios';
import callCallback from '../utils/callcallback';
import handleArguments from "../utils/handleArguments";
import p5Utils from '../utils/p5Utils';
import generatedImageResult from '../utils/generatedImageResult';
import handleArguments from '../utils/handleArguments';

// Default pre-trained face model

Expand Down Expand Up @@ -103,43 +103,9 @@ class DCGANBase {
* @return {object} includes blob, raw, and tensor. if P5 exists, then a p5Image
*/
async generateInternal(latentVector) {

const {
modelLatentDim
} = this.modelInfo;
const { modelLatentDim } = this.modelInfo;
const imageTensor = await this.compute(modelLatentDim, latentVector);

// get the raw data from tensor
const raw = await tf.browser.toPixels(imageTensor);
// get the blob from raw
const [imgHeight, imgWidth] = imageTensor.shape;
const blob = await p5Utils.rawToBlob(raw, imgWidth, imgHeight);

// get the p5.Image object
let p5Image;
if (p5Utils.checkP5()) {
p5Image = await p5Utils.blobToP5Image(blob);
}

// wrap up the final js result object
const result = {};
result.blob = blob;
result.raw = raw;


if (p5Utils.checkP5()) {
result.image = p5Image;
}

if(!this.config.returnTensors){
result.tensor = null;
imageTensor.dispose();
} else {
result.tensor = imageTensor;
}

return result;

return generatedImageResult(imageTensor, this.config);
}


Expand Down
Loading

0 comments on commit 2c5cd1e

Please sign in to comment.