-
Notifications
You must be signed in to change notification settings - Fork 837
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for grounding dino (#1137)
* Add sum, norm, normalize unit tests * Add min/max unit tests * make tests synchronous * Cleanup * Update mean op unit tests * Add more tensor unit tests * Update view unit test * Add tensor construction unit tests * Add more tensor op unit tests * Add another squeeze unit test * Multiple dims for squeeze unit test * Refactor tensor reduce ops * Add support for `gt` and `lt` tensor ops * Add grounding dino implementation * Allow grounding dino to be usable via the pipeline API * Add listed support for grounding dino * Add grounding dino unit tests * Add zero-shot object detection pipeline unit test for grounding dino
- Loading branch information
Showing
15 changed files
with
915 additions
and
274 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
src/models/grounding_dino/image_processing_grounding_dino.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
|
||
import { | ||
ImageProcessor, | ||
} from "../../base/image_processors_utils.js"; | ||
import { ones } from '../../utils/tensor.js'; | ||
|
||
|
||
/** | ||
* @typedef {object} GroundingDinoFeatureExtractorResultProps | ||
* @property {import('../../utils/tensor.js').Tensor} pixel_mask | ||
* @typedef {import('../../base/image_processors_utils.js').ImageProcessorResult & GroundingDinoFeatureExtractorResultProps} GroundingDinoFeatureExtractorResult | ||
*/ | ||
|
||
export class GroundingDinoImageProcessor extends ImageProcessor { | ||
/** | ||
* Calls the feature extraction process on an array of images, preprocesses | ||
* each image, and concatenates the resulting features into a single Tensor. | ||
* @param {import('../../utils/image.js').RawImage[]} images The image(s) to extract features from. | ||
* @returns {Promise<GroundingDinoFeatureExtractorResult>} An object containing the concatenated pixel values of the preprocessed images. | ||
*/ | ||
async _call(images) { | ||
const result = await super._call(images); | ||
|
||
const dims = result.pixel_values.dims; | ||
const pixel_mask = ones([dims[0], dims[2], dims[3]]); | ||
|
||
return { ...result, pixel_mask }; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import { Processor } from "../../base/processing_utils.js"; | ||
import { AutoImageProcessor } from "../auto/image_processing_auto.js"; | ||
import { AutoTokenizer } from "../../tokenizers.js"; | ||
import { center_to_corners_format } from "../../base/image_processors_utils.js"; | ||
|
||
/** | ||
* Get token ids of phrases from posmaps and input_ids. | ||
* @param {import('../../utils/tensor.js').Tensor} posmaps A boolean tensor of unbatched text-thresholded logits related to the detected bounding boxes of shape `(hidden_size, )`. | ||
* @param {import('../../utils/tensor.js').Tensor} input_ids A tensor of token ids of shape `(sequence_length, )`. | ||
*/ | ||
function get_phrases_from_posmap(posmaps, input_ids) { | ||
|
||
const left_idx = 0; | ||
const right_idx = posmaps.dims.at(-1) - 1; | ||
|
||
const posmaps_list = posmaps.tolist(); | ||
posmaps_list.fill(false, 0, left_idx + 1); | ||
posmaps_list.fill(false, right_idx); | ||
|
||
const input_ids_list = input_ids.tolist(); | ||
return posmaps_list | ||
.map((val, idx) => val ? idx : null) | ||
.filter(idx => idx !== null) | ||
.map(i => input_ids_list[i]); | ||
} | ||
|
||
export class GroundingDinoProcessor extends Processor { | ||
static tokenizer_class = AutoTokenizer | ||
static image_processor_class = AutoImageProcessor | ||
|
||
/** | ||
* @typedef {import('../../utils/image.js').RawImage} RawImage | ||
*/ | ||
/** | ||
* | ||
* @param {RawImage|RawImage[]|RawImage[][]} images | ||
* @param {string|string[]} text | ||
* @returns {Promise<any>} | ||
*/ | ||
async _call(images, text, options = {}) { | ||
|
||
const image_inputs = images ? await this.image_processor(images, options) : {}; | ||
const text_inputs = text ? this.tokenizer(text, options) : {}; | ||
|
||
return { | ||
...text_inputs, | ||
...image_inputs, | ||
} | ||
} | ||
post_process_grounded_object_detection(outputs, input_ids, { | ||
box_threshold = 0.25, | ||
text_threshold = 0.25, | ||
target_sizes = null | ||
} = {}) { | ||
const { logits, pred_boxes } = outputs; | ||
const batch_size = logits.dims[0]; | ||
|
||
if (target_sizes !== null && target_sizes.length !== batch_size) { | ||
throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits") | ||
} | ||
const num_queries = logits.dims.at(1); | ||
|
||
const probs = logits.sigmoid(); // (batch_size, num_queries, 256) | ||
const scores = probs.max(-1).tolist(); // (batch_size, num_queries) | ||
|
||
// Convert to [x0, y0, x1, y1] format | ||
const boxes = pred_boxes.tolist() // (batch_size, num_queries, 4) | ||
.map(batch => batch.map(box => center_to_corners_format(box))); | ||
|
||
const results = []; | ||
for (let i = 0; i < batch_size; ++i) { | ||
const target_size = target_sizes !== null ? target_sizes[i] : null; | ||
|
||
// Convert from relative [0, 1] to absolute [0, height] coordinates | ||
if (target_size !== null) { | ||
boxes[i] = boxes[i].map(box => box.map((x, j) => x * target_size[(j + 1) % 2])); | ||
} | ||
|
||
const batch_scores = scores[i]; | ||
const final_scores = []; | ||
const final_phrases = []; | ||
const final_boxes = []; | ||
for (let j = 0; j < num_queries; ++j) { | ||
const score = batch_scores[j]; | ||
if (score <= box_threshold) { | ||
continue; | ||
} | ||
const box = boxes[i][j]; | ||
const prob = probs[i][j]; | ||
|
||
final_scores.push(score); | ||
final_boxes.push(box); | ||
|
||
const phrases = get_phrases_from_posmap(prob.gt(text_threshold), input_ids[i]); | ||
final_phrases.push(phrases); | ||
} | ||
results.push({ scores: final_scores, boxes: final_boxes, labels: this.batch_decode(final_phrases) }); | ||
} | ||
return results; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.