Skip to content

Commit

Permalink
Merge pull request #294 from Manojkumarmuru/efficientdet
Browse files Browse the repository at this point in the history
Efficientdet
  • Loading branch information
oarriaga authored Jul 4, 2023
2 parents 0d0db2f + 5ab3294 commit fa34803
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 229 deletions.
40 changes: 0 additions & 40 deletions examples/object_detection/boxes.py

This file was deleted.

3 changes: 1 addition & 2 deletions examples/object_detection/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from paz.datasets import VOC
from paz.abstract import Processor, SequentialProcessor
from paz import processors as pr
from detection import AugmentDetection
# from paz.pipelines import AugmentDetection
from paz.pipelines import AugmentDetection


class ShowBoxes(Processor):
Expand Down
113 changes: 0 additions & 113 deletions examples/object_detection/detection.py

This file was deleted.

22 changes: 0 additions & 22 deletions examples/object_detection/processors.py

This file was deleted.

43 changes: 0 additions & 43 deletions examples/object_detection/test.py

This file was deleted.

17 changes: 14 additions & 3 deletions examples/object_detection/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import copy
import argparse
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
Expand All @@ -11,7 +12,7 @@
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint
from paz.optimization.callbacks import LearningRateScheduler
from detection import AugmentDetection
from paz.pipelines import AugmentDetection
from paz.models import SSD300
from paz.datasets import VOC
from paz.optimization import MultiBoxLoss
Expand Down Expand Up @@ -97,19 +98,29 @@
checkpoint = ModelCheckpoint(save_path, verbose=1, save_weights_only=True)
schedule = LearningRateScheduler(
args.learning_rate, args.gamma_decay, args.scheduled_epochs)
evaluate = EvaluateMAP(
evaluate_test = EvaluateMAP(
evaluation_data_managers[0],
DetectSingleShot(model, data_managers[0].class_names, 0.01, 0.45),
args.evaluation_period,
args.save_path,
args.AP_IOU)

data_manager = VOC(args.data_path, ['trainval', 'trainval'],
name=['VOC2007', 'VOC2012'], evaluate=True)
class_names = copy.deepcopy(data_manager.class_names)
evaluate_trainval = EvaluateMAP(
data_manager,
DetectSingleShot(model, class_names, 0.01, 0.45),
args.evaluation_period,
args.save_path,
args.AP_IOU)

# training
model.fit(
sequencers[0],
epochs=args.num_epochs,
verbose=1,
callbacks=[checkpoint, log, schedule, evaluate],
callbacks=[checkpoint, log, schedule, evaluate_trainval, evaluate_test],
validation_data=sequencers[1],
use_multiprocessing=args.multiprocessing,
workers=args.workers)
9 changes: 5 additions & 4 deletions paz/models/detection/efficientdet/efficientdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def EFFICIENTDET(image, num_classes, base_weights, head_weights,
model_filename = '-'.join([model_name, str(base_weights),
str(head_weights) + '_weights.hdf5'])

weights_path = get_file(model_filename, WEIGHT_PATH + model_filename,
cache_subdir='paz/models')
print('Loading %s model weights' % weights_path)
model.load_weights(weights_path)
if not ((base_weights is None) and (head_weights is None)):
weights_path = get_file(model_filename, WEIGHT_PATH + model_filename,
cache_subdir='paz/models')
print('Loading %s model weights' % weights_path)
model.load_weights(weights_path)

image_shape = image.shape[1:3].as_list()
model.prior_boxes = build_anchors(
Expand Down
2 changes: 1 addition & 1 deletion paz/pipelines/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, mean=pr.BGR_IMAGENET_MEAN):
super(AugmentBoxes, self).__init__()
self.add(pr.ToImageBoxCoordinates())
self.add(pr.Expand(mean=mean))
self.add(pr.RandomSampleCrop())
self.add(pr.RandomSampleCrop(1.0))
self.add(pr.RandomFlipBoxesLeftRight())
self.add(pr.ToNormalizedBoxCoordinates())

Expand Down
2 changes: 1 addition & 1 deletion paz/pipelines/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self):
super(AugmentImage, self).__init__()
self.add(pr.RandomContrast())
self.add(pr.RandomBrightness())
self.add(pr.RandomSaturation())
self.add(pr.RandomSaturation(0.7))
self.add(pr.RandomHue())


Expand Down

0 comments on commit fa34803

Please sign in to comment.