Skip to content

Commit

Permalink
format (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Jul 11, 2022
1 parent 8acc7b6 commit b3a2061
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 59 deletions.
17 changes: 6 additions & 11 deletions federatedscope/core/auxiliaries/eunms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ class MODE:
"""
Note:
Currently StrEnum cannot be imported with the environment `sys.version_info < (3, 11)`, so we simply create a
MODE class here.
Currently StrEnum cannot be imported with the environment
`sys.version_info < (3, 11)`, so we simply create a MODE class here.
"""
TRAIN = 'train'
TEST = 'test'
Expand All @@ -24,12 +24,7 @@ class TRIGGER:
@classmethod
def contains(cls, item):
return item in [
"on_fit_start",
"on_epoch_start",
"on_batch_start",
"on_batch_forward",
"on_batch_backward",
"on_batch_end",
"on_epoch_end",
"on_fit_end"
]
"on_fit_start", "on_epoch_start", "on_batch_start",
"on_batch_forward", "on_batch_backward", "on_batch_end",
"on_epoch_end", "on_fit_end"
]
4 changes: 3 additions & 1 deletion federatedscope/core/auxiliaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def update_logger(cfg, clear_before_add=False):
if cfg.outdir == "":
cfg.outdir = os.path.join(os.getcwd(), "exp")
if cfg.expname == "":
cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lstep{cfg.train.local_update_steps}"
cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on" \
f"_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lste" \
f"p{cfg.train.local_update_steps}"
cfg.expname = f"{cfg.expname}_{cfg.expname_tag}"
cfg.outdir = os.path.join(cfg.outdir, cfg.expname)

Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/configs/cfg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def assert_hpo_cfg(cfg):

assert not (cfg.hpo.fedex.use and cfg.federate.use_ss
), "Cannot use secret sharing and FedEx at the same time"
assert cfg.train.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, "SGD is required if FedEx is considered"
assert cfg.train.optimizer.type == 'SGD' or not cfg.hpo.fedex.use, \
"SGD is required if FedEx is considered"
assert cfg.hpo.fedex.sched in [
'adaptive', 'aggressive', 'auto', 'constant', 'scale'
], "schedule of FedEx must be choice from {}".format(
Expand Down
25 changes: 13 additions & 12 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def extend_training_cfg(cfg):

cfg.trainer.type = 'general'

# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
# Training related options
# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
cfg.train = CN()

cfg.train.local_update_steps = 1
Expand All @@ -22,9 +22,9 @@ def extend_training_cfg(cfg):
cfg.train.optimizer.type = 'SGD'
cfg.train.optimizer.lr = 0.1

# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
# Finetune related options
# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
cfg.finetune = CN()

cfg.finetune.before_eval = False
Expand All @@ -42,7 +42,7 @@ def extend_training_cfg(cfg):
cfg.grad = CN()
cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad

# ------------------------------------------------------------------------ #
# ---------------------------------------------------------------------- #
# Early stopping related options
# ---------------------------------------------------------------------- #
cfg.early_stop = CN()
Expand All @@ -66,13 +66,13 @@ def extend_training_cfg(cfg):
def assert_training_cfg(cfg):
if cfg.train.batch_or_epoch not in ['batch', 'epoch']:
raise ValueError(
"Value of 'cfg.train.batch_or_epoch' must be chosen from ['batch', 'epoch']."
)
"Value of 'cfg.train.batch_or_epoch' must be chosen from ["
"'batch', 'epoch'].")

if cfg.finetune.batch_or_epoch not in ['batch', 'epoch']:
raise ValueError(
"Value of 'cfg.finetune.batch_or_epoch' must be chosen from ['batch', 'epoch']."
)
"Value of 'cfg.finetune.batch_or_epoch' must be chosen from ["
"'batch', 'epoch'].")

# TODO: should not be here?
if cfg.backend not in ['torch', 'tensorflow']:
Expand All @@ -87,10 +87,11 @@ def assert_training_cfg(cfg):
raise ValueError(
"We only support run with cpu when backend is tensorflow")

if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps <= 0:
if cfg.finetune.before_eval is False and cfg.finetune.local_update_steps\
<= 0:
raise ValueError(
f"When adopting fine-tuning, please set a valid local fine-tune steps, got {cfg.finetune.local_update_steps}"
)
f"When adopting fine-tuning, please set a valid local fine-tune "
f"steps, got {cfg.finetune.local_update_steps}")


register_config("fl_training", extend_training_cfg)
34 changes: 19 additions & 15 deletions federatedscope/core/trainers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging

from federatedscope.core.auxiliaries.criterion_builder import get_criterion
from federatedscope.core.auxiliaries.model_builder import get_trainable_para_names
from federatedscope.core.auxiliaries.model_builder import \
get_trainable_para_names
from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
from federatedscope.core.auxiliaries.eunms import MODE

Expand Down Expand Up @@ -125,9 +126,11 @@ def setup_vars(self):

# Process training data
if self.train_data is not None or self.train_loader is not None:
# Calculate the number of update steps during training given the local_update_steps
num_train_batch, num_train_batch_last_epoch, num_train_epoch, num_total_train_batch = self.pre_calculate_batch_epoch_num(
self.cfg.train.local_update_steps)
# Calculate the number of update steps during training given the
# local_update_steps
num_train_batch, num_train_batch_last_epoch, num_train_epoch, \
num_total_train_batch = self.pre_calculate_batch_epoch_num(
self.cfg.train.local_update_steps)

self.num_train_epoch = num_train_epoch
self.num_train_batch = num_train_batch
Expand All @@ -148,9 +151,9 @@ def setup_vars(self):
self.cfg.data.batch_size)))

def pre_calculate_batch_epoch_num(self, local_update_steps):
num_train_batch = self.num_train_data // self.cfg.data.batch_size + int(
not self.cfg.data.drop_last
and bool(self.num_train_data % self.cfg.data.batch_size))
num_train_batch = self.num_train_data // self.cfg.data.batch_size + \
int(not self.cfg.data.drop_last and bool(
self.num_train_data % self.cfg.data.batch_size))
if self.cfg.train.batch_or_epoch == "epoch":
num_train_epoch = local_update_steps
num_train_batch_last_epoch = num_train_batch
Expand Down Expand Up @@ -199,18 +202,19 @@ def reset_used_dataset(self):
len(self.cur_data_splits_used_by_routine) != 0 else None

def check_data_split(self, target_data_split_name, skip=False):
if self.get(
f"{target_data_split_name}_data") is None and self.get(
f"{target_data_split_name}_loader") is None:
if self.get(f"{target_data_split_name}_data") is None and self.get(
f"{target_data_split_name}_loader") is None:
if skip:
logger.warning(
f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer, will skip evaluation"
f"If this is not the case you want, please check whether there is typo for the name"
)
f"No {target_data_split_name}_data or"
f" {target_data_split_name}_loader in the trainer, "
f"will skip evaluation"
f"If this is not the case you want, please check "
f"whether there is typo for the name")
return False
else:
raise ValueError(
f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer"
)
f"No {target_data_split_name}_data or"
f" {target_data_split_name}_loader in the trainer")
else:
return True
23 changes: 11 additions & 12 deletions federatedscope/core/trainers/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,20 @@ def register_default_hooks_train(self):
self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")

def register_default_hooks_ft(self):
self.register_hook_in_ft(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_ft(
self._hook_on_fit_start_calculate_model_size, "on_fit_start")
self.register_hook_in_ft(self._hook_on_epoch_start,
"on_epoch_start")
self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start")
self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size,
"on_fit_start")
self.register_hook_in_ft(self._hook_on_epoch_start, "on_epoch_start")
self.register_hook_in_ft(self._hook_on_batch_start_init,
"on_batch_start")
"on_batch_start")
self.register_hook_in_ft(self._hook_on_batch_forward,
"on_batch_forward")
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_forward_regularizer,
"on_batch_forward")
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_forward_flop_count,
"on_batch_forward")
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_backward,
"on_batch_backward")
"on_batch_backward")
self.register_hook_in_ft(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_ft(self._hook_on_fit_end, "on_fit_end")

Expand All @@ -134,7 +132,8 @@ def _hook_on_fit_start_init(self, ctx):
ctx.model.to(ctx.device)

if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]:
# Initialize optimizer here to avoid the reuse of optimizers across different routines
# Initialize optimizer here to avoid the reuse of optimizers
# across different routines
ctx.optimizer = get_optimizer(ctx.model,
**ctx.cfg[ctx.cur_mode].optimizer)

Expand Down
9 changes: 5 additions & 4 deletions federatedscope/core/trainers/trainer_Ditto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ def wrap_DittoTrainer(
init_Ditto_ctx(base_trainer)

# ---------------- action-level plug-in -----------------------
base_trainer.register_hook_in_train(
new_hook=_hook_on_fit_start_del_opt,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=_hook_on_fit_start_del_opt,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(
new_hook=hook_on_fit_start_set_regularized_para,
trigger="on_fit_start",
Expand Down Expand Up @@ -114,10 +113,12 @@ def hook_on_fit_start_set_regularized_para(ctx):
ctx.optimizer_for_local_model.set_compared_para_group(
compared_global_model_para)


def _hook_on_fit_start_del_opt(ctx):
# remove the unnecessary optimizer
del ctx.optimizer


def _hook_on_batch_end_flop_count(ctx):
# besides the normal forward flops, the regularization adds the cost of
# number of model parameters
Expand Down
9 changes: 6 additions & 3 deletions federatedscope/core/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,11 @@ def callback_funcs_for_join_in_info(self, message: Message):
for requirement in requirements:
if requirement.lower() == 'num_sample':
if self._cfg.train.batch_or_epoch == 'batch':
num_sample = self._cfg.train.local_update_steps * self._cfg.data.batch_size
num_sample = self._cfg.train.local_update_steps * \
self._cfg.data.batch_size
else:
num_sample = self._cfg.train.local_update_steps * self.trainer.ctx.num_train_batch
num_sample = self._cfg.train.local_update_steps * \
self.trainer.ctx.num_train_batch
join_in_info['num_sample'] = num_sample
else:
raise ValueError(
Expand Down Expand Up @@ -364,7 +366,8 @@ def callback_funcs_for_evaluate(self, message: Message):
self._monitor.format_eval_res(eval_metrics,
rnd=self.state,
role='Client #{}'.format(
self.ID)))
self.ID),
return_raw=True))

metrics.update(**eval_metrics)

Expand Down

0 comments on commit b3a2061

Please sign in to comment.