Skip to content

Commit

Permalink
Fix auto detectors (#1497)
Browse files Browse the repository at this point in the history
* fix yolo predictor

* fix predict
  • Loading branch information
zhreshold authored Nov 4, 2020
1 parent 218007a commit 6c1a584
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 4 deletions.
1 change: 1 addition & 0 deletions gluoncv/auto/estimators/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import mxnet as mx
from mxnet import gluon

from ....data.batchify import Tuple
from ....data.transforms.presets.rcnn import load_test, transform_test
from ....data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, FasterRCNNDefaultValTransform
from ....model_zoo import get_model
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/auto/estimators/ssd/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class SSD:
@dataclass
class TrainCfg:
# Batch size during training
batch_size : int = 32
batch_size : int = 16
# starting epoch
start_epoch : int = 0
# total epoch for training
Expand Down
1 change: 1 addition & 0 deletions gluoncv/auto/estimators/ssd/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mxnet.contrib import amp

from .... import utils as gutils
from ....data.batchify import Tuple, Stack, Pad
from ....utils.metrics.voc_detection import VOC07MApMetric, VOCMApMetric
from ....model_zoo import get_model
from ....model_zoo import custom_ssd
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/auto/estimators/yolo/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class YOLOv3:
@dataclass
class TrainCfg:
# Training mini-batch size
batch_size : int = 32
batch_size : int = 16
# Training epochs.
epochs : int = 20
# Starting epoch for resuming, default is 0 for new training.
Expand Down
1 change: 1 addition & 0 deletions gluoncv/auto/estimators/yolo/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mxnet.contrib import amp

from .... import utils as gutils
from ....data.batchify import Tuple, Stack, Pad
from ....data.transforms.presets.yolo import load_test, transform_test
from ....model_zoo import get_model
from ....model_zoo import custom_yolov3
Expand Down
12 changes: 11 additions & 1 deletion tests/auto/test_auto_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,48 @@
from gluoncv.auto.tasks import ObjectDetection
from autogluon.core.scheduler.resource import get_cpu_count, get_gpu_count

IMAGE_CLASS_DATASET, _, _ = ImageClassification.Dataset.from_folders('https://autogluon.s3.amazonaws.com/datasets/shopee-iet.zip')
IMAGE_CLASS_DATASET, _, IMAGE_CLASS_TEST = ImageClassification.Dataset.from_folders(
'https://autogluon.s3.amazonaws.com/datasets/shopee-iet.zip')
OBJECT_DETCTION_DATASET = ObjectDetection.Dataset.from_voc('https://autogluon.s3.amazonaws.com/datasets/tiny_motorbike.zip')

def test_image_classification_estimator():
from gluoncv.auto.estimators import ImageClassificationEstimator
est = ImageClassificationEstimator({'train': {'epochs': 1, 'batch_size': 8}, 'gpus': list(range(get_gpu_count()))})
res = est.fit(IMAGE_CLASS_DATASET)
assert res.get('valid_acc', 0) > 0
test_result = est.predict(IMAGE_CLASS_TEST)

def test_center_net_estimator():
from gluoncv.auto.estimators import CenterNetEstimator
est = CenterNetEstimator({'train': {'epochs': 1, 'batch_size': 8}, 'gpus': list(range(get_gpu_count()))})
res = est.fit(OBJECT_DETCTION_DATASET)
assert res.get('valid_map', 0) > 0
_, _, test_data = OBJECT_DETCTION_DATASET.random_split()
test_result = est.predict(test_data)

def test_ssd_estimator():
from gluoncv.auto.estimators import SSDEstimator
est = SSDEstimator({'train': {'epochs': 1, 'batch_size': 8}, 'gpus': list(range(get_gpu_count()))})
res = est.fit(OBJECT_DETCTION_DATASET)
assert res.get('valid_map', 0) > 0
_, _, test_data = OBJECT_DETCTION_DATASET.random_split()
test_result = est.predict(test_data)

def test_yolo3_estimator():
from gluoncv.auto.estimators import YOLOv3Estimator
est = YOLOv3Estimator({'train': {'epochs': 1, 'batch_size': 8}, 'gpus': list(range(get_gpu_count()))})
res = est.fit(OBJECT_DETCTION_DATASET)
assert res.get('valid_map', 0) > 0
_, _, test_data = OBJECT_DETCTION_DATASET.random_split()
test_result = est.predict(test_data)

def test_frcnn_estimator():
from gluoncv.auto.estimators import FasterRCNNEstimator
est = FasterRCNNEstimator({'train': {'epochs': 1}, 'gpus': list(range(get_gpu_count()))})
res = est.fit(OBJECT_DETCTION_DATASET)
assert res.get('valid_map', 0) > 0
_, _, test_data = OBJECT_DETCTION_DATASET.random_split()
test_result = est.predict(test_data)

if __name__ == '__main__':
import nose
Expand Down
6 changes: 5 additions & 1 deletion tests/auto/test_auto_tasks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from gluoncv.auto.tasks import ImageClassification
from gluoncv.auto.tasks import ObjectDetection

IMAGE_CLASS_DATASET, _, _ = ImageClassification.Dataset.from_folders('https://autogluon.s3.amazonaws.com/datasets/shopee-iet.zip')
IMAGE_CLASS_DATASET, _, IMAGE_CLASS_TEST = ImageClassification.Dataset.from_folders(
'https://autogluon.s3.amazonaws.com/datasets/shopee-iet.zip')
OBJECT_DETCTION_DATASET = ObjectDetection.Dataset.from_voc('https://autogluon.s3.amazonaws.com/datasets/tiny_motorbike.zip')

def test_image_classification():
from gluoncv.auto.tasks import ImageClassification
task = ImageClassification({'num_trials': 1})
classifier = task.fit(IMAGE_CLASS_DATASET)
assert task.fit_summary.get('valid_acc', 0) > 0
test_result = classifier.predict(IMAGE_CLASS_TEST)

def test_center_net_estimator():
from gluoncv.auto.tasks import ObjectDetection
task = ObjectDetection({'num_trials': 1})
detector = task.fit(OBJECT_DETCTION_DATASET)
assert task.fit_summary.get('valid_map', 0) > 0
_, _, test_data = OBJECT_DETCTION_DATASET.random_split()
test_result = detector.predict(test_data)

0 comments on commit 6c1a584

Please sign in to comment.