Skip to content

Commit

Permalink
pt: fix multitask print_summary (#3409)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Mar 4, 2024
1 parent 945f1b5 commit 4454811
Showing 1 changed file with 37 additions and 31 deletions.
68 changes: 37 additions & 31 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
f"training in {model_key}",
to_numpy_array(self.training_dataloader[model_key].sampler.weights),
)
if validation_data is not None:
if (
validation_data is not None
and validation_data[model_key] is not None
):
validation_data[model_key].print_summary(
f"validation in {model_key}",
to_numpy_array(
Expand Down Expand Up @@ -723,7 +726,7 @@ def log_loss_valid(_task_key="Default"):
)
if input_dict == {}:
# no validation data
return "", None
return {}
_, loss, more_loss = self.wrapper(
**input_dict,
cur_lr=pref_lr,
Expand All @@ -744,23 +747,24 @@ def log_loss_valid(_task_key="Default"):
if not self.multi_task:
train_results = log_loss_train(loss, more_loss)
valid_results = log_loss_valid()
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results is not None:
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
Expand All @@ -783,33 +787,35 @@ def log_loss_valid(_task_key="Default"):
loss, more_loss, _task_key=_key
)
valid_results[_key] = log_loss_valid(_task_key=_key)
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None:
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None and valid_results[_key]:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
)
)

current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
if self.rank == 0:
log.info(
format_training_message(
batch=_step_id,
wall_time=train_time,
)
)
)

if fout:
if self.lcurve_should_print_header:
Expand Down

0 comments on commit 4454811

Please sign in to comment.