Skip to content

Commit

Permalink
Merge pull request #1322 from lindapaiste/cleanup/handle-arguments
Browse files Browse the repository at this point in the history
Cleanup 🧹  image argument checking
  • Loading branch information
lindapaiste authored May 12, 2022
2 parents 1f5f8b1 + 2f54164 commit b6fd7eb
Show file tree
Hide file tree
Showing 22 changed files with 645 additions and 1,090 deletions.
259 changes: 98 additions & 161 deletions src/BodyPix/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,27 @@
// @ts-check
import * as tf from '@tensorflow/tfjs';
import * as bp from '@tensorflow-models/body-pix';
import handleArguments from "../utils/handleArguments";
import callCallback from '../utils/callcallback';
import p5Utils from '../utils/p5Utils';
import BODYPIX_PALETTE from './BODYPIX_PALETTE';

/**
* @typedef {Record<string, {color: [number, number, number], id: number}>} BodyPixPalette
*/

/**
* @typedef {Object} BodyPixOptions
* @property {import('@tensorflow-models/body-pix/dist/mobilenet').MobileNetMultiplier} [multiplier]
* @property {import('@tensorflow-models/body-pix/dist/mobilenet').OutputStride} [outputStride]
* @property {number} [segmentationThreshold]
* @property {BodyPixPalette} [palette]
* @property {boolean} [returnTensors]
*/

/**
* @type {BodyPixOptions}
*/
const DEFAULTS = {
"multiplier": 0.75,
"outputStride": 16,
Expand All @@ -28,17 +45,11 @@ const DEFAULTS = {

class BodyPix {
/**
* Create BodyPix.
* @param {HTMLVideoElement} video - An HTMLVideoElement.
* @param {{
* multiplier: Number;
* outputStride: Number;
* segmentationThreshold: Number;
* palette: Object;
* returnTensors: Boolean;
* }} options - An object with options.
* @param {Function} callback - A callback to be called when the model is ready.
*/
* Create BodyPix.
* @param {HTMLVideoElement} [video] - An HTMLVideoElement.
* @param {BodyPixOptions} [options] - An object with options.
* @param {ML5Callback<BodyPix>} [callback] - A callback to be called when the model is ready.
*/
constructor(video, options, callback) {
this.video = video;
this.model = null;
Expand All @@ -56,20 +67,20 @@ class BodyPix {
}

/**
* Load the model and set it to this.model
* @return {Promise<Object>} the BodyPix model.
*/
* Load the model and set it to this.model
* @return {Promise<BodyPix>} the BodyPix model.
*/
async loadModel() {
this.model = await bp.load(this.config.multiplier);
this.modelReady = true;
return this;
}

/**
* Returns an rgb array
* @param {Object} p5ColorObj - a p5.Color obj
* @return {Array} an [r,g,b] array
*/
* Returns an rgb array
* @param {Object} p5ColorObj - a p5.Color obj
* @return {Array} an [r,g,b] array
*/
/* eslint class-methods-use-this: "off" */
p5Color2RGB(p5ColorObj) {
const regExp = /\(([^)]+)\)/;
Expand All @@ -79,20 +90,23 @@ class BodyPix {
}

/**
* Returns a p5Image
* @param {*} tfBrowserPixelImage
*/
* Returns a p5Image
* @param {Uint8ClampedArray} tfBrowserPixelImage
* @param {number} segmentationWidth
* @param {number} segmentationHeight
* @return {Promise<p5.Image>}
*/
async convertToP5Image(tfBrowserPixelImage, segmentationWidth, segmentationHeight) {
const blob1 = await p5Utils.rawToBlob(tfBrowserPixelImage, segmentationWidth, segmentationHeight);
const p5Image1 = await p5Utils.blobToP5Image(blob1);
return p5Image1
}

/**
* Returns a bodyPartsSpec object
* @param {Array} colorOptions - an array of [r,g,b] colors
* @return {object} an object with the bodyParts by color and id
*/
* Returns a bodyPartsSpec object
* @param {Array} colorOptions - an array of [r,g,b] colors
* @return {object} an object with the bodyParts by color and id
*/
/* eslint class-methods-use-this: "off" */
bodyPartsSpec(colorOptions) {
const result = colorOptions !== undefined || Object.keys(colorOptions).length >= 24 ? colorOptions : this.config.palette;
Expand All @@ -112,13 +126,23 @@ class BodyPix {
}

/**
* Segments the image with partSegmentation, return result object
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} imgToSegment -
* takes any of the following params
* @param {object} segmentationOptions - config params for the segmentation
* includes outputStride, segmentationThreshold
* @return {Promise<Object>} a result object with image, raw, bodyParts
*/
* @typedef {Object} SegmentationResult
* @property {{data: Uint8Array | Int32Array, width: number, height: number}} segmentation
* @property {p5.Image | Uint8ClampedArray} personMask - will be a p5 Image if p5 is available,
* or an array of pixel values otherwise.
* @property {p5.Image | Uint8ClampedArray} backgroundMask
* @property {{personMask: tf.Tensor | null, backgroundMask: tf.Tensor | null, partMask?: tf.Tensor | null}} tensor -
* return the Tensor objects for the person and the background if option `returnTensors` is true.
* @property {{personMask: ImageData, backgroundMask: ImageData, partMask?: ImageData}} raw
* @property {BodyPixPalette} [bodyParts] - body parts are included when calling `segmentWithParts`.
*/

/**
* Segments the image with partSegmentation, return result object* @param {InputImage} [imgToSegment]
* @param {BodyPixOptions} [segmentationOptions] - config params for the segmentation
* includes outputStride, segmentationThreshold
* @return {Promise<SegmentationResult>} a result object with image, raw, bodyParts
*/
async segmentWithPartsInternal(imgToSegment, segmentationOptions) {
// estimatePartSegmentation
await this.ready;
Expand Down Expand Up @@ -220,66 +244,34 @@ class BodyPix {
}

/**
* Segments the image with partSegmentation
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} optionsOrCallback -
* takes any of the following params
* @param {object} configOrCallback - config params for the segmentation
* includes palette, outputStride, segmentationThreshold
* @param {function} cb - a callback function that handles the results of the function.
* @return {Promise<Object>} a promise or the results of a given callback, cb.
*/
async segmentWithParts(optionsOrCallback, configOrCallback, cb) {
let imgToSegment = this.video;
let callback;
let segmentationOptions = this.config;

// Handle the image to predict
if (typeof optionsOrCallback === 'function') {
imgToSegment = this.video;
callback = optionsOrCallback;
// clean the following conditional statement up!
} else if (optionsOrCallback instanceof HTMLImageElement ||
optionsOrCallback instanceof HTMLCanvasElement ||
optionsOrCallback instanceof HTMLVideoElement ||
optionsOrCallback instanceof ImageData) {
imgToSegment = optionsOrCallback;
} else if (typeof optionsOrCallback === 'object' && (optionsOrCallback.elt instanceof HTMLImageElement ||
optionsOrCallback.elt instanceof HTMLCanvasElement ||
optionsOrCallback.elt instanceof ImageData)) {
imgToSegment = optionsOrCallback.elt; // Handle p5.js image
} else if (typeof optionsOrCallback === 'object' && optionsOrCallback.canvas instanceof HTMLCanvasElement) {
imgToSegment = optionsOrCallback.canvas; // Handle p5.js image
} else if (typeof optionsOrCallback === 'object' && optionsOrCallback.elt instanceof HTMLVideoElement) {
imgToSegment = optionsOrCallback.elt; // Handle p5.js image
} else if (!(this.video instanceof HTMLVideoElement)) {
// Handle unsupported input
* Segments the image with partSegmentation
*
* Takes any of the following params:
* - an image to segment
* - config params for the segmentation, includes palette, outputStride, segmentationThreshold
* - a callback function that handles the results of the function.
* @param {(InputImage | BodyPixOptions | ML5Callback<SegmentationResult>[])} [args]
* @return {Promise<SegmentationResult>}
*/
async segmentWithParts(...args) {
const { options = this.config, callback, image = this.video } = handleArguments(...args);

if (!image) {
throw new Error(
'No input image provided. If you want to classify a video, pass the video element in the constructor. ',
'No input image provided. If you want to classify a video, pass the video element in the constructor.'
);
}

if (typeof configOrCallback === 'object') {
segmentationOptions = configOrCallback;
} else if (typeof configOrCallback === 'function') {
callback = configOrCallback;
}

if (typeof cb === 'function') {
callback = cb;
}

return callCallback(this.segmentWithPartsInternal(imgToSegment, segmentationOptions), callback);

return callCallback(this.segmentWithPartsInternal(image, options), callback);
}

/**
* Segments the image with personSegmentation, return result object
* @param {HTMLImageElement | HTMLCanvasElement | object | function | number} imgToSegment -
* takes any of the following params
* @param {object} segmentationOptions - config params for the segmentation
* includes outputStride, segmentationThreshold
* @return {Promise<Object>} a result object with maskBackground, maskPerson, raw
*/
* Segments the image with personSegmentation, return result object
* @param {InputImage} imgToSegment
* @param {BodyPixOptions} segmentationOptions - config params for the segmentation
* includes outputStride, segmentationThreshold
* @return {Promise<SegmentationResult>} a result object with maskBackground, maskPerson, raw
*/
async segmentInternal(imgToSegment, segmentationOptions) {

await this.ready;
Expand Down Expand Up @@ -377,92 +369,37 @@ class BodyPix {
}

/**
* Segments the image with personSegmentation
* @param {HTMLVideoElement | HTMLImageElement | HTMLCanvasElement | object | function | number} optionsOrCallback -
* takes any of the following params
* @param {object} configOrCallback - config params for the segmentation
* includes outputStride, segmentationThreshold
* @param {function} cb - a callback function that handles the results of the function.
* @return {Promise<Object>} a promise or the results of a given callback, cb.
*/
async segment(optionsOrCallback, configOrCallback, cb) {
let imgToSegment = this.video;
let callback;
let segmentationOptions = this.config;

// Handle the image to predict
if (typeof optionsOrCallback === 'function') {
imgToSegment = this.video;
callback = optionsOrCallback;
// clean the following conditional statement up!
} else if (optionsOrCallback instanceof HTMLImageElement ||
optionsOrCallback instanceof HTMLCanvasElement ||
optionsOrCallback instanceof HTMLVideoElement ||
optionsOrCallback instanceof ImageData) {
imgToSegment = optionsOrCallback;
} else if (typeof optionsOrCallback === 'object' && (optionsOrCallback.elt instanceof HTMLImageElement ||
optionsOrCallback.elt instanceof HTMLCanvasElement ||
optionsOrCallback.elt instanceof ImageData)) {
imgToSegment = optionsOrCallback.elt; // Handle p5.js image
} else if (typeof optionsOrCallback === 'object' && optionsOrCallback.canvas instanceof HTMLCanvasElement) {
imgToSegment = optionsOrCallback.canvas; // Handle p5.js image
} else if (typeof optionsOrCallback === 'object' && optionsOrCallback.elt instanceof HTMLVideoElement) {
imgToSegment = optionsOrCallback.elt; // Handle p5.js image
} else if (!(this.video instanceof HTMLVideoElement)) {
// Handle unsupported input
* Segments the image with personSegmentation
*
* Takes any of the following params:
* - an image to segment
* - config params for the segmentation, includes outputStride, segmentationThreshold
* - a callback function that handles the results of the function.
* @param {(InputImage | BodyPixOptions | ML5Callback<SegmentationResult>)[]} [args]
* @return {Promise<SegmentationResult>}
*/
async segment(...args) {
const { options = this.config, callback, image = this.video } = handleArguments(...args);

if (!image) {
throw new Error(
'No input image provided. If you want to classify a video, pass the video element in the constructor. ',
'No input image provided. If you want to classify a video, pass the video element in the constructor.'
);
}

if (typeof configOrCallback === 'object') {
segmentationOptions = configOrCallback;
} else if (typeof configOrCallback === 'function') {
callback = configOrCallback;
}

if (typeof cb === 'function') {
callback = cb;
}

return callCallback(this.segmentInternal(imgToSegment, segmentationOptions), callback);
return callCallback(this.segmentInternal(image, options), callback);
}

}

/**
*
* @param {Object | Function} videoOrOptionsOrCallback
* @param {Object | Function} optionsOrCallback
* @param {Function} cb
* @returns {Promise<Object> | Function}
* @param {(HTMLVideoElement | p5.Video | BodyPixOptions | ML5Callback<BodyPix>)[]} [inputs]
* @return {BodyPix | Promise<BodyPix>}
*/
const bodyPix = (videoOrOptionsOrCallback, optionsOrCallback, cb) => {
let video;
let options = {};
let callback = cb;

if (videoOrOptionsOrCallback instanceof HTMLVideoElement) {
video = videoOrOptionsOrCallback;
} else if (
typeof videoOrOptionsOrCallback === 'object' &&
videoOrOptionsOrCallback.elt instanceof HTMLVideoElement
) {
video = videoOrOptionsOrCallback.elt; // Handle a p5.js video element
} else if (typeof videoOrOptionsOrCallback === 'object') {
options = videoOrOptionsOrCallback;
} else if (typeof videoOrOptionsOrCallback === 'function') {
callback = videoOrOptionsOrCallback;
}

if (typeof optionsOrCallback === 'object') {
options = optionsOrCallback;
} else if (typeof optionsOrCallback === 'function') {
callback = optionsOrCallback;
}

const instance = new BodyPix(video, options, callback);
return callback ? instance : instance.ready;
const bodyPix = (...inputs) => {
const args = handleArguments(...inputs);
const instance = new BodyPix(args.video, args.options || {}, args.callback);
return args.callback ? instance : instance.ready;
}

export default bodyPix;
Loading

0 comments on commit b6fd7eb

Please sign in to comment.