diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 29da818765c..bc12a9e2e31 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -30,9 +30,6 @@ class WeightEntry: transforms: Callable meta: Dict[str, Any] - def state_dict(self, progress: bool) -> Dict[str, Any]: - return load_state_dict_from_url(self.url, progress=progress) - class Weights(Enum): """ @@ -66,7 +63,7 @@ def from_str(cls, value: str) -> "Weights": raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") def state_dict(self, progress: bool) -> OrderedDict: - return self.value.state_dict(progress) + return load_state_dict_from_url(self.url, progress=progress) def __repr__(self): return f"{self.__class__.__name__}.{self._name_}"