Skip to content

Commit

Permalink
Explicitly remove the stride keys from the checkpoint if they are pre…
Browse files Browse the repository at this point in the history
…sent which should fix the issue with DeciDet checkpoints (#1397)

Co-authored-by: Eugene Khvedchenya <[email protected]>
  • Loading branch information
Louis-Dupont and BloodAxe authored Aug 21, 2023
1 parent d822731 commit 51d61bd
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/super_gradients/training/models/detection_models/yolo_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import math
import warnings
from typing import Union, Type, List, Tuple, Optional
Expand Down Expand Up @@ -590,6 +591,15 @@ def forward(self, x):

def load_state_dict(self, state_dict, strict=True):
try:
keys_dropped_in_sg_320 = {
"stride",
"_head.anchors._stride",
"_head.anchors._anchors",
"_head.anchors._anchor_grid",
"_head._modules_list.14.stride",
}
state_dict = collections.OrderedDict([(k, v) for k, v in state_dict.items() if k not in keys_dropped_in_sg_320])

super().load_state_dict(state_dict, strict)
except RuntimeError as e:
raise RuntimeError(
Expand Down

0 comments on commit 51d61bd

Please sign in to comment.