From c10e9235365bc5b05a6beb806fc9b7520824b2b0 Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Fri, 3 Jan 2025 19:08:20 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8Followed=20`mypy`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/learn/callbacks/defaults.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/learn/callbacks/defaults.py b/core/learn/callbacks/defaults.py index 8c551a2..1ed7bc0 100644 --- a/core/learn/callbacks/defaults.py +++ b/core/learn/callbacks/defaults.py @@ -5,6 +5,7 @@ from torch import Tensor from typing import Any from typing import Dict +from typing import List from typing import Tuple from typing import Optional @@ -65,6 +66,7 @@ def before_summary(self, trainer: ITrainer) -> None: full_states = torch.load(ckpt, weights_only=False, map_location=device) states: tensor_dict_type = full_states["states"] exclude = finetune_config.get("exclude", "") + exclude_names: List[str] if not exclude: exclude_names = [] model.load_state_dict(states)