From 690c006f9158949041f8ca9e8a2772162736498b Mon Sep 17 00:00:00 2001 From: Neeraj Deshpande Date: Thu, 22 Jun 2023 11:29:19 +0530 Subject: [PATCH 1/2] feat: initial nanodet model --- sahi/models/nanodet.py | 206 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 sahi/models/nanodet.py diff --git a/sahi/models/nanodet.py b/sahi/models/nanodet.py new file mode 100644 index 000000000..8e76d2e49 --- /dev/null +++ b/sahi/models/nanodet.py @@ -0,0 +1,206 @@ +# OBSS SAHI Tool +# Code written by AnNT, 2023. + +import contextlib +import logging +import os +from typing import Any, List, Optional + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +from nanodet.data.batch_process import stack_batch_img +from nanodet.data.collate import naive_collate +from nanodet.data.transform import Pipeline +from nanodet.model.arch import build_model +from nanodet.util import cfg, load_config, load_model_weight + +from sahi.models.base import DetectionModel +from sahi.prediction import ObjectPrediction +from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list +from sahi.utils.import_utils import check_requirements + + +class NanodetDetectionModel(DetectionModel): + """A class for performing object detection using the Nanodet model.""" + + def check_dependencies(self) -> None: + """Checks the system for the following dependencies: ["nanodet", "torch", "torchvision"]. + + Raises: + AssertionError: If any of the required dependencies is not installed. + """ + check_requirements(["nanodet", "torch", "torchvision"]) + + def load_model(self): + """Loads the detection model from configuration and weights. + + Raises: + IOError: If the model weights file is not found or unreadable. + """ + load_config(cfg, self.config_path) + self.pipeline = Pipeline(cfg.data.val.pipeline, cfg.data.val.keep_ratio) + + self.cfg = cfg + # create model + model = build_model(self.cfg.model) + ckpt = torch.load(self.model_path, map_location=lambda storage, loc: storage) + load_model_weight(model, ckpt, logger) + self.model = model.eval() + self.model.cuda() + # set category_mapping + if not self.category_mapping: + category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} + self.category_mapping = category_mapping + + def set_model(self, model: Any, **kwargs): + """Sets the Nanodet model to self.model and prepares it for inference. + + Args: + model (Any): A Nanodet model + + Raises: + TypeError: If the model provided is not a Nanodet model. + """ + self.model = model + self.model.eval() + self.model.cuda() + # set category_mapping + if not self.category_mapping: + category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} + self.category_mapping = category_mapping + + def perform_inference(self, image: np.ndarray, image_size: int = None): + """Performs prediction using self.model and sets the result to self._original_predictions. + + Args: + image (np.ndarray): A numpy array that contains the image to be predicted. + 3 channel image should be in RGB order. + image_size (int, optional): Inference input size. + + Raises: + AssertionError: If the model is not loaded. + """ + assert self.model is not None, "Model is not loaded, load it by calling .load_model()" + + img_info = {"id": 0} + img_info["file_name"] = None + height, width = image.shape[:2] + img_info["height"] = height + img_info["width"] = width + meta = dict(img_info=img_info, raw_img=image, img=image) + meta = self.pipeline(None, meta, self.cfg.data.val.input_size) + meta["img"] = torch.from_numpy(meta["img"].transpose(2, 0, 1)).to(self.device) + meta = naive_collate([meta]) + meta["img"] = stack_batch_img(meta["img"], divisible=32) + + # Muting nanodet logs to avoid clutter + with torch.no_grad(): + with open(os.devnull, "w") as dev_null, contextlib.redirect_stdout(dev_null): + results = self.model.inference(meta) + # compatibility with sahi v0.8.15 + if not isinstance(image, list): + image = [image] + self._original_predictions = results + + @property + def category_names(self): + """Returns category names in the configuration.""" + if isinstance(self.cfg.class_names, str): + return (self.cfg.class_names,) + return self.cfg.class_names + + @property + def num_categories(self): + """Returns the number of categories in the configuration.""" + if isinstance(self.cfg.class_names, str): + num_categories = 1 + else: + num_categories = len(self.cfg.class_names) + return num_categories + + @property + def has_mask(self): + """Returns False as Nanodet does not support segmentation models as of now.""" + return False # fix when Nanodet supports segmentation models + + def process_prediction(self, category_box, category_id, shift_amount, full_shape): + """Processes a single category prediction. + + Args: + category_box: The bounding box of the category prediction. + category_id: The category ID. + shift_amount: The shift amount. + full_shape: The full shape of the prediction. + + Returns: + ObjectPrediction or None: The processed object prediction if valid, otherwise None. + """ + bbox, score = category_box[:4], category_box[4] + category_name = self.category_mapping[str(category_id)] + + if score < self.confidence_threshold: + return None + + bool_mask = None + bbox = [max(0, x) for x in bbox] + + if full_shape is not None: + bbox = [min(dim, x) for dim, x in zip(full_shape[::-1], bbox)] + + if (not (bbox[0] < bbox[2])) or (not (bbox[1] < bbox[3])): + logger.warning(f"ignoring invalid prediction with bbox: {bbox}") + return None + + return ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = None, + full_shape_list: Optional[List[List[int]]] = None, + ): + """Creates a list of ObjectPrediction from the original predictions. + + Args: + shift_amount_list (List[List[int]], optional): The shift amount list. + full_shape_list (List[List[int]], optional): The full shape list. + + Returns: + List[List[ObjectPrediction]]: The list of ObjectPrediction per image. + """ + if shift_amount_list is None: + shift_amount_list = [[0, 0]] + original_predictions = self._original_predictions + shift_amount_list = fix_shift_amount_list(shift_amount_list) + full_shape_list = fix_full_shape_list(full_shape_list) + num_categories = self.num_categories + + object_prediction_list_per_image = [ + [ + pred + for category_id in range(num_categories) + for pred in ( + self.process_prediction( + category_box, + category_id, + shift_amount_list[image_ind], + full_shape_list[image_ind] if full_shape_list else None, + ) + for category_box in original_predictions[image_ind][category_id] + ) + if pred is not None + ] + for image_ind in original_predictions.keys() + ] + + self._object_prediction_list_per_image = object_prediction_list_per_image From 6169a4c28f0dadac450dca5e938a74ad6a83c145 Mon Sep 17 00:00:00 2001 From: Neeraj Deshpande Date: Thu, 22 Jun 2023 15:01:29 +0530 Subject: [PATCH 2/2] feat: tests and bugfixes --- .gitignore | 1 + sahi/models/nanodet.py | 132 ++++++------ sahi/utils/nanodet.py | 22 ++ .../models/nanodet/nanodet-plus-m_416.yml | 201 +++++++++++++++++ tests/test_nanodetmodel.py | 203 ++++++++++++++++++ 5 files changed, 494 insertions(+), 65 deletions(-) create mode 100644 sahi/utils/nanodet.py create mode 100644 tests/data/models/nanodet/nanodet-plus-m_416.yml create mode 100644 tests/test_nanodetmodel.py diff --git a/.gitignore b/.gitignore index e9510eefa..25b6bd7c7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pkl *.pth *.pt +*.ckpt weights* .vscode .idea diff --git a/sahi/models/nanodet.py b/sahi/models/nanodet.py index 8e76d2e49..2fcdb6e9c 100644 --- a/sahi/models/nanodet.py +++ b/sahi/models/nanodet.py @@ -49,7 +49,7 @@ def load_model(self): ckpt = torch.load(self.model_path, map_location=lambda storage, loc: storage) load_model_weight(model, ckpt, logger) self.model = model.eval() - self.model.cuda() + self.model.to(self.device) # set category_mapping if not self.category_mapping: category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} @@ -66,7 +66,7 @@ def set_model(self, model: Any, **kwargs): """ self.model = model self.model.eval() - self.model.cuda() + self.model.to(self.device) # set category_mapping if not self.category_mapping: category_mapping = {str(ind): category_name for ind, category_name in enumerate(self.category_names)} @@ -126,81 +126,83 @@ def has_mask(self): """Returns False as Nanodet does not support segmentation models as of now.""" return False # fix when Nanodet supports segmentation models - def process_prediction(self, category_box, category_id, shift_amount, full_shape): - """Processes a single category prediction. - - Args: - category_box: The bounding box of the category prediction. - category_id: The category ID. - shift_amount: The shift amount. - full_shape: The full shape of the prediction. - - Returns: - ObjectPrediction or None: The processed object prediction if valid, otherwise None. - """ - bbox, score = category_box[:4], category_box[4] - category_name = self.category_mapping[str(category_id)] - - if score < self.confidence_threshold: - return None - - bool_mask = None - bbox = [max(0, x) for x in bbox] - - if full_shape is not None: - bbox = [min(dim, x) for dim, x in zip(full_shape[::-1], bbox)] - - if (not (bbox[0] < bbox[2])) or (not (bbox[1] < bbox[3])): - logger.warning(f"ignoring invalid prediction with bbox: {bbox}") - return None - - return ObjectPrediction( - bbox=bbox, - category_id=category_id, - score=score, - bool_mask=bool_mask, - category_name=category_name, - shift_amount=shift_amount, - full_shape=full_shape, - ) - def _create_object_prediction_list_from_original_predictions( self, shift_amount_list: Optional[List[List[int]]] = None, full_shape_list: Optional[List[List[int]]] = None, ): - """Creates a list of ObjectPrediction from the original predictions. - + """ + self._original_predictions is converted to a list of prediction.ObjectPrediction and set to + self._object_prediction_list_per_image. Args: - shift_amount_list (List[List[int]], optional): The shift amount list. - full_shape_list (List[List[int]], optional): The full shape list. - - Returns: - List[List[ObjectPrediction]]: The list of ObjectPrediction per image. + shift_amount_list: list of list + To shift the box predictions from sliced image to full sized image, should + be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...] + full_shape_list: list of list + Size of the full image after shifting, should be in the form of + List[[height, width],[height, width],...] """ if shift_amount_list is None: - shift_amount_list = [[0, 0]] + shift_amount_list = [[0, 0]] * len(self._original_predictions) original_predictions = self._original_predictions + category_mapping = self.category_mapping + + # compatilibty for sahi v0.8.15 shift_amount_list = fix_shift_amount_list(shift_amount_list) full_shape_list = fix_full_shape_list(full_shape_list) + + # parse boxes from predictions num_categories = self.num_categories + object_prediction_list_per_image = [] - object_prediction_list_per_image = [ - [ - pred - for category_id in range(num_categories) - for pred in ( - self.process_prediction( - category_box, - category_id, - shift_amount_list[image_ind], - full_shape_list[image_ind] if full_shape_list else None, - ) - for category_box in original_predictions[image_ind][category_id] - ) - if pred is not None - ] - for image_ind in original_predictions.keys() - ] + for image_ind, original_prediction in original_predictions.items(): + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None else full_shape_list[image_ind] + + object_prediction_list = [] + + # process predictions + for category_id in range(num_categories): + category_boxes = original_prediction[category_id] + + for *bbox, score in category_boxes: + # ignore low scored predictions + if score < self.confidence_threshold: + continue + + category_name = category_mapping[str(category_id)] + + bool_mask = None + # fix negative box coords + bbox = [max(0, coord) for coord in bbox] + + # fix out of image box coords + if full_shape is not None: + bbox = [min(full_shape[i % 2], bbox[i]) for i in range(4)] + + # ignore invalid predictions + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): + logger.warning(f"Ignoring invalid prediction with bbox: {bbox}") + continue + + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + bool_mask=bool_mask, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) self._object_prediction_list_per_image = object_prediction_list_per_image + + +""" + if full_shape is not None: + bbox[0] = min(full_shape[1], bbox[0]) + bbox[1] = min(full_shape[0], bbox[1]) + bbox[2] = min(full_shape[1], bbox[2]) + bbox[3] = min(full_shape[0], bbox[3])""" diff --git a/sahi/utils/nanodet.py b/sahi/utils/nanodet.py new file mode 100644 index 000000000..135a7669c --- /dev/null +++ b/sahi/utils/nanodet.py @@ -0,0 +1,22 @@ +import logging +from pathlib import Path + +import requests + + +class NanodetConstants: + NANODET_PLUS_CONFIG = Path("tests/data/models/nanodet/nanodet-plus-m_416.yml").resolve().as_posix() + + NANODET_PLUS_MODEL = Path("tests/data/models/nanodet/model.ckpt").resolve().as_posix() + + NANODET_PLUS_URL = ( + "https://github.com/RangiLyu/nanodet/releases/download/v1.0.0-alpha-1/nanodet-plus-m_416_checkpoint.ckpt" + ) + + def __init__(self) -> None: + if not Path(self.NANODET_PLUS_MODEL).exists(): + logging.info("Downloading Nanodet model.") + response = requests.get(self.NANODET_PLUS_URL, allow_redirects=True, timeout=10) + logging.info("Downloaded Nanodet model.") + with open(self.NANODET_PLUS_MODEL, "wb") as model_file: + model_file.write(response.content) diff --git a/tests/data/models/nanodet/nanodet-plus-m_416.yml b/tests/data/models/nanodet/nanodet-plus-m_416.yml new file mode 100644 index 000000000..50bc64ac3 --- /dev/null +++ b/tests/data/models/nanodet/nanodet-plus-m_416.yml @@ -0,0 +1,201 @@ +# nanodet-plus-m_416 +# COCO mAP(0.5:0.95) = 0.304 +# AP_50 = 0.459 +# AP_75 = 0.317 +# AP_small = 0.106 +# AP_m = 0.322 +# AP_l = 0.477 +save_dir: tests/data/models/nanodet/nanodet-plus-m_416 +model: + weight_averager: + name: ExpMovingAverager + decay: 0.9998 + arch: + name: NanoDetPlus + detach_epoch: 10 + backbone: + name: ShuffleNetV2 + model_size: 1.0x + out_stages: [2, 3, 4] + activation: LeakyReLU + fpn: + name: GhostPAN + in_channels: [116, 232, 464] + out_channels: 96 + kernel_size: 5 + num_extra_level: 1 + use_depthwise: True + activation: LeakyReLU + head: + name: NanoDetPlusHead + num_classes: 80 + input_channel: 96 + feat_channels: 96 + stacked_convs: 2 + kernel_size: 5 + strides: [8, 16, 32, 64] + activation: LeakyReLU + reg_max: 7 + norm_cfg: + type: BN + loss: + loss_qfl: + name: QualityFocalLoss + use_sigmoid: True + beta: 2.0 + loss_weight: 1.0 + loss_dfl: + name: DistributionFocalLoss + loss_weight: 0.25 + loss_bbox: + name: GIoULoss + loss_weight: 2.0 + # Auxiliary head, only use in training time. + aux_head: + name: SimpleConvHead + num_classes: 80 + input_channel: 192 + feat_channels: 192 + stacked_convs: 4 + strides: [8, 16, 32, 64] + activation: LeakyReLU + reg_max: 7 +data: + train: + name: CocoDataset + img_path: coco/train2017 + ann_path: coco/annotations/instances_train2017.json + input_size: [416, 416] #[w,h] + keep_ratio: False + pipeline: + perspective: 0.0 + scale: [0.6, 1.4] + stretch: [[0.8, 1.2], [0.8, 1.2]] + rotation: 0 + shear: 0 + translate: 0.2 + flip: 0.5 + brightness: 0.2 + contrast: [0.6, 1.4] + saturation: [0.5, 1.2] + normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] + val: + name: CocoDataset + img_path: coco/val2017 + ann_path: coco/annotations/instances_val2017.json + input_size: [416, 416] #[w,h] + keep_ratio: False + pipeline: + normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] +device: + gpu_ids: [0] + workers_per_gpu: 10 + batchsize_per_gpu: 96 + precision: 32 # set to 16 to use AMP training +schedule: + # resume: + # load_model: + optimizer: + name: AdamW + lr: 0.001 + weight_decay: 0.05 + warmup: + name: linear + steps: 500 + ratio: 0.0001 + total_epochs: 300 + lr_schedule: + name: CosineAnnealingLR + T_max: 300 + eta_min: 0.00005 + val_intervals: 10 +grad_clip: 35 +evaluator: + name: CocoDetectionEvaluator + save_key: mAP +log: + interval: 50 + +class_names: + [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic_light", + "fire_hydrant", + "stop_sign", + "parking_meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports_ball", + "kite", + "baseball_bat", + "baseball_glove", + "skateboard", + "surfboard", + "tennis_racket", + "bottle", + "wine_glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot_dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted_plant", + "bed", + "dining_table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell_phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy_bear", + "hair_drier", + "toothbrush", + ] diff --git a/tests/test_nanodetmodel.py b/tests/test_nanodetmodel.py new file mode 100644 index 000000000..fa9d2f925 --- /dev/null +++ b/tests/test_nanodetmodel.py @@ -0,0 +1,203 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon, 2020. + +import unittest + +from sahi.models.nanodet import NanodetDetectionModel +from sahi.utils.cv import read_image +from sahi.utils.nanodet import NanodetConstants + +MODEL_DEVICE = "cpu" +CONFIDENCE_THRESHOLD = 0.5 +IMAGE_SIZE = 320 +CAR_INDEX = 2 +nanodet_constants = NanodetConstants() + + +class TestNanodetDetectionModel(unittest.TestCase): + def test_load_model(self): + nanodet_detection_model = NanodetDetectionModel( + model_path=nanodet_constants.NANODET_PLUS_MODEL, + config_path=nanodet_constants.NANODET_PLUS_CONFIG, + device=MODEL_DEVICE, + confidence_threshold=CONFIDENCE_THRESHOLD, + load_at_init=True, + ) + self.assertNotEqual(nanodet_detection_model.model, None) + + def test_perform_inference(self): + nanodet_detection_model = NanodetDetectionModel( + model_path=nanodet_constants.NANODET_PLUS_MODEL, + config_path=nanodet_constants.NANODET_PLUS_CONFIG, + device=MODEL_DEVICE, + confidence_threshold=CONFIDENCE_THRESHOLD, + load_at_init=True, + ) + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + # perform inference + nanodet_detection_model.perform_inference(image) + original_predictions = nanodet_detection_model.original_predictions[0] + + # find box of first car detection with conf greater than 0.5 + for detection in original_predictions[CAR_INDEX]: + if detection[-1] > CONFIDENCE_THRESHOLD: + box = detection[:4] + break + # compare + + self.assertEqual([i for i in map(int, box)], [445, 309, 493, 342]) + self.assertEqual(len(original_predictions), 80) + + def test_convert_original_predictions_without_mask_output(self): + nanodet_detection_model = NanodetDetectionModel( + model_path=nanodet_constants.NANODET_PLUS_MODEL, + config_path=nanodet_constants.NANODET_PLUS_CONFIG, + device=MODEL_DEVICE, + confidence_threshold=CONFIDENCE_THRESHOLD, + load_at_init=True, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # perform inference + nanodet_detection_model.perform_inference(image) + + # convert predictions to ObjectPrediction list + nanodet_detection_model.convert_original_predictions() + object_prediction_list = nanodet_detection_model.object_prediction_list + + # compare + self.assertEqual(len(object_prediction_list), 3) + self.assertEqual(object_prediction_list[0].category.id, 2) + self.assertEqual(object_prediction_list[0].category.name, "car") + predicted_bbox = object_prediction_list[0].bbox.to_xywh() + desired_bbox = [445, 309, 47, 33] + margin = 3 + for ind, point in enumerate(predicted_bbox): + if not (point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin): + raise AssertionError(f"desired_bbox: {desired_bbox}, predicted_bbox: {predicted_bbox}") + + self.assertEqual(object_prediction_list[2].category.id, 2) + self.assertEqual(object_prediction_list[2].category.name, "car") + predicted_bbox = object_prediction_list[2].bbox.to_xywh() + desired_bbox = [377, 281, 41, 24] + margin = 3 + for ind, point in enumerate(predicted_bbox): + if not (point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin): + raise AssertionError(f"desired_bbox: {desired_bbox}, predicted_bbox: {predicted_bbox}") + + def test_get_prediction_detectron2(self): + from sahi.predict import get_prediction + + # init model + nanodet_detection_model = NanodetDetectionModel( + model_path=nanodet_constants.NANODET_PLUS_MODEL, + config_path=nanodet_constants.NANODET_PLUS_CONFIG, + device=MODEL_DEVICE, + confidence_threshold=CONFIDENCE_THRESHOLD, + load_at_init=True, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # get full sized prediction + prediction_result = get_prediction( + image=image, + detection_model=nanodet_detection_model, + shift_amount=[0, 0], + full_shape=None, + postprocess=None, + ) + + object_prediction_list = prediction_result.object_prediction_list + + # compare + self.assertEqual(len(object_prediction_list), 3) + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + self.assertEqual(num_person, 0) + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + self.assertEqual(num_truck, 0) + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + self.assertEqual(num_car, 3) + + def test_get_sliced_prediction_detectron2(self): + from sahi.models.nanodet import NanodetDetectionModel + from sahi.predict import get_sliced_prediction + + # init model + nanodet_detection_model = NanodetDetectionModel( + model_path=nanodet_constants.NANODET_PLUS_MODEL, + config_path=nanodet_constants.NANODET_PLUS_CONFIG, + device=MODEL_DEVICE, + confidence_threshold=CONFIDENCE_THRESHOLD, + load_at_init=True, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + + slice_height = 416 + slice_width = 416 + overlap_height_ratio = 0.1 + overlap_width_ratio = 0.2 + postprocess_type = "GREEDYNMM" + match_metric = "IOS" + match_threshold = 0.5 + class_agnostic = True + + # get sliced prediction + prediction_result = get_sliced_prediction( + image=image_path, + detection_model=nanodet_detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=False, + postprocess_type=postprocess_type, + postprocess_match_threshold=match_threshold, + postprocess_match_metric=match_metric, + postprocess_class_agnostic=class_agnostic, + ) + object_prediction_list = prediction_result.object_prediction_list + + # compare + + self.assertEqual(len(object_prediction_list), 11) + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + + self.assertEqual(num_person, 0) + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + + self.assertEqual(num_truck, 0) + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + + self.assertEqual(num_car, 11) + + +if __name__ == "__main__": + unittest.main()