Skip to content

Commit

Permalink
Fix fine-tuning entries bug when doing restart. (#3616)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Mar 28, 2024
1 parent f16d543 commit 6da8eef
Showing 1 changed file with 41 additions and 37 deletions.
78 changes: 41 additions & 37 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,46 +558,50 @@ def update_single_finetune_params(
]
self.wrapper.load_state_dict(state_dict)

def single_model_finetune(
_model,
_model_params,
_sample_func,
):
old_type_map, new_type_map = (
_model_params["type_map"],
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
)
else:
# need to updated
pass
if finetune_model is not None:

# finetune
if not self.multi_task:
single_model_finetune(
self.model, model_params, self.get_sample_func
)
else:
for model_key in self.model_keys:
if model_key in self.finetune_links:
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
def single_model_finetune(
_model,
_model_params,
_sample_func,
):
old_type_map, new_type_map = (
_model_params["type_map"],
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
)
else:
log.info(f"Model branch {model_key} will resume training.")
# need to updated
pass

# finetune
if not self.multi_task:
single_model_finetune(
self.model, model_params, self.get_sample_func
)
else:
for model_key in self.model_keys:
if model_key in self.finetune_links:
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
)
else:
log.info(
f"Model branch {model_key} will resume training."
)

if init_frz_model is not None:
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
Expand Down

0 comments on commit 6da8eef

Please sign in to comment.