Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HotFix] Fix invalid logging in distributed mode #223

Merged
merged 1 commit into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions federatedscope/core/auxiliaries/eunms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ class MODE:
"""

Note:
Currently StrEnum cannot be imported with the environment `sys.version_info < (3, 11)`, so we simply create a
MODE class here.
Currently StrEnum cannot be imported with the environment
`sys.version_info < (3, 11)`, so we simply create a MODE class here.
"""
TRAIN = 'train'
TEST = 'test'
Expand All @@ -24,12 +24,7 @@ class TRIGGER:
@classmethod
def contains(cls, item):
return item in [
"on_fit_start",
"on_epoch_start",
"on_batch_start",
"on_batch_forward",
"on_batch_backward",
"on_batch_end",
"on_epoch_end",
"on_fit_end"
]
"on_fit_start", "on_epoch_start", "on_batch_start",
"on_batch_forward", "on_batch_backward", "on_batch_end",
"on_epoch_end", "on_fit_end"
]
4 changes: 3 additions & 1 deletion federatedscope/core/auxiliaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def update_logger(cfg, clear_before_add=False):
if cfg.outdir == "":
cfg.outdir = os.path.join(os.getcwd(), "exp")
if cfg.expname == "":
cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lstep{cfg.train.local_update_steps}"
cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on" \
f"_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lste" \
f"p{cfg.train.local_update_steps}"
cfg.expname = f"{cfg.expname}_{cfg.expname_tag}"
cfg.outdir = os.path.join(cfg.outdir, cfg.expname)

Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/configs/cfg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def assert_hpo_cfg(cfg):

assert not (cfg.hpo.fedex.use and cfg.federate.use_ss
), "Cannot use secret sharing and FedEx at the same time"
assert cfg.train.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, "SGD is required if FedEx is considered"
assert cfg.train.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, \
"SGD is required if FedEx is considered"
assert cfg.hpo.fedex.sched in [
'adaptive', 'aggressive', 'auto', 'constant', 'scale'
], "schedule of FedEx must be choice from {}".format(
Expand Down
25 changes: 13 additions & 12 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def extend_training_cfg(cfg):

cfg.trainer.type = 'general'

# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
# Training related options
# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
cfg.train = CN()

cfg.train.local_update_steps = 1
Expand All @@ -22,9 +22,9 @@ def extend_training_cfg(cfg):
cfg.train.optimizer.type = 'SGD'
cfg.train.optimizer.lr = 0.1

# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
# Finetune related options
# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
cfg.finetune = CN()

cfg.finetune.before_eval = False
Expand All @@ -42,7 +42,7 @@ def extend_training_cfg(cfg):
cfg.grad = CN()
cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad

# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
# Early stopping related options
# ---------------------------------------------------------------------- #
cfg.early_stop = CN()
Expand All @@ -66,13 +66,13 @@ def extend_training_cfg(cfg):
def assert_training_cfg(cfg):
if cfg.train.batch_or_epoch not in ['batch', 'epoch']:
raise ValueError(
"Value of 'cfg.train.batch_or_epoch' must be chosen from ['batch', 'epoch']."
)
"Value of 'cfg.train.batch_or_epoch' must be chosen from ["
"'batch', 'epoch'].")

if cfg.finetune.batch_or_epoch not in ['batch', 'epoch']:
raise ValueError(
"Value of 'cfg.finetune.batch_or_epoch' must be chosen from ['batch', 'epoch']."
)
"Value of 'cfg.finetune.batch_or_epoch' must be chosen from ["
"'batch', 'epoch'].")

# TODO: should not be here?
if cfg.backend not in ['torch', 'tensorflow']:
Expand All @@ -87,10 +87,11 @@ def assert_training_cfg(cfg):
raise ValueError(
"We only support run with cpu when backend is tensorflow")

if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps <= 0:
if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps\
<= 0:
raise ValueError(
f"When adopting fine-tuning, please set a valid local fine-tune steps, got {cfg.finetune.local_update_steps}"
)
f"When adopting fine-tuning, please set a valid local fine-tune "
f"steps, got {cfg.finetune.local_update_steps}")


register_config("fl_training", extend_training_cfg)
34 changes: 19 additions & 15 deletions federatedscope/core/trainers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging

from federatedscope.core.auxiliaries.criterion_builder import get_criterion
from federatedscope.core.auxiliaries.model_builder import get_trainable_para_names
from federatedscope.core.auxiliaries.model_builder import \
get_trainable_para_names
from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
from federatedscope.core.auxiliaries.eunms import MODE

Expand Down Expand Up @@ -125,9 +126,11 @@ def setup_vars(self):

# Process training data
if self.train_data is not None or self.train_loader is not None:
# Calculate the number of update steps during training given the local_update_steps
num_train_batch, num_train_batch_last_epoch, num_train_epoch, num_total_train_batch = self.pre_calculate_batch_epoch_num(
self.cfg.train.local_update_steps)
# Calculate the number of update steps during training given the
# local_update_steps
num_train_batch, num_train_batch_last_epoch, num_train_epoch, \
num_total_train_batch = self.pre_calculate_batch_epoch_num(
self.cfg.train.local_update_steps)

self.num_train_epoch = num_train_epoch
self.num_train_batch = num_train_batch
Expand All @@ -148,9 +151,9 @@ def setup_vars(self):
self.cfg.data.batch_size)))

def pre_calculate_batch_epoch_num(self, local_update_steps):
num_train_batch = self.num_train_data // self.cfg.data.batch_size + int(
not self.cfg.data.drop_last
and bool(self.num_train_data % self.cfg.data.batch_size))
num_train_batch = self.num_train_data // self.cfg.data.batch_size + \
int(not self.cfg.data.drop_last and bool(
self.num_train_data % self.cfg.data.batch_size))
if self.cfg.train.batch_or_epoch == "epoch":
num_train_epoch = local_update_steps
num_train_batch_last_epoch = num_train_batch
Expand Down Expand Up @@ -199,18 +202,19 @@ def reset_used_dataset(self):
len(self.cur_data_splits_used_by_routine) != 0 else None

def check_data_split(self, target_data_split_name, skip=False):
if self.get(
f"{target_data_split_name}_data") is None and self.get(
f"{target_data_split_name}_loader") is None:
if self.get(f"{target_data_split_name}_data") is None and self.get(
f"{target_data_split_name}_loader") is None:
if skip:
logger.warning(
f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer, will skip evaluation"
f"If this is not the case you want, please check whether there is typo for the name"
)
f"No {target_data_split_name}_data or"
f" {target_data_split_name}_loader in the trainer, "
f"will skip evaluation"
f"If this is not the case you want, please check "
f"whether there is typo for the name")
return False
else:
raise ValueError(
f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer"
)
f"No {target_data_split_name}_data or"
f" {target_data_split_name}_loader in the trainer")
else:
return True
23 changes: 11 additions & 12 deletions federatedscope/core/trainers/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,20 @@ def register_default_hooks_train(self):
self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")

def register_default_hooks_ft(self):
self.register_hook_in_ft(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_ft(
self._hook_on_fit_start_calculate_model_size, "on_fit_start")
self.register_hook_in_ft(self._hook_on_epoch_start,
"on_epoch_start")
self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start")
self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size,
"on_fit_start")
self.register_hook_in_ft(self._hook_on_epoch_start, "on_epoch_start")
self.register_hook_in_ft(self._hook_on_batch_start_init,
"on_batch_start")
"on_batch_start")
self.register_hook_in_ft(self._hook_on_batch_forward,
"on_batch_forward")
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_forward_regularizer,
"on_batch_forward")
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_forward_flop_count,
"on_batch_forward")
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_backward,
"on_batch_backward")
"on_batch_backward")
self.register_hook_in_ft(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_ft(self._hook_on_fit_end, "on_fit_end")

Expand All @@ -134,7 +132,8 @@ def _hook_on_fit_start_init(self, ctx):
ctx.model.to(ctx.device)

if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]:
# Initialize optimizer here to avoid the reuse of optimizers across different routines
# Initialize optimizer here to avoid the reuse of optimizers
# across different routines
ctx.optimizer = get_optimizer(ctx.model,
**ctx.cfg[ctx.cur_mode].optimizer)

Expand Down
9 changes: 5 additions & 4 deletions federatedscope/core/trainers/trainer_Ditto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ def wrap_DittoTrainer(
init_Ditto_ctx(base_trainer)

# ---------------- action-level plug-in -----------------------
base_trainer.register_hook_in_train(
new_hook=_hook_on_fit_start_del_opt,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=_hook_on_fit_start_del_opt,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(
new_hook=hook_on_fit_start_set_regularized_para,
trigger="on_fit_start",
Expand Down Expand Up @@ -114,10 +113,12 @@ def hook_on_fit_start_set_regularized_para(ctx):
ctx.optimizer_for_local_model.set_compared_para_group(
compared_global_model_para)


def _hook_on_fit_start_del_opt(ctx):
# remove the unnecessary optimizer
del ctx.optimizer


def _hook_on_batch_end_flop_count(ctx):
# besides the normal forward flops, the regularization adds the cost of
# number of model parameters
Expand Down
9 changes: 6 additions & 3 deletions federatedscope/core/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,11 @@ def callback_funcs_for_join_in_info(self, message: Message):
for requirement in requirements:
if requirement.lower() == 'num_sample':
if self._cfg.train.batch_or_epoch == 'batch':
num_sample = self._cfg.train.local_update_steps * self._cfg.data.batch_size
num_sample = self._cfg.train.local_update_steps * \
self._cfg.data.batch_size
else:
num_sample = self._cfg.train.local_update_steps * self.trainer.ctx.num_train_batch
num_sample = self._cfg.train.local_update_steps * \
self.trainer.ctx.num_train_batch
join_in_info['num_sample'] = num_sample
else:
raise ValueError(
Expand Down Expand Up @@ -364,7 +366,8 @@ def callback_funcs_for_evaluate(self, message: Message):
self._monitor.format_eval_res(eval_metrics,
rnd=self.state,
role='Client #{}'.format(
self.ID)))
self.ID),
return_raw=True))

metrics.update(**eval_metrics)

Expand Down