Skip to content

Commit

Permalink
Adds v3-cls model
Browse files Browse the repository at this point in the history
  • Loading branch information
LashaO committed Jan 30, 2024
1 parent 54e18f9 commit ebf48c7
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 63 deletions.
43 changes: 0 additions & 43 deletions dasd

This file was deleted.

48 changes: 36 additions & 12 deletions scoutbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def fetch(pull=False, config=None):
Raises:
AssertionError: If any model cannot be fetched.
"""
if config == 'v3':
if config in ['v3', 'v3-cls']:
loc.fetch(pull=pull, config=config)
else:
wic.fetch(pull=pull, config=None)
Expand Down Expand Up @@ -191,8 +191,16 @@ def pipeline(

def pipeline_v3(
filepath,
config,
batched_detection_model=None,
loc_thresh=0.45
loc_thresh=0.45,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
perform_standard_pred=False,
postprocess_class_agnostic=True

):
"""
Run the ML pipeline on a given image filepath and return the detections
Expand All @@ -215,7 +223,7 @@ def pipeline_v3(
# Run Localizer

if batched_detection_model is None:
yolov8_model_path = loc.fetch(config='v3')
yolov8_model_path = loc.fetch(config=config)

batched_detection_model = tile_batched.Yolov8DetectionModel(
model_path=yolov8_model_path,
Expand All @@ -226,12 +234,12 @@ def pipeline_v3(
det_result = tile_batched.get_sliced_prediction_batched(
cv2.imread(filepath),
batched_detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
perform_standard_pred=False,
postprocess_class_agnostic=True
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
perform_standard_pred=perform_standard_pred,
postprocess_class_agnostic=postprocess_class_agnostic
)

# Postprocess detections for WIC
Expand Down Expand Up @@ -407,9 +415,16 @@ def batch(

def batch_v3(
filepaths,
loc_thresh=0.45
config,
loc_thresh=0.45,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
perform_standard_pred=False,
postprocess_class_agnostic=True
):
yolov8_model_path = loc.fetch(config='v3')
yolov8_model_path = loc.fetch(config=config)

batched_detection_model = tile_batched.Yolov8DetectionModel(
model_path=yolov8_model_path,
Expand All @@ -420,7 +435,16 @@ def batch_v3(
wic_list = []
detects_list = []
for filepath in filepaths:
wic_, detects = pipeline_v3(filepath, batched_detection_model)
wic_, detects = pipeline_v3(filepath,
batched_detection_model,
loc_thresh=loc_thresh,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
perform_standard_pred=perform_standard_pred,
postprocess_class_agnostic=postprocess_class_agnostic
)
wic_list.append(wic_)
detects_list.append(detects)

Expand Down
21 changes: 20 additions & 1 deletion scoutbot/loc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,28 @@
],
},
'v3': {
'hash': None, # 9e001aa3c10d05ba8a269103d3d358ceeb7d6f3bcc5758c1be4405ff743e0e90 #'46cbbccf922552703a1fe8a756544e43'
'hash': None,
'name': 'yolov8.kaza.pt',
'path': join(PWD, 'models', 'yolo', 'yolov8.kaza.pt'),
'thresh': 0.45,
'slice_height': 512,
'slice_width': 512,
'overlap_height_ratio': 0.25,
'overlap_width_ratio': 0.25,
'perform_standard_pred': False,
'postprocess_class_agnostic': True
},
'v3-cls': {
'hash': None,
'name': 'yolov8-cls.kaza.pt',
'path': join(PWD, 'models', 'yolo', 'yolov8-cls.kaza.pt'),
'thresh': 0.45,
'slice_height': 512,
'slice_width': 512,
'overlap_height_ratio': 0.25,
'overlap_width_ratio': 0.25,
'perform_standard_pred': False,
'postprocess_class_agnostic': True
}
}
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
Expand Down
30 changes: 24 additions & 6 deletions scoutbot/scoutbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def pipeline_filepath_validator(ctx, param, value):
'--config',
help='Which ML models to use for inference',
default=None,
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3']),
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3', 'v3-cls']),
)
def fetch(config):
"""
Expand All @@ -45,7 +45,7 @@ def fetch(config):
'--config',
help='Which ML models to use for inference',
default=None,
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3']),
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3', 'v3-cls']),
)
@click.option(
'--output',
Expand Down Expand Up @@ -124,9 +124,18 @@ def pipeline(
agg_thresh /= 100.0
agg_nms_thresh /= 100.0

if config == 'v3':
if config in ['v3', 'v3-cls']:
wic_, detects = scoutbot.pipeline_v3(
filepath
filepath,
config,
loc_thresh=loc.CONFIGS[config]['thresh'],
slice_height=loc.CONFIGS[config]['slice_height'],
slice_width=loc.CONFIGS[config]['slice_width'],
overlap_height_ratio=loc.CONFIGS[config]['overlap_height_ratio'],
overlap_width_ratio=loc.CONFIGS[config]['overlap_width_ratio'],
perform_standard_pred=loc.CONFIGS[config]['perform_standard_pred'],
postprocess_class_agnostic=loc.CONFIGS[config]['postprocess_class_agnostic']

)
else:
wic_, detects = scoutbot.pipeline(
Expand Down Expand Up @@ -164,7 +173,7 @@ def pipeline(
'--config',
help='Which ML models to use for inference',
default=None,
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3']),
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3', 'v3-cls']),
)
@click.option(
'--output',
Expand Down Expand Up @@ -260,9 +269,18 @@ def batch(

log.debug(f'Running batch on {len(filepaths)} files...')

if config == 'v3':
if config in ['v3', 'v3-cls']:
wic_list, detects_list = scoutbot.batch_v3(
filepaths,
config,
loc_thresh=loc.CONFIGS[config]['thresh'],
slice_height=loc.CONFIGS[config]['slice_height'],
slice_width=loc.CONFIGS[config]['slice_width'],
overlap_height_ratio=loc.CONFIGS[config]['overlap_height_ratio'],
overlap_width_ratio=loc.CONFIGS[config]['overlap_width_ratio'],
perform_standard_pred=loc.CONFIGS[config]['perform_standard_pred'],
postprocess_class_agnostic=loc.CONFIGS[config]['postprocess_class_agnostic']

)
else:
wic_list, detects_list = scoutbot.batch(
Expand Down
2 changes: 1 addition & 1 deletion scoutbot/tile_batched/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __len__(self):


def slice_image(
image: Union[str],
image: Union[str, np.ndarray],
slice_height: Optional[int] = None,
slice_width: Optional[int] = None,
overlap_height_ratio: float = 0.2,
Expand Down

0 comments on commit ebf48c7

Please sign in to comment.