Skip to content

Commit

Permalink
Add multiclass nms options for object detector.
Browse files Browse the repository at this point in the history
Expose nms threshold to object detector options.

PiperOrigin-RevId: 707955417
  • Loading branch information
MediaPipe Team authored and copybara-github committed Dec 19, 2024
1 parent e00c842 commit e2f9635
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 25 deletions.
1 change: 1 addition & 0 deletions mediapipe/calculators/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ cc_library(
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
],
Expand Down
84 changes: 62 additions & 22 deletions mediapipe/calculators/util/non_max_suppression_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
Expand Down Expand Up @@ -121,6 +122,21 @@ float OverlapSimilarity(
return OverlapSimilarity(overlap_type, rect1, rect2);
}

// Copy all the scores (there is a single score in each detection after
// pruning detections) to an indexed vector for sorting. The first value is
// the index of the detection in the original vector from which the score
// stems, while the second is the actual score.
IndexedScores GetIndexedScores(const Detections& detections) {
IndexedScores indexed_scores;
indexed_scores.reserve(detections.size());
for (int index = 0; index < detections.size(); ++index) {
indexed_scores.push_back(
std::make_pair(index, detections.at(index).score(0)));
}
std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond);
return indexed_scores;
}

} // namespace

// A calculator performing non-maximum suppression on a set of detections.
Expand Down Expand Up @@ -203,7 +219,49 @@ class NonMaxSuppressionCalculator : public CalculatorBase {
}
return absl::OkStatus();
}
auto retained_detections = std::make_unique<Detections>();
if (options_.multiclass_nms()) {
absl::flat_hash_map<int, Detections> category_index_to_detections;
for (const auto& detection : input_detections) {
for (int index : detection.label_id()) {
category_index_to_detections[index].push_back(detection);
}
}
// For each category, do non-maximum suppression separately.
Detections detections_nms;
for (auto& [index, detections] : category_index_to_detections) {
auto retained_detections_per_category = std::make_unique<Detections>();
DoNonMaxSuppression(detections, cc,
retained_detections_per_category.get());
detections_nms.insert(detections_nms.end(),
retained_detections_per_category->begin(),
retained_detections_per_category->end());
}

// Descending sort and shrink the collected detections according to max
// num detections.
IndexedScores indexed_scores = GetIndexedScores(detections_nms);
int max_num_detections = static_cast<int>(indexed_scores.size());
if (options_.max_num_detections() > -1) {
max_num_detections =
std::min(max_num_detections, options_.max_num_detections());
}
retained_detections->reserve(max_num_detections);
for (int i = 0; i < max_num_detections; i++) {
retained_detections->push_back(
detections_nms.at(indexed_scores.at(i).first));
}
} else {
DoNonMaxSuppression(input_detections, cc, retained_detections.get());
}
cc->Outputs().Index(0).Add(retained_detections.release(),
cc->InputTimestamp());
return absl::OkStatus();
}

private:
void DoNonMaxSuppression(Detections& input_detections, CalculatorContext* cc,
Detections* output_detections) {
// Remove all but the maximum scoring label from each input detection. This
// corresponds to non-maximum suppression among detections which have
// identical locations.
Expand All @@ -214,42 +272,24 @@ class NonMaxSuppressionCalculator : public CalculatorBase {
pruned_detections.push_back(detection);
}
}

// Copy all the scores (there is a single score in each detection after
// the above pruning) to an indexed vector for sorting. The first value is
// the index of the detection in the original vector from which the score
// stems, while the second is the actual score.
IndexedScores indexed_scores;
indexed_scores.reserve(pruned_detections.size());
for (int index = 0; index < pruned_detections.size(); ++index) {
indexed_scores.push_back(
std::make_pair(index, pruned_detections[index].score(0)));
}
std::sort(indexed_scores.begin(), indexed_scores.end(), SortBySecond);

IndexedScores indexed_scores = GetIndexedScores(pruned_detections);
const int max_num_detections =
(options_.max_num_detections() > -1)
? options_.max_num_detections()
: static_cast<int>(indexed_scores.size());
// A set of detections and locations, wrapping the location data from each
// detection, which are retained after the non-maximum suppression.
auto* retained_detections = new Detections();
retained_detections->reserve(max_num_detections);
output_detections->reserve(max_num_detections);

if (options_.algorithm() == NonMaxSuppressionCalculatorOptions::WEIGHTED) {
WeightedNonMaxSuppression(indexed_scores, pruned_detections,
max_num_detections, cc, retained_detections);
max_num_detections, cc, output_detections);
} else {
NonMaxSuppression(indexed_scores, pruned_detections, max_num_detections,
cc, retained_detections);
cc, output_detections);
}

cc->Outputs().Index(0).Add(retained_detections, cc->InputTimestamp());

return absl::OkStatus();
}

private:
void NonMaxSuppression(const IndexedScores& indexed_scores,
const Detections& detections, int max_num_detections,
CalculatorContext* cc, Detections* output_detections) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,6 @@ message NonMaxSuppressionCalculatorOptions {
WEIGHTED = 1;
}
optional NmsAlgorithm algorithm = 7 [default = DEFAULT];

optional bool multiclass_nms = 8 [default = false];
}
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ void ConfigureNonMaxSuppressionCalculator(
options->set_algorithm(
mediapipe::NonMaxSuppressionCalculatorOptions::DEFAULT);
options->set_max_num_detections(detector_options.max_results());
options->set_multiclass_nms(detector_options.multiclass_nms());
}

// Sets the labels from post PostProcessingSpecs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,8 @@ message DetectorOptions {
// name is in this set will be filtered out. Duplicate or unknown category
// names are ignored. Mutually exclusive with category_allowlist.
repeated string category_denylist = 5;

// Whether to use multiclass NMS. That is, each category processes
// non-maximum-suppression separately.
optional bool multiclass_nms = 7 [default = false];
}
4 changes: 4 additions & 0 deletions mediapipe/tasks/cc/vision/object_detector/object_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ std::unique_ptr<ObjectDetectorOptionsProto> ConvertObjectDetectorOptionsToProto(
for (const std::string& category : options->category_denylist) {
options_proto->add_category_denylist(category);
}
options_proto->set_multiclass_nms(
options->non_max_suppression_options.multiclass_nms);
options_proto->set_min_suppression_threshold(
options->non_max_suppression_options.min_suppression_threshold);
return options_proto;
}

Expand Down
14 changes: 14 additions & 0 deletions mediapipe/tasks/cc/vision/object_detector/object_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ namespace vision {
using ObjectDetectorResult =
::mediapipe::tasks::components::containers::DetectionResult;

// Options related to non-maximum-suppression.
struct NonMaxSuppressionOptions {
// Whether to use multiclass non-max-suppression. That is, each category
// processes non-max-suppression separately.
bool multiclass_nms = false;

// Overlapping threshold for non-maximum-suppression. Only used for
// models without built-in non-maximum-suppression, i.e., models that don't
// use the Detection_Postprocess TFLite Op
float min_suppression_threshold = 0.3f;
};

// The options for configuring a mediapipe object detector task.
struct ObjectDetectorOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
Expand Down Expand Up @@ -87,6 +99,8 @@ struct ObjectDetectorOptions {
std::function<void(absl::StatusOr<ObjectDetectorResult>, const Image&,
int64_t)>
result_callback = nullptr;

NonMaxSuppressionOptions non_max_suppression_options;
};

// Performs object detection on single images, video frames, or live stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
task_options.category_allowlist());
detector_options.mutable_category_denylist()->CopyFrom(
task_options.category_denylist());
// TODO: expose min suppression threshold in
// ObjectDetectorOptions.
detector_options.set_min_suppression_threshold(0.3);
detector_options.set_multiclass_nms(task_options.multiclass_nms());
detector_options.set_min_suppression_threshold(
task_options.min_suppression_threshold());
MP_RETURN_IF_ERROR(
components::processors::ConfigureDetectionPostprocessingGraph(
model_resources, detector_options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,13 @@ message ObjectDetectorOptions {
// category name is in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_allowlist.
repeated string category_denylist = 6;

// Whether to use multiclass NMS. That is, each category processes
// non-maximum-suppression separately.
optional bool multiclass_nms = 7 [default = false];

// Overlapping threshold for non-maximum-suppression calculator. Only used for
// models without built-in non-maximum-suppression, i.e., models that don't
// use the Detection_Postprocess TFLite Op
optional float min_suppression_threshold = 8 [default = 0.3];
}

0 comments on commit e2f9635

Please sign in to comment.