Skip to content

Commit

Permalink
Fix val.py Ensemble() (#7490)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Apr 20, 2022
1 parent ab5b917 commit 3f3852e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
return model[-1] # return model
else:
print(f'Ensemble created with {weights}\n')
for k in ['names']:
setattr(model, k, getattr(model[-1], k))
for k in 'names', 'nc', 'yaml':
setattr(model, k, getattr(model[0], k))
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
return model # return ensemble
2 changes: 1 addition & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def run(
# Dataloader
if not training:
if pt and not single_cls: # check --weights are trained on --data
ncm = model.model.yaml['nc']
ncm = model.model.nc
assert ncm == nc, f'{weights[0]} ({ncm} classes) trained on different --data than what you passed ({nc} ' \
f'classes). Pass correct combination of --weights and --data that are trained together.'
model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
Expand Down

0 comments on commit 3f3852e

Please sign in to comment.