Skip to content

Commit

Permalink
Weights are now downloaded automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
Manojkumarmuru committed Oct 24, 2022
1 parent bebca37 commit c159e69
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
11 changes: 7 additions & 4 deletions examples/efficientdet/efficientdet_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.utils import get_file

from anchors import build_prior_boxes
from efficientdet_blocks import BiFPN, BoxNet, ClassNet
from efficientnet_model import EfficientNet
from utils import create_multibox_head

WEIGHT_PATH = (
'/home/manummk95/Desktop/efficientdet_working/required/weights/')
'https://github.com/oarriaga/altamira-data/releases/download/v0.16/')


def EfficientDet(num_classes, base_weights, head_weights, input_shape,
Expand Down Expand Up @@ -83,9 +84,11 @@ def EfficientDet(num_classes, base_weights, head_weights, input_shape,

if (((base_weights == 'COCO') and (head_weights == 'COCO')) or
((base_weights == 'COCO') and (head_weights is None))):
weights_path = (WEIGHT_PATH + model_name + '-' +
str(base_weights) + '-' + str(head_weights) +
'_weights.hdf5')
model_filename = (model_name + '-' + str(base_weights) + '-' +
str(head_weights) + '_weights.hdf5')
weights_path = get_file(model_filename, WEIGHT_PATH + model_filename,
cache_subdir='paz/models')
print('Loading %s model weights' % weights_path)
model.load_weights(weights_path)

model.prior_boxes = build_prior_boxes(
Expand Down
10 changes: 8 additions & 2 deletions examples/efficientdet/infer_efficientdet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from paz.backend.image import show_image, write_image
from paz.datasets import get_class_names
from tensorflow.keras.utils import get_file

from detection import DetectSingleShot
from efficientdet import EFFICIENTDETD0
from utils import raw_images

WEIGHT_PATH = (
'https://github.com/oarriaga/altamira-data/releases/download/v0.16/')
WEIGHT_FILE = 'efficientdet-d0-VOC-VOC_weights.hdf5'

if __name__ == "__main__":
model = EFFICIENTDETD0(num_classes=21, base_weights='COCO',
head_weights=None)
model.load_weights("/home/manummk95/Desktop/efficientdet_working/temp/"
"weight/weights.209-3.73.hdf5")
weights_path = get_file(WEIGHT_FILE, WEIGHT_PATH + WEIGHT_FILE,
cache_subdir='paz/models')
model.load_weights(weights_path)
detections = DetectSingleShot(model, get_class_names('VOC'),
0.5, 0.45)(raw_images)
show_image(detections['image'])
Expand Down

0 comments on commit c159e69

Please sign in to comment.