From 75dd470ad854aec655b429db7e30c8e1b92760e3 Mon Sep 17 00:00:00 2001 From: Weirui Kuang <39145382+rayrayraykk@users.noreply.github.com> Date: Fri, 28 Oct 2022 17:05:27 +0800 Subject: [PATCH 01/21] merge action and refactor --- .github/workflows/codecov.yml | 46 ++ .github/workflows/codeql.yml | 39 ++ .github/workflows/pre-commit.yml | 3 +- .github/workflows/test_distribute.yml | 42 ++ README.md | 2 +- doc/source/attack.rst | 6 +- doc/source/autotune.rst | 23 + doc/source/core.rst | 39 +- doc/source/cv.rst | 4 + doc/source/gfl.rst | 4 + doc/source/mf.rst | 4 + doc/source/nlp.rst | 4 + ...xtra_dependencies_torch1.10-application.sh | 2 +- .../attack/auxiliary/poisoning_data.py | 8 +- federatedscope/attack/auxiliary/utils.py | 2 +- federatedscope/autotune/algos.py | 18 +- federatedscope/autotune/fedex/client.py | 3 +- federatedscope/autotune/fedex/server.py | 6 +- federatedscope/autotune/utils.py | 10 +- federatedscope/contrib/metrics/example.py | 3 +- federatedscope/contrib/metrics/poison_acc.py | 3 +- .../contrib/trainer/torch_example.py | 11 +- .../core/auxiliaries/metric_builder.py | 4 +- .../core/auxiliaries/model_builder.py | 8 +- .../core/auxiliaries/trainer_builder.py | 9 +- federatedscope/core/auxiliaries/utils.py | 286 +++-------- federatedscope/core/configs/README.md | 1 - federatedscope/core/configs/cfg_training.py | 1 - federatedscope/core/data/base_data.py | 5 +- federatedscope/core/data/base_translator.py | 2 +- federatedscope/core/data/utils.py | 85 +++- federatedscope/core/fed_runner.py | 457 ++++++++++++++++- federatedscope/core/monitors/early_stopper.py | 12 +- .../core/monitors/metric_calculator.py | 32 +- federatedscope/core/monitors/monitor.py | 51 +- federatedscope/core/trainers/README.md | 461 ++++++++++++++++++ federatedscope/core/trainers/__init__.py | 3 +- federatedscope/core/trainers/base_trainer.py | 10 +- federatedscope/core/trainers/context.py | 237 ++++++--- .../core/{auxiliaries => trainers}/enums.py | 0 federatedscope/core/trainers/tf_trainer.py | 89 +++- federatedscope/core/trainers/torch_trainer.py | 177 +++++-- federatedscope/core/trainers/trainer.py | 92 ++-- federatedscope/core/trainers/trainer_Ditto.py | 131 ++++- federatedscope/core/trainers/trainer_FedEM.py | 101 +++- .../core/trainers/trainer_fedprox.py | 35 +- .../core/trainers/trainer_multi_model.py | 33 +- federatedscope/core/trainers/trainer_nbafl.py | 54 +- .../core/trainers/trainer_pFedMe.py | 88 +++- federatedscope/core/trainers/utils.py | 83 ++++ federatedscope/core/workers/client.py | 15 +- federatedscope/core/workers/server.py | 26 +- federatedscope/cv/dataset/leaf_cv.py | 2 +- .../fedavg_gin_minibatch_on_cikmcup.yaml | 1 - .../isolated_gin_minibatch_on_cikmcup.yaml | 1 - federatedscope/gfl/fedsageplus/worker.py | 6 +- federatedscope/gfl/flitplus/trainer.py | 2 +- federatedscope/gfl/gcflplus/worker.py | 6 +- federatedscope/gfl/trainer/graphtrainer.py | 2 +- federatedscope/gfl/trainer/linktrainer.py | 4 +- federatedscope/gfl/trainer/nodetrainer.py | 5 +- federatedscope/main.py | 12 +- federatedscope/mf/trainer/trainer.py | 3 +- federatedscope/mf/trainer/trainer_sgdmf.py | 2 +- federatedscope/nlp/dataset/leaf_nlp.py | 2 +- federatedscope/nlp/dataset/leaf_synthetic.py | 2 +- federatedscope/nlp/dataset/leaf_twitter.py | 2 +- federatedscope/nlp/trainer/trainer.py | 4 +- federatedscope/tabular/dataloader/toy.py | 11 +- .../vertical_fl/worker/vertical_server.py | 18 +- .../distributed_client_1.yaml | 2 +- .../distributed_client_2.yaml | 2 +- .../distributed_client_3.yaml | 2 +- .../distributed_server.yaml | 2 +- scripts/distributed_scripts/gen_data.py | 2 +- .../distributed_scripts/run_distributed_lr.sh | 3 - .../example_configs/femnist_global_train.yaml | 11 +- setup.py | 16 +- tests/test_CRA_gan_attack.py | 10 +- tests/test_MIA_gradient_ascent.py | 10 +- tests/test_PIA_toy.py | 10 +- tests/test_asyn_cifar10.py | 38 +- tests/test_backdoor_attack.py | 10 +- tests/test_ditto.py | 10 +- tests/test_efficient_simulation.py | 18 +- tests/test_external_dataset.py | 26 +- tests/test_fedem.py | 10 +- tests/test_fedopt.py | 10 +- tests/test_fedprox.py | 10 +- tests/test_fedsageplus.py | 10 +- tests/test_femnist.py | 10 +- tests/test_finetune_lr.py | 10 +- tests/test_global_train_lr.py | 10 +- tests/test_graph_node_trainer.py | 10 +- tests/test_local_train_lr.py | 10 +- tests/test_mf.py | 10 +- tests/test_nbafl.py | 10 +- tests/test_optimizer.py | 10 +- tests/test_pfedme.py | 10 +- tests/test_rec_IG_opt_attack.py | 10 +- tests/test_rec_opt_attack.py | 10 +- tests/test_toy_lr.py | 18 +- tests/test_unseen_clients_lr.py | 10 +- tests/test_vertical_fl.py | 10 +- 104 files changed, 2473 insertions(+), 842 deletions(-) create mode 100644 .github/workflows/codecov.yml create mode 100644 .github/workflows/codeql.yml create mode 100644 .github/workflows/test_distribute.yml create mode 100644 federatedscope/core/trainers/README.md rename federatedscope/core/{auxiliaries => trainers}/enums.py (100%) create mode 100644 federatedscope/core/trainers/utils.py diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml new file mode 100644 index 000000000..37bb48dbb --- /dev/null +++ b/.github/workflows/codecov.yml @@ -0,0 +1,46 @@ +name: Codecov UnitTests + +on: [push, pull_request] + +jobs: + run: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ['3.9'] + torch-version: ['1.10.1'] + torchvision-version: ['0.11.2'] + torchaudio-version: ['0.10.1'] + torchtext-version: ['0.11.1'] + env: + OS: ${{ matrix.os }} + PYTHON: '3.9' + steps: + - uses: actions/checkout@master + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@master + with: + python-version: ${{ matrix.python-version }} + - name: Install PyTorch ${{ matrix.torch-version }}+cpu + run: | + pip install numpy typing-extensions dataclasses + pip install torch==${{ matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}} torchtext==${{matrix.torchtext-version}} -f https://download.pytorch.org/whl/torch_stable.html + - name: Install FS + run: | + pip install -e .[test] + - name: Generate coverage report + run: | + pytest --cov=./ --cov-report=xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + directory: ./coverage/reports/ + env_vars: OS,PYTHON + fail_ci_if_error: true + files: ./coverage1.xml,./coverage2.xml + flags: unittests + name: codecov-umbrella + path_to_write_report: ./coverage/codecov_report.txt + verbose: true \ No newline at end of file diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..8f142525e --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,39 @@ +name: CodeQL (Code Scanning) + +on: + pull_request: + branches: [ "master" ] + schedule: + - cron: '0 8 * * *' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + + # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 2d6edc3b6..aad4a946f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,4 +1,4 @@ -name: pre-commit +name: Pre-commit (Required) on: [push, pull_request] @@ -6,6 +6,7 @@ jobs: run: runs-on: ${{ matrix.os }} strategy: + fail-fast: True matrix: os: [ubuntu-latest] env: diff --git a/.github/workflows/test_distribute.yml b/.github/workflows/test_distribute.yml new file mode 100644 index 000000000..957c241fc --- /dev/null +++ b/.github/workflows/test_distribute.yml @@ -0,0 +1,42 @@ +name: UnitTests for Distributed Mode + +on: [push, pull_request] + +jobs: + run: + runs-on: ${{ matrix.os }} + timeout-minutes: 10 + strategy: + matrix: + os: [ubuntu-latest] + python-version: ['3.9'] + torch-version: ['1.10.1'] + torchvision-version: ['0.11.2'] + torchaudio-version: ['0.10.1'] + env: + OS: ${{ matrix.os }} + PYTHON: '3.9' + steps: + - uses: actions/checkout@master + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@master + with: + python-version: ${{ matrix.python-version }} + - name: Install PyTorch ${{ matrix.torch-version }}+cpu + run: | + pip install numpy typing-extensions dataclasses + pip install torch==${{ matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install FS + run: | + pip install -e .[test] + - name: Test Distributed (LR on toy) + run: | + python scripts/distributed_scripts/gen_data.py + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server.yaml & + sleep 2 + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml & + sleep 2 + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml & + sleep 2 + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml + [ $? -eq 1 ] && exit 1 || echo "Passed" \ No newline at end of file diff --git a/README.md b/README.md index 273b8e694..14f1624c0 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ conda activate fs If your backend is torch, please install torch in advance ([torch-get-started](https://pytorch.org/get-started/locally/)). For example, if your cuda version is 11.3 please execute the following command: ```bash -conda install -y pytorch=1.10.1 torchvision=0.11.2 torchaudio=0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge +conda install -y pytorch=1.10.1 torchvision=0.11.2 torchaudio=0.10.1 torchtext=0.11.1 cudatoolkit=11.3 -c pytorch -c conda-forge ``` For users with Apple M1 chips: diff --git a/doc/source/attack.rst b/doc/source/attack.rst index 730a5019c..fd5f0715e 100644 --- a/doc/source/attack.rst +++ b/doc/source/attack.rst @@ -6,6 +6,7 @@ federatedscope.attack.privacy_attacks .. automodule:: federatedscope.attack.privacy_attacks :members: + :private-members: federatedscope.attack.worker_as_attacker @@ -13,20 +14,21 @@ federatedscope.attack.worker_as_attacker .. automodule:: federatedscope.attack.worker_as_attacker :members: + :private-members: federatedscope.attack.auxiliary -------------------------------- .. automodule:: federatedscope.attack.auxiliary :members: - - + :private-members: federatedscope.attack.trainer --------------------------------- .. automodule:: federatedscope.attack.trainer :members: + :private-members: diff --git a/doc/source/autotune.rst b/doc/source/autotune.rst index e8e1b153c..b74ec10cb 100644 --- a/doc/source/autotune.rst +++ b/doc/source/autotune.rst @@ -6,6 +6,7 @@ federatedscope.autotune.choice_types .. automodule:: federatedscope.autotune.choice_types :members: + :private-members: federatedscope.autotune.algos ----------------------- @@ -13,3 +14,25 @@ federatedscope.autotune.algos .. automodule:: federatedscope.autotune.algos :show-inheritance: :members: + :private-members: + +federatedscope.autotune.hpbandster +----------------------- + +.. automodule:: federatedscope.autotune.hpbandster + :members: + :private-members: + +federatedscope.autotune.smac +----------------------- + +.. automodule:: federatedscope.autotune.smac + :members: + :private-members: + +federatedscope.autotune.utils +----------------------- + +.. automodule:: federatedscope.autotune.utils + :members: + :private-members: diff --git a/doc/source/core.rst b/doc/source/core.rst index 11b9c87e5..9c7c3fae8 100644 --- a/doc/source/core.rst +++ b/doc/source/core.rst @@ -1,33 +1,52 @@ Core Module References ======================= -federatedscope.core.configs +federatedscope.core.fed_runner ----------------------- -.. automodule:: federatedscope.core.configs +.. automodule:: federatedscope.core.fed_runner :members: + :private-members: +federatedscope.core.workers +----------------------- -federatedscope.core.monitors +.. automodule:: federatedscope.core.worker + :members: + :private-members: + +federatedscope.core.trainers ----------------------- -.. automodule:: federatedscope.core.monitors +.. automodule:: federatedscope.core.trainers :members: + :private-members: -federatedscope.core.fed_runner +federatedscope.core.data ----------------------- -.. automodule:: federatedscope.core.fed_runner +.. automodule:: federatedscope.core.data :members: + :private-members: -federatedscope.core.worker +federatedscope.core.splitters ----------------------- -.. automodule:: federatedscope.core.worker +.. automodule:: federatedscope.core.splitters :members: + :private-members: -federatedscope.core.trainers +federatedscope.core.configs ----------------------- -.. automodule:: federatedscope.core.trainers +.. automodule:: federatedscope.core.configs + :members: + :private-members: + + +federatedscope.core.monitors +----------------------- + +.. automodule:: federatedscope.core.monitors :members: + :private-members: diff --git a/doc/source/cv.rst b/doc/source/cv.rst index ef3700bcc..12d1c7e31 100644 --- a/doc/source/cv.rst +++ b/doc/source/cv.rst @@ -6,21 +6,25 @@ federatedscope.cv.dataset .. automodule:: federatedscope.cv.dataset :members: + :private-members: federatedscope.cv.dataloader ----------------------- .. automodule:: federatedscope.cv.dataloader :members: + :private-members: federatedscope.cv.model ----------------------- .. automodule:: federatedscope.cv.model :members: + :private-members: federatedscope.cv.trainer ----------------------- .. automodule:: federatedscope.cv.trainer :members: + :private-members: diff --git a/doc/source/gfl.rst b/doc/source/gfl.rst index 3e7d07e8a..ed8482b4d 100644 --- a/doc/source/gfl.rst +++ b/doc/source/gfl.rst @@ -6,21 +6,25 @@ federatedscope.gfl.dataset .. automodule:: federatedscope.gfl.dataset :members: + :private-members: federatedscope.gfl.dataloader ----------------------- .. automodule:: federatedscope.gfl.dataloader :members: + :private-members: federatedscope.gfl.model ----------------------- .. automodule:: federatedscope.gfl.model :members: + :private-members: federatedscope.gfl.trainer ----------------------- .. automodule:: federatedscope.gfl.trainer :members: + :private-members: diff --git a/doc/source/mf.rst b/doc/source/mf.rst index 94485783a..01f63ea0b 100644 --- a/doc/source/mf.rst +++ b/doc/source/mf.rst @@ -6,21 +6,25 @@ federatedscope.mf.dataset .. automodule:: federatedscope.mf.dataset :members: + :private-members: federatedscope.mf.model ----------------------- .. automodule:: federatedscope.mf.model :members: + :private-members: federatedscope.mf.dataloader ----------------------- .. automodule:: federatedscope.mf.dataloader :members: + :private-members: federatedscope.mf.trainer ----------------------- .. automodule:: federatedscope.mf.trainer :members: + :private-members: diff --git a/doc/source/nlp.rst b/doc/source/nlp.rst index 26c87fdcb..938505421 100644 --- a/doc/source/nlp.rst +++ b/doc/source/nlp.rst @@ -6,21 +6,25 @@ federatedscope.nlp.dataset .. automodule:: federatedscope.nlp.dataset :members: + :private-members: federatedscope.nlp.dataloader ----------------------- .. automodule:: federatedscope.nlp.dataloader :members: + :private-members: federatedscope.nlp.model ----------------------- .. automodule:: federatedscope.nlp.model :members: + :private-members: federatedscope.nlp.trainer ----------------------- .. automodule:: federatedscope.nlp.trainer :members: + :private-members: diff --git a/environment/extra_dependencies_torch1.10-application.sh b/environment/extra_dependencies_torch1.10-application.sh index 4657afe7e..3e7963f9b 100644 --- a/environment/extra_dependencies_torch1.10-application.sh +++ b/environment/extra_dependencies_torch1.10-application.sh @@ -8,7 +8,7 @@ conda install -y nltk # Speech and NLP conda install -y sentencepiece textgrid typeguard -c conda-forge conda install -y transformers==4.16.2 tokenizers==0.10.3 datasets -c huggingface -c conda-forge -conda install -y torchtext -c pytorch +conda install -y torchtext==0.9.0 -c pytorch # Tabular conda install -y openml==0.12.2 diff --git a/federatedscope/attack/auxiliary/poisoning_data.py b/federatedscope/attack/auxiliary/poisoning_data.py index 0d0a9581e..acab87cae 100644 --- a/federatedscope/attack/auxiliary/poisoning_data.py +++ b/federatedscope/attack/auxiliary/poisoning_data.py @@ -1,16 +1,12 @@ -from re import M import torch from PIL import Image import numpy as np -from torchvision.datasets import MNIST, EMNIST, CIFAR10 -from torchvision.datasets import DatasetFolder from torchvision import transforms from federatedscope.core.auxiliaries.transform_builder import get_transform from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from federatedscope.attack.auxiliary.backdoor_utils import normalize -from federatedscope.core.auxiliaries.enums import MODE -import matplotlib +from federatedscope.core.trainers.enums import MODE import pickle import logging import os diff --git a/federatedscope/attack/auxiliary/utils.py b/federatedscope/attack/auxiliary/utils.py index 932e0f002..7fc7dab29 100644 --- a/federatedscope/attack/auxiliary/utils.py +++ b/federatedscope/attack/auxiliary/utils.py @@ -284,7 +284,7 @@ def get_passive_PIA_auxiliary_dataset(dataset_name): def _generate_data(instance_num=1000, feature_num=5, save_data=False): """ - Generate data in FedRunner format + Generate data in Runner format Args: instance_num: feature_num: diff --git a/federatedscope/autotune/algos.py b/federatedscope/autotune/algos.py index b7f66f156..27469d136 100644 --- a/federatedscope/autotune/algos.py +++ b/federatedscope/autotune/algos.py @@ -13,7 +13,7 @@ from federatedscope.core.auxiliaries.data_builder import get_data from federatedscope.core.auxiliaries.worker_builder import get_client_cls, \ get_server_cls -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.autotune.utils import parse_search_space, \ config2cmdargs, config2str, summarize_hpo_results @@ -26,10 +26,10 @@ def make_trial(trial_cfg): trial_cfg.merge_from_other_cfg(modified_config) trial_cfg.freeze() # TODO: enable client-wise configuration - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(trial_cfg), - client_class=get_client_cls(trial_cfg), - config=trial_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(trial_cfg), + client_class=get_client_cls(trial_cfg), + config=trial_cfg.clone()) results = Fed_runner.run() key1, key2 = trial_cfg.hpo.metric.split('.') return results[key1][key2] @@ -53,10 +53,10 @@ def run(self): self._trial_cfg.merge_from_other_cfg(modified_config) self._trial_cfg.freeze() # TODO: enable client-wise configuration - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(self._trial_cfg), - client_class=get_client_cls(self._trial_cfg), - config=self._trial_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(self._trial_cfg), + client_class=get_client_cls(self._trial_cfg), + config=self._trial_cfg.clone()) results = Fed_runner.run() key1, key2 = self._trial_cfg.hpo.metric.split('.') self._returns['perf'] = results[key1][key2] diff --git a/federatedscope/autotune/fedex/client.py b/federatedscope/autotune/fedex/client.py index 6b3e5c25e..a20b6682f 100644 --- a/federatedscope/autotune/fedex/client.py +++ b/federatedscope/autotune/fedex/client.py @@ -27,8 +27,7 @@ def _apply_hyperparams(self, hyperparams): self._cfg.defrost() self._cfg.merge_from_list(cmd_args, check_cfg=False) self._cfg.freeze(inform=False, check_cfg=False) - - self.trainer.ctx.setup_vars() + self.trainer.cfg = self._cfg def callback_funcs_for_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py index fac5e6488..f693d95b0 100644 --- a/federatedscope/autotune/fedex/server.py +++ b/federatedscope/autotune/fedex/server.py @@ -10,7 +10,7 @@ from federatedscope.core.message import Message from federatedscope.core.workers import Server -from federatedscope.core.auxiliaries.utils import merge_dict +from federatedscope.core.auxiliaries.utils import merge_dict_of_results logger = logging.getLogger(__name__) @@ -377,8 +377,8 @@ def check_and_move_on(self, else: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() - self.history_results = merge_dict(self.history_results, - formatted_eval_res) + self.history_results = merge_dict_of_results( + self.history_results, formatted_eval_res) self.check_and_save() else: move_on_flag = False diff --git a/federatedscope/autotune/utils.py b/federatedscope/autotune/utils.py index cd6523a42..959d9112e 100644 --- a/federatedscope/autotune/utils.py +++ b/federatedscope/autotune/utils.py @@ -139,7 +139,7 @@ def eval_in_fs(cfg, config, budget): from federatedscope.core.auxiliaries.data_builder import get_data from federatedscope.core.auxiliaries.worker_builder import \ get_client_cls, get_server_cls - from federatedscope.core.fed_runner import FedRunner + from federatedscope.core.fed_runner import get_runner from federatedscope.autotune.utils import config2cmdargs from os.path import join as osp @@ -167,10 +167,10 @@ def eval_in_fs(cfg, config, budget): data, modified_config = get_data(config=trial_cfg.clone()) trial_cfg.merge_from_other_cfg(modified_config) trial_cfg.freeze() - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(trial_cfg), - client_class=get_client_cls(trial_cfg), - config=trial_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(trial_cfg), + client_class=get_client_cls(trial_cfg), + config=trial_cfg.clone()) results = Fed_runner.run() key1, key2 = trial_cfg.hpo.metric.split('.') return results[key1][key2] diff --git a/federatedscope/contrib/metrics/example.py b/federatedscope/contrib/metrics/example.py index f162606a9..67f4252d8 100644 --- a/federatedscope/contrib/metrics/example.py +++ b/federatedscope/contrib/metrics/example.py @@ -9,8 +9,9 @@ def MyMetric(ctx, **kwargs): def call_my_metric(types): if METRIC_NAME in types: + the_larger_the_better = True metric_builder = MyMetric - return METRIC_NAME, metric_builder + return METRIC_NAME, metric_builder, the_larger_the_better register_metric(METRIC_NAME, call_my_metric) diff --git a/federatedscope/contrib/metrics/poison_acc.py b/federatedscope/contrib/metrics/poison_acc.py index 0093e9e5d..d919fdca7 100644 --- a/federatedscope/contrib/metrics/poison_acc.py +++ b/federatedscope/contrib/metrics/poison_acc.py @@ -25,7 +25,8 @@ def load_poison_metrics(ctx, y_true, y_pred, y_prob, **kwargs): def call_poison_metric(types): if 'poison_attack_acc' in types: - return 'poison_attack_acc', load_poison_metrics + the_larger_the_better = True + return 'poison_attack_acc', load_poison_metrics, the_larger_the_better register_metric('poison_attack_acc', call_poison_metric) diff --git a/federatedscope/contrib/trainer/torch_example.py b/federatedscope/contrib/trainer/torch_example.py index 18cd5d7a0..09f21b997 100644 --- a/federatedscope/contrib/trainer/torch_example.py +++ b/federatedscope/contrib/trainer/torch_example.py @@ -1,4 +1,3 @@ -import inspect from federatedscope.register import register_trainer from federatedscope.core.trainers import BaseTrainer @@ -84,21 +83,15 @@ def evaluate(self, target_data_split_name='test'): def update(self, model_parameters, strict=False): self.model.load_state_dict(model_parameters, strict) + return self.get_model_para() def get_model_para(self): return self.model.cpu().state_dict() - def print_trainer_meta_info(self): - sign = inspect.signature(self.__init__).parameters.values() - meta_info = tuple([(val.name, getattr(self, val.name)) - for val in sign]) - return f'{self.__class__.__name__}{meta_info}' - def call_my_torch_trainer(trainer_type): if trainer_type == 'mytorchtrainer': - trainer_builder = MyTorchTrainer - return trainer_builder + return MyTorchTrainer register_trainer('mytorchtrainer', call_my_torch_trainer) diff --git a/federatedscope/core/auxiliaries/metric_builder.py b/federatedscope/core/auxiliaries/metric_builder.py index ddbcbed0b..0d825754a 100644 --- a/federatedscope/core/auxiliaries/metric_builder.py +++ b/federatedscope/core/auxiliaries/metric_builder.py @@ -16,6 +16,6 @@ def get_metric(types): for func in register.metric_dict.values(): res = func(types) if res is not None: - name, metric = res - metrics[name] = metric + name, metric, the_larger_the_better = res + metrics[name] = metric, the_larger_the_better return metrics diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 522159153..4bec0c2ff 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -1,8 +1,7 @@ import logging -import numpy as np - import federatedscope.register as register +from federatedscope.core.data.wrap_dataset import WrapDataset logger = logging.getLogger(__name__) @@ -39,10 +38,10 @@ def get_shape_from_data(data, model_config, backend='torch'): if model_config.task.startswith('graph'): # graph-level task data_representative = next(iter(data['train'])) - return (data_representative.x.shape, num_label, num_edge_features) + return data_representative.x.shape, num_label, num_edge_features else: # node/link-level task - return (data['data'].x.shape, num_label, num_edge_features) + return data['data'].x.shape, num_label, num_edge_features if isinstance(data, dict): keys = list(data.keys()) @@ -58,7 +57,6 @@ def get_shape_from_data(data, model_config, backend='torch'): key_representative = keys[0] logger.warning(f'We chose the key {key_representative} as the ' f'representative key to extract data shape.') - data_representative = data[key_representative] else: # Handle the data with non-dict format diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index 41d3ffe89..599291762 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -2,6 +2,7 @@ import importlib import federatedscope.register as register +from federatedscope.core.trainers import Trainer logger = logging.getLogger(__name__) @@ -45,8 +46,7 @@ def get_trainer(model=None, only_for_eval=only_for_eval, monitor=monitor) elif config.backend == 'tensorflow': - from federatedscope.core.trainers.tf_trainer import \ - GeneralTFTrainer + from federatedscope.core.trainers import GeneralTFTrainer trainer = GeneralTFTrainer(model=model, data=data, device=device, @@ -108,6 +108,11 @@ def get_trainer(model=None, raise ValueError('Trainer {} is not provided'.format( config.trainer.type)) + if not isinstance(trainer, Trainer): + logger.warning(f'When using {trainer}, trainer plug-in cannot be ' + f'enabled. Please use {Trainer} instead.') + return trainer + # differential privacy plug-in if config.nbafl.use: from federatedscope.core.trainers import wrap_nbafl_trainer diff --git a/federatedscope/core/auxiliaries/utils.py b/federatedscope/core/auxiliaries/utils.py index 190a96102..0a0d323ec 100644 --- a/federatedscope/core/auxiliaries/utils.py +++ b/federatedscope/core/auxiliaries/utils.py @@ -1,110 +1,51 @@ -import collections -import json import logging import math import os import random import signal -import ssl -import urllib.request -from os import path as osp import pickle import numpy as np -# Blind torch try: import torch - import torchvision - import torch.distributions as distributions except ImportError: torch = None - torchvision = None - distributions = None + +try: + import tensorflow as tf +except ImportError: + tf = None logger = logging.getLogger(__name__) -def setup_seed(seed): - np.random.seed(seed) - random.seed(seed) - if torch is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - else: - import tensorflow as tf - tf.set_random_seed(seed) +# ****** Worker-related utils ****** +class Timeout(object): + def __init__(self, seconds, max_failure=5): + self.seconds = seconds + self.max_failure = max_failure + def __enter__(self): + def signal_handler(signum, frame): + raise TimeoutError() -def get_dataset(type, root, transform, target_transform, download=True): - if isinstance(type, str): - if hasattr(torchvision.datasets, type): - return getattr(torchvision.datasets, - type)(root=root, - transform=transform, - target_transform=target_transform, - download=download) - else: - raise NotImplementedError('Dataset {} not implement'.format(type)) - else: - raise TypeError() - - -def save_local_data(dir_path, - train_data=None, - train_targets=None, - test_data=None, - test_targets=None, - val_data=None, - val_targets=None): - r""" - https://github.com/omarfoq/FedEM/blob/main/data/femnist/generate_data.py - - save (`train_data`, `train_targets`) in {dir_path}/train.pt, - (`val_data`, `val_targets`) in {dir_path}/val.pt - and (`test_data`, `test_targets`) in {dir_path}/test.pt - :param dir_path: - :param train_data: - :param train_targets: - :param test_data: - :param test_targets: - :param val_data: - :param val_targets - """ - if (train_data is not None) and (train_targets is not None): - torch.save((train_data, train_targets), osp.join(dir_path, "train.pt")) + if self.seconds > 0: + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(self.seconds) + return self - if (test_data is not None) and (test_targets is not None): - torch.save((test_data, test_targets), osp.join(dir_path, "test.pt")) + def __exit__(self, exc_type, exc_value, traceback): + signal.alarm(0) - if (val_data is not None) and (val_targets is not None): - torch.save((val_data, val_targets), osp.join(dir_path, "val.pt")) + def reset(self): + signal.alarm(self.seconds) + def block(self): + signal.alarm(0) -def filter_by_specified_keywords(param_name, filter_keywords): - ''' - Arguments: - param_name (str): parameter name. - Returns: - preserve (bool): whether to preserve this parameter. - ''' - preserve = True - for kw in filter_keywords: - if kw in param_name: - preserve = False - break - return preserve - - -def get_random(type, sample_shape, params, device): - if not hasattr(distributions, type): - raise NotImplementedError("Distribution {} is not implemented, " - "please refer to ```torch.distributions```" - "(https://pytorch.org/docs/stable/ " - "distributions.html).".format(type)) - generator = getattr(distributions, type)(**params) - return generator.sample(sample_shape=sample_shape).to(device) + def exceed_max_failure(self, num_failure): + return num_failure > self.max_failure def batch_iter(data, batch_size=64, shuffled=True): @@ -124,73 +65,33 @@ def batch_iter(data, batch_size=64, shuffled=True): yield {'x': data_x[sample_index], 'y': data_y[sample_index]} -def merge_dict(dict1, dict2): - # Merge results for history +def merge_dict_of_results(dict1, dict2): + """ + Merge two ``dict`` according to their keys, and concatenate their value. + + Args: + dict1: ``dict`` to be merged + dict2: ``dict`` to be merged + + Returns: + dict1: Merged ``dict``. + + """ for key, value in dict2.items(): if key not in dict1: if isinstance(value, dict): - dict1[key] = merge_dict({}, value) + dict1[key] = merge_dict_of_results({}, value) else: dict1[key] = [value] else: if isinstance(value, dict): - merge_dict(dict1[key], value) + merge_dict_of_results(dict1[key], value) else: dict1[key].append(value) return dict1 -def download_url(url: str, folder='folder'): - r"""Downloads the content of an url to a folder. - - Modified from `https://github.com/pyg-team/pytorch_geometric/blob/master - /torch_geometric/data/download.py` - - Args: - url (string): The url of target file. - folder (string): The target folder. - - Returns: - path (string): File path of downloaded files. - """ - - file = url.rpartition('/')[2] - file = file if file[0] == '?' else file.split('?')[0] - path = osp.join(folder, file) - if osp.exists(path): - logger.info(f'File {file} exists, use existing file.') - return path - - logger.info(f'Downloading {url}') - os.makedirs(folder, exist_ok=True) - ctx = ssl._create_unverified_context() - data = urllib.request.urlopen(url, context=ctx) - with open(path, 'wb') as f: - f.write(data.read()) - - return path - - -def move_to(obj, device): - import torch - if torch.is_tensor(obj): - return obj.to(device) - elif isinstance(obj, dict): - res = {} - for k, v in obj.items(): - res[k] = move_to(v, device) - return res - elif isinstance(obj, list): - res = [] - for v in obj: - res.append(move_to(v, device)) - return res - else: - raise TypeError("Invalid type for move_to") - - def param2tensor(param): - import torch if isinstance(param, list): param = torch.FloatTensor(param) elif isinstance(param, int): @@ -200,63 +101,10 @@ def param2tensor(param): return param -class Timeout(object): - def __init__(self, seconds, max_failure=5): - self.seconds = seconds - self.max_failure = max_failure - - def __enter__(self): - def signal_handler(signum, frame): - raise TimeoutError() - - if self.seconds > 0: - signal.signal(signal.SIGALRM, signal_handler) - signal.alarm(self.seconds) - return self - - def __exit__(self, exc_type, exc_value, traceback): - signal.alarm(0) - - def reset(self): - signal.alarm(self.seconds) - - def block(self): - signal.alarm(0) - - def exceed_max_failure(self, num_failure): - return num_failure > self.max_failure - - -def format_log_hooks(hooks_set): - def format_dict(target_dict): - print_dict = collections.defaultdict(list) - for k, v in target_dict.items(): - for element in v: - print_dict[k].append(element.__name__) - return print_dict - - if isinstance(hooks_set, list): - print_obj = [format_dict(_) for _ in hooks_set] - elif isinstance(hooks_set, dict): - print_obj = format_dict(hooks_set) - return json.dumps(print_obj, indent=2).replace('\n', '\n\t') - - -def get_resource_info(filename): - if filename is None or not os.path.exists(filename): - logger.info('The device information file is not provided') - return None - - # Users can develop this loading function according to resource_info_file - # As an example, we use the device_info provided by FedScale (FedScale: - # Benchmarking Model and System Performance of Federated Learning - # at Scale), which can be downloaded from - # https://github.com/SymbioticLab/FedScale/blob/master/benchmark/dataset/ - # data/device_info/client_device_capacity The expected format is - # { INDEX:{'computation': FLOAT_VALUE_1, 'communication': FLOAT_VALUE_2}} - with open(filename, 'br') as f: - device_info = pickle.load(f) - return device_info +def merge_param_dict(raw_param, filtered_param): + for key in filtered_param.keys(): + raw_param[key] = filtered_param[key] + return raw_param def calculate_time_cost(instance_number, @@ -278,28 +126,30 @@ def calculate_time_cost(instance_number, return comp_cost, comm_cost -def calculate_batch_epoch_num(steps, batch_or_epoch, num_data, batch_size, - drop_last): - num_batch_per_epoch = num_data // batch_size + int( - not drop_last and bool(num_data % batch_size)) - if num_batch_per_epoch == 0: - raise RuntimeError( - "The number of batch is 0, please check 'batch_size' or set " - "'drop_last' as False") - elif batch_or_epoch == "epoch": - num_epoch = steps - num_batch_last_epoch = num_batch_per_epoch - num_total_batch = steps * num_batch_per_epoch - else: - num_epoch = math.ceil(steps / num_batch_per_epoch) - num_batch_last_epoch = steps % num_batch_per_epoch or \ - num_batch_per_epoch - num_total_batch = steps - return num_batch_per_epoch, num_batch_last_epoch, num_epoch, \ - num_total_batch +# ****** Runner-related utils ****** +def setup_seed(seed): + np.random.seed(seed) + random.seed(seed) + if torch is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + if tf is not None: + tf.set_random_seed(seed) -def merge_param_dict(raw_param, filtered_param): - for key in filtered_param.keys(): - raw_param[key] = filtered_param[key] - return raw_param +def get_resource_info(filename): + if filename is None or not os.path.exists(filename): + logger.info('The device information file is not provided') + return None + + # Users can develop this loading function according to resource_info_file + # As an example, we use the device_info provided by FedScale (FedScale: + # Benchmarking Model and System Performance of Federated Learning + # at Scale), which can be downloaded from + # https://github.com/SymbioticLab/FedScale/blob/master/benchmark/dataset/ + # data/device_info/client_device_capacity The expected format is + # { INDEX:{'computation': FLOAT_VALUE_1, 'communication': FLOAT_VALUE_2}} + with open(filename, 'br') as f: + device_info = pickle.load(f) + return device_info diff --git a/federatedscope/core/configs/README.md b/federatedscope/core/configs/README.md index 815048234..b741cfdc5 100644 --- a/federatedscope/core/configs/README.md +++ b/federatedscope/core/configs/README.md @@ -180,7 +180,6 @@ The following configurations are related to the grad clipping. | `early_stop.patience` | (int) 5 | How long to wait after last time the monitored metric improved. | Note that the actual_checking_round = `early_step.patience` * `eval.freq`. To disable the early stop, set the `early_stop.patience` <=0 | | `early_stop.delta` | (float) 0. | Minimum change in the monitored metric to indicate a improvement. | - | | `early_stop.improve_indicaator_mode` | (string) 'best' | Early stop when there is no improvement within the last `early_step.patience` rounds, in ['mean', 'best'] | Chosen from 'mean' or 'best' | -| `early_step.the_smaller_the_better` | (bool) True | The optimized direction of the chosen metric | - | ### FL Setting diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index e4089a15a..543186fd1 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -64,7 +64,6 @@ def extend_training_cfg(cfg): cfg.early_stop.delta = 0.0 # Early stop when no improve to last `patience` round, in ['mean', 'best'] cfg.early_stop.improve_indicator_mode = 'best' - cfg.early_stop.the_smaller_the_better = True # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_training_cfg) diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index c219f5049..ec5d67e87 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -59,7 +59,7 @@ def preprocess(self, datadict): merged_max_data_id=self.global_cfg.federate.client_num, specified_dataset_name=['test']) # `0` indicate Server - datadict[0] = server_data + datadict[0] = ClientData(self.global_cfg, **server_data) if self.global_cfg.federate.method == "global": if self.global_cfg.federate.client_num != 1: @@ -73,9 +73,10 @@ def preprocess(self, datadict): else: logger.info(f"Will merge data from clients whose ids in " f"[1, {self.global_cfg.federate.client_num}]") - datadict[1] = merge_data( + merged_data = merge_data( all_data=datadict, merged_max_data_id=self.global_cfg.federate.client_num) + datadict[1] = ClientData(self.global_cfg, **merged_data) datadict = self.attack(datadict) return datadict diff --git a/federatedscope/core/data/base_translator.py b/federatedscope/core/data/base_translator.py index 4fc82a0d3..8dd71f0c9 100644 --- a/federatedscope/core/data/base_translator.py +++ b/federatedscope/core/data/base_translator.py @@ -10,7 +10,7 @@ class BaseDataTranslator: """ Perform process: - Dataset -> ML split -> FL split -> Data (passed to FedRunner) + Dataset -> ML split -> FL split -> Data (passed to Runner) """ def __init__(self, global_cfg, client_cfgs=None): diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index 1b46cac67..763279db4 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -3,10 +3,14 @@ import logging import os import re -from collections import defaultdict +import ssl +import urllib.request import numpy as np +import os.path as osp + from random import shuffle +from collections import defaultdict logger = logging.getLogger(__name__) @@ -577,9 +581,9 @@ def merge_data(all_data, merged_max_data_id=None, specified_dataset_name=None): torch.utils.data.DataLoader): if isinstance(all_data[id_contain_all_dataset_key][data_name].dataset, WrapDataset): - data_elem_names = list(all_data[id_contain_all_dataset_key] - [data_name].dataset.dataset.keys()) # # e.g., x, y + data_elem_names = list(all_data[id_contain_all_dataset_key] + [data_name].dataset.dataset.keys()) merged_data = {name: defaultdict(list) for name in dataset_names} for data_id in range(1, merged_max_data_id + 1): for d_name in dataset_names: @@ -593,19 +597,21 @@ def merge_data(all_data, merged_max_data_id=None, specified_dataset_name=None): for elem_name in data_elem_names: merged_data[d_name][elem_name] = np.concatenate( merged_data[d_name][elem_name]) - for name in all_data[id_contain_all_dataset_key]: - all_data[id_contain_all_dataset_key][ - name].dataset.dataset = merged_data[name] + merged_data[d_name] = WrapDataset(merged_data[d_name]) else: - merged_data = copy.deepcopy(all_data[id_contain_all_dataset_key]) + client_data = copy.deepcopy(all_data[id_contain_all_dataset_key]) for data_id in range(1, merged_max_data_id + 1): if data_id == id_contain_all_dataset_key: continue for d_name in dataset_names: if d_name not in all_data[data_id]: continue - merged_data[d_name].dataset.extend( + client_data[d_name].dataset.extend( all_data[data_id][d_name].dataset) + merged_data = { + key: client_data[key].dataset + for key in client_data + } else: raise NotImplementedError( "Un-supported type when merging data across different clients." @@ -615,3 +621,66 @@ def merge_data(all_data, merged_max_data_id=None, specified_dataset_name=None): " 1): {data_id: {train: {x:ndarray, y:ndarray}} }" " 2): {data_id: {train: DataLoader }") return merged_data + + +def save_local_data(dir_path, + train_data=None, + train_targets=None, + test_data=None, + test_targets=None, + val_data=None, + val_targets=None): + r""" + https://github.com/omarfoq/FedEM/blob/main/data/femnist/generate_data.py + + save (`train_data`, `train_targets`) in {dir_path}/train.pt, + (`val_data`, `val_targets`) in {dir_path}/val.pt + and (`test_data`, `test_targets`) in {dir_path}/test.pt + :param dir_path: + :param train_data: + :param train_targets: + :param test_data: + :param test_targets: + :param val_data: + :param val_targets + """ + import torch + if (train_data is not None) and (train_targets is not None): + torch.save((train_data, train_targets), osp.join(dir_path, "train.pt")) + + if (test_data is not None) and (test_targets is not None): + torch.save((test_data, test_targets), osp.join(dir_path, "test.pt")) + + if (val_data is not None) and (val_targets is not None): + torch.save((val_data, val_targets), osp.join(dir_path, "val.pt")) + + +def download_url(url: str, folder='folder'): + r"""Downloads the content of an url to a folder. + + Modified from `https://github.com/pyg-team/pytorch_geometric/blob/master + /torch_geometric/data/download.py` + + Args: + url (string): The url of target file. + folder (string): The target folder. + + Returns: + path (string): File path of downloaded files. + """ + + file = url.rpartition('/')[2] + file = file if file[0] == '?' else file.split('?')[0] + path = osp.join(folder, file) + if osp.exists(path): + logger.info(f'File {file} exists, use existing file.') + return path + + logger.info(f'Downloading {url}') + os.makedirs(folder, exist_ok=True) + ctx = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=ctx) + with open(path, 'wb') as f: + f.write(data.read()) + + return path diff --git a/federatedscope/core/fed_runner.py b/federatedscope/core/fed_runner.py index dac9e6649..aadf6e7be 100644 --- a/federatedscope/core/fed_runner.py +++ b/federatedscope/core/fed_runner.py @@ -1,3 +1,4 @@ +import abc import logging from collections import deque @@ -14,6 +15,457 @@ logger = logging.getLogger(__name__) +def get_runner(data, server_class, client_class, config, client_configs=None): + # Instantiate a Runner based on a configuration file + mode = config.federate.mode.lower() + runner_dict = { + 'standalone': StandaloneRunner, + 'distributed': DistributedRunner + } + return runner_dict[mode](data=data, + server_class=server_class, + client_class=client_class, + config=config, + client_configs=client_configs) + + +class BaseRunner(object): + """ + This class is used to construct an FL course, which includes `_set_up` + and `run`. + + Arguments: + data: The data used in the FL courses, which are formatted as { + 'ID':data} for standalone mode. More details can be found in + federatedscope.core.auxiliaries.data_builder . + server_class: The server class is used for instantiating a ( + customized) server. + client_class: The client class is used for instantiating a ( + customized) client. + config: The configurations of the FL course. + client_configs: The clients' configurations. + """ + def __init__(self, + data, + server_class=Server, + client_class=Client, + config=None, + client_configs=None): + self.data = data + self.server_class = server_class + self.client_class = client_class + assert config is not None, \ + "When using Runner, you should specify the `config` para" + if not config.is_ready_for_run: + config.ready_for_run() + self.cfg = config + self.client_cfgs = client_configs + + self.mode = self.cfg.federate.mode.lower() + self.gpu_manager = GPUManager(gpu_available=self.cfg.use_gpu, + specified_device=self.cfg.device) + + self.unseen_clients_id = [] + if self.cfg.federate.unseen_clients_rate > 0: + self.unseen_clients_id = np.random.choice( + np.arange(1, self.cfg.federate.client_num + 1), + size=max( + 1, + int(self.cfg.federate.unseen_clients_rate * + self.cfg.federate.client_num)), + replace=False).tolist() + # get resource information + self.resource_info = get_resource_info( + config.federate.resource_info_file) + + # Set up for Runner + self._set_up() + + @abc.abstractmethod + def _set_up(self): + """ + Set up client and/or server + """ + raise NotImplementedError + + @abc.abstractmethod + def _get_server_args(self, resource_info, client_resource_info): + """ + Get the args for instantiating the server. + + Args: + resource_info: information of resource + client_resource_info: information of client's resource + + Returns: + server_data: None or data which server holds. + model: model to be aggregated. + kw: kwargs dict to instantiate the server. + """ + raise NotImplementedError + + @abc.abstractmethod + def _get_client_args(self, client_id, resource_info): + """ + Get the args for instantiating the server. + + Args: + client_id: ID of client + resource_info: information of resource + + Returns: + client_data: data which client holds. + kw: kwargs dict to instantiate the client. + """ + raise NotImplementedError + + @abc.abstractmethod + def run(self): + """ + Launch the worker + + Returns: + best_results: best results during the FL course + """ + raise NotImplementedError + + def _setup_server(self, resource_info=None, client_resource_info=None): + """ + Set up the server + """ + assert self.server_class is not None, \ + "`server_class` cannot be None." + self.server_id = 0 + server_data, model, kw = self._get_server_args(resource_info, + client_resource_info) + self._server_device = self.gpu_manager.auto_choice() + server = self.server_class( + ID=self.server_id, + config=self.cfg, + data=server_data, + model=model, + client_num=self.cfg.federate.client_num, + total_round_num=self.cfg.federate.total_round_num, + device=self._server_device, + unseen_clients_id=self.unseen_clients_id, + **kw) + if self.cfg.nbafl.use: + from federatedscope.core.trainers.trainer_nbafl import \ + wrap_nbafl_server + wrap_nbafl_server(server) + logger.info('Server has been set up ... ') + return server + + def _setup_client(self, + client_id=-1, + client_model=None, + resource_info=None): + """ + Set up the Client + """ + assert self.client_class is not None, \ + "`client_class` cannot be None" + self.server_id = 0 + client_data, kw = self._get_client_args(client_id, resource_info) + client_specific_config = self.cfg.clone() + if self.client_cfgs: + client_specific_config.defrost() + client_specific_config.merge_from_other_cfg( + self.client_cfgs.get('client_{}'.format(client_id))) + client_specific_config.freeze() + client_device = self._server_device if \ + self.cfg.federate.share_local_model else \ + self.gpu_manager.auto_choice() + client = self.client_class(ID=client_id, + server_id=self.server_id, + config=client_specific_config, + data=client_data, + model=client_model + or get_model(client_specific_config.model, + client_data, + backend=self.cfg.backend), + device=client_device, + is_unseen_client=client_id + in self.unseen_clients_id, + **kw) + + if client_id == -1: + logger.info('Client (address {}:{}) has been set up ... '.format( + self.client_address['host'], self.client_address['port'])) + else: + logger.info(f'Client {client_id} has been set up ... ') + + return client + + +class StandaloneRunner(BaseRunner): + def _set_up(self): + """ + To set up server and client for standalone mode. + """ + self.is_run_online = True if self.cfg.federate.online_aggr else False + self.shared_comm_queue = deque() + + if self.cfg.backend == 'torch': + import torch + torch.set_num_threads(1) + + assert self.cfg.federate.client_num != 0, \ + "In standalone mode, self.cfg.federate.client_num should be " \ + "non-zero. " \ + "This is usually cased by using synthetic data and users not " \ + "specify a non-zero value for client_num" + + if self.cfg.federate.method == "global": + self.cfg.defrost() + self.cfg.federate.client_num = 1 + self.cfg.federate.sample_client_num = 1 + self.cfg.freeze() + + # sample resource information + if self.resource_info is not None: + if len(self.resource_info) < self.cfg.federate.client_num + 1: + replace = True + logger.warning( + f"Because the provided the number of resource information " + f"{len(self.resource_info)} is less than the number of " + f"participants {self.cfg.federate.client_num + 1}, one " + f"candidate might be selected multiple times.") + else: + replace = False + sampled_index = np.random.choice( + list(self.resource_info.keys()), + size=self.cfg.federate.client_num + 1, + replace=replace) + server_resource_info = self.resource_info[sampled_index[0]] + client_resource_info = [ + self.resource_info[x] for x in sampled_index[1:] + ] + else: + server_resource_info = None + client_resource_info = None + + self.server = self._setup_server( + resource_info=server_resource_info, + client_resource_info=client_resource_info) + + self.client = dict() + # assume the client-wise data are consistent in their input&output + # shape + self._shared_client_model = get_model( + self.cfg.model, self.data[1], backend=self.cfg.backend + ) if self.cfg.federate.share_local_model else None + for client_id in range(1, self.cfg.federate.client_num + 1): + self.client[client_id] = self._setup_client( + client_id=client_id, + client_model=self._shared_client_model, + resource_info=client_resource_info[client_id - 1] + if client_resource_info is not None else None) + + # in standalone mode, by default, we print the trainer info only + # once for better logs readability + trainer_representative = self.client[1].trainer + if trainer_representative is not None and hasattr( + trainer_representative, 'print_trainer_meta_info'): + trainer_representative.print_trainer_meta_info() + + def _get_server_args(self, resource_info=None, client_resource_info=None): + if self.server_id in self.data: + server_data = self.data[self.server_id] + model = get_model(self.cfg.model, + server_data, + backend=self.cfg.backend) + else: + server_data = None + data_representative = self.data[1] + model = get_model( + self.cfg.model, data_representative, backend=self.cfg.backend + ) # get the model according to client's data if the server + # does not own data + kw = { + 'shared_comm_queue': self.shared_comm_queue, + 'resource_info': resource_info, + 'client_resource_info': client_resource_info + } + return server_data, model, kw + + def _get_client_args(self, client_id=-1, resource_info=None): + client_data = self.data[client_id] + kw = { + 'shared_comm_queue': self.shared_comm_queue, + 'resource_info': resource_info + } + return client_data, kw + + def run(self): + for each_client in self.client: + # Launch each client + self.client[each_client].join_in() + + if self.is_run_online: + self._run_simulation_online() + else: + self._run_simulation() + # TODO: avoid using private attr + self.server._monitor.finish_fed_runner(fl_mode=self.mode) + return self.server.best_results + + def _handle_msg(self, msg, rcv=-1): + """ + To simulate the message handling process (used only for the + standalone mode) + """ + if rcv != -1: + # simulate broadcast one-by-one + self.client[rcv].msg_handlers[msg.msg_type](msg) + return + + _, receiver = msg.sender, msg.receiver + download_bytes, upload_bytes = msg.count_bytes() + if not isinstance(receiver, list): + receiver = [receiver] + for each_receiver in receiver: + if each_receiver == 0: + self.server.msg_handlers[msg.msg_type](msg) + self.server._monitor.track_download_bytes(download_bytes) + else: + self.client[each_receiver].msg_handlers[msg.msg_type](msg) + self.client[each_receiver]._monitor.track_download_bytes( + download_bytes) + + def _run_simulation_online(self): + """ + Run for online aggregation. + Any broadcast operation would be executed client-by-clien to avoid + the existence of #clients messages at the same time. Currently, + only consider centralized topology + """ + def is_broadcast(msg): + return len(msg.receiver) >= 1 and msg.sender == 0 + + cached_bc_msgs = [] + cur_idx = 0 + while True: + if len(self.shared_comm_queue) > 0: + msg = self.shared_comm_queue.popleft() + if is_broadcast(msg): + cached_bc_msgs.append(msg) + # assume there is at least one client + msg = cached_bc_msgs[0] + self._handle_msg(msg, rcv=msg.receiver[cur_idx]) + cur_idx += 1 + if cur_idx >= len(msg.receiver): + del cached_bc_msgs[0] + cur_idx = 0 + else: + self._handle_msg(msg) + elif len(cached_bc_msgs) > 0: + msg = cached_bc_msgs[0] + self._handle_msg(msg, rcv=msg.receiver[cur_idx]) + cur_idx += 1 + if cur_idx >= len(msg.receiver): + del cached_bc_msgs[0] + cur_idx = 0 + else: + # finished + break + + def _run_simulation(self): + """ + Run for standalone simulation (W/O online aggr) + """ + server_msg_cache = list() + while True: + if len(self.shared_comm_queue) > 0: + msg = self.shared_comm_queue.popleft() + if msg.receiver == [self.server_id]: + # For the server, move the received message to a + # cache for reordering the messages according to + # the timestamps + heapq.heappush(server_msg_cache, msg) + else: + self._handle_msg(msg) + elif len(server_msg_cache) > 0: + msg = heapq.heappop(server_msg_cache) + if self.cfg.asyn.use and self.cfg.asyn.aggregator \ + == 'time_up': + # When the timestamp of the received message beyond + # the deadline for the currency round, trigger the + # time up event first and push the message back to + # the cache + if self.server.trigger_for_time_up(msg.timestamp): + heapq.heappush(server_msg_cache, msg) + else: + self._handle_msg(msg) + else: + self._handle_msg(msg) + else: + if self.cfg.asyn.use and self.cfg.asyn.aggregator \ + == 'time_up': + self.server.trigger_for_time_up() + if len(self.shared_comm_queue) == 0 and \ + len(server_msg_cache) == 0: + break + else: + # terminate when shared_comm_queue and + # server_msg_cache are all empty + break + + +class DistributedRunner(BaseRunner): + def _set_up(self): + """ + To set up server or client for distributed mode. + """ + # sample resource information + if self.resource_info is not None: + sampled_index = np.random.choice(list(self.resource_info.keys())) + sampled_resource = self.resource_info[sampled_index] + else: + sampled_resource = None + + self.server_address = { + 'host': self.cfg.distribute.server_host, + 'port': self.cfg.distribute.server_port + } + if self.cfg.distribute.role == 'server': + self.server = self._setup_server(resource_info=sampled_resource) + elif self.cfg.distribute.role == 'client': + # When we set up the client in the distributed mode, we assume + # the server has been set up and number with #0 + self.client_address = { + 'host': self.cfg.distribute.client_host, + 'port': self.cfg.distribute.client_port + } + self.client = self._setup_client(resource_info=sampled_resource) + + def _get_server_args(self, resource_info, client_resource_info): + server_data = self.data + model = get_model(self.cfg.model, + server_data, + backend=self.cfg.backend) + kw = self.server_address + kw.update({'resource_info': resource_info}) + return server_data, model, kw + + def _get_client_args(self, client_id, resource_info): + client_data = self.data + kw = self.client_address + kw['server_host'] = self.server_address['host'] + kw['server_port'] = self.server_address['port'] + kw['resource_info'] = resource_info + return client_data, kw + + def run(self): + if self.cfg.distribute.role == 'server': + self.server.run() + return self.server.best_results + elif self.cfg.distribute.role == 'client': + self.client.join_in() + self.client.run() + + +# TODO: remove FedRunner (keep now for forward compatibility) class FedRunner(object): """ This class is used to construct an FL course, which includes `_set_up` @@ -36,6 +488,10 @@ def __init__(self, client_class=Client, config=None, client_configs=None): + logger.warning('`federate.core.fed_runner.FedRunner` will be ' + 'removed in the future, please use' + '`federate.core.fed_runner.get_runner` to get ' + 'Runner.') self.data = data self.server_class = server_class self.client_class = client_class @@ -228,7 +684,6 @@ def is_broadcast(msg): break def _run_simulation(self): - server_msg_cache = list() while True: if len(self.shared_comm_queue) > 0: diff --git a/federatedscope/core/monitors/early_stopper.py b/federatedscope/core/monitors/early_stopper.py index ab3516305..48fb146d8 100644 --- a/federatedscope/core/monitors/early_stopper.py +++ b/federatedscope/core/monitors/early_stopper.py @@ -13,7 +13,7 @@ def __init__(self, patience=5, delta=0, improve_indicator_mode='best', - the_smaller_the_better=True): + the_larger_the_better=True): """ Args: patience (int): How long to wait after last time the monitored @@ -39,7 +39,7 @@ def __init__(self, self.counter_no_improve = 0 self.best_metric = None self.early_stopped = False - self.the_smaller_the_better = the_smaller_the_better + self.the_larger_the_better = the_larger_the_better self.delta = delta self.improve_indicator_mode = improve_indicator_mode # For expansion usages of comparisons @@ -54,12 +54,12 @@ def __track_and_check_best(self, history_result): new_result = history_result[-1] if self.best_metric is None: self.best_metric = new_result - elif self.the_smaller_the_better and self.comparator( + elif not self.the_larger_the_better and self.comparator( self.improvement_operator(self.best_metric, -self.delta), new_result): # add(best_metric, -delta) < new_result self.counter_no_improve += 1 - elif not self.the_smaller_the_better and self.comparator( + elif self.the_larger_the_better and self.comparator( new_result, self.improvement_operator(self.best_metric, self.delta)): # new_result < add(best_metric, delta) @@ -74,12 +74,12 @@ def __track_and_check_best(self, history_result): def __track_and_check_mean(self, history_result): new_result = history_result[-1] if len(history_result) > self.patience: - if self.the_smaller_the_better and self.comparator( + if not self.the_larger_the_better and self.comparator( self.improvement_operator( np.mean(history_result[-self.patience - 1:-1]), -self.delta), new_result): self.early_stopped = True - elif not self.the_smaller_the_better and self.comparator( + elif self.the_larger_the_better and self.comparator( new_result, self.improvement_operator( np.mean(history_result[-self.patience - 1:-1]), diff --git a/federatedscope/core/monitors/metric_calculator.py b/federatedscope/core/monitors/metric_calculator.py index 6d32122e4..ed3ec4b3b 100644 --- a/federatedscope/core/monitors/metric_calculator.py +++ b/federatedscope/core/monitors/metric_calculator.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) -# TODO: make this as a sub-module of monitor class class MetricCalculator(object): def __init__(self, eval_metric: Union[Set[str], List[str], str]): @@ -41,7 +40,7 @@ def get_metric_funcs(self, eval_metric): def eval(self, ctx): results = {} y_true, y_pred, y_prob = self._check_and_parse(ctx) - for metric, func in self.eval_metric.items(): + for metric, (func, _) in self.eval_metric.items(): results["{}_{}".format(ctx.cur_split, metric)] = func(ctx=ctx, y_true=y_true, @@ -218,18 +217,21 @@ def eval_imp_ratio(ctx, y_true, y_prob, y_pred, **kwargs): return (base - perform) / base * 100. +# SUPPORT_METRICS dict, key: `metric_name`, value: (eval_func, +# the_larger_the_better) SUPPORT_METRICS = { - 'loss': eval_loss, - 'avg_loss': eval_avg_loss, - 'total': eval_total, - 'correct': eval_correct, - 'acc': eval_acc, - 'ap': eval_ap, - 'f1': eval_f1_score, - 'roc_auc': eval_roc_auc, - 'rmse': eval_rmse, - 'mse': eval_mse, - 'loss_regular': eval_regular, - 'imp_ratio': eval_imp_ratio, - **dict.fromkeys([f'hits@{n}' for n in range(1, 101)], eval_hits) + 'loss': (eval_loss, False), + 'avg_loss': (eval_avg_loss, False), + 'total': (eval_total, False), + 'correct': (eval_correct, True), + 'acc': (eval_acc, True), + 'ap': (eval_ap, True), + 'f1': (eval_f1_score, True), + 'roc_auc': (eval_roc_auc, True), + 'rmse': (eval_rmse, False), + 'mse': (eval_mse, False), + 'loss_regular': (eval_regular, False), + 'imp_ratio': (eval_imp_ratio, True), + 'std': (None, False), + **dict.fromkeys([f'hits@{n}' for n in range(1, 101)], (eval_hits, True)) } diff --git a/federatedscope/core/monitors/monitor.py b/federatedscope/core/monitors/monitor.py index 0f4832164..3050e8e56 100644 --- a/federatedscope/core/monitors/monitor.py +++ b/federatedscope/core/monitors/monitor.py @@ -10,6 +10,7 @@ import numpy as np from federatedscope.core.auxiliaries.logging import logline_2_wandb_dict +from federatedscope.core.monitors.metric_calculator import MetricCalculator try: import torch @@ -40,6 +41,17 @@ def __init__(self, cfg, monitored_object=None): # self.use_tensorboard = cfg.use_tensorboard self.monitored_object = monitored_object + self.metric_calculator = MetricCalculator(cfg.eval.metrics) + + # Obtain the whether the larger the better + self.round_wise_update_key = cfg.eval.best_res_update_round_wise_key + for mode in ['train', 'val', 'test']: + if mode in self.round_wise_update_key: + update_key = self.round_wise_update_key.split(f'{mode}_')[1] + assert update_key in self.metric_calculator.eval_metric, \ + f'{update_key} not found in metrics.' + self.the_larger_the_better = self.metric_calculator.eval_metric[ + update_key][1] # ======= efficiency indicators of the worker to be monitored ======= # leveraged the flops counter provided by [fvcore]( @@ -80,6 +92,10 @@ def __init__(self, cfg, monitored_object=None): "cfg.wandb.use=True but not install the wandb package") exit() + def eval(self, ctx): + results = self.metric_calculator.eval(ctx) + return results + def global_converged(self): self.global_convergence_wall_time = datetime.datetime.now( ) - self.fl_begin_wall_time @@ -517,11 +533,7 @@ def track_upload_bytes(self, bytes): def track_download_bytes(self, bytes): self.total_download_bytes += bytes - def update_best_result(self, - best_results, - new_results, - results_type, - round_wise_update_key="val_loss"): + def update_best_result(self, best_results, new_results, results_type): """ update best evaluation results. by default, the update is based on validation loss with @@ -538,7 +550,7 @@ def update_best_result(self, best_result = best_results[results_type] # update different keys separately: the best values can be in # different rounds - if round_wise_update_key is None: + if self.round_wise_update_key is None: for key in new_results: cur_result = new_results[key] if 'loss' in key or 'std' in key: # the smaller, @@ -569,24 +581,10 @@ def update_best_result(self, # update different keys round-wise: if find better # round_wise_update_key, update others at the same time else: - if round_wise_update_key not in [ - "val_loss", "test_loss", "loss", "val_avg_loss", - "test_avg_loss", "avg_loss", "test_acc", "test_std", - "val_acc", "val_std", "val_imp_ratio", "train_loss", - "train_avg_loss" - ]: - raise NotImplementedError( - f"We currently support round_wise_update_key as one " - f"of ['val_loss', 'test_loss', 'loss', " - f"'val_avg_loss', 'test_avg_loss', 'avg_loss," - f"''val_acc', 'val_std', 'test_acc', 'test_std', " - f"'val_imp_ratio'] for round-wise best results " - f" update, but got {round_wise_update_key}.") - found_round_wise_update_key = False sorted_keys = [] for key in new_results: - if round_wise_update_key in key: + if self.round_wise_update_key in key: sorted_keys.insert(0, key) found_round_wise_update_key = True else: @@ -597,15 +595,13 @@ def update_best_result(self, "is not in target results, " "use another key or check the name. \n" f"Got eval.best_res_update_round_wise_key" - f"={round_wise_update_key}, " + f"={self.round_wise_update_key}, " f"the keys of results are {list(new_results.keys())}") for key in sorted_keys: cur_result = new_results[key] - if update_best_this_round or \ - ('loss' in round_wise_update_key and 'loss' in - key) or \ - ('std' in round_wise_update_key and 'std' in key): + if update_best_this_round or ( + not self.the_larger_the_better): # The smaller the better if results_type in [ "client_best_individual", @@ -617,8 +613,7 @@ def update_best_result(self, best_result[key]: best_result[key] = cur_result update_best_this_round = True - elif update_best_this_round or \ - 'acc' in round_wise_update_key and 'acc' in key: + elif update_best_this_round or self.the_larger_the_better: # The larger the better if results_type in [ "client_best_individual", diff --git a/federatedscope/core/trainers/README.md b/federatedscope/core/trainers/README.md new file mode 100644 index 000000000..7a9ae4198 --- /dev/null +++ b/federatedscope/core/trainers/README.md @@ -0,0 +1,461 @@ +# Local Learning Abstraction: Trainer + +FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithm via the `trainer`. Each worker holds a `trainer` object to manage the details of local learning, such as the loss function, optimizer, training step, evaluation, etc. + +This tutorial is a shorter version of [full version tutorial](https://federatedscope.io/docs/trainer/), where you can learn more details about FS Trainer. + +## Code Structure + +The code structure is shown below, and we will discuss all the concepts of our FS Trainer later. + +```bash +federatedscope/core +├── trainers +│ ├── BaseTrainer +│ │ ├── Trainer +│ │ │ ├── GeneralTorchTrainer +│ │ │ ├── GeneralTFTrainer +│ │ │ ├── Context +│ │ │ ├── ... +│ │ ├── UserDefineTrainer +│ │ ├── ... +``` + +## FS Trainer + +A typical machine-learning process consists of the following procedures: + +1. Preparing data. +2. Iterations over training datasets to update the model parameters +3. Evaluation of the quality of the learned model on validation/evaluation datasets +4. Saving, loading, and monitoring the model and intermediate results + +### BaseTrainer + +`BaseTrainer` is an abstract class of our Trainer, which provide the interface of each method. And you can implement your own trainer by inheriting from `BaseTrainer`. More examples can be found in `federatedscope/contrib/trainer`. + +```python +class BaseTrainer(abc.ABC): + def __init__(self, model, data, device, **kwargs): + self.model = model + self.data = data + self.device = device + self.kwargs = kwargs + + @abc.abstractmethod + def train(self): + raise NotImplementedError + + @abc.abstractmethod + def evaluate(self, target_data_split_name='test'): + raise NotImplementedError + + @abc.abstractmethod + def update(self, model_parameters, strict=False): + raise NotImplementedError + + @abc.abstractmethod + def get_model_para(self): + raise NotImplementedError + + ... ... +``` + +### Trainer + +As the figure shows, in FederatedScope `Trainer` (a subclass of `BaseTrainer`), these above procedures are provided with high-level `routines` abstraction, which is made up of `Context` class and several pluggable `Hooks`. And we provide `GeneralTorchTrainer` and `GeneralTFTrainer` for `PyTorch` and `TensorFlow`, separately. + +undefined + +#### Context + +The `Context` class (a subclass of `dict`) is used to hold learning-related attributes, including data, model, optimizer and etc, and user and add or delete these attributes in hook functions. We classify and show the default attributes below: + +* Data-related attributes + * `ctx.data`: the raw data (not split) the trainer holds + * `ctx.num_samples`: the number of samples used in training + * `ctx.train_data`, `ctx.val_data`, `ctx.test_data`: the split data the trainer holds + * `ctx.train_loader`, `ctx.val_loader`, `ctx.test_loader`: the DataLoader of each split data + * `ctx.num_train_data`, `ctx.num_val_data`, `ctx.num_test_data`: the number of samples of the split data +* Model-related attributes + * `ctx.model`: the model used + * `ctx.models`: the multi models if use + * `ctx.mirrored_models`: the mirrored models + * `ctx.trainable_para_names`: the trainable parameter names of the model +* Optimizer-related attributes + * `ctx.optimizer`: see [`torch.optim`](https://pytorch.org/docs/stable/optim.html#module-torch.optim) for details + * `ctx.scheduler`: decays the learning rate of each parameter group + * `ctx.criterion`: loss/criterion function + * `ctx.regularizer`: regular terms + * `ctx.grad_clip`: gradient clipping +* Mode-related attributes + * `ctx.cur_mode`: mode of trainer, which is one of `['train', 'val', 'test']` + * `ctx.mode_stack`: stack of mode, only used for switching mode + * `ctx.cur_split`: split of data, which is one of `['train', 'val', 'test']` (Note: use `train` data in `test` mode is allowed) + * `ctx.split_stack`: stack of split, only used for switching data split +* Metric-related attributes + * `ctx.loss_batch_total`: Loss of current batch + * `ctx.loss_regular_total`: Loss of regular term + * `ctx.y_true`: true label of batch data + * `ctx.y_prob`: output of the model with batch data as input + * `ctx.ys_true`: true label of data + * `ctx.ys_prob`: output of the model + * `ctx.eval_metrics`: evaluation metrics calculated by `Monitor` + * `ctx.monitor`: used for monitor trainer's behavior and statistics +* Other (statistics) attributes (@property, query from ``cfg`` if not set) + * `ctx.cfg`: configuration of FL course, see [link](https://github.com/alibaba/FederatedScope/tree/master/federatedscope/core/configs) for details + * `ctx.device`: current device, such as `cpu` and `gpu0`. + * `ctx.num_train_batch_last_epoch`, `ctx.num_total_train_batch`: the number of batch + * `ctx.num_train_epoch`, `ctx.num_val_epoch`, `ctx.num_test_epoch`: the number of epoch in each data split + * `ctx.num_train_batch`, `ctx.num_val_batch`, `ctx.num_test_batch`: the number of batch in each data split + +#### Hooks + +The `Hooks` represent fine-grained learning behaviors at different point-in-times, which provides a simple yet powerful way to customize learning behaviors with a few modifications and easy re-use of fruitful default hooks. In this section, we will show the detail of each hook used in `Trainer`. + +##### Hook trigger + +The hook trigger is where the hook functions are executed, and all the hook functions are executed following the pattern below: + +* **on_fit_start** + * **on_epoch_start** + * **on_batch_start** + * **on_batch_forward** + * **on_batch_backward** + * **on_batch_end** + * **on_epoch_end** +* **on_fit_end** + +##### Train hooks + +Train hooks are executed when `ctx.cur_mode` is `train`, following the execution paradigm as shown below: + +* **on_fit_start** + + `_hook_on_fit_start_init` + + `_hook_on_fit_start_calculate_model_size` + + * **on_epoch_start** + + `_hook_on_epoch_start` + + * **on_batch_start** + + `_hook_on_batch_start_init` + + * **on_batch_forward** + + `_hook_on_batch_forward` + + `_hook_on_batch_forward_regularizer` + + `_hook_on_batch_forward_flop_count` + + * **on_batch_backward** + + `_hook_on_batch_backward` + + * **on_batch_end** + + `_hook_on_batch_end` + + * **on_epoch_end** + + `None` + +* **on_fit_end** + + `_hook_on_fit_end` + +##### Evaluation (val/test) hooks + +Evaluation hooks are executed when `ctx.cur_mode` is `val` or `test`, following the execution paradigm as shown below: + +* **on_fit_start** + + `_hook_on_fit_start_init` + + * **on_epoch_start** + + `_hook_on_epoch_start` + + * **on_batch_start** + + `_hook_on_batch_start_init` + + * **on_batch_forward** + + `_hook_on_batch_forward` + + * **on_batch_backward** + + `None` + + * **on_batch_end** + + `_hook_on_batch_end` + + * **on_epoch_end** + + `None` + +* **on_fit_end** + + `_hook_on_fit_end` + +##### Finetune hooks + +Finetune hooks are executed when `ctx.cur_mode` is `finetune`, following the execution paradigm as shown below: + +* **on_fit_start** + + `_hook_on_fit_start_init` + + `_hook_on_fit_start_calculate_model_size` + + * **on_epoch_start** + + `_hook_on_epoch_start` + + * **on_batch_start** + + `_hook_on_batch_start_init` + + * **on_batch_forward** + + `_hook_on_batch_forward` + + `_hook_on_batch_forward_regularizer` + + `_hook_on_batch_forward_flop_count` + + * **on_batch_backward** + + `_hook_on_batch_backward` + + * **on_batch_end** + + `_hook_on_batch_end` + + * **on_epoch_end** + + `None` + +* **on_fit_end** + + `_hook_on_fit_end` + +##### Hook functions + +In this section, we will briefly describe what the hook functions do with the attributes/variables in `ctx`. + +###### GeneralTorchTrainer + +* `_hook_on_fit_start_init` + + | Modified attribute | Operation | + | ------------------------ | ----------------------- | + | `ctx.model` | Move to `ctx.device` | + | `ctx.optimizer` | Initialize by `ctx.cfg` | + | `ctx.scheduler` | Initialize by `ctx.cfg` | + | `ctx.loss_batch_total` | Initialize to `0` | + | `ctx.loss_regular_total` | Initialize to `0` | + | `ctx.num_samples` | Initialize to `0` | + | `ctx.ys_true` | Initialize to `[]` | + | `ctx.ys_prob` | Initialize to `[]` | + +* `_hook_on_fit_start_calculate_model_size` + + | Modified attribute | Operation | + | ------------------ | ---------------- | + | `ctx.monitor` | Track model size | + +* `_hook_on_epoch_start` + + | Modified attribute | Operation | + | ---------------------------- | --------------------- | + | `ctx.{ctx.cur_split}_loader` | Initialize DataLoader | + +* `_hook_on_batch_start_init` + + | Modified attribute | Operation | + | ------------------ | --------------------- | + | `ctx.data_batch` | Initialize batch data | + +* `_hook_on_batch_forward` + + | Modified attribute | Operation | + | ------------------ | ----------------------------------- | + | `ctx.y_true` | Move to `ctx.device` | + | `ctx.y_prob` | Forward propagation to get `y_prob` | + | `ctx.loss_batch` | Calculate the loss | + | `ctx.batch_size` | Get the batch_size | + +* `_hook_on_batch_forward_regularizer` + + | Modified attribute | Operation | + | ------------------ | ----------------------------------------- | + | `ctx.loss_regular` | Calculate the regular loss | + | `ctx.loss_task` | Sum the `ctx.loss_regular` and `ctx.loss` | + +* `_hook_on_batch_forward_flop_count` + + | Modified attribute | Operation | + | ------------------ | ------------------- | + | `ctx.monitor` | Track average flops | + +* `_hook_on_batch_backward` + + | Modified attribute | Operation | + | ------------------ | -------------------- | + | `ctx.optimizer` | Update by gradient | + | `ctx.loss_task` | Backward propagation | + | `ctx.scheduler` | Update by gradient | + +* `_hook_on_batch_end ` + + | Modified attribute | Operation | + | ------------------------ | ---------------------- | + | `ctx.num_samples` | Add `ctx.batch_size` | + | `ctx.loss_batch_total` | Add batch loss | + | `ctx.loss_regular_total` | Add batch regular loss | + | `ctx.ys_true` | Append `ctx.y_true` | + | `ctx.ys_prob` | Append `ctx.ys_prob` | + +* `_hook_on_fit_end ` + + | Modified attribute | Operation | + | ------------------ | ---------------------------------------- | + | `ctx.ys_true` | Convert to `numpy.array` | + | `ctx.ys_prob` | Convert to `numpy.array` | + | `ctx.monitor` | Evaluate the results | + | `ctx.eval_metrics` | Get evaluated results from `ctx.monitor` | + +###### DittoTrainer + +* `_hook_on_fit_start_set_regularized_para` + + | Modified attribute | Operation | + | -------------------------------- | ------------------------------------------------------------ | + | `ctx.global_model` | Move to `ctx.device` and set to `train` mode | + | `ctx.local_model` | Move to `ctx.device` and set to `train` mode | + | `ctx.optimizer_for_global_model` | Initialize by `ctx.cfg` and wrapped by `wrap_regularized_optimizer` | + | `ctx.optimizer_for_local_model` | Initialize by `ctx.cfg` and set compared parameter group | + +* `_hook_on_fit_start_clean` + + | Modified attribute | Operation | + | ----------------------------------- | ----------------- | + | `ctx.optimizer` | Delete | + | `ctx.num_samples_local_model_train` | Initialize to `0` | + +* `_hook_on_fit_start_switch_local_model` + + | Modified attribute | Operation | + | ------------------ | ----------------------------------------------- | + | `ctx.model` | Set to `ctx.local_model` and set to `eval` mode | + +* `_hook_on_batch_start_switch_model` + + | Modified attribute | Operation | + | ----------------------------- | ------------------------------------------------------------ | + | `ctx.use_local_model_current` | Set to `True` or `False` | + | `ctx.model` | Set to `ctx.local_model` or `ctx.global_model` | + | `ctx.optimizer` | Set to `ctx.optimizer_for_local_model` or `ctx.optimizer_for_global_model` | + +* `_hook_on_batch_forward_cnt_num` + + | Modified attribute | Operation | + | ----------------------------------- | -------------------- | + | `ctx.num_samples_local_model_train` | Add `ctx.batch_size` | + +* `_hook_on_batch_end_flop_count` + + | Modified attribute | Operation | + | ------------------ | ------------------- | + | `ctx.monitor` | Monitor total flops | + +* `_hook_on_fit_end_calibrate` + + | Modified attribute | Operation | + | ------------------ | -------------------------------------------------- | + | `ctx.num_samples` | Minus `ctx.num_samples_local_model_train` | + | `ctx.eval_metrics` | Record `train_total` and `train_total_local_model` | + +* `_hook_on_fit_end_switch_global_model` + + | Modified attribute | Operation | + | ------------------ | ------------------------- | + | `ctx.model ` | Set to `ctx.global_model` | + +* `_hook_on_fit_end_free_cuda` + + | Modified attribute | Operation | + | ------------------ | ------------- | + | `ctx.global_model` | Move to `cpu` | + | `ctx.local_model` | Move to `cpu` | + +###### pFedMeTrainer + +* `_hook_on_fit_start_set_local_para_tmp` + + | Modified attribute | Operation | + | ---------------------------- | ------------------------------------------------------------ | + | `ctx.optimizer` | Wrapped by `wrap_regularized_optimizer` and set compared parameter group | + | `ctx.pFedMe_outer_lr` | Initialize to `ctx.cfg.train.optimizer.lr` | + | `ctx.pFedMe_local_model_tmp` | Copy from `ctx.model` | + +* `_hook_on_batch_start_init_pfedme` + + | Modified attribute | Operation | + | ------------------------------- | ---------------------------------- | + | `ctx.data_batch_cache` | Copy from `ctx.data_batch` | + | `ctx.pFedMe_approx_fit_counter` | Count to refresh data every K step | + +* `_hook_on_batch_end_flop_count` + + | Modified attribute | Operation | + | ------------------ | ------------------- | + | `ctx.monitor` | Monitor total flops | + +* `_hook_on_epoch_end_flop_count` + + | Modified attribute | Operation | + | ------------------ | ------------------- | + | `ctx.monitor` | Monitor total flops | + +* `_hook_on_epoch_end_update_local` + + | Modified attribute | Operation | + | ------------------ | ------------------------------------------------- | + | `ctx.model` | Update parameters by `ctx.pFedMe_local_model_tmp` | + | `ctx.optimizer` | Set compared parameter group | + +* `_hook_on_fit_end_update_local` + + | Modified attribute | Operation | + | ---------------------------- | ------------------------------------------------- | + | `ctx.model` | Update parameters by `ctx.pFedMe_local_model_tmp` | + | `ctx.pFedMe_local_model_tmp` | Delete | + +###### FedProxTrainer & NbaflTrainer + +* `_hook_record_initialization` + + | Modified attribute | Operation | + | ------------------ | --------------------- | + | `ctx.weight_init` | Copy from `ctx.model` | + +* `_hook_del_initialization` + + | Modified attribute | Operation | + | ------------------ | ------------- | + | `ctx.weight_init` | Set to `None` | + +* `_hook_inject_noise_in_upload` + + | Modified attribute | Operation | + | ------------------ | -------------------------- | + | `ctx.model` | Inject noise to parameters | + diff --git a/federatedscope/core/trainers/__init__.py b/federatedscope/core/trainers/__init__.py index de9221ee1..072d6369d 100644 --- a/federatedscope/core/trainers/__init__.py +++ b/federatedscope/core/trainers/__init__.py @@ -1,6 +1,7 @@ from federatedscope.core.trainers.base_trainer import BaseTrainer from federatedscope.core.trainers.trainer import Trainer from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from federatedscope.core.trainers.tf_trainer import GeneralTFTrainer from federatedscope.core.trainers.trainer_multi_model import \ GeneralMultiModelTrainer from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer @@ -15,5 +16,5 @@ 'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer', 'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer', 'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server', - 'BaseTrainer' + 'BaseTrainer', 'GeneralTFTrainer' ] diff --git a/federatedscope/core/trainers/base_trainer.py b/federatedscope/core/trainers/base_trainer.py index 50bf57245..1d0637d42 100644 --- a/federatedscope/core/trainers/base_trainer.py +++ b/federatedscope/core/trainers/base_trainer.py @@ -1,4 +1,5 @@ import abc +import inspect class BaseTrainer(abc.ABC): @@ -24,6 +25,11 @@ def update(self, model_parameters, strict=False): def get_model_para(self): raise NotImplementedError - @abc.abstractmethod def print_trainer_meta_info(self): - raise NotImplementedError + """ + Returns: String contains meta information of Trainer. + """ + sign = inspect.signature(self.__init__).parameters.values() + meta_info = tuple([(val.name, getattr(self, val.name)) + for val in sign]) + return f'{self.__class__.__name__}{meta_info}' diff --git a/federatedscope/core/trainers/context.py b/federatedscope/core/trainers/context.py index e612339df..06dffea82 100644 --- a/federatedscope/core/trainers/context.py +++ b/federatedscope/core/trainers/context.py @@ -5,9 +5,8 @@ from federatedscope.core.auxiliaries.model_builder import \ get_trainable_para_names from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer -from federatedscope.core.auxiliaries.enums import MODE -from federatedscope.core.auxiliaries.utils import calculate_batch_epoch_num -from federatedscope.core.data import ClientData +from federatedscope.core.trainers.enums import MODE +from federatedscope.core.trainers.utils import calculate_batch_epoch_num logger = logging.getLogger(__name__) @@ -46,7 +45,9 @@ def clear(self, lifecycle): class Context(LifecycleDict): - """Record and pass variables among different hook functions + """ + Record and pass variables among different hook functions. + Arguments: model: training model cfg: config @@ -55,54 +56,88 @@ class Context(LifecycleDict): init_dict (dict): a dict used to initialize the instance of Context init_attr (bool): if set up the static variables Note: - - The variables within an instance of class `Context` - can be set/get as an attribute. + - The variables within an instance of class `Context` can be set/get \ + as an attribute. ``` ctx.${NAME_VARIABLE} = ${VALUE_VARIABLE} ``` - where `${NAME_VARIABLE}` and `${VALUE_VARIABLE}` + where ``${NAME_VARIABLE}`` and ``${VALUE_VARIABLE}`` is the name and value of the variable. - - To achieve automatically lifecycle management, you can - wrap the variable with `CtxVar` and a lifecycle parameter + - To achieve automatically lifecycle management, you can \ + wrap the variable with ``CtxVar`` and a lifecycle parameter \ as follows ``` - ctx.${NAME_VARIABLE} = CtxVar(${VALUE_VARIABLE}, ${LFECYCLE}) + ctx.${NAME_VARIABLE} = CtxVar(${VALUE_VARIABLE}, ${LIFECYCLE}) ``` - The parameter `${LFECYCLE}` can be chosen from `LIFECYCLE.BATCH`, - `LIFECYCLE.EPOCH` and `LIFECYCLE.ROUTINE`. - Then the variable `ctx.${NAME_VARIABLE}` will be deleted at + The parameter ``${LIFECYCLE}`` can be chosen from \ + ``LIFECYCLE.BATCH``, ``LIFECYCLE.EPOCH`` and ``LIFECYCLE.ROUTINE``. \ + Then the variable ``ctx.${NAME_VARIABLE}`` will be deleted at \ the end of the corresponding stage - - `LIFECYCLE.BATCH`: the variables will + - ``LIFECYCLE.BATCH``: the variables will \ be deleted after running a batch - - `LIFECYCLE.EPOCH`: the variables will be + - ``LIFECYCLE.EPOCH``: the variables will be \ deleted after running a epoch - - `LIFECYCLE.ROUTINE`: the variables will be + - ``LIFECYCLE.ROUTINE``: the variables will be \ deleted after running a routine More details please refer to our [tutorial](https://federatedscope.io/docs/trainer/). - - Context also maintains some special variables across - different routines, like - - cfg - - model - - data - - device - - ${split}_data: the dataset object of data split - named `${split}` - - ${split}_loader: the data loader object of data - split named `${split}` - - num_${split}_data: the number of examples within - the dataset named `${split}` + We classify and show the default attributes below: + + Data-related attributes + - ``ctx.data``: the raw data (not split) the trainer holds + - ``ctx.num_samples``: the number of samples used in training + - ``ctx.train_data``, ``ctx.val_data``, ``ctx.test_data``: the \ + split data the trainer holds + - ``ctx.train_loader``, ``ctx.val_loader``, ``ctx.test_loader``: \ + the DataLoader of each split data + - ``ctx.num_train_data``, ``ctx.num_val_data``, \ + ``ctx.num_test_data``: the number of samples of the split data \ + Model-related attributes + - ``ctx.model``: the model used + - ``ctx.models``: the multi models if use + - ``ctx.mirrored_models``: the mirrored models + - ``ctx.trainable_para_names``: the trainable parameter names of \ + the model + Optimizer-related attributes + - ``ctx.optimizer``: see ``torch.optim`` + - ``ctx.scheduler``: decays the learning rate of each parameter group + - ``ctx.criterion``: loss/criterion function + - ``ctx.regularizer``: regular terms + - ``ctx.grad_clip``: gradient clipping + Mode-related attributes + - ``ctx.cur_mode``: mode of trainer, which is one of ``['train', \ + 'val', 'test']`` + - ``ctx.mode_stack``: stack of mode, only used for switching mode + - ``ctx.cur_split``: split of data, which is one of ``['train', \ + 'val', 'test']`` (Note: use ``train`` data in ``test`` mode is \ + allowed) + - ``ctx.split_stack``: stack of split, only used for switching data \ + split + Metric-related attributes + - ``ctx.loss_batch_total``: Loss of current batch + - ``ctx.loss_regular_total``: Loss of regular term + - ``ctx.y_true``: true label of batch data + - ``ctx.y_prob``: output of the model with batch data as input + - ``ctx.ys_true``: true label of data + - ``ctx.ys_prob``: output of the model + - ``ctx.eval_metrics``: evaluation metrics calculated by \ + ``ctx.monitor`` + - ``ctx.monitor``: used for monitor trainer's behavior and statistics + Other (statistics) attributes (@property, query from ``cfg`` if not \ + set) + - ``ctx.cfg``: configuration of FL course + - ``ctx.device``: current device, such as ``cpu`` and ``gpu0``. + - ``ctx.num_train_batch_last_epoch``, \ + ``ctx.num_total_train_batch``: the number of batch + - ``ctx.num_train_epoch``, ``ctx.num_val_epoch``, \ + ``ctx.num_test_epoch``: the number of epoch in each data split + - ``ctx.num_train_batch``, ``ctx.num_val_batch``, \ + ``ctx.num_test_batch``: the number of batch in each data split """ - def __init__(self, - model, - cfg, - data=None, - device=None, - init_dict=None, - init_attr=True): - super(Context, self).__init__(init_dict) + def __init__(self, model, cfg, data=None, device=None): + super(Context, self).__init__({}) self.cfg = cfg self.model = model @@ -117,19 +152,15 @@ def __init__(self, self.lifecycles = collections.defaultdict(set) - if init_attr: - # setup static variables for training/evaluation - self.setup_vars() - - def setup_vars(self): + # Setup optimize-related context variable if self.cfg.backend == 'torch': self.trainable_para_names = get_trainable_para_names(self.model) + # TODO: make `criterion` and `regularizer` @property and cached + # to compare whether changes happen self.criterion = get_criterion(self.cfg.criterion.type, self.device) self.regularizer = get_regularizer(self.cfg.regularizer.type) self.grad_clip = self.cfg.grad.grad_clip - if isinstance(self.data, ClientData): - self.data.setup(self.cfg) elif self.cfg.backend == 'tensorflow': self.trainable_para_names = self.model.trainable_variables() self.criterion = None @@ -137,31 +168,91 @@ def setup_vars(self): self.optimizer = None self.grad_clip = None - # Process training data - if self.get('train_data', None) is not None or self.get( - 'train_loader', None) is not None: - # Calculate the number of update steps during training given the - # local_update_steps - self.num_train_batch, self.num_train_batch_last_epoch, \ - self.num_train_epoch, self.num_total_train_batch = \ + # Train related property, query from `cfg` if not set + @property + def num_train_batch(self): + if self.get('num_train_batch'): + return self.get('num_train_batch') + return self._calculate_batch_epoch_num(mode='train')[0] + + @property + def num_train_batch_last_epoch(self): + if self.get('num_train_batch_last_epoch'): + return self.get('num_train_batch_last_epoch') + return self._calculate_batch_epoch_num(mode='train')[1] + + @property + def num_train_epoch(self): + if self.get('num_train_epoch'): + return self.get('num_train_epoch') + return self._calculate_batch_epoch_num(mode='train')[2] + + @property + def num_total_train_batch(self): + if self.get('num_total_train_batch'): + return self.get('num_total_train_batch') + return self._calculate_batch_epoch_num(mode='train')[3] + + # Val related property, query from `cfg` if not set + @property + def num_val_batch(self): + if self.get('num_val_batch'): + return self.get('num_val_batch') + return self._calculate_batch_epoch_num(mode='val')[0] + + @property + def num_val_epoch(self): + if self.get('num_val_epoch'): + return self.get('num_val_epoch') + return self._calculate_batch_epoch_num(mode='val')[2] + + # Test related property, query from `cfg` if not set + @property + def num_test_batch(self): + if self.get('num_test_batch'): + return self.get('num_test_batch') + return self._calculate_batch_epoch_num(mode='test')[0] + + @property + def num_test_epoch(self): + if self.get('num_test_epoch'): + return self.get('num_test_epoch') + return self._calculate_batch_epoch_num(mode='test')[2] + + def _calculate_batch_epoch_num(self, mode='train'): + if self.cur_mode is not None and self.cur_mode != mode: + logger.warning( + f'cur_mode `{self.cur_mode}` mismatch mode `{mode}`, ' + f'will use `{mode}` to calculate `ctx.var`.') + if self.cur_split is None: + logger.warning( + f'cur_split `{self.cur_split}` not found in data_split, ' + f'will use `train` split to calculate `ctx.var`.') + cur_split = 'train' + else: + cur_split = self.cur_split + + num_batch_last_epoch, num_total_batch = None, None + if mode in ['train', 'finetune']: + num_batch, num_batch_last_epoch, num_epoch, num_total_batch = \ calculate_batch_epoch_num( self.cfg.train.local_update_steps, - self.cfg.train.batch_or_epoch, self.num_train_data, + self.cfg.train.batch_or_epoch, + self.get(f'num_{cur_split}_data'), self.cfg.dataloader.batch_size, self.cfg.dataloader.drop_last) + elif mode in ['val', 'test']: + num_epoch = 1 + num_batch = self.get(f'num_{cur_split}_data' + ) // self.cfg.dataloader.batch_size + int( + not self.cfg.dataloader.drop_last + and bool( + self.get(f'num_{cur_split}_data') % + self.cfg.dataloader.batch_size)) + else: + raise ValueError(f'Invalid mode {mode}.') - # Process evaluation data - for mode in ["val", "test"]: - setattr(self, "num_{}_epoch".format(mode), 1) - if self.get("{}_data".format(mode)) is not None or self.get( - "{}_loader".format(mode)) is not None: - setattr( - self, "num_{}_batch".format(mode), - getattr(self, "num_{}_data".format(mode)) // - self.cfg.dataloader.batch_size + - int(not self.cfg.dataloader.drop_last and bool( - getattr(self, "num_{}_data".format(mode)) % - self.cfg.dataloader.batch_size))) + return num_batch, num_batch_last_epoch, num_epoch, num_total_batch def track_mode(self, mode): self.mode_stack.append(mode) @@ -202,7 +293,7 @@ def check_split(self, target_split_name, skip=False): logger.warning( f"No {target_split_name}_data or" f" {target_split_name}_loader in the trainer, " - f"will skip evaluation" + f"will skip evaluation." f"If this is not the case you want, please check " f"whether there is typo for the name") return False @@ -212,25 +303,33 @@ def check_split(self, target_split_name, skip=False): else: return True + def merge_from_dict(self, other_dict): + for key, value in other_dict.items(): + setattr(self, key, value) + class CtxVar(object): - """Basic variable class + """ + Basic variable class + Arguments: lifecycle: specific lifecycle of the attribute """ - LIEFTCYCLES = ["batch", "epoch", "routine", None] + LIFECYCLES = ["batch", "epoch", "routine", None] def __init__(self, obj, lifecycle=None): - assert lifecycle in CtxVar.LIEFTCYCLES + assert lifecycle in CtxVar.LIFECYCLES self.obj = obj self.lifecycle = lifecycle def lifecycle(lifecycle): - """Manage the lifecycle of the variables within context, + """ + Manage the lifecycle of the variables within context, \ and blind these operations from user. - Args: + + Arguments: lifecycle: the type of lifecycle, choose from "batch/epoch/routine" """ if lifecycle == "routine": diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/trainers/enums.py similarity index 100% rename from federatedscope/core/auxiliaries/enums.py rename to federatedscope/core/trainers/enums.py diff --git a/federatedscope/core/trainers/tf_trainer.py b/federatedscope/core/trainers/tf_trainer.py index 549311128..adda494c1 100644 --- a/federatedscope/core/trainers/tf_trainer.py +++ b/federatedscope/core/trainers/tf_trainer.py @@ -1,11 +1,14 @@ -import tensorflow as tf +try: + import tensorflow as tf +except ImportError: + tf = None import numpy as np from federatedscope.core.trainers import Trainer -from federatedscope.core.auxiliaries.enums import MODE +from federatedscope.core.trainers.enums import MODE from federatedscope.core.auxiliaries.utils import batch_iter from federatedscope.core.trainers.context import CtxVar -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE class GeneralTFTrainer(Trainer): @@ -25,7 +28,6 @@ def train(self, target_data_split_name="train", hooks_set=None): def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - """ init_dict = dict() if isinstance(data, dict): @@ -69,6 +71,20 @@ def register_default_hooks_eval(self): self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") def _hook_on_fit_start_init(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Move to `ctx.device` + ``ctx.loss_batch_total`` Initialize to 0 + ``ctx.loss_regular_total`` Initialize to 0 + ``ctx.num_samples`` Initialize to 0 + ``ctx.ys_true`` Initialize to ``[]`` + ``ctx.ys_prob`` Initialize to ``[]`` + ================================== =========================== + """ # prepare model ctx.model.to(ctx.device) @@ -80,11 +96,29 @@ def _hook_on_fit_start_init(self, ctx): ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE) def _hook_on_epoch_start(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.{cur_split}_loader`` Initialize DataLoader + ================================== =========================== + """ # prepare dataloader setattr(ctx, "{}_loader".format(ctx.cur_split), batch_iter(ctx.get("{}_data".format(ctx.cur_split)))) def _hook_on_batch_start_init(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.data_batch`` Initialize batch data + ================================== =========================== + """ # prepare data batch try: ctx.data_batch = next(ctx.get("{}_loader".format(ctx.cur_split))) @@ -92,7 +126,21 @@ def _hook_on_batch_start_init(self, ctx): raise StopIteration def _hook_on_batch_forward(self, ctx): - + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.optimizer`` Initialize optimizer + ``ctx.batch_size`` Calculate batch size + ``ctx.loss_batch`` Calculate batch loss + ``ctx.model`` Forward propagation + ``ctx.y_true`` Get y_true from batch + ``ctx.y_prob`` Forward propagation to get \ + `y_prob` + ================================== =========================== + """ ctx.optimizer = ctx.model.optimizer ctx.batch_size = len(ctx.data_batch) @@ -120,6 +168,19 @@ def _hook_on_batch_backward(self, ctx): pass def _hook_on_batch_end(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.num_samples`` Add ``ctx.batch_size`` + ``ctx.loss_batch_total`` Add batch loss + ``ctx.loss_regular_total`` Add batch regular loss + ``ctx.ys_true`` Append ``ctx.y_true`` + ``ctx.ys_prob`` Append ``ctx.ys_prob`` + ================================== =========================== + """ # TODO: the same with the torch_trainer # update statistics ctx.num_samples += ctx.batch_size @@ -131,12 +192,24 @@ def _hook_on_batch_end(self, ctx): ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy()) def _hook_on_fit_end(self, ctx): - """Evaluate metrics. - + """ + Evaluate metrics. + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.ys_true`` Convert to `numpy.array` + ``ctx.ys_prob`` Convert to `numpy.array` + ``ctx.monitor`` Evaluate the results + ``ctx.eval_metrics`` Get evaluated results from \ + ``ctx.monitor`` + ================================== =========================== """ ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE) - results = self.metric_calculator.eval(ctx) + results = self.ctx.monitor.eval(ctx) setattr(ctx, 'eval_metrics', results) def update(self, model_parameters, strict=False): diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py index a5c2a098a..c343cbabf 100644 --- a/federatedscope/core/trainers/torch_trainer.py +++ b/federatedscope/core/trainers/torch_trainer.py @@ -10,12 +10,12 @@ DataLoader = None Dataset = None -from federatedscope.core.auxiliaries.enums import MODE -from federatedscope.core.auxiliaries.enums import LIFECYCLE -from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer -from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler +from federatedscope.core.trainers.enums import MODE, LIFECYCLE from federatedscope.core.trainers.trainer import Trainer from federatedscope.core.trainers.context import CtxVar +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler +from federatedscope.core.data import ClientData from federatedscope.core.data.wrap_dataset import WrapDataset from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader from federatedscope.core.auxiliaries.ReIterator import ReIterator @@ -32,10 +32,20 @@ def get_model_para(self): self.ctx.model.state_dict() if self.cfg.federate. share_local_model else self.ctx.model.cpu().state_dict()) + def setup_data(self, ctx): + """ + Initialization data by ``cfg``. + """ + if isinstance(ctx.data, ClientData): + ctx.data.setup(ctx.cfg) + else: + logger.warning(f'The data type should be `ClientData` to ' + f'enable new `config`, but got ' + f'{type(ctx.data)} instead.') + def parse_data(self, data): """Populate "${split}_data", "${split}_loader" and "num_${ split}_data" for different data splits - """ init_dict = dict() if isinstance(data, dict): @@ -135,6 +145,22 @@ def register_default_hooks_eval(self): self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") def _hook_on_fit_start_init(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Move to ``ctx.device`` + ``ctx.optimizer`` Initialize by ``ctx.cfg`` + ``ctx.scheduler`` Initialize by ``ctx.cfg`` + ``ctx.loss_batch_total`` Initialize to 0 + ``ctx.loss_regular_total`` Initialize to 0 + ``ctx.num_samples`` Initialize to 0 + ``ctx.ys_true`` Initialize to ``[]`` + ``ctx.ys_prob`` Initialize to ``[]`` + ================================== =========================== + """ # prepare model and optimizer ctx.model.to(ctx.device) @@ -158,17 +184,35 @@ def _hook_on_fit_start_init(self, ctx): ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE) def _hook_on_fit_start_calculate_model_size(self, ctx): - if not isinstance(self.ctx.monitor, Monitor): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Track model size + ================================== =========================== + """ + if not isinstance(ctx.monitor, Monitor): logger.warning( f"The trainer {type(self)} does contain a valid monitor, " f"this may be caused by initializing trainer subclasses " f"without passing a valid monitor instance." f"Plz check whether this is you want.") return - if self.ctx.monitor.total_model_size == 0: - self.ctx.monitor.track_model_size(ctx.models) + if ctx.monitor.total_model_size == 0: + ctx.monitor.track_model_size(ctx.models) def _hook_on_epoch_start(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.{ctx.cur_split}_loader`` Initialize DataLoader + ================================== =========================== + """ # prepare dataloader if ctx.get("{}_loader".format(ctx.cur_split)) is None: loader = get_dataloader( @@ -183,6 +227,15 @@ def _hook_on_epoch_start(self, ctx): ctx.get("{}_loader".format(ctx.cur_split)).reset() def _hook_on_batch_start_init(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.data_batch`` Initialize batch data + ================================== =========================== + """ # prepare data batch try: ctx.data_batch = CtxVar( @@ -192,6 +245,18 @@ def _hook_on_batch_start_init(self, ctx): raise StopIteration def _hook_on_batch_forward(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.y_true`` Move to `ctx.device` + ``ctx.y_prob`` Forward propagation get y_prob + ``ctx.loss_batch`` Calculate the loss + ``ctx.batch_size`` Get the batch_size + ================================== =========================== + """ x, label = [_.to(ctx.device) for _ in ctx.data_batch] pred = ctx.model(x) if len(label.size()) == 0: @@ -204,25 +269,29 @@ def _hook_on_batch_forward(self, ctx): def _hook_on_batch_forward_flop_count(self, ctx): """ - the monitoring hook to calculate the flops during the fl course - - Note: for customized cases that the forward process is not only - based on ctx.model, please override this function (inheritance - case) or replace this hook (plug-in case) - - :param ctx: - :return: + The monitoring hook to calculate the flops during the fl course + + Note: + For customized cases that the forward process is not only \ + based on ctx.model, please override this function (inheritance \ + case) or replace this hook (plug-in case) + + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Track average flops + ================================== =========================== """ - if not isinstance(self.ctx.monitor, Monitor): + if not isinstance(ctx.monitor, Monitor): logger.warning( f"The trainer {type(self)} does contain a valid monitor, " f"this may be caused by initializing trainer subclasses " f"without passing a valid monitor instance." - f"Plz check whether this is you want.") + f"Please check whether this is you want.") return - if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ - == 0: + if self.cfg.eval.count_flops and ctx.monitor.flops_per_sample == 0: # calculate the flops_per_sample try: x, y = [_.to(ctx.device) for _ in ctx.data_batch] @@ -235,9 +304,9 @@ def _hook_on_batch_forward_flop_count(self, ctx): "by internal model nums as self.mirrored_models=True." "if this is not the case you want, " "please customize the count hook") - self.ctx.monitor.track_avg_flops(flops_one_batch, - ctx.batch_size) + ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) except: + # Raise warning at the first failure logger.warning( "current flop count implementation is for general " "trainer case: " @@ -245,21 +314,42 @@ def _hook_on_batch_forward_flop_count(self, ctx): "2) the ctx.model takes only x as input." "Please check the forward format or implement your own " "flop_count function") - self.ctx.monitor.flops_per_sample = -1 # warning at the - # first failure + ctx.monitor.flops_per_sample = -1 # by default, we assume the data has the same input shape, # thus simply multiply the flops to avoid redundant forward - self.ctx.monitor.total_flops +=\ - self.ctx.monitor.flops_per_sample * ctx.batch_size + ctx.monitor.total_flops += ctx.monitor.flops_per_sample * \ + ctx.batch_size def _hook_on_batch_forward_regularizer(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.loss_regular`` Calculate the regular loss + ``ctx.loss_task`` Sum the ``ctx.loss_regular`` \ + and ``ctx.loss`` + ================================== =========================== + """ ctx.loss_regular = CtxVar( self.cfg.regularizer.mu * ctx.regularizer(ctx), LIFECYCLE.BATCH) ctx.loss_task = CtxVar(ctx.loss_batch + ctx.loss_regular, LIFECYCLE.BATCH) def _hook_on_batch_backward(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.optimizer`` Update by gradient + ``ctx.loss_task`` Backward propagation + ``ctx.scheduler`` Update by gradient + ================================== =========================== + """ ctx.optimizer.zero_grad() ctx.loss_task.backward() if ctx.grad_clip > 0: @@ -271,6 +361,19 @@ def _hook_on_batch_backward(self, ctx): ctx.scheduler.step() def _hook_on_batch_end(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.num_samples`` Add ``ctx.batch_size`` + ``ctx.loss_batch_total`` Add batch loss + ``ctx.loss_regular_total`` Add batch regular loss + ``ctx.ys_true`` Append ``ctx.y_true`` + ``ctx.ys_prob`` Append ``ctx.ys_prob`` + ================================== =========================== + """ # update statistics ctx.num_samples += ctx.batch_size ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size @@ -280,12 +383,24 @@ def _hook_on_batch_end(self, ctx): ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy()) def _hook_on_fit_end(self, ctx): - """Evaluate metrics. - + """ + Evaluate metrics. + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.ys_true`` Convert to ``numpy.array`` + ``ctx.ys_prob`` Convert to ``numpy.array`` + ``ctx.monitor`` Evaluate the results + ``ctx.eval_metrics`` Get evaluated results from \ + ``ctx.monitor`` + ================================== =========================== """ ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE) - results = self.metric_calculator.eval(ctx) + results = ctx.monitor.eval(ctx) setattr(ctx, 'eval_metrics', results) def save_model(self, path, cur_round=-1): @@ -305,8 +420,8 @@ def load_model(self, path): raise ValueError("The file {} does NOT exist".format(path)) def discharge_model(self): - """Discharge the model from GPU device - + """ + Discharge the model from GPU device """ # Avoid memory leak if not self.cfg.federate.share_local_model: diff --git a/federatedscope/core/trainers/trainer.py b/federatedscope/core/trainers/trainer.py index 409b6631f..ad63c02aa 100644 --- a/federatedscope/core/trainers/trainer.py +++ b/federatedscope/core/trainers/trainer.py @@ -3,23 +3,11 @@ import logging from federatedscope.core.trainers.base_trainer import BaseTrainer -from federatedscope.core.auxiliaries.enums import MODE -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import MODE, LIFECYCLE from federatedscope.core.auxiliaries.decorators import use_diff -from federatedscope.core.auxiliaries.utils import format_log_hooks -from federatedscope.core.auxiliaries.utils import filter_by_specified_keywords -from federatedscope.core.trainers.context import Context -from federatedscope.core.trainers.context import CtxVar -from federatedscope.core.trainers.context import lifecycle -from federatedscope.core.monitors.metric_calculator import MetricCalculator - -try: - import torch - from torch.utils.data import DataLoader, Dataset -except ImportError: - torch = None - DataLoader = None - Dataset = None +from federatedscope.core.trainers.utils import format_log_hooks, \ + filter_by_specified_keywords +from federatedscope.core.trainers.context import Context, CtxVar, lifecycle logger = logging.getLogger(__name__) @@ -41,18 +29,15 @@ def __init__(self, config, only_for_eval=False, monitor=None): - self.cfg = config - self.metric_calculator = MetricCalculator(config.eval.metrics) - - self.ctx = Context(model, - self.cfg, - data, - device, - init_dict=self.parse_data(data)) - - if monitor is None: - logger.warning( - f"Will not use monitor in trainer with class {type(self)}") + self._cfg = config + + self.ctx = Context(model, self.cfg, data, device) + + # Parse data and setup init vars in ctx + self._setup_data_related_var_in_ctx(self.ctx) + + assert monitor is not None, \ + f"Monitor not found in trainer with class {type(self)}" self.ctx.monitor = monitor # the "model_nums", and "models" are used for multi-model case and # model size calculation @@ -85,9 +70,38 @@ def __init__(self, # once for better logs readability pass + @property + def cfg(self): + return self._cfg + + @cfg.setter + def cfg(self, new_cfg): + self._cfg = new_cfg + self._setup_data_related_var_in_ctx(self.ctx) + def parse_data(self, data): + """ + Populate ``${split}_data``, ``${split}_loader`` and \ + ``num_${split}_data`` for different data splits + """ + raise NotImplementedError + + def setup_data(self, ctx): + """ + Initialization data by ``cfg``. + """ pass + def _setup_data_related_var_in_ctx(self, ctx): + """ + Populate ``${split}_data``, ``${split}_loader`` and \ + ``num_${split}_data`` for different data splits, and setup init var \ + in ctx. + """ + self.setup_data(ctx) + init_dict = self.parse_data(ctx.data) + ctx.merge_from_dict(init_dict) + def register_default_hooks_train(self): pass @@ -265,7 +279,8 @@ def _run_routine(self, mode, hooks_set, dataset_name=None): @lifecycle(LIFECYCLE.EPOCH) def _run_epoch(self, hooks_set): - for epoch_i in range(self.ctx.get(f"num_{self.ctx.cur_split}_epoch")): + for epoch_i in range( + getattr(self.ctx, f"num_{self.ctx.cur_split}_epoch")): self.ctx.cur_epoch_i = CtxVar(epoch_i, "epoch") for hook in hooks_set["on_epoch_start"]: @@ -278,7 +293,8 @@ def _run_epoch(self, hooks_set): @lifecycle(LIFECYCLE.BATCH) def _run_batch(self, hooks_set): - for batch_i in range(self.ctx.get(f"num_{self.ctx.cur_split}_batch")): + for batch_i in range( + getattr(self.ctx, f"num_{self.ctx.cur_split}_batch")): self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH) for hook in hooks_set["on_batch_start"]: @@ -301,26 +317,26 @@ def _run_batch(self, hooks_set): break def update(self, model_parameters, strict=False): - ''' + """ Called by the FL client to update the model parameters Arguments: model_parameters (dict): {model_name: model_val} strict (bool): ensure the k-v paris are strictly same - ''' + """ pass def get_model_para(self): - ''' + """ :return: model_parameters (dict): {model_name: model_val} - ''' + """ pass def print_trainer_meta_info(self): - ''' + """ print some meta info for code-users, e.g., model type; the para names will be filtered out, etc., - ''' + """ logger.info(f"Model meta-info: {type(self.ctx.model)}.") logger.debug(f"Model meta-info: {self.ctx.model}.") # logger.info(f"Data meta-info: {self.ctx['data']}.") @@ -348,7 +364,7 @@ def print_trainer_meta_info(self): t{format_log_hooks(self.hooks_in_eval)}") def _param_filter(self, state_dict, filter_keywords=None): - ''' + """ model parameter filter when transmit between local and gloabl, which is useful in personalization. e.g., setting cfg.personalization.local_param= ['bn', 'norms'] @@ -362,7 +378,7 @@ def _param_filter(self, state_dict, filter_keywords=None): Returns: state_dict (dict): remove the keys that match any of the given keywords. - ''' + """ if self.cfg.federate.method in ["local", "global"]: return {} diff --git a/federatedscope/core/trainers/trainer_Ditto.py b/federatedscope/core/trainers/trainer_Ditto.py index e4e4a0200..f30b50f46 100644 --- a/federatedscope/core/trainers/trainer_Ditto.py +++ b/federatedscope/core/trainers/trainer_Ditto.py @@ -6,7 +6,7 @@ from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from federatedscope.core.optimizer import wrap_regularized_optimizer -from federatedscope.core.auxiliaries.utils import calculate_batch_epoch_num +from federatedscope.core.trainers.utils import calculate_batch_epoch_num from typing import Type logger = logging.getLogger(__name__) @@ -34,16 +34,17 @@ def wrap_DittoTrainer( trigger='on_fit_start', insert_pos=-1) base_trainer.register_hook_in_train( - new_hook=hook_on_fit_start_set_regularized_para, + new_hook=_hook_on_fit_start_set_regularized_para, trigger="on_fit_start", insert_pos=0) base_trainer.register_hook_in_train( - new_hook=hook_on_batch_start_switch_model, + new_hook=_hook_on_batch_start_switch_model, trigger="on_batch_start", insert_pos=0) - base_trainer.register_hook_in_train(new_hook=hook_on_batch_forward_cnt_num, - trigger="on_batch_forward", - insert_pos=-1) + base_trainer.register_hook_in_train( + new_hook=_hook_on_batch_forward_cnt_num, + trigger="on_batch_forward", + insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count, trigger="on_batch_end", insert_pos=-1) @@ -52,18 +53,18 @@ def wrap_DittoTrainer( insert_pos=-1) # evaluation is based on the local personalized model base_trainer.register_hook_in_eval( - new_hook=hook_on_fit_start_switch_local_model, + new_hook=_hook_on_fit_start_switch_local_model, trigger="on_fit_start", insert_pos=0) base_trainer.register_hook_in_eval( - new_hook=hook_on_fit_end_switch_global_model, + new_hook=_hook_on_fit_end_switch_global_model, trigger="on_fit_end", insert_pos=-1) - base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_free_cuda, + base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_free_cuda, trigger="on_fit_end", insert_pos=-1) - base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_free_cuda, + base_trainer.register_hook_in_eval(new_hook=_hook_on_fit_end_free_cuda, trigger="on_fit_end", insert_pos=-1) @@ -117,7 +118,23 @@ def init_Ditto_ctx(base_trainer): ctx.num_train_epoch += ctx.num_train_epoch_for_local_model -def hook_on_fit_start_set_regularized_para(ctx): +def _hook_on_fit_start_set_regularized_para(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.global_model`` Move to ``ctx.device`` and set \ + to ``train`` mode + ``ctx.local_model`` Move to ``ctx.device`` and set \ + to ``train`` mode + ``ctx.optimizer_for_global_model`` Initialize by ``ctx.cfg`` and \ + wrapped by ``wrap_regularized_optimizer`` + ``ctx.optimizer_for_local_model`` Initialize by ``ctx.cfg`` and \ + set compared parameter group + ================================== =========================== + """ # set the compared model data for local personalized model ctx.global_model.to(ctx.device) ctx.local_model.to(ctx.device) @@ -140,12 +157,34 @@ def hook_on_fit_start_set_regularized_para(ctx): def _hook_on_fit_start_clean(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.optimizer`` Delete + ``ctx.num_..._local_model_train`` Initialize to 0 + ================================== =========================== + """ # remove the unnecessary optimizer del ctx.optimizer ctx.num_samples_local_model_train = 0 def _hook_on_fit_end_calibrate(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.num_samples`` Minus \ + ``ctx.num_samples_local_model_train`` + ``ctx.eval_metrics`` Record ``train_total`` and \ + ``train_total_local_model`` + ================================== =========================== + """ # make the num_samples_train only related to the global model. # (num_samples_train will be used in aggregation process) ctx.num_samples -= ctx.num_samples_local_model_train @@ -155,17 +194,48 @@ def _hook_on_fit_end_calibrate(ctx): def _hook_on_batch_end_flop_count(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Monitor total flops + ================================== =========================== + """ # besides the normal forward flops, the regularization adds the cost of # number of model parameters ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 -def hook_on_batch_forward_cnt_num(ctx): +def _hook_on_batch_forward_cnt_num(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.num_..._local_model_train`` Add `ctx.batch_size` + ================================== =========================== + """ if ctx.use_local_model_current: ctx.num_samples_local_model_train += ctx.batch_size -def hook_on_batch_start_switch_model(ctx): +def _hook_on_batch_start_switch_model(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.use_local_model_current`` Set to ``True`` or ``False`` + ``ctx.model`` Set to ``ctx.local_model`` or \ + ``ctx.global_model`` + ``ctx.optimizer`` Set to \ + ``ctx.optimizer_for_local_model`` or ``ctx.optimizer_for_global_model`` + ================================== =========================== + """ if ctx.cfg.train.batch_or_epoch == 'batch': if ctx.cur_epoch_i == (ctx.num_train_epoch - 1): ctx.use_local_model_current = \ @@ -205,15 +275,44 @@ def hook_on_batch_start_switch_model(ctx): # ctx.model = ctx.global_model -def hook_on_fit_start_switch_local_model(ctx): +def _hook_on_fit_start_switch_local_model(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Set to ``ctx.local_model`` and \ + set to ``eval`` mode + ================================== =========================== + """ ctx.model = ctx.local_model ctx.model.eval() -def hook_on_fit_end_switch_global_model(ctx): +def _hook_on_fit_end_switch_global_model(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model `` Set to ``ctx.global_model`` + ================================== =========================== + """ ctx.model = ctx.global_model -def hook_on_fit_end_free_cuda(ctx): +def _hook_on_fit_end_free_cuda(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.global_model`` Move to ``cpu`` + ``ctx.locol_model`` Move to ``cpu`` + ================================== =========================== + """ ctx.global_model.to(torch.device("cpu")) ctx.local_model.to(torch.device("cpu")) diff --git a/federatedscope/core/trainers/trainer_FedEM.py b/federatedscope/core/trainers/trainer_FedEM.py index e6da6df63..2a391d3f9 100644 --- a/federatedscope/core/trainers/trainer_FedEM.py +++ b/federatedscope/core/trainers/trainer_FedEM.py @@ -4,7 +4,7 @@ import torch from torch.nn.functional import softmax as f_softmax -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE from federatedscope.core.trainers.context import CtxVar from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from federatedscope.core.trainers.trainer_multi_model import \ @@ -13,8 +13,8 @@ class FedEMTrainer(GeneralMultiModelTrainer): """ - The FedEM implementation, "Federated Multi-Task Learning under a - Mixture of Distributions (NeurIPS 2021)" + The FedEM implementation, "Federated Multi-Task Learning under a \ + Mixture of Distributions (NeurIPS 2021)" \ based on the Algorithm 1 in their paper and official codes: https://github.com/omarfoq/FedEM """ @@ -60,7 +60,7 @@ def register_multiple_model_hooks(self): # First register hooks for model 0 # ---------------- train hooks ----------------------- self.register_hook_in_train( - new_hook=self.hook_on_fit_start_mixture_weights_update, + new_hook=self._hook_on_fit_start_mixture_weights_update, trigger="on_fit_start", insert_pos=0) # insert at the front self.register_hook_in_train( @@ -72,21 +72,21 @@ def register_multiple_model_hooks(self): trigger="on_fit_end", insert_pos=-1) self.register_hook_in_train( - new_hook=self.hook_on_batch_forward_weighted_loss, + new_hook=self._hook_on_batch_forward_weighted_loss, trigger="on_batch_forward", insert_pos=-1) self.register_hook_in_train( - new_hook=self.hook_on_batch_start_track_batch_idx, + new_hook=self._hook_on_batch_start_track_batch_idx, trigger="on_batch_start", insert_pos=0) # insert at the front # ---------------- eval hooks ----------------------- self.register_hook_in_eval( - new_hook=self.hook_on_batch_end_gather_loss, + new_hook=self._hook_on_batch_end_gather_loss, trigger="on_batch_end", insert_pos=0 ) # insert at the front, (we need gather the loss before clean it) self.register_hook_in_eval( - new_hook=self.hook_on_batch_start_track_batch_idx, + new_hook=self._hook_on_batch_start_track_batch_idx, trigger="on_batch_start", insert_pos=0) # insert at the front # replace the original evaluation into the ensemble one @@ -106,30 +106,67 @@ def register_multiple_model_hooks(self): for _ in range(1, self.model_nums) ]) - def hook_on_batch_start_track_batch_idx(self, ctx): + def _hook_on_batch_start_track_batch_idx(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.optimizer_for_global_model`` False + ================================== =========================== + """ # for both train & eval ctx.cur_batch_idx = (self.ctx.cur_batch_idx + 1) % self.ctx.num_train_batch - def hook_on_batch_forward_weighted_loss(self, ctx): + def _hook_on_batch_forward_weighted_loss(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.loss_batch`` Multiply by \ + ``weights_internal_models`` + ================================== =========================== + """ # for only train ctx.loss_batch *= self.weights_internal_models[ctx.cur_model_idx] - def hook_on_batch_end_gather_loss(self, ctx): + def _hook_on_batch_end_gather_loss(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.all_losses_model_batch`` Gather loss + ================================== =========================== + """ # for only eval # before clean the loss_batch; we record it # for further weights_data_sample update ctx.all_losses_model_batch[ctx.cur_model_idx][ ctx.cur_batch_idx] = ctx.loss_batch.item() - def hook_on_fit_start_mixture_weights_update(self, ctx): + def _hook_on_fit_start_mixture_weights_update(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.mode`` Evaluate + ================================== =========================== + """ # for only train if ctx.cur_model_idx != 0: # do the mixture_weights_update once pass else: # gathers losses for all sample in iterator - # for each internal model, calling *evaluate()* + # for each internal model, calling `evaluate()` for model_idx in range(self.model_nums): self._switch_model_ctx(model_idx) self.evaluate(target_data_split_name="train") @@ -144,16 +181,46 @@ def hook_on_fit_start_mixture_weights_update(self, ctx): self._switch_model_ctx(0) def _hook_on_fit_start_flop_count(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Count total_flops + ================================== =========================== + """ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ - self.model_nums * ctx.num_train_data + self.model_nums * ctx.num_train_data def _hook_on_fit_end_flop_count(self, ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Count total_flops + ================================== =========================== + """ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ - self.model_nums * ctx.num_train_data + self.model_nums * ctx.num_train_data def _hook_on_fit_end_ensemble_eval(self, ctx): """ - Ensemble evaluation + Ensemble evaluation + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.ys_prob_ensemble`` Ensemble ys_prob + ``ctx.ys_true`` Concatenate results + ``ctx.ys_prob`` Concatenate results + ``ctx.eval_metrics`` Get evaluated results from \ + ``ctx.monitor`` + ================================== =========================== """ if ctx.get("ys_prob_ensemble", None) is None: ctx.ys_prob_ensemble = CtxVar(0, LIFECYCLE.ROUTINE) @@ -166,4 +233,4 @@ def _hook_on_fit_end_ensemble_eval(self, ctx): ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_prob = ctx.ys_prob_ensemble - ctx.eval_metrics = self.metric_calculator.eval(ctx) + ctx.eval_metrics = self.ctx.monitor.eval(ctx) diff --git a/federatedscope/core/trainers/trainer_fedprox.py b/federatedscope/core/trainers/trainer_fedprox.py index 89e02da65..c5cacf274 100644 --- a/federatedscope/core/trainers/trainer_fedprox.py +++ b/federatedscope/core/trainers/trainer_fedprox.py @@ -1,6 +1,8 @@ -from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from typing import Type -from copy import deepcopy + +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +from federatedscope.core.trainers.trainer_nbafl import \ + _hook_record_initialization, _hook_del_initialization def wrap_fedprox_trainer( @@ -16,19 +18,19 @@ def wrap_fedprox_trainer( init_fedprox_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- - base_trainer.register_hook_in_train(new_hook=record_initialization, + base_trainer.register_hook_in_train(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) - base_trainer.register_hook_in_eval(new_hook=record_initialization, + base_trainer.register_hook_in_eval(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) - base_trainer.register_hook_in_train(new_hook=del_initialization, + base_trainer.register_hook_in_train(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) - base_trainer.register_hook_in_eval(new_hook=del_initialization, + base_trainer.register_hook_in_eval(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) @@ -50,24 +52,3 @@ def init_fedprox_ctx(base_trainer): from federatedscope.core.auxiliaries.regularizer_builder import \ get_regularizer ctx.regularizer = get_regularizer(cfg.regularizer.type) - - -# ---------------------------------------------------------------------- # -# Additional functions for FedProx algorithm -# ---------------------------------------------------------------------- # - - -# Trainer -def record_initialization(ctx): - """Record the initialized weights within local updates - - """ - ctx.weight_init = deepcopy( - [_.data.detach() for _ in ctx.model.parameters()]) - - -def del_initialization(ctx): - """Clear the variable to avoid memory leakage - - """ - ctx.weight_init = None diff --git a/federatedscope/core/trainers/trainer_multi_model.py b/federatedscope/core/trainers/trainer_multi_model.py index b2fe60ce7..d230070ed 100644 --- a/federatedscope/core/trainers/trainer_multi_model.py +++ b/federatedscope/core/trainers/trainer_multi_model.py @@ -121,22 +121,21 @@ def init_multiple_models(self): def register_multiple_model_hooks(self): """ - By default, all internal models adopt the same hook_set. - ========================= Extension ============================= - Users can override this function to register customized hooks - for different internal models. - - Note: - for sequential mode, users can append interact_hook on - begin/end triggers such as - " -> (on_fit_end, _interact_to_other_models) -> " + By default, all internal models adopt the same hook_set. - for parallel mode, users can append interact_hook on any - trigger they want such as - " -> (on_xxx_point, _interact_to_other_models) -> " + Extension + Users can override this function to register customized hooks \ + for different internal models. - self.ctx, we must tell the running hooks which data_loader to - call and which num_samples to count + Note: + - for sequential mode, users can append interact_hook on \ + begin/end triggers such as \ + " -> (on_fit_end, _interact_to_other_models) -> " + - for parallel mode, users can append interact_hook on any \ + trigger they want such as \ + " -> (on_xxx_point, _interact_to_other_models) -> " + - we must tell the running hooks which data_loader to \ + call and which num_samples to count """ self.hooks_in_train_multiple_models.extend([ @@ -217,9 +216,9 @@ def _run_routine(self, mode, hooks_set, dataset_name=None): Note: Considering evaluation could be in ```hooks_set[ - "on_epoch_end"]```, there could be two data loaders in - self.ctx, we must tell the running hooks which data_loader to call - and which num_samples to count + "on_epoch_end"]```, there could be two data loaders in \ + self.ctx, we must tell the running hooks which data_loader to \ + call and which num_samples to count """ num_samples_model = list() diff --git a/federatedscope/core/trainers/trainer_nbafl.py b/federatedscope/core/trainers/trainer_nbafl.py index 53959b2fe..27bfe5f41 100644 --- a/federatedscope/core/trainers/trainer_nbafl.py +++ b/federatedscope/core/trainers/trainer_nbafl.py @@ -1,4 +1,4 @@ -from federatedscope.core.auxiliaries.utils import get_random +from federatedscope.core.trainers.utils import get_random from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from typing import Type from copy import deepcopy @@ -24,23 +24,23 @@ def wrap_nbafl_trainer( init_nbafl_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- - base_trainer.register_hook_in_train(new_hook=record_initialization, + base_trainer.register_hook_in_train(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) - base_trainer.register_hook_in_eval(new_hook=record_initialization, + base_trainer.register_hook_in_eval(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) - base_trainer.register_hook_in_train(new_hook=del_initialization, + base_trainer.register_hook_in_train(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) - base_trainer.register_hook_in_eval(new_hook=del_initialization, + base_trainer.register_hook_in_eval(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) - base_trainer.register_hook_in_train(new_hook=inject_noise_in_upload, + base_trainer.register_hook_in_train(new_hook=_hook_inject_noise_in_upload, trigger='on_fit_end', insert_pos=-1) return base_trainer @@ -78,24 +78,48 @@ def init_nbafl_ctx(base_trainer): # Trainer -def record_initialization(ctx): - """Record the initialized weights within local updates - +def _hook_record_initialization(ctx): + """ + Record the initialized weights within local updates + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.weight_init`` Copy from `ctx.model` + ================================== =========================== """ ctx.weight_init = deepcopy( [_.data.detach() for _ in ctx.model.parameters()]) -def del_initialization(ctx): - """Clear the variable to avoid memory leakage - +def _hook_del_initialization(ctx): + """ + Clear the variable to avoid memory leakage + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.weight_init`` Set to `None` + ================================== =========================== """ ctx.weight_init = None -def inject_noise_in_upload(ctx): - """Inject noise into weights before the client upload them to server - +def _hook_inject_noise_in_upload(ctx): + """ + Inject noise into weights before the client upload them to server + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Inject noise to parameters + ================================== =========================== """ for p in ctx.model.parameters(): noise = get_random("Normal", p.shape, { diff --git a/federatedscope/core/trainers/trainer_pFedMe.py b/federatedscope/core/trainers/trainer_pFedMe.py index 601429d1d..dac1e81f0 100644 --- a/federatedscope/core/trainers/trainer_pFedMe.py +++ b/federatedscope/core/trainers/trainer_pFedMe.py @@ -22,14 +22,14 @@ def wrap_pFedMeTrainer( # ---------------- action-level plug-in ----------------------- base_trainer.register_hook_in_train( - new_hook=hook_on_fit_start_set_local_para_tmp, + new_hook=_hook_on_fit_start_set_local_para_tmp, trigger="on_fit_start", insert_pos=-1) base_trainer.register_hook_in_train( - new_hook=hook_on_epoch_end_update_local, + new_hook=_hook_on_epoch_end_update_local, trigger="on_epoch_end", insert_pos=-1) - base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_update_local, + base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_update_local, trigger="on_fit_end", insert_pos=-1) @@ -47,7 +47,7 @@ def wrap_pFedMeTrainer( base_trainer.hooks_in_train["on_batch_start"] # 2) replace the original hooks for "on_batch_start" base_trainer.replace_hook_in_train( - new_hook=hook_on_batch_start_init_pfedme, + new_hook=_hook_on_batch_start_init_pfedme, target_trigger="on_batch_start", target_hook_name=None) @@ -59,6 +59,15 @@ def init_pFedMe_ctx(base_trainer): init necessary attributes used in pFedMe, some new attributes will be with prefix `pFedMe` optimizer to avoid namespace pollution + + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.optimizer_for_global_model`` False + ================================== =========================== + """ ctx = base_trainer.ctx cfg = base_trainer.cfg @@ -75,7 +84,20 @@ def init_pFedMe_ctx(base_trainer): ctx.pFedMe_local_model_tmp = None -def hook_on_fit_start_set_local_para_tmp(ctx): +def _hook_on_fit_start_set_local_para_tmp(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.optimizer`` Wrapped by \ + ``wrap_regularized_optimizer`` and set compared parameter group + ``ctx.pFedMe_outer_lr`` Initialize to \ + ``ctx.cfg.train.optimizer.lr`` + ``ctx.pFedMe_local_model_tmp`` Copy from ``ctx.model`` + ================================== =========================== + """ # the optimizer used in pFedMe is based on Moreau Envelopes regularization # besides, there are two distinct lr for the approximate model and base # model @@ -94,7 +116,17 @@ def hook_on_fit_start_set_local_para_tmp(ctx): ctx.optimizer.set_compared_para_group(compared_global_model_para) -def hook_on_batch_start_init_pfedme(ctx): +def _hook_on_batch_start_init_pfedme(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.data_batch_cache`` Copy from ``ctx.data_batch`` + ``ctx.pFedMe_approx_fit_counter`` Count to refresh data every K step + ================================== =========================== + """ # refresh data every K step if ctx.pFedMe_approx_fit_counter == 0: if ctx.cur_mode == "train": @@ -113,17 +145,46 @@ def hook_on_batch_start_init_pfedme(ctx): def _hook_on_batch_end_flop_count(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Monitor total flops + ================================== =========================== + """ # besides the normal forward flops, pFedMe introduces # 1) the regularization adds the cost of number of model parameters ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 def _hook_on_epoch_end_flop_count(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.monitor`` Monitor total flops + ================================== =========================== + """ # due to the local weight updating ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 -def hook_on_epoch_end_update_local(ctx): +def _hook_on_epoch_end_update_local(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Update parameters by \ + ``ctx.pFedMe_local_model_tmp`` + ``ctx.optimizer`` Set compared parameter group + ================================== =========================== + """ # update local weight after finding approximate theta for client_param, local_para_tmp in zip( ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()): @@ -140,7 +201,18 @@ def hook_on_epoch_end_update_local(ctx): ctx.optimizer.set_compared_para_group(compared_global_model_para) -def hook_on_fit_end_update_local(ctx): +def _hook_on_fit_end_update_local(ctx): + """ + Note: + The modified attributes and according operations are shown below: + ================================== =========================== + Attribute Operation + ================================== =========================== + ``ctx.model`` Update parameters by + ``ctx.pFedMe_local_model_tmp`` + ``ctx.pFedMe_local_model_tmp`` Delete + ================================== =========================== + """ for param, local_para_tmp in zip(ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()): param.data = local_para_tmp.data diff --git a/federatedscope/core/trainers/utils.py b/federatedscope/core/trainers/utils.py new file mode 100644 index 000000000..d34a2d30b --- /dev/null +++ b/federatedscope/core/trainers/utils.py @@ -0,0 +1,83 @@ +import collections +import json +import math + + +def format_log_hooks(hooks_set): + def format_dict(target_dict): + print_dict = collections.defaultdict(list) + for k, v in target_dict.items(): + for element in v: + print_dict[k].append(element.__name__) + return print_dict + + if isinstance(hooks_set, list): + print_obj = [format_dict(_) for _ in hooks_set] + elif isinstance(hooks_set, dict): + print_obj = format_dict(hooks_set) + return json.dumps(print_obj, indent=2).replace('\n', '\n\t') + + +def filter_by_specified_keywords(param_name, filter_keywords): + """ + Arguments: + param_name (str): parameter name. + Returns: + preserve (bool): whether to preserve this parameter. + """ + preserve = True + for kw in filter_keywords: + if kw in param_name: + preserve = False + break + return preserve + + +def move_to(obj, device): + import torch + if torch.is_tensor(obj): + return obj.to(device) + elif isinstance(obj, dict): + res = {} + for k, v in obj.items(): + res[k] = move_to(v, device) + return res + elif isinstance(obj, list): + res = [] + for v in obj: + res.append(move_to(v, device)) + return res + else: + raise TypeError("Invalid type for move_to") + + +def get_random(dis_type, sample_shape, params, device): + import torch.distributions as distributions + if not hasattr(distributions, dis_type): + raise NotImplementedError("Distribution {} is not implemented, " + "please refer to ```torch.distributions```" + "(https://pytorch.org/docs/stable/ " + "distributions.html).".format(dis_type)) + generator = getattr(distributions, dis_type)(**params) + return generator.sample(sample_shape=sample_shape).to(device) + + +def calculate_batch_epoch_num(steps, batch_or_epoch, num_data, batch_size, + drop_last): + num_batch_per_epoch = num_data // batch_size + int( + not drop_last and bool(num_data % batch_size)) + if num_batch_per_epoch == 0: + raise RuntimeError( + "The number of batch is 0, please check 'batch_size' or set " + "'drop_last' as False") + elif batch_or_epoch == "epoch": + num_epoch = steps + num_batch_last_epoch = num_batch_per_epoch + num_total_batch = steps * num_batch_per_epoch + else: + num_epoch = math.ceil(steps / num_batch_per_epoch) + num_batch_last_epoch = steps % num_batch_per_epoch or \ + num_batch_per_epoch + num_total_batch = steps + return num_batch_per_epoch, num_batch_last_epoch, num_epoch, \ + num_total_batch diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index 7c145b99f..e08fff908 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -10,7 +10,7 @@ from federatedscope.core.workers import Worker from federatedscope.core.auxiliaries.trainer_builder import get_trainer from federatedscope.core.secret_sharing import AdditiveSecretSharing -from federatedscope.core.auxiliaries.utils import merge_dict, \ +from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \ calculate_time_cost from federatedscope.core.workers.base_client import BaseClient @@ -84,7 +84,7 @@ def __init__(self, self.early_stopper = EarlyStopper( patience, self._cfg.early_stop.delta, self._cfg.early_stop.improve_indicator_mode, - self._cfg.early_stop.the_smaller_the_better) + self._monitor.the_larger_the_better) # Secret Sharing Manager and message buffer self.ss_manager = AdditiveSecretSharing( @@ -468,13 +468,10 @@ def callback_funcs_for_evaluate(self, message: Message): role='Client #{}'.format(self.ID), forms='raw', return_raw=True) - self._monitor.update_best_result( - self.best_results, - formatted_eval_res['Results_raw'], - results_type=f"client #{self.ID}", - round_wise_update_key=self._cfg.eval. - best_res_update_round_wise_key) - self.history_results = merge_dict( + self._monitor.update_best_result(self.best_results, + formatted_eval_res['Results_raw'], + results_type=f"client #{self.ID}") + self.history_results = merge_dict_of_results( self.history_results, formatted_eval_res['Results_raw']) self.early_stopper.track_and_check(self.history_results[ self._cfg.eval.best_res_update_round_wise_key]) diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 0bb9a0d82..33e9c09a6 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -12,8 +12,8 @@ gRPCCommManager from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator from federatedscope.core.auxiliaries.sampler_builder import get_sampler -from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ - merge_param_dict +from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \ + Timeout, merge_param_dict from federatedscope.core.auxiliaries.trainer_builder import get_trainer from federatedscope.core.secret_sharing import AdditiveSecretSharing from federatedscope.core.workers.base_server import BaseServer @@ -60,7 +60,7 @@ def __init__(self, self.early_stopper = EarlyStopper( self._cfg.early_stop.patience, self._cfg.early_stop.delta, self._cfg.early_stop.improve_indicator_mode, - self._cfg.early_stop.the_smaller_the_better) + self._monitor.the_larger_the_better) if self._cfg.federate.share_local_model: # put the model to the specified device @@ -473,8 +473,8 @@ def _merge_and_format_eval_results(self): # Get all the message & aggregate formatted_eval_res = \ self.merge_eval_results_from_all_clients() - self.history_results = merge_dict(self.history_results, - formatted_eval_res) + self.history_results = merge_dict_of_results(self.history_results, + formatted_eval_res) if self.mode == 'standalone' and \ self._monitor.wandb_online_track and \ self._monitor.use_wandb: @@ -572,9 +572,7 @@ def merge_eval_results_from_all_clients(self): self.best_results, metrics_all_clients, results_type="unseen_client_best_individual" - if merge_type == "unseen" else "client_best_individual", - round_wise_update_key=self._cfg.eval. - best_res_update_round_wise_key) + if merge_type == "unseen" else "client_best_individual") self._monitor.save_formatted_results(formatted_logs) for form in self._cfg.eval.report: if form != "raw": @@ -585,9 +583,7 @@ def merge_eval_results_from_all_clients(self): formatted_logs[f"Results_{metric_name}"], results_type=f"unseen_client_summarized_{form}" if merge_type == "unseen" else - f"client_summarized_{form}", - round_wise_update_key=self._cfg.eval. - best_res_update_round_wise_key) + f"client_summarized_{form}") return formatted_logs_all_set @@ -841,11 +837,9 @@ def eval(self): self._monitor.update_best_result( self.best_results, formatted_eval_res['Results_raw'], - results_type="server_global_eval", - round_wise_update_key=self._cfg.eval. - best_res_update_round_wise_key) - self.history_results = merge_dict(self.history_results, - formatted_eval_res) + results_type="server_global_eval") + self.history_results = merge_dict_of_results( + self.history_results, formatted_eval_res) self._monitor.save_formatted_results(formatted_eval_res) logger.info(formatted_eval_res) self.check_and_save() diff --git a/federatedscope/cv/dataset/leaf_cv.py b/federatedscope/cv/dataset/leaf_cv.py index f1be136e0..12b79ec00 100644 --- a/federatedscope/cv/dataset/leaf_cv.py +++ b/federatedscope/cv/dataset/leaf_cv.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split -from federatedscope.core.auxiliaries.utils import save_local_data, download_url +from federatedscope.core.data.utils import save_local_data, download_url from federatedscope.cv.dataset.leaf import LEAF IMAGE_SIZE = {'femnist': (28, 28), 'celeba': (84, 84, 3)} diff --git a/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml b/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml index c7865d4c8..c932fda46 100644 --- a/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml +++ b/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml @@ -3,7 +3,6 @@ device: 0 early_stop: patience: 20 improve_indicator_mode: mean - the_smaller_the_better: False federate: mode: 'standalone' make_global_eval: False diff --git a/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml b/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml index 3153c315d..f91465547 100644 --- a/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml +++ b/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml @@ -3,7 +3,6 @@ device: 0 early_stop: patience: 20 improve_indicator_mode: mean - the_smaller_the_better: False federate: mode: standalone method: local diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py index 467b6d867..071e7d4e3 100644 --- a/federatedscope/gfl/fedsageplus/worker.py +++ b/federatedscope/gfl/fedsageplus/worker.py @@ -7,7 +7,7 @@ from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client -from federatedscope.core.auxiliaries.utils import merge_dict +from federatedscope.core.auxiliaries.utils import merge_dict_of_results from federatedscope.core.data import ClientData from federatedscope.gfl.trainer.nodetrainer import NodeMiniBatchTrainer @@ -235,8 +235,8 @@ def check_and_move_on(self, check_eval_result=False): else: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() - self.history_results = merge_dict(self.history_results, - formatted_eval_res) + self.history_results = merge_dict_of_results( + self.history_results, formatted_eval_res) self.check_and_save() diff --git a/federatedscope/gfl/flitplus/trainer.py b/federatedscope/gfl/flitplus/trainer.py index bc22ccb39..e1965bcaf 100644 --- a/federatedscope/gfl/flitplus/trainer.py +++ b/federatedscope/gfl/flitplus/trainer.py @@ -1,7 +1,7 @@ import torch from copy import deepcopy -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE from federatedscope.core.trainers.context import CtxVar from federatedscope.gfl.loss.vat import VATLoss from federatedscope.core.trainers import GeneralTorchTrainer diff --git a/federatedscope/gfl/gcflplus/worker.py b/federatedscope/gfl/gcflplus/worker.py index 26a1d62df..2cbc85d04 100644 --- a/federatedscope/gfl/gcflplus/worker.py +++ b/federatedscope/gfl/gcflplus/worker.py @@ -6,7 +6,7 @@ from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client -from federatedscope.core.auxiliaries.utils import merge_dict +from federatedscope.core.auxiliaries.utils import merge_dict_of_results from federatedscope.gfl.gcflplus.utils import compute_pairwise_distances, \ min_cut, norm @@ -172,8 +172,8 @@ def check_and_move_on(self, check_eval_result=False): else: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() - self.history_results = merge_dict(self.history_results, - formatted_eval_res) + self.history_results = merge_dict_of_results( + self.history_results, formatted_eval_res) self.check_and_save() diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py index 2479ce933..cfa59c529 100644 --- a/federatedscope/gfl/trainer/graphtrainer.py +++ b/federatedscope/gfl/trainer/graphtrainer.py @@ -4,7 +4,7 @@ from federatedscope.register import register_trainer from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.trainers.context import CtxVar -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE logger = logging.getLogger(__name__) diff --git a/federatedscope/gfl/trainer/linktrainer.py b/federatedscope/gfl/trainer/linktrainer.py index e2885b1a4..6313bdf8b 100644 --- a/federatedscope/gfl/trainer/linktrainer.py +++ b/federatedscope/gfl/trainer/linktrainer.py @@ -3,7 +3,7 @@ from torch.utils.data import DataLoader from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE from federatedscope.core.monitors import Monitor from federatedscope.core.trainers.context import CtxVar from federatedscope.register import register_trainer @@ -38,7 +38,6 @@ def register_default_hooks_train(self): def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - """ init_dict = dict() if isinstance(data, dict): @@ -139,7 +138,6 @@ class LinkMiniBatchTrainer(GeneralTorchTrainer): def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - """ init_dict = dict() if isinstance(data, dict): diff --git a/federatedscope/gfl/trainer/nodetrainer.py b/federatedscope/gfl/trainer/nodetrainer.py index 7871bdb2c..277a51fd0 100644 --- a/federatedscope/gfl/trainer/nodetrainer.py +++ b/federatedscope/gfl/trainer/nodetrainer.py @@ -1,7 +1,7 @@ import torch from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE from federatedscope.core.monitors import Monitor from federatedscope.core.trainers.context import CtxVar from federatedscope.register import register_trainer @@ -16,7 +16,6 @@ class NodeFullBatchTrainer(GeneralTorchTrainer): def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - """ init_dict = dict() if isinstance(data, dict): @@ -92,7 +91,6 @@ class NodeMiniBatchTrainer(GeneralTorchTrainer): def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - """ init_dict = dict() if isinstance(data, dict): @@ -124,7 +122,6 @@ def parse_data(self, data): return init_dict def _hook_on_epoch_start(self, ctx): - # TODO: blind torch if not isinstance(ctx.get("{}_loader".format(ctx.cur_split)), ReIterator): if isinstance(ctx.get("{}_loader".format(ctx.cur_split)), diff --git a/federatedscope/main.py b/federatedscope/main.py index 0eae9c8c3..f7dabc993 100644 --- a/federatedscope/main.py +++ b/federatedscope/main.py @@ -14,7 +14,7 @@ from federatedscope.core.auxiliaries.worker_builder import get_client_cls, \ get_server_cls from federatedscope.core.configs.config import global_cfg, CfgNode -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner if os.environ.get('https_proxy'): del os.environ['https_proxy'] @@ -44,9 +44,9 @@ init_cfg.freeze() - runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone(), - client_configs=client_cfgs) + runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone(), + client_configs=client_cfgs) _ = runner.run() diff --git a/federatedscope/mf/trainer/trainer.py b/federatedscope/mf/trainer/trainer.py index f169c4d16..76757632f 100644 --- a/federatedscope/mf/trainer/trainer.py +++ b/federatedscope/mf/trainer/trainer.py @@ -1,7 +1,7 @@ import numpy from wandb.wandb_torch import torch -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE from federatedscope.core.monitors import Monitor from federatedscope.core.trainers.context import CtxVar from federatedscope.mf.dataloader.dataloader import MFDataLoader @@ -24,7 +24,6 @@ class MFTrainer(GeneralTorchTrainer): def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - """ init_dict = dict() if isinstance(data, dict): diff --git a/federatedscope/mf/trainer/trainer_sgdmf.py b/federatedscope/mf/trainer/trainer_sgdmf.py index 653eeb555..c1398857e 100644 --- a/federatedscope/mf/trainer/trainer_sgdmf.py +++ b/federatedscope/mf/trainer/trainer_sgdmf.py @@ -1,7 +1,7 @@ import logging from federatedscope.mf.trainer.trainer import MFTrainer -from federatedscope.core.auxiliaries.utils import get_random +from federatedscope.core.trainers.utils import get_random from typing import Type import numpy as np diff --git a/federatedscope/nlp/dataset/leaf_nlp.py b/federatedscope/nlp/dataset/leaf_nlp.py index 7f9803496..daa3cdd69 100644 --- a/federatedscope/nlp/dataset/leaf_nlp.py +++ b/federatedscope/nlp/dataset/leaf_nlp.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split -from federatedscope.core.auxiliaries.utils import save_local_data, download_url +from federatedscope.core.data.utils import save_local_data, download_url from federatedscope.cv.dataset.leaf import LEAF from federatedscope.nlp.dataset.utils import * diff --git a/federatedscope/nlp/dataset/leaf_synthetic.py b/federatedscope/nlp/dataset/leaf_synthetic.py index b5f01f032..768e22958 100644 --- a/federatedscope/nlp/dataset/leaf_synthetic.py +++ b/federatedscope/nlp/dataset/leaf_synthetic.py @@ -8,7 +8,7 @@ from sklearn.utils import shuffle from torch.utils.data import Dataset -from federatedscope.core.auxiliaries.utils import save_local_data +from federatedscope.core.data.utils import save_local_data from federatedscope.cv.dataset.leaf import LEAF diff --git a/federatedscope/nlp/dataset/leaf_twitter.py b/federatedscope/nlp/dataset/leaf_twitter.py index 733f54aef..bd04d07da 100644 --- a/federatedscope/nlp/dataset/leaf_twitter.py +++ b/federatedscope/nlp/dataset/leaf_twitter.py @@ -10,7 +10,7 @@ from tqdm import tqdm from sklearn.model_selection import train_test_split -from federatedscope.core.auxiliaries.utils import save_local_data, download_url +from federatedscope.core.data.utils import save_local_data, download_url from federatedscope.cv.dataset.leaf import LEAF, LocalDataset from federatedscope.nlp.dataset.utils import * diff --git a/federatedscope/nlp/trainer/trainer.py b/federatedscope/nlp/trainer/trainer.py index 8f3fac5fc..2e640e25d 100644 --- a/federatedscope/nlp/trainer/trainer.py +++ b/federatedscope/nlp/trainer/trainer.py @@ -1,11 +1,11 @@ from federatedscope.register import register_trainer from federatedscope.core.trainers import GeneralTorchTrainer -from federatedscope.core.auxiliaries import utils +from federatedscope.core.trainers.utils import move_to class NLPTrainer(GeneralTorchTrainer): def _hook_on_batch_forward(self, ctx): - x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] + x, label = [move_to(_, ctx.device) for _ in ctx.data_batch] if isinstance(x, dict): pred = ctx.model(**x)[0] else: diff --git a/federatedscope/tabular/dataloader/toy.py b/federatedscope/tabular/dataloader/toy.py index a5a4cb0aa..26114fa71 100644 --- a/federatedscope/tabular/dataloader/toy.py +++ b/federatedscope/tabular/dataloader/toy.py @@ -6,8 +6,6 @@ def load_toy_data(config=None): - generate = config.federate.mode.lower() == 'standalone' - def _generate_data(client_num=5, instance_num=1000, feature_num=5, @@ -104,13 +102,8 @@ def _generate_data(client_num=5, return data - if generate: - data = _generate_data(client_num=config.federate.client_num, - save_data=config.data.save_data) - else: - with open(config.distribute.data_file, 'rb') as f: - data = pickle.load(f) - data = {config.distribute.data_idx: data} + data = _generate_data(client_num=config.federate.client_num, + save_data=config.data.save_data) for client_id in data.keys(): data[client_id] = { k: WrapDataset(v) diff --git a/federatedscope/vertical_fl/worker/vertical_server.py b/federatedscope/vertical_fl/worker/vertical_server.py index d1e2da946..53348470c 100644 --- a/federatedscope/vertical_fl/worker/vertical_server.py +++ b/federatedscope/vertical_fl/worker/vertical_server.py @@ -84,12 +84,9 @@ def callback_funcs_for_encryped_gradient(self, message: Message): if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: metrics = self.evaluate() - self._monitor.update_best_result( - self.best_results, - metrics, - results_type='server_global_eval', - round_wise_update_key=self._cfg.eval. - best_res_update_round_wise_key) + self._monitor.update_best_result(self.best_results, + metrics, + results_type='server_global_eval') formatted_logs = self._monitor.format_eval_res( metrics, rnd=self.state, @@ -104,12 +101,9 @@ def callback_funcs_for_encryped_gradient(self, message: Message): self.broadcast_model_para() else: metrics = self.evaluate() - self._monitor.update_best_result( - self.best_results, - metrics, - results_type='server_global_eval', - round_wise_update_key=self._cfg.eval. - best_res_update_round_wise_key) + self._monitor.update_best_result(self.best_results, + metrics, + results_type='server_global_eval') formatted_logs = self._monitor.format_eval_res( metrics, rnd=self.state, diff --git a/scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml b/scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml index a5c31dcf0..e98e4236e 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml @@ -12,7 +12,7 @@ distribute: client_host: '127.0.0.1' client_port: 50052 role: 'client' - data_file: 'toy_data/client_1_data' + data_idx: 1 trainer: type: 'general' eval: diff --git a/scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml b/scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml index 2fbf540de..0acd42a41 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml @@ -12,7 +12,7 @@ distribute: client_host: '127.0.0.1' client_port: 50053 role: 'client' - data_file: 'toy_data/client_2_data' + data_idx: 1 trainer: type: 'general' eval: diff --git a/scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml b/scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml index 334725920..66e793493 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml @@ -12,7 +12,7 @@ distribute: client_host: '127.0.0.1' client_port: 50054 role: 'client' - data_file: 'toy_data/client_3_data' + data_idx: 1 trainer: type: 'general' eval: diff --git a/scripts/distributed_scripts/distributed_configs/distributed_server.yaml b/scripts/distributed_scripts/distributed_configs/distributed_server.yaml index 6728ff9c2..6ea218f31 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_server.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_server.yaml @@ -10,7 +10,7 @@ distribute: server_host: '127.0.0.1' server_port: 50051 role: 'server' - data_file: 'toy_data/server_data' + data_idx: 0 trainer: type: 'general' eval: diff --git a/scripts/distributed_scripts/gen_data.py b/scripts/distributed_scripts/gen_data.py index 6ef6b6ca7..2784a8121 100644 --- a/scripts/distributed_scripts/gen_data.py +++ b/scripts/distributed_scripts/gen_data.py @@ -8,7 +8,7 @@ def generate_data(client_num=3, feature_num=5, save_data=True): """ - Generate data in FedRunner format + Generate data in Runner format Args: client_num: instance_num: diff --git a/scripts/distributed_scripts/run_distributed_lr.sh b/scripts/distributed_scripts/run_distributed_lr.sh index efce4cb4b..aa1840e25 100755 --- a/scripts/distributed_scripts/run_distributed_lr.sh +++ b/scripts/distributed_scripts/run_distributed_lr.sh @@ -4,9 +4,6 @@ cd .. echo "Test distributed mode with LR..." -echo "Data generation" -python scripts/distributed_scripts/gen_data.py - ### server owns global test data python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server.yaml & ### server doesn't own data diff --git a/scripts/example_configs/femnist_global_train.yaml b/scripts/example_configs/femnist_global_train.yaml index 15871047a..8126ed2a1 100644 --- a/scripts/example_configs/femnist_global_train.yaml +++ b/scripts/example_configs/femnist_global_train.yaml @@ -5,11 +5,15 @@ early_stop: seed: 12345 federate: mode: standalone - local_update_steps: 1 - batch_or_epoch: epoch total_round_num: 300 sample_client_rate: 0.2 method: global +train: + local_update_steps: 1 + batch_or_epoch: epoch + optimizer: + lr: 0.01 + weight_decay: 0.0 data: root: data/ type: femnist @@ -22,9 +26,6 @@ model: type: convnet2 hidden: 2048 out_channels: 62 -optimizer: - lr: 0.01 - weight_decay: 0.0 grad: grad_clip: 5.0 criterion: diff --git a/setup.py b/setup.py index 8a402f766..b3e2d398c 100644 --- a/setup.py +++ b/setup.py @@ -9,21 +9,31 @@ minimal_requires = [ 'numpy<1.23.0', 'scikit-learn==1.0.2', 'scipy==1.7.3', 'pandas', 'grpcio>=1.45.0', 'grpcio-tools', 'pyyaml>=5.1', 'fvcore', 'iopath', - 'wandb', 'tensorboard', 'tensorboardX', 'pympler', 'protobuf==3.19.4' + 'wandb', 'tensorboard', 'tensorboardX', 'pympler', 'protobuf==3.19.4', + 'matplotlib' ] -test_requires = [] +test_requires = ['pytest', 'pytest-cov'] dev_requires = test_requires + ['pre-commit'] org_requires = ['paramiko==2.11.0', 'celery[redis]', 'cmd2'] +app_requires = [ + 'torch-geometric==2.0.4', 'nltk', 'transformers==4.16.2', + 'tokenizers==0.10.3', 'datasets', 'sentencepiece', 'textgrid', 'typeguard', + 'openml==0.12.2' +] + benchmark_hpo_requires = [ 'configspace==0.5.0', 'hpbandster==0.7.4', 'smac==1.3.3', 'optuna==2.10.0' ] benchmark_htl_requires = ['learn2learn'] +full_requires = org_requires + benchmark_hpo_requires + \ + benchmark_htl_requires + app_requires + with open("README.md", "r") as fh: long_description = fh.read() @@ -45,10 +55,12 @@ install_requires=minimal_requires, extras_require={ 'test': test_requires, + 'app': app_requires, 'org': org_requires, 'dev': dev_requires, 'hpo': benchmark_hpo_requires, 'htl': benchmark_htl_requires, + 'full': full_requires }, license="Apache License 2.0", classifiers=[ diff --git a/tests/test_CRA_gan_attack.py b/tests/test_CRA_gan_attack.py index 6fad02276..5a1bb72dd 100644 --- a/tests/test_CRA_gan_attack.py +++ b/tests/test_CRA_gan_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -68,10 +68,10 @@ def test_CRA_GAN_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_MIA_gradient_ascent.py b/tests/test_MIA_gradient_ascent.py index feb4248cf..b4c3dd750 100644 --- a/tests/test_MIA_gradient_ascent.py +++ b/tests/test_MIA_gradient_ascent.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -68,10 +68,10 @@ def test_GradAscent_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_PIA_toy.py b/tests/test_PIA_toy.py index 24753ea28..e25555926 100644 --- a/tests/test_PIA_toy.py +++ b/tests/test_PIA_toy.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -41,10 +41,10 @@ def test_PIA_toy_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_asyn_cifar10.py b/tests/test_asyn_cifar10.py index 59a87cad0..530b2ad01 100644 --- a/tests/test_asyn_cifar10.py +++ b/tests/test_asyn_cifar10.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -33,7 +33,7 @@ def set_config_cifar10_goalAchieved_afterReceiving(self, cfg): cfg.data.root = 'test_data/' cfg.data.type = 'CIFAR10@torchvision' - cfg.data.args = [{'download': False}] + cfg.data.args = [{'download': True}] cfg.data.splits = [0.8, 0.2, 0.2] cfg.data.batch_size = 10 cfg.data.subsample = 0.2 @@ -49,7 +49,7 @@ def set_config_cifar10_goalAchieved_afterReceiving(self, cfg): cfg.data.splitter_args = [{'alpha': 0.2}] cfg.model.type = 'convnet2' - cfg.model.hidden = 512 + cfg.model.hidden = 128 cfg.model.out_channels = 10 cfg.train.local_update_steps = 2 @@ -92,7 +92,7 @@ def set_config_cifar10_timeUp_afterAggregating(self, cfg): cfg.data.root = 'test_data/' cfg.data.type = 'CIFAR10@torchvision' - cfg.data.args = [{'download': False}] + cfg.data.args = [{'download': True}] cfg.data.splits = [0.8, 0.2, 0.2] cfg.data.batch_size = 10 cfg.data.subsample = 0.2 @@ -108,7 +108,7 @@ def set_config_cifar10_timeUp_afterAggregating(self, cfg): cfg.data.splitter_args = [{'alpha': 0.2}] cfg.model.type = 'convnet2' - cfg.model.hidden = 512 + cfg.model.hidden = 128 cfg.model.out_channels = 10 cfg.train.local_update_steps = 2 @@ -152,7 +152,7 @@ def set_config_cifar10_overselection(self, cfg): cfg.data.root = 'test_data/' cfg.data.type = 'CIFAR10@torchvision' - cfg.data.args = [{'download': False}] + cfg.data.args = [{'download': True}] cfg.data.splits = [0.8, 0.2, 0.2] cfg.data.batch_size = 10 cfg.data.subsample = 0.2 @@ -168,7 +168,7 @@ def set_config_cifar10_overselection(self, cfg): cfg.data.splitter_args = [{'alpha': 0.2}] cfg.model.type = 'convnet2' - cfg.model.hidden = 512 + cfg.model.hidden = 128 cfg.model.out_channels = 10 cfg.train.local_update_steps = 2 @@ -202,10 +202,10 @@ def test_asyn_cifar10_goalAchieved_afterReceiving(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) @@ -223,10 +223,10 @@ def test_asyn_cifar10_timeUp_afterAggregating(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) @@ -244,10 +244,10 @@ def test_asyn_cifar10_overselection(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_backdoor_attack.py b/tests/test_backdoor_attack.py index 1d7a97b43..f0683c812 100644 --- a/tests/test_backdoor_attack.py +++ b/tests/test_backdoor_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -71,10 +71,10 @@ def test_backdoor_edge_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_ditto.py b/tests/test_ditto.py index 07cd8710a..77f23e445 100644 --- a/tests/test_ditto.py +++ b/tests/test_ditto.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -73,10 +73,10 @@ def test_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_efficient_simulation.py b/tests/test_efficient_simulation.py index 8ec927b51..aa16cdf9a 100644 --- a/tests/test_efficient_simulation.py +++ b/tests/test_efficient_simulation.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -21,10 +21,10 @@ def test_toy_example_standalone_cmp_sim_impl(self): update_logger(case_cfg) data, _ = get_data(case_cfg.clone()) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(case_cfg), - client_class=get_client_cls(case_cfg), - config=case_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(case_cfg), + client_class=get_client_cls(case_cfg), + config=case_cfg.clone()) efficient_test_results = Fed_runner.run() setup_seed(case_cfg.seed) @@ -33,10 +33,10 @@ def test_toy_example_standalone_cmp_sim_impl(self): 'False' ]) data, _ = get_data(case_cfg.clone()) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(case_cfg), - client_class=get_client_cls(case_cfg), - config=case_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(case_cfg), + client_class=get_client_cls(case_cfg), + config=case_cfg.clone()) ordinary_test_results = Fed_runner.run() gap = efficient_test_results["client_summarized_weighted_avg"][ 'test_loss'] - ordinary_test_results[ diff --git a/tests/test_external_dataset.py b/tests/test_external_dataset.py index 7d243c70d..806efe603 100644 --- a/tests/test_external_dataset.py +++ b/tests/test_external_dataset.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -27,6 +27,8 @@ def set_config_torchvision_dataset(self, cfg): cfg.train.batch_or_epoch = 'epoch' cfg.federate.client_num = 5 cfg.federate.sample_client_rate = 0.2 + cfg.federate.share_local_model = True + cfg.federate.online_aggr = True cfg.data.root = 'test_data/' cfg.data.type = 'MNIST@torchvision' @@ -67,10 +69,12 @@ def set_config_torchtext_dataset(self, cfg): cfg.federate.mode = 'standalone' cfg.train.local_update_steps = 1 - cfg.federate.total_round_num = 20 + cfg.federate.total_round_num = 10 cfg.train.batch_or_epoch = 'epoch' cfg.federate.client_num = 5 cfg.federate.sample_client_rate = 0.2 + cfg.federate.share_local_model = True + cfg.federate.online_aggr = True cfg.data.root = 'test_data/' cfg.data.args = [{'max_len': 100}] @@ -107,10 +111,10 @@ def test_torchvision_dataset_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) @@ -129,17 +133,17 @@ def test_torchtext_dataset_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) init_cfg.merge_from_other_cfg(backup_cfg) self.assertGreater( test_best_results["client_summarized_weighted_avg"]['test_acc'], - 0.65) + 0.6) if __name__ == '__main__': diff --git a/tests/test_fedem.py b/tests/test_fedem.py index d3de07e18..28a7dec23 100644 --- a/tests/test_fedem.py +++ b/tests/test_fedem.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -67,10 +67,10 @@ def test_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_fedopt.py b/tests/test_fedopt.py index 0d598e2c3..30d9a53aa 100644 --- a/tests/test_fedopt.py +++ b/tests/test_fedopt.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -67,10 +67,10 @@ def test_fedopt_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) diff --git a/tests/test_fedprox.py b/tests/test_fedprox.py index 10ce8b583..2f7efb049 100644 --- a/tests/test_fedprox.py +++ b/tests/test_fedprox.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -66,10 +66,10 @@ def test_fedprox_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) diff --git a/tests/test_fedsageplus.py b/tests/test_fedsageplus.py index ee9706219..85173664f 100644 --- a/tests/test_fedsageplus.py +++ b/tests/test_fedsageplus.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -64,10 +64,10 @@ def test_fedsageplus_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) diff --git a/tests/test_femnist.py b/tests/test_femnist.py index ad99f693e..31141b2bc 100644 --- a/tests/test_femnist.py +++ b/tests/test_femnist.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls SAMPLE_CLIENT_NUM = 5 @@ -67,10 +67,10 @@ def test_femnist_standalone(self): self.assertEqual(init_cfg.federate.sample_client_num, SAMPLE_CLIENT_NUM) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_finetune_lr.py b/tests/test_finetune_lr.py index 8f6254705..ef137d59b 100644 --- a/tests/test_finetune_lr.py +++ b/tests/test_finetune_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -42,10 +42,10 @@ def test_toy_example_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_global_train_lr.py b/tests/test_global_train_lr.py index 41167e57a..f2ed6632a 100644 --- a/tests/test_global_train_lr.py +++ b/tests/test_global_train_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -40,10 +40,10 @@ def test_toy_example_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_graph_node_trainer.py b/tests/test_graph_node_trainer.py index a6e441bd5..8f74ac660 100644 --- a/tests/test_graph_node_trainer.py +++ b/tests/test_graph_node_trainer.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -57,10 +57,10 @@ def test_node_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) diff --git a/tests/test_local_train_lr.py b/tests/test_local_train_lr.py index f87a96cfa..9d06c9968 100644 --- a/tests/test_local_train_lr.py +++ b/tests/test_local_train_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, \ get_client_cls @@ -41,10 +41,10 @@ def test_toy_example_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_mf.py b/tests/test_mf.py index c602d8621..8ff641a65 100644 --- a/tests/test_mf.py +++ b/tests/test_mf.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -56,10 +56,10 @@ def test_mf_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) diff --git a/tests/test_nbafl.py b/tests/test_nbafl.py index b328905ba..c61036bb7 100644 --- a/tests/test_nbafl.py +++ b/tests/test_nbafl.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -72,10 +72,10 @@ def test_nbafl_standalone(self): # Run on first 10 clients init_cfg.merge_from_list(['federate.client_num', 10]) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 377a6a66f..ba2733053 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -66,10 +66,10 @@ def test_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_pfedme.py b/tests/test_pfedme.py index bb69bc107..48e793083 100644 --- a/tests/test_pfedme.py +++ b/tests/test_pfedme.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -69,10 +69,10 @@ def test_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_rec_IG_opt_attack.py b/tests/test_rec_IG_opt_attack.py index 63e029351..25fc840d4 100644 --- a/tests/test_rec_IG_opt_attack.py +++ b/tests/test_rec_IG_opt_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -69,10 +69,10 @@ def test_IG_rec_femnist_standalone(self): init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_rec_opt_attack.py b/tests/test_rec_opt_attack.py index f28942457..be5ecfef2 100644 --- a/tests/test_rec_opt_attack.py +++ b/tests/test_rec_opt_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -75,10 +75,10 @@ def test_rec_femnist_standalone(self): # else: # server_class = Server - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_toy_lr.py b/tests/test_toy_lr.py index 276997de2..523a4bec4 100644 --- a/tests/test_toy_lr.py +++ b/tests/test_toy_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -40,10 +40,10 @@ def test_toy_example_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) @@ -64,10 +64,10 @@ def test_toy_example_standalone_global_eval(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_unseen_clients_lr.py b/tests/test_unseen_clients_lr.py index db87af090..53e9d6ec0 100644 --- a/tests/test_unseen_clients_lr.py +++ b/tests/test_unseen_clients_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls @@ -38,10 +38,10 @@ def test_toy_example_standalone(self): self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) diff --git a/tests/test_vertical_fl.py b/tests/test_vertical_fl.py index 95613c466..b5394c0f9 100644 --- a/tests/test_vertical_fl.py +++ b/tests/test_vertical_fl.py @@ -6,7 +6,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.fed_runner import get_runner class vFLTest(unittest.TestCase): @@ -52,10 +52,10 @@ def test_vFL(self): init_cfg.merge_from_other_cfg(modified_config) self.assertIsNotNone(data) - Fed_runner = FedRunner(data=data, - server_class=get_server_cls(init_cfg), - client_class=get_client_cls(init_cfg), - config=init_cfg.clone()) + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) From c023a950a377f6c72950e17f77ec201534d39f00 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Fri, 28 Oct 2022 17:06:35 +0800 Subject: [PATCH 02/21] rm codecov for now --- .github/workflows/codecov.yml | 46 ----------------------------------- 1 file changed, 46 deletions(-) delete mode 100644 .github/workflows/codecov.yml diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml deleted file mode 100644 index 37bb48dbb..000000000 --- a/.github/workflows/codecov.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: Codecov UnitTests - -on: [push, pull_request] - -jobs: - run: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: ['3.9'] - torch-version: ['1.10.1'] - torchvision-version: ['0.11.2'] - torchaudio-version: ['0.10.1'] - torchtext-version: ['0.11.1'] - env: - OS: ${{ matrix.os }} - PYTHON: '3.9' - steps: - - uses: actions/checkout@master - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@master - with: - python-version: ${{ matrix.python-version }} - - name: Install PyTorch ${{ matrix.torch-version }}+cpu - run: | - pip install numpy typing-extensions dataclasses - pip install torch==${{ matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}} torchtext==${{matrix.torchtext-version}} -f https://download.pytorch.org/whl/torch_stable.html - - name: Install FS - run: | - pip install -e .[test] - - name: Generate coverage report - run: | - pytest --cov=./ --cov-report=xml - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - with: - token: ${{ secrets.CODECOV_TOKEN }} - directory: ./coverage/reports/ - env_vars: OS,PYTHON - fail_ci_if_error: true - files: ./coverage1.xml,./coverage2.xml - flags: unittests - name: codecov-umbrella - path_to_write_report: ./coverage/codecov_report.txt - verbose: true \ No newline at end of file From 03777918d71357bd0f1cfc78a6ba7776b27a7e6b Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 15:11:52 +0800 Subject: [PATCH 03/21] update docstring --- federatedscope/core/aggregators/aggregator.py | 2 +- .../core/auxiliaries/aggregator_builder.py | 46 +++++++++++ .../core/auxiliaries/criterion_builder.py | 23 ++++-- .../core/auxiliaries/trainer_builder.py | 76 ++++++++++++++++++- 4 files changed, 138 insertions(+), 9 deletions(-) diff --git a/federatedscope/core/aggregators/aggregator.py b/federatedscope/core/aggregators/aggregator.py index c8e2052ac..4f966b74b 100644 --- a/federatedscope/core/aggregators/aggregator.py +++ b/federatedscope/core/aggregators/aggregator.py @@ -11,7 +11,7 @@ def aggregate(self, agg_info): class NoCommunicationAggregator(Aggregator): - """"Clients do not communicate. Each client work locally + """Clients do not communicate. Each client work locally """ def aggregate(self, agg_info): # do nothing diff --git a/federatedscope/core/auxiliaries/aggregator_builder.py b/federatedscope/core/auxiliaries/aggregator_builder.py index 778ff0c6f..c6b07c284 100644 --- a/federatedscope/core/auxiliaries/aggregator_builder.py +++ b/federatedscope/core/auxiliaries/aggregator_builder.py @@ -6,6 +6,52 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): + """ + This function builds an aggregator, which is a protocol for aggregate \ + all clients' model(s). + + Arguments: + method: key to determine which aggregator to use + model: model to be aggregated + device: where to aggregate models (``cpu`` or ``gpu``) + online: ``True`` or ``False`` to use online aggregator. + config: configurations for FL, see ``federatedscope.core.configs`` + + Returns: + An instance of aggregator (see ``core.aggregator`` for details) + + Note: + The key-value pairs of ``method`` and aggregators: + ================================== =========================== + Method Aggregator + ================================== =========================== + ``tensorflow`` ``cross_backends.FedAvgAggregator`` + ``local`` \ + ``core.aggregators.NoCommunicationAggregator`` + ``global`` \ + ``core.aggregators.NoCommunicationAggregator`` + ``fedavg`` \ + ``core.aggregators.OnlineClientsAvgAggregator`` or \ + ``core.aggregators.AsynClientsAvgAggregator`` or \ + ``ClientsAvgAggregator`` + ``pfedme`` \ + ``core.aggregators.ServerClientsInterpolateAggregator`` + ``ditto`` \ + ``core.aggregators.OnlineClientsAvgAggregator`` or \ + ``core.aggregators.AsynClientsAvgAggregator`` or \ + ``ClientsAvgAggregator`` + ``fedsageplus`` \ + ``core.aggregators.OnlineClientsAvgAggregator`` or \ + ``core.aggregators.AsynClientsAvgAggregator`` or \ + ``ClientsAvgAggregator`` + ``gcflplus`` \ + ``core.aggregators.OnlineClientsAvgAggregator`` or \ + ``core.aggregators.AsynClientsAvgAggregator`` or \ + ``ClientsAvgAggregator`` + ``fedopt`` \ + ``core.aggregators.FedOptAggregator`` + ================================== =========================== + """ if config.backend == 'tensorflow': from federatedscope.cross_backends import FedAvgAggregator return FedAvgAggregator(model=model, device=device) diff --git a/federatedscope/core/auxiliaries/criterion_builder.py b/federatedscope/core/auxiliaries/criterion_builder.py index 4192a9b02..2baf1b85a 100644 --- a/federatedscope/core/auxiliaries/criterion_builder.py +++ b/federatedscope/core/auxiliaries/criterion_builder.py @@ -17,17 +17,28 @@ f'available.') -def get_criterion(type, device): +def get_criterion(criterion_type, device): + """ + This function builds an instance of loss functions from \ + "https://pytorch.org/docs/stable/nn.html#loss-functions". + + Arguments: + criterion_type: loss function type + device: move to device (``cpu`` or ``gpu``) + + Returns: + An instance of loss functions. + """ for func in register.criterion_dict.values(): - criterion = func(type, device) + criterion = func(criterion_type, device) if criterion is not None: return criterion - if isinstance(type, str): - if hasattr(nn, type): - return getattr(nn, type)() + if isinstance(criterion_type, str): + if hasattr(nn, criterion_type): + return getattr(nn, criterion_type)() else: raise NotImplementedError( - 'Criterion {} not implement'.format(type)) + 'Criterion {} not implement'.format(criterion_type)) else: raise TypeError() diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index 599291762..6f773fc45 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -36,6 +36,75 @@ def get_trainer(model=None, only_for_eval=False, is_attacker=False, monitor=None): + """ + This function builds an instance of trainer. + + Arguments: + model: model used in FL course + data: data used in FL course + device: where to train model (``cpu`` or ``gpu``) + config: configurations for FL, see ``federatedscope.core.configs`` + only_for_eval: ``True`` or ``False``, if ``True``, ``train`` \ + routine will be removed in this trainer + is_attacker: ``True`` or ``False`` to determine whether this client \ + is an attacker + monitor: an instance of ``federatedscope.core.monitors.Monitor`` to \ + observe the evaluation and system metrics + + Returns: + An instance of trainer. + + Note: + The key-value pairs of ``cfg.trainer.type`` and trainers: + ================================== =========================== + Trainer Type Source + ================================== =========================== + ``general`` \ + ``core.trainers.GeneralTorchTrainer`` and \ + ``core.trainers.GeneralTFTrainer`` + ``cvtrainer`` ``cv.trainer.trainer.CVTrainer`` + ``nlptrainer`` ``nlp.trainer.trainer.NLPTrainer`` + ``graphminibatch_trainer`` \ + ``gfl.trainer.graphtrainer.GraphMiniBatchTrainer`` + ``linkfullbatch_trainer`` \ + ``gfl.trainer.linktrainer.LinkFullBatchTrainer`` + ``linkminibatch_trainer`` \ + ``gfl.trainer.linktrainer.LinkMiniBatchTrainer`` + ``nodefullbatch_trainer`` \ + ``gfl.trainer.nodetrainer.NodeFullBatchTrainer`` + ``nodeminibatch_trainer`` \ + ``gfl.trainer.nodetrainer.NodeMiniBatchTrainer`` + ``flitplustrainer`` \ + ``gfl.flitplus.trainer.FLITPlusTrainer`` + ``flittrainer`` \ + ``gfl.flitplus.trainer.FLITTrainer`` + ``fedvattrainer`` \ + ``gfl.flitplus.trainer.FedVATTrainer`` + ``fedfocaltrainer`` \ + ``gfl.flitplus.trainer.FedFocalTrainer`` + ``mftrainer`` \ + ``federatedscope.mf.trainer.MFTrainer`` + ``mytorchtrainer`` \ + ``contrib.trainer.torch_example.MyTorchTrainer`` + ================================== =========================== + Wrapper functions are shown below: + ================================== =========================== + Wrapper Functions Source + ================================== =========================== + ``nbafl`` \ + ``core.trainers.wrap_nbafl_trainer`` + ``sgdmf`` ``mf.trainer.wrap_MFTrainer`` + ``pfedme`` \ + ``core.trainers.wrap_pFedMeTrainer`` + ``ditto`` ``core.trainers.wrap_DittoTrainer`` + ``fedem`` ``core.trainers.FedEMTrainer`` + ``fedprox`` \ + ``core.trainers.wrap_fedprox_trainer`` + ``attack`` \ + ``attack.trainer.wrap_benignTrainer`` and \ + ``attack.auxiliary.attack_trainer_builder.wrap_attacker_trainer`` + ================================== =========================== + """ if config.trainer.type == 'general': if config.backend == 'torch': from federatedscope.core.trainers import GeneralTorchTrainer @@ -109,8 +178,11 @@ def get_trainer(model=None, config.trainer.type)) if not isinstance(trainer, Trainer): - logger.warning(f'When using {trainer}, trainer plug-in cannot be ' - f'enabled. Please use {Trainer} instead.') + logger.warning(f'Hook-like plug-in functions cannot be enabled when ' + f'using {trainer}. If you want use our wrapper ' + f'functions for your trainer please consider ' + f'inheriting from ' + f'`federatedscope.core.trainers.Trainer` instead.') return trainer # differential privacy plug-in From 574a3ffb51a723758e9a5c5aec40b79327d467da Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 16:35:31 +0800 Subject: [PATCH 04/21] update API ref for xxbuidler --- doc/source/core.rst | 21 ++++++++++++++++++- federatedscope/autotune/algos.py | 2 +- federatedscope/autotune/utils.py | 2 +- .../core/auxiliaries/runner_builder.py | 15 +++++++++++++ federatedscope/core/fed_runner.py | 15 ------------- federatedscope/core/workers/__init__.py | 5 ----- federatedscope/main.py | 2 +- 7 files changed, 38 insertions(+), 24 deletions(-) create mode 100644 federatedscope/core/auxiliaries/runner_builder.py diff --git a/doc/source/core.rst b/doc/source/core.rst index 9c7c3fae8..39d626954 100644 --- a/doc/source/core.rst +++ b/doc/source/core.rst @@ -11,7 +11,7 @@ federatedscope.core.fed_runner federatedscope.core.workers ----------------------- -.. automodule:: federatedscope.core.worker +.. automodule:: federatedscope.core.workers :members: :private-members: @@ -50,3 +50,22 @@ federatedscope.core.monitors .. automodule:: federatedscope.core.monitors :members: :private-members: + +federatedscope.core.auxiliaries +----------------------- + +.. autofunction:: federatedscope.core.auxiliaries.aggregator_builder.get_aggregator() +.. autofunction:: federatedscope.core.auxiliaries.criterion_builder.get_criterion() +.. autofunction:: federatedscope.core.auxiliaries.data_builder.get_data() +.. autofunction:: federatedscope.core.auxiliaries.dataloader_builder.get_dataloader() +.. autofunction:: federatedscope.core.auxiliaries.metric_builder.get_metric() +.. autofunction:: federatedscope.core.auxiliaries.model_builder.get_model() +.. autofunction:: federatedscope.core.auxiliaries.optimizer_builder.get_optimizer() +.. autofunction:: federatedscope.core.auxiliaries.regularizer_builder.get_regularizer() +.. autofunction:: federatedscope.core.auxiliaries.runner_builder.get_runner() +.. autofunction:: federatedscope.core.auxiliaries.sampler_builder.get_sampler() +.. autofunction:: federatedscope.core.auxiliaries.scheduler_builder.get_scheduler() +.. autofunction:: federatedscope.core.auxiliaries.splitter_builder.get_splitter() +.. autofunction:: federatedscope.core.auxiliaries.trainer_builder.get_trainer() +.. autofunction:: federatedscope.core.auxiliaries.transform_builder.get_transform() +.. autofunction:: federatedscope.core.auxiliaries.worker_builder.get_worker() diff --git a/federatedscope/autotune/algos.py b/federatedscope/autotune/algos.py index 27469d136..b84c2b8d6 100644 --- a/federatedscope/autotune/algos.py +++ b/federatedscope/autotune/algos.py @@ -13,7 +13,7 @@ from federatedscope.core.auxiliaries.data_builder import get_data from federatedscope.core.auxiliaries.worker_builder import get_client_cls, \ get_server_cls -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.autotune.utils import parse_search_space, \ config2cmdargs, config2str, summarize_hpo_results diff --git a/federatedscope/autotune/utils.py b/federatedscope/autotune/utils.py index 959d9112e..aefa771a5 100644 --- a/federatedscope/autotune/utils.py +++ b/federatedscope/autotune/utils.py @@ -139,7 +139,7 @@ def eval_in_fs(cfg, config, budget): from federatedscope.core.auxiliaries.data_builder import get_data from federatedscope.core.auxiliaries.worker_builder import \ get_client_cls, get_server_cls - from federatedscope.core.fed_runner import get_runner + from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.autotune.utils import config2cmdargs from os.path import join as osp diff --git a/federatedscope/core/auxiliaries/runner_builder.py b/federatedscope/core/auxiliaries/runner_builder.py new file mode 100644 index 000000000..678892e50 --- /dev/null +++ b/federatedscope/core/auxiliaries/runner_builder.py @@ -0,0 +1,15 @@ +from federatedscope.core.fed_runner import StandaloneRunner, DistributedRunner + + +def get_runner(data, server_class, client_class, config, client_configs=None): + # Instantiate a Runner based on a configuration file + mode = config.federate.mode.lower() + runner_dict = { + 'standalone': StandaloneRunner, + 'distributed': DistributedRunner + } + return runner_dict[mode](data=data, + server_class=server_class, + client_class=client_class, + config=config, + client_configs=client_configs) diff --git a/federatedscope/core/fed_runner.py b/federatedscope/core/fed_runner.py index aadf6e7be..c14fbd642 100644 --- a/federatedscope/core/fed_runner.py +++ b/federatedscope/core/fed_runner.py @@ -10,25 +10,10 @@ from federatedscope.core.gpu_manager import GPUManager from federatedscope.core.auxiliaries.model_builder import get_model from federatedscope.core.auxiliaries.utils import get_resource_info -from federatedscope.core.data.utils import merge_data logger = logging.getLogger(__name__) -def get_runner(data, server_class, client_class, config, client_configs=None): - # Instantiate a Runner based on a configuration file - mode = config.federate.mode.lower() - runner_dict = { - 'standalone': StandaloneRunner, - 'distributed': DistributedRunner - } - return runner_dict[mode](data=data, - server_class=server_class, - client_class=client_class, - config=config, - client_configs=client_configs) - - class BaseRunner(object): """ This class is used to construct an FL course, which includes `_set_up` diff --git a/federatedscope/core/workers/__init__.py b/federatedscope/core/workers/__init__.py index 777c989da..aaee5bfa9 100644 --- a/federatedscope/core/workers/__init__.py +++ b/federatedscope/core/workers/__init__.py @@ -1,8 +1,3 @@ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division -from __future__ import with_statement - from federatedscope.core.workers.base_worker import Worker from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client diff --git a/federatedscope/main.py b/federatedscope/main.py index f7dabc993..20120766f 100644 --- a/federatedscope/main.py +++ b/federatedscope/main.py @@ -14,7 +14,7 @@ from federatedscope.core.auxiliaries.worker_builder import get_client_cls, \ get_server_cls from federatedscope.core.configs.config import global_cfg, CfgNode -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner if os.environ.get('https_proxy'): del os.environ['https_proxy'] From ab4ccb655f62285b18303f623e7a3335c7576816 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 16:49:10 +0800 Subject: [PATCH 05/21] fix minor bugs in test --- tests/test_CRA_gan_attack.py | 2 +- tests/test_MIA_gradient_ascent.py | 2 +- tests/test_PIA_toy.py | 2 +- tests/test_asyn_cifar10.py | 2 +- tests/test_backdoor_attack.py | 2 +- tests/test_ditto.py | 2 +- tests/test_efficient_simulation.py | 2 +- tests/test_external_dataset.py | 2 +- tests/test_fedem.py | 2 +- tests/test_fedopt.py | 2 +- tests/test_fedprox.py | 2 +- tests/test_fedsageplus.py | 2 +- tests/test_femnist.py | 2 +- tests/test_finetune_lr.py | 2 +- tests/test_global_train_lr.py | 2 +- tests/test_graph_node_trainer.py | 2 +- tests/test_local_train_lr.py | 2 +- tests/test_mf.py | 2 +- tests/test_nbafl.py | 2 +- tests/test_optimizer.py | 2 +- tests/test_pfedme.py | 2 +- tests/test_rec_IG_opt_attack.py | 2 +- tests/test_rec_opt_attack.py | 2 +- tests/test_toy_lr.py | 2 +- tests/test_unseen_clients_lr.py | 2 +- tests/test_vertical_fl.py | 2 +- 26 files changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/test_CRA_gan_attack.py b/tests/test_CRA_gan_attack.py index 5a1bb72dd..d7c28b393 100644 --- a/tests/test_CRA_gan_attack.py +++ b/tests/test_CRA_gan_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_MIA_gradient_ascent.py b/tests/test_MIA_gradient_ascent.py index b4c3dd750..23e98868b 100644 --- a/tests/test_MIA_gradient_ascent.py +++ b/tests/test_MIA_gradient_ascent.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_PIA_toy.py b/tests/test_PIA_toy.py index e25555926..df0c4b5e9 100644 --- a/tests/test_PIA_toy.py +++ b/tests/test_PIA_toy.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_asyn_cifar10.py b/tests/test_asyn_cifar10.py index 530b2ad01..d0a07823b 100644 --- a/tests/test_asyn_cifar10.py +++ b/tests/test_asyn_cifar10.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_backdoor_attack.py b/tests/test_backdoor_attack.py index f0683c812..187afff7b 100644 --- a/tests/test_backdoor_attack.py +++ b/tests/test_backdoor_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_ditto.py b/tests/test_ditto.py index 77f23e445..c1cf561ce 100644 --- a/tests/test_ditto.py +++ b/tests/test_ditto.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_efficient_simulation.py b/tests/test_efficient_simulation.py index aa16cdf9a..232028af1 100644 --- a/tests/test_efficient_simulation.py +++ b/tests/test_efficient_simulation.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_external_dataset.py b/tests/test_external_dataset.py index 806efe603..c95097548 100644 --- a/tests/test_external_dataset.py +++ b/tests/test_external_dataset.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_fedem.py b/tests/test_fedem.py index 28a7dec23..9750d020f 100644 --- a/tests/test_fedem.py +++ b/tests/test_fedem.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_fedopt.py b/tests/test_fedopt.py index 30d9a53aa..2eaff505f 100644 --- a/tests/test_fedopt.py +++ b/tests/test_fedopt.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_fedprox.py b/tests/test_fedprox.py index 2f7efb049..2a66e4689 100644 --- a/tests/test_fedprox.py +++ b/tests/test_fedprox.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_fedsageplus.py b/tests/test_fedsageplus.py index 85173664f..65987da41 100644 --- a/tests/test_fedsageplus.py +++ b/tests/test_fedsageplus.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_femnist.py b/tests/test_femnist.py index 31141b2bc..a0f63a8b7 100644 --- a/tests/test_femnist.py +++ b/tests/test_femnist.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls SAMPLE_CLIENT_NUM = 5 diff --git a/tests/test_finetune_lr.py b/tests/test_finetune_lr.py index ef137d59b..7660dbd82 100644 --- a/tests/test_finetune_lr.py +++ b/tests/test_finetune_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_global_train_lr.py b/tests/test_global_train_lr.py index f2ed6632a..684aff96e 100644 --- a/tests/test_global_train_lr.py +++ b/tests/test_global_train_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_graph_node_trainer.py b/tests/test_graph_node_trainer.py index 8f74ac660..fd874eb0d 100644 --- a/tests/test_graph_node_trainer.py +++ b/tests/test_graph_node_trainer.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_local_train_lr.py b/tests/test_local_train_lr.py index 9d06c9968..f7fc7f6ae 100644 --- a/tests/test_local_train_lr.py +++ b/tests/test_local_train_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, \ get_client_cls diff --git a/tests/test_mf.py b/tests/test_mf.py index 8ff641a65..9dbfe3fb0 100644 --- a/tests/test_mf.py +++ b/tests/test_mf.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_nbafl.py b/tests/test_nbafl.py index c61036bb7..3a8889988 100644 --- a/tests/test_nbafl.py +++ b/tests/test_nbafl.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index ba2733053..2c06d322e 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_pfedme.py b/tests/test_pfedme.py index 48e793083..cf1e8741c 100644 --- a/tests/test_pfedme.py +++ b/tests/test_pfedme.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_rec_IG_opt_attack.py b/tests/test_rec_IG_opt_attack.py index 25fc840d4..7cc6bbd25 100644 --- a/tests/test_rec_IG_opt_attack.py +++ b/tests/test_rec_IG_opt_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_rec_opt_attack.py b/tests/test_rec_opt_attack.py index be5ecfef2..92517ee17 100644 --- a/tests/test_rec_opt_attack.py +++ b/tests/test_rec_opt_attack.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_toy_lr.py b/tests/test_toy_lr.py index 523a4bec4..7eb1a673d 100644 --- a/tests/test_toy_lr.py +++ b/tests/test_toy_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_unseen_clients_lr.py b/tests/test_unseen_clients_lr.py index 53e9d6ec0..019711b38 100644 --- a/tests/test_unseen_clients_lr.py +++ b/tests/test_unseen_clients_lr.py @@ -5,7 +5,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls diff --git a/tests/test_vertical_fl.py b/tests/test_vertical_fl.py index b5394c0f9..69da23f2e 100644 --- a/tests/test_vertical_fl.py +++ b/tests/test_vertical_fl.py @@ -6,7 +6,7 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg -from federatedscope.core.fed_runner import get_runner +from federatedscope.core.auxiliaries.runner_builder import get_runner class vFLTest(unittest.TestCase): From 1056b40fad277058904d7eee61087167e0f694c1 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 18:24:18 +0800 Subject: [PATCH 06/21] update docstring for builders --- federatedscope/contrib/scheduler/example.py | 4 +- .../core/auxiliaries/criterion_builder.py | 5 +- .../core/auxiliaries/data_builder.py | 75 ++++++++++++++++++- .../core/auxiliaries/dataloader_builder.py | 21 +++++- .../core/auxiliaries/metric_builder.py | 35 +++++++++ .../core/auxiliaries/model_builder.py | 40 +++++++--- .../core/auxiliaries/optimizer_builder.py | 29 +++++-- .../core/auxiliaries/regularizer_builder.py | 22 +++++- .../core/auxiliaries/runner_builder.py | 36 ++++++--- 9 files changed, 223 insertions(+), 44 deletions(-) diff --git a/federatedscope/contrib/scheduler/example.py b/federatedscope/contrib/scheduler/example.py index 642d6fab3..e505829a7 100644 --- a/federatedscope/contrib/scheduler/example.py +++ b/federatedscope/contrib/scheduler/example.py @@ -1,14 +1,14 @@ from federatedscope.register import register_scheduler -def call_my_scheduler(optimizer, type): +def call_my_scheduler(optimizer, reg_type): try: import torch.optim as optim except ImportError: optim = None scheduler = None - if type == 'myscheduler': + if reg_type == 'myscheduler': if optim is not None: lr_lambda = [lambda epoch: epoch // 30] scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) diff --git a/federatedscope/core/auxiliaries/criterion_builder.py b/federatedscope/core/auxiliaries/criterion_builder.py index 2baf1b85a..269c5868a 100644 --- a/federatedscope/core/auxiliaries/criterion_builder.py +++ b/federatedscope/core/auxiliaries/criterion_builder.py @@ -19,8 +19,9 @@ def get_criterion(criterion_type, device): """ - This function builds an instance of loss functions from \ - "https://pytorch.org/docs/stable/nn.html#loss-functions". + This function builds an instance of loss functions from: \ + "https://pytorch.org/docs/stable/nn.html#loss-functions", + where the ``criterion_type`` is chosen from. Arguments: criterion_type: loss function type diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 617d9689e..3f534a2f9 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -37,12 +37,79 @@ def get_data(config, client_cfgs=None): """Instantiate the data and update the configuration accordingly if necessary. + Arguments: - config: a cfg node object. - client_cfgs: dict of client-specific cfg node object. + config: a cfg node object, see ``cfg.data`` for details + client_cfgs: dict of client-specific cfg node object Returns: - obj: The dataset object. - cfg.node: The updated configuration. + The dataset object and the updated configuration. + + Note: + The available ``data.type`` is shown below: + ================================== =========================== + Data type Domain + ================================== =========================== + FEMNIST CV + Celeba CV + ``${DNAME}@torchvision`` CV + Shakespeare NLP + SubReddit NLP + Twitter (Sentiment140) NLP + ``${DNAME}@torchtext`` NLP + ``${DNAME}@huggingface_datasets`` NLP + Cora Graph (node-level) + CiteSeer Graph (node-level) + PubMed Graph (node-level) + DBLP_conf Graph (node-level) + DBLP_org Graph (node-level) + csbm Graph (node-level) + Epinions Graph (link-level) + Ciao Graph (link-level) + FB15k Graph (link-level) + FB15k-237 Graph (link-level) + WN18 Graph (link-level) + MUTAG Graph (graph-level) + BZR Graph (graph-level) + COX2 Graph (graph-level) + DHFR Graph (graph-level) + PTC_MR Graph (graph-level) + AIDS Graph (graph-level) + NCI1 Graph (graph-level) + ENZYMES Graph (graph-level) + DD Graph (graph-level) + PROTEINS Graph (graph-level) + COLLAB Graph (graph-level) + IMDB-BINARY Graph (graph-level) + IMDB-MULTI Graph (graph-level) + REDDIT-BINARY Graph (graph-level) + HIV Graph (graph-level) + ESOL Graph (graph-level) + FREESOLV Graph (graph-level) + LIPO Graph (graph-level) + PCBA Graph (graph-level) + MUV Graph (graph-level) + BACE Graph (graph-level) + BBBP Graph (graph-level) + TOX21 Graph (graph-level) + TOXCAST Graph (graph-level) + SIDER Graph (graph-level) + CLINTOX Graph (graph-level) + graph_multi_domain_mol Graph (graph-level) + graph_multi_domain_small Graph (graph-level) + graph_multi_domain_biochem Graph (graph-level) + cikmcup Graph (graph-level) + toy Tabular + synthetic Tabular + quadratic Tabular + ``${DNAME}openml`` Tabular + vertical_fl_data Tabular(vertical) + VFLMovieLens1M Recommendation + VFLMovieLens10M Recommendation + HFLMovieLens1M Recommendation + HFLMovieLens10M Recommendation + VFLNetflix Recommendation + HFLNetflix Recommendation + ================================== =========================== """ # Fix the seed for data generation setup_seed(12345) diff --git a/federatedscope/core/auxiliaries/dataloader_builder.py b/federatedscope/core/auxiliaries/dataloader_builder.py index a412d76d2..4b9574113 100644 --- a/federatedscope/core/auxiliaries/dataloader_builder.py +++ b/federatedscope/core/auxiliaries/dataloader_builder.py @@ -15,13 +15,26 @@ def get_dataloader(dataset, config, split='train'): Args: dataset: dataset from which to load the data. config: configs containing batch_size, shuffle, etc. - split: current split (default: 'train'), if split is 'test', shuffle - will be `False`. And in PyG, 'test' split will use - `NeighborSampler` by default. + split: current split (default: ``train``), if split is ``test``, \ + ``cfg.dataloader.shuffle`` will be ``False``. And in PyG, ``test`` \ + split will use ``NeighborSampler`` by default. Returns: - dataloader: Instance of specific DataLoader configured by config. + Instance of specific ``DataLoader`` configured by config. + Note: + The key-value pairs of ``dataloader.type`` and ``DataLoader``: + ======================== =============================== + ``dataloader.type`` Source + ======================== =============================== + ``raw`` No DataLoader + ``base`` ``torch.utils.data.DataLoader`` + ``pyg`` ``torch_geometric.loader.DataLoader`` + ``graphsaint-rw`` \ + ``torch_geometric.loader.GraphSAINTRandomWalkSampler`` + ``neighbor`` ``torch_geometric.loader.NeighborSampler`` + ``mf`` ``federatedscope.mf.dataloader.MFDataLoader`` + ======================== =============================== """ # DataLoader builder only support torch backend now. if config.backend != 'torch': diff --git a/federatedscope/core/auxiliaries/metric_builder.py b/federatedscope/core/auxiliaries/metric_builder.py index 0d825754a..f25d5af93 100644 --- a/federatedscope/core/auxiliaries/metric_builder.py +++ b/federatedscope/core/auxiliaries/metric_builder.py @@ -12,6 +12,41 @@ def get_metric(types): + """ + This function returns a dict, where the key is metric name, and value is \ + the function of how to calculate the metric and a bool to indicate the \ + metric is larger the better. + + Args: + types: list of metric names + + Returns: + A metric calculator dict, such as \ + ``{'loss': (eval_loss, False), 'acc': (eval_acc, True), ...}`` + + Note: + The key-value pairs of built-in metric and related funcs and \ + ``the_larger_the_better`` sign is shown below: + ================= ============================================= ===== + Metric name Source \ + The larger the better + ================= ============================================= ===== + ``loss`` ``monitors.metric_calculator.eval_loss`` False + ``avg_loss`` ``monitors.metric_calculator.eval_avg_loss`` False + ``total`` ``monitors.metric_calculator.eval_total`` False + ``correct`` ``monitors.metric_calculator.eval_correct`` True + ``acc`` ``monitors.metric_calculator.eval_acc`` True + ``ap`` ``monitors.metric_calculator.eval_ap`` True + ``f1`` ``monitors.metric_calculator.eval_f1_score`` True + ``roc_auc`` ``monitors.metric_calculator.eval_roc_auc`` True + ``rmse`` ``monitors.metric_calculator.eval_rmse`` False + ``mse`` ``monitors.metric_calculator.eval_mse`` False + ``loss_regular`` ``monitors.metric_calculator.eval_regular`` False + ``imp_ratio`` ``monitors.metric_calculator.eval_imp_ratio`` True + ``std`` ``None`` False + ``hits@{n}`` ``monitors.metric_calculator.eval_hits`` True + ================= ============================================= ===== + """ metrics = dict() for func in register.metric_dict.values(): res = func(types) diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 4bec0c2ff..cddaf732b 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -15,13 +15,12 @@ def get_shape_from_data(data, model_config, backend='torch'): """ - Extract the input shape from the given data, which can be used to build - the data. Users can also use `data.input_shape` to specify the shape + Extract the input shape from the given data, which can be used to build \ + the data. Users can also use `data.input_shape` to specify the shape. + Arguments: - data (`ClientData`): the data used for local training or evaluation - The expected data format: - 1): {train/val/test: {x:ndarray, y:ndarray}}} - 2): {train/val/test: DataLoader} + data (`ClientData`): the data used for local training or evaluation \ + Returns: shape (tuple): the input shape """ @@ -91,11 +90,32 @@ def get_shape_from_data(data, model_config, backend='torch'): def get_model(model_config, local_data=None, backend='torch'): """ + This function builds an instance of model to be trained. + Arguments: - local_data (object): the model to be instantiated is - responsible for the given data. + model_config: ``cfg.model``, a submodule of ``cfg`` + local_data: the model to be instantiated is responsible for the \ + given data + backend: chosen from ``torch`` and ``tensorflow`` Returns: - model (torch.Module): the instantiated model. + model (``torch.Module``): the instantiated model. + + Note: + The key-value pairs of built-in model and source are shown below: + =================================== ============================== + Model type Source + =================================== ============================== + ``lr`` ``core.lr.LogisticRegression`` \ + or ``cross_backends.LogisticRegression`` + ``mlp`` ``core.mlp.MLP`` + ``quadratic`` ``tabular.model.QuadraticModel`` + ``convnet2, convnet5, vgg11`` ``cv.model.get_cnn()`` + ``lstm`` ``nlp.model.get_rnn()`` + ``{}@transformers`` ``nlp.model.get_transformer()`` + ``gcn, sage, gpr, gat, gin, mpnn`` ``gfl.model.get_gnn()`` + ``vmfnet, hmfnet`` \ + ``mf.model.model_builder.get_mfnet()`` + =================================== ============================== """ if local_data is not None: input_shape = get_shape_from_data(local_data, model_config, backend) @@ -135,7 +155,7 @@ def get_model(model_config, local_data=None, backend='torch'): from federatedscope.tabular.model import QuadraticModel model = QuadraticModel(input_shape[-1], 1) - elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: + elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11']: from federatedscope.cv.model import get_cnn model = get_cnn(model_config, input_shape) elif model_config.type.lower() in ['lstm']: diff --git a/federatedscope/core/auxiliaries/optimizer_builder.py b/federatedscope/core/auxiliaries/optimizer_builder.py index bd6d1bd13..cd529f09d 100644 --- a/federatedscope/core/auxiliaries/optimizer_builder.py +++ b/federatedscope/core/auxiliaries/optimizer_builder.py @@ -17,7 +17,20 @@ f'available.') -def get_optimizer(model, type, lr, **kwargs): +def get_optimizer(model, opt_type, lr, **kwargs): + """ + This function returns an instantiated optimizer to optimize the model. + + Args: + model: model to be optimized + opt_type: type of optimizer, see \ + https://pytorch.org/docs/stable/optim.html + lr: learning rate + **kwargs: kwargs dict + + Returns: + An instantiated optimizer + """ if torch is None: return None # in case of users have not called the cfg.freeze() @@ -30,19 +43,19 @@ def get_optimizer(model, type, lr, **kwargs): del tmp_kwargs['is_ready_for_run'] for func in register.optimizer_dict.values(): - optimizer = func(model, type, lr, **tmp_kwargs) + optimizer = func(model, opt_type, lr, **tmp_kwargs) if optimizer is not None: return optimizer - if isinstance(type, str): - if hasattr(torch.optim, type): + if isinstance(opt_type, str): + if hasattr(torch.optim, opt_type): if isinstance(model, torch.nn.Module): - return getattr(torch.optim, type)(model.parameters(), lr, - **tmp_kwargs) + return getattr(torch.optim, opt_type)(model.parameters(), lr, + **tmp_kwargs) else: - return getattr(torch.optim, type)(model, lr, **tmp_kwargs) + return getattr(torch.optim, opt_type)(model, lr, **tmp_kwargs) else: raise NotImplementedError( - 'Optimizer {} not implement'.format(type)) + 'Optimizer {} not implement'.format(opt_type)) else: raise TypeError() diff --git a/federatedscope/core/auxiliaries/regularizer_builder.py b/federatedscope/core/auxiliaries/regularizer_builder.py index 75af98cf9..c96f35c91 100644 --- a/federatedscope/core/auxiliaries/regularizer_builder.py +++ b/federatedscope/core/auxiliaries/regularizer_builder.py @@ -6,17 +6,31 @@ Module = object -def get_regularizer(type): - if type is None or type == '': +def get_regularizer(reg_type): + """ + This function builds an instance of scheduler to regularize training. + + Args: + reg_type: type of scheduler, such as see \ + https://pytorch.org/docs/stable/optim.html for details + + Returns: + An instantiated scheduler. + + Note: + We do not provide built-in scheduler, please follow \ + ``contrib.scheduler.example`` to implement your own scheduler. + """ + if reg_type is None or reg_type == '': return DummyRegularizer() for func in regularizer_dict.values(): - regularizer = func(type) + regularizer = func(reg_type) if regularizer is not None: return regularizer() raise NotImplementedError( - "Regularizer {} is not implemented.".format(type)) + "Regularizer {} is not implemented.".format(reg_type)) class DummyRegularizer(Module): diff --git a/federatedscope/core/auxiliaries/runner_builder.py b/federatedscope/core/auxiliaries/runner_builder.py index 678892e50..f336f99b1 100644 --- a/federatedscope/core/auxiliaries/runner_builder.py +++ b/federatedscope/core/auxiliaries/runner_builder.py @@ -2,14 +2,30 @@ def get_runner(data, server_class, client_class, config, client_configs=None): - # Instantiate a Runner based on a configuration file + """ + Instantiate a runner based on a configuration file + + Args: + data: ``core.data.StandaloneDataDict`` in standalone mode, \ + ``core.data.ClientData`` in distribute mode + server_class: server class + client_class: client class + config: configurations for FL, see ``federatedscope.core.configs`` + client_configs: client-specific configurations + + Returns: + An instantiated FedRunner to run the FL course. + """ + mode = config.federate.mode.lower() - runner_dict = { - 'standalone': StandaloneRunner, - 'distributed': DistributedRunner - } - return runner_dict[mode](data=data, - server_class=server_class, - client_class=client_class, - config=config, - client_configs=client_configs) + + if mode == 'standalone': + runner_cls = StandaloneRunner + elif mode == 'distributed': + runner_cls = DistributedRunner + + return runner_cls(data=data, + server_class=server_class, + client_class=client_class, + config=config, + client_configs=client_configs) From f6a8ca0786837606dfc255168354ea3c4320c8c8 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 19:21:41 +0800 Subject: [PATCH 07/21] fix minor bugs --- .../attack/privacy_attacks/passive_PIA.py | 2 +- .../worker_as_attacker/server_attacker.py | 2 +- .../core/auxiliaries/runner_builder.py | 11 ++++++++++ .../core/auxiliaries/sampler_builder.py | 22 +++++++++++++++++++ 4 files changed, 35 insertions(+), 2 deletions(-) diff --git a/federatedscope/attack/privacy_attacks/passive_PIA.py b/federatedscope/attack/privacy_attacks/passive_PIA.py index 8d7ecef5f..f02618ccd 100644 --- a/federatedscope/attack/privacy_attacks/passive_PIA.py +++ b/federatedscope/attack/privacy_attacks/passive_PIA.py @@ -93,7 +93,7 @@ def _get_parameter_updates(self, model, previous_para, x_batch, y_batch): # get last phase model parameters model.load_state_dict(previous_para, strict=False) - optimizer = get_optimizer(type=self.fl_type_optimizer, + optimizer = get_optimizer(opt_type=self.fl_type_optimizer, model=model, lr=self.fl_lr) diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 226568242..1ee3b1f97 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -333,7 +333,7 @@ def __init__(self, batch_size=100) # self.optimizer = get_optimizer( - # type=self._cfg.fedopt.type_optimizer, model=self.model, + # opt_type=self._cfg.fedopt.type_optimizer, model=self.model, # lr=self._cfg.fedopt.optimizer.lr) # print(self.optimizer) def callback_funcs_model_para(self, message: Message): diff --git a/federatedscope/core/auxiliaries/runner_builder.py b/federatedscope/core/auxiliaries/runner_builder.py index f336f99b1..16a5f08da 100644 --- a/federatedscope/core/auxiliaries/runner_builder.py +++ b/federatedscope/core/auxiliaries/runner_builder.py @@ -15,6 +15,17 @@ def get_runner(data, server_class, client_class, config, client_configs=None): Returns: An instantiated FedRunner to run the FL course. + + Note: + The key-value pairs of built-in runner and source are shown below: + =================================== ============================== + Mode Source + =================================== ============================== + ``standalone`` \ + ``core.fed_runner.StandaloneRunner`` + ``distributed`` \ + ``core.fed_runner.DistributedRunner`` + =================================== ============================== """ mode = config.federate.mode.lower() diff --git a/federatedscope/core/auxiliaries/sampler_builder.py b/federatedscope/core/auxiliaries/sampler_builder.py index 0b7d2ff44..94e9c804b 100644 --- a/federatedscope/core/auxiliaries/sampler_builder.py +++ b/federatedscope/core/auxiliaries/sampler_builder.py @@ -9,6 +9,28 @@ def get_sampler(sample_strategy='uniform', client_num=None, client_info=None, bins=10): + """ + This function builds a sampler for sampling clients who should join the \ + aggregation per communication round. + + Args: + sample_strategy: Sampling strategy of sampler + client_num: total number of client joining the FL course + client_info: client information + bins: size of bins for group sampler + + Returns: + An instantiated Sampler to sample during aggregation. + + Note: + The key-value pairs of built-in sampler and source are shown below: + =================================== ============================== + Sampling strategy Source + =================================== ============================== + ``uniform`` ``core.sampler.UniformSampler`` + ``group`` ``core.sampler.GroupSampler`` + =================================== ============================== + """ if sample_strategy == 'uniform': return UniformSampler(client_num=client_num) elif sample_strategy == 'group': From b85aa7b2f5b0155c89f40a07025f474398ce8ac9 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 20:29:42 +0800 Subject: [PATCH 08/21] fix minor bugs --- doc/source/core.rst | 3 +- .../core/auxiliaries/data_builder.py | 2 +- .../core/auxiliaries/regularizer_builder.py | 8 +-- .../core/auxiliaries/runner_builder.py | 14 ++--- .../core/auxiliaries/scheduler_builder.py | 30 +++++++--- .../core/auxiliaries/splitter_builder.py | 25 +++++++++ .../core/auxiliaries/transform_builder.py | 12 ++-- .../core/auxiliaries/worker_builder.py | 55 +++++++++++++++++++ 8 files changed, 120 insertions(+), 29 deletions(-) diff --git a/doc/source/core.rst b/doc/source/core.rst index 39d626954..b1d3bbe42 100644 --- a/doc/source/core.rst +++ b/doc/source/core.rst @@ -68,4 +68,5 @@ federatedscope.core.auxiliaries .. autofunction:: federatedscope.core.auxiliaries.splitter_builder.get_splitter() .. autofunction:: federatedscope.core.auxiliaries.trainer_builder.get_trainer() .. autofunction:: federatedscope.core.auxiliaries.transform_builder.get_transform() -.. autofunction:: federatedscope.core.auxiliaries.worker_builder.get_worker() +.. autofunction:: federatedscope.core.auxiliaries.worker_builder.get_client_cls() +.. autofunction:: federatedscope.core.auxiliaries.worker_builder.get_server_cls() diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 3f534a2f9..a7d07d262 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -39,7 +39,7 @@ def get_data(config, client_cfgs=None): necessary. Arguments: - config: a cfg node object, see ``cfg.data`` for details + config: a cfg node object client_cfgs: dict of client-specific cfg node object Returns: The dataset object and the updated configuration. diff --git a/federatedscope/core/auxiliaries/regularizer_builder.py b/federatedscope/core/auxiliaries/regularizer_builder.py index c96f35c91..cea57d99e 100644 --- a/federatedscope/core/auxiliaries/regularizer_builder.py +++ b/federatedscope/core/auxiliaries/regularizer_builder.py @@ -8,18 +8,14 @@ def get_regularizer(reg_type): """ - This function builds an instance of scheduler to regularize training. + This function builds an instance of regularizer to regularize training. Args: reg_type: type of scheduler, such as see \ https://pytorch.org/docs/stable/optim.html for details Returns: - An instantiated scheduler. - - Note: - We do not provide built-in scheduler, please follow \ - ``contrib.scheduler.example`` to implement your own scheduler. + An instantiated regularizer. """ if reg_type is None or reg_type == '': return DummyRegularizer() diff --git a/federatedscope/core/auxiliaries/runner_builder.py b/federatedscope/core/auxiliaries/runner_builder.py index 16a5f08da..fd7df011f 100644 --- a/federatedscope/core/auxiliaries/runner_builder.py +++ b/federatedscope/core/auxiliaries/runner_builder.py @@ -18,14 +18,12 @@ def get_runner(data, server_class, client_class, config, client_configs=None): Note: The key-value pairs of built-in runner and source are shown below: - =================================== ============================== - Mode Source - =================================== ============================== - ``standalone`` \ - ``core.fed_runner.StandaloneRunner`` - ``distributed`` \ - ``core.fed_runner.DistributedRunner`` - =================================== ============================== + =============================== ============================== + Mode Source + =============================== ============================== + ``standalone`` ``core.fed_runner.StandaloneRunner`` + ``distributed`` ``core.fed_runner.DistributedRunner`` + =============================== ============================== """ mode = config.federate.mode.lower() diff --git a/federatedscope/core/auxiliaries/scheduler_builder.py b/federatedscope/core/auxiliaries/scheduler_builder.py index afc7a5f99..d71444a9e 100644 --- a/federatedscope/core/auxiliaries/scheduler_builder.py +++ b/federatedscope/core/auxiliaries/scheduler_builder.py @@ -16,19 +16,35 @@ f'available.') -def get_scheduler(optimizer, type, **kwargs): +def get_scheduler(optimizer, scheduler_type, **kwargs): + """ + This function builds an instance of scheduler. + + Args: + optimizer: optimizer to be scheduled + scheduler_type: type of scheduler + **kwargs: kwargs dict + + Returns: + An instantiated scheduler. + + Note: + Please follow ``contrib.scheduler.example`` to implement your own \ + scheduler. + """ for func in register.scheduler_dict.values(): - scheduler = func(optimizer, type) + scheduler = func(optimizer, scheduler_type) if scheduler is not None: return scheduler - if torch is None or type == '': + if torch is None or scheduler_type == '': return None - if isinstance(type, str): - if hasattr(torch.optim.lr_scheduler, type): - return getattr(torch.optim.lr_scheduler, type)(optimizer, **kwargs) + if isinstance(scheduler_type, str): + if hasattr(torch.optim.lr_scheduler, scheduler_type): + return getattr(torch.optim.lr_scheduler, scheduler_type)(optimizer, + **kwargs) else: raise NotImplementedError( - 'Scheduler {} not implement'.format(type)) + 'Scheduler {} not implement'.format(scheduler_type)) else: raise TypeError() diff --git a/federatedscope/core/auxiliaries/splitter_builder.py b/federatedscope/core/auxiliaries/splitter_builder.py index fe2a138d7..b6c47fc4a 100644 --- a/federatedscope/core/auxiliaries/splitter_builder.py +++ b/federatedscope/core/auxiliaries/splitter_builder.py @@ -5,6 +5,31 @@ def get_splitter(config): + """ + This function is to build splitter to generate simulated federation \ + datasets from non-FL dataset. + + Args: + config: configurations for FL, see ``federatedscope.core.configs`` + + Returns: + An instance of splitter (see ``core.splitters`` for details) + + Note: + The key-value pairs of ``cfg.data.splitter`` and domain: + =================== ================================================ + Splitter type Domain + =================== ================================================ + lda Generic + iid Generic + louvain Graph (node-level) + random Graph (node-level) + rel_type Graph (link-level) + scaffold Molecular + scaffold_lda Molecular + rand_chunk Graph (graph-level) + =================== ================================================ + """ client_num = config.federate.client_num if config.data.splitter_args: kwargs = config.data.splitter_args[0] diff --git a/federatedscope/core/auxiliaries/transform_builder.py b/federatedscope/core/auxiliaries/transform_builder.py index 6cd1d81e2..1390499d0 100644 --- a/federatedscope/core/auxiliaries/transform_builder.py +++ b/federatedscope/core/auxiliaries/transform_builder.py @@ -3,16 +3,16 @@ def get_transform(config, package): - r""" + """ + This function is to build transforms applying to dataset. Args: - config: `CN` from `federatedscope/core/configs/config.py` - package: one of package from ['torchvision', 'torch_geometric', - 'torchtext', 'torchaudio'] + config: ``CN`` from ``federatedscope/core/configs/config.py`` + package: one of package from \ + ``['torchvision', 'torch_geometric', 'torchtext', 'torchaudio']`` Returns: - dict of transform functions. - + Dict of transform functions. """ transform_funcs = {} for name in ['transform', 'target_transform', 'pre_transform']: diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index f0f94e375..b0847e781 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -15,6 +15,32 @@ def get_client_cls(cfg): + """ + This function return a class of client. + + Args: + cfg: configurations for FL, see ``federatedscope.core.configs`` + + Returns: + A client class decided by ``cfg``. + + Note: + The key-value pairs of client type and source: + ==================== ============================================== + Client type Source + ==================== ============================================== + ``local`` ``core.workers.Client`` + ``fedavg`` ``core.workers.Client`` + ``pfedme`` ``core.workers.Client`` + ``ditto`` ``core.workers.Client`` + ``fedex`` ``autotune.fedex.FedExClient`` + ``vfl`` ``vertical_fl.worker.vFLClient`` + ``fedsageplus`` ``gfl.fedsageplus.worker.FedSagePlusClient`` + ``gcflplus`` ``gfl.gcflplus.worker.GCFLPlusClient`` + ``gradascent`` \ + ``attack.worker_as_attacker.active_client`` + ==================== ============================================== + """ for func in register.worker_dict.values(): worker_class = func(cfg.federate.method.lower()) if worker_class is not None: @@ -62,6 +88,35 @@ def get_client_cls(cfg): def get_server_cls(cfg): + """ + This function return a class of server. + + Args: + cfg: configurations for FL, see ``federatedscope.core.configs`` + + Returns: + A server class decided by ``cfg``. + + Note: + The key-value pairs of server type and source: + ==================== ============================================== + Server type Source + ==================== ============================================== + ``local`` ``core.workers.Server`` + ``fedavg`` ``core.workers.Server`` + ``pfedme`` ``core.workers.Server`` + ``ditto`` ``core.workers.Server`` + ``fedex`` ``autotune.fedex.FedExServer`` + ``vfl`` ``vertical_fl.worker.vFLServer`` + ``fedsageplus`` ``gfl.fedsageplus.worker.FedSagePlusServer`` + ``gcflplus`` ``gfl.gcflplus.worker.GCFLPlusServer`` + ``attack`` \ + ``attack.worker_as_attacker.server_attacker.PassiveServer`` and \ + ``attack.worker_as_attacker.server_attacker.PassivePIAServer`` + ``backdoor`` \ + ``attack.worker_as_attacker.server_attacker.BackdoorServer`` + ==================== ============================================== + """ for func in register.worker_dict.values(): worker_class = func(cfg.federate.method.lower()) if worker_class is not None: From e1a0a327e247ec1517be5b96eadaccde58686f2a Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 20:34:21 +0800 Subject: [PATCH 09/21] roll back optimizer --- .../attack/privacy_attacks/passive_PIA.py | 2 +- .../worker_as_attacker/server_attacker.py | 2 +- .../core/auxiliaries/optimizer_builder.py | 18 +++++++++--------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/federatedscope/attack/privacy_attacks/passive_PIA.py b/federatedscope/attack/privacy_attacks/passive_PIA.py index f02618ccd..8d7ecef5f 100644 --- a/federatedscope/attack/privacy_attacks/passive_PIA.py +++ b/federatedscope/attack/privacy_attacks/passive_PIA.py @@ -93,7 +93,7 @@ def _get_parameter_updates(self, model, previous_para, x_batch, y_batch): # get last phase model parameters model.load_state_dict(previous_para, strict=False) - optimizer = get_optimizer(opt_type=self.fl_type_optimizer, + optimizer = get_optimizer(type=self.fl_type_optimizer, model=model, lr=self.fl_lr) diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 1ee3b1f97..226568242 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -333,7 +333,7 @@ def __init__(self, batch_size=100) # self.optimizer = get_optimizer( - # opt_type=self._cfg.fedopt.type_optimizer, model=self.model, + # type=self._cfg.fedopt.type_optimizer, model=self.model, # lr=self._cfg.fedopt.optimizer.lr) # print(self.optimizer) def callback_funcs_model_para(self, message: Message): diff --git a/federatedscope/core/auxiliaries/optimizer_builder.py b/federatedscope/core/auxiliaries/optimizer_builder.py index cd529f09d..75a43b031 100644 --- a/federatedscope/core/auxiliaries/optimizer_builder.py +++ b/federatedscope/core/auxiliaries/optimizer_builder.py @@ -17,13 +17,13 @@ f'available.') -def get_optimizer(model, opt_type, lr, **kwargs): +def get_optimizer(model, type, lr, **kwargs): """ This function returns an instantiated optimizer to optimize the model. Args: model: model to be optimized - opt_type: type of optimizer, see \ + type: type of optimizer, see \ https://pytorch.org/docs/stable/optim.html lr: learning rate **kwargs: kwargs dict @@ -43,19 +43,19 @@ def get_optimizer(model, opt_type, lr, **kwargs): del tmp_kwargs['is_ready_for_run'] for func in register.optimizer_dict.values(): - optimizer = func(model, opt_type, lr, **tmp_kwargs) + optimizer = func(model, type, lr, **tmp_kwargs) if optimizer is not None: return optimizer - if isinstance(opt_type, str): - if hasattr(torch.optim, opt_type): + if isinstance(type, str): + if hasattr(torch.optim, type): if isinstance(model, torch.nn.Module): - return getattr(torch.optim, opt_type)(model.parameters(), lr, - **tmp_kwargs) + return getattr(torch.optim, type)(model.parameters(), lr, + **tmp_kwargs) else: - return getattr(torch.optim, opt_type)(model, lr, **tmp_kwargs) + return getattr(torch.optim, type)(model, lr, **tmp_kwargs) else: raise NotImplementedError( - 'Optimizer {} not implement'.format(opt_type)) + 'Optimizer {} not implement'.format(type)) else: raise TypeError() From ba9d313c98d6c9c88aafcfb7c5443358d78f8e54 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 31 Oct 2022 20:40:55 +0800 Subject: [PATCH 10/21] roll back for scheduler --- .../core/auxiliaries/scheduler_builder.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/federatedscope/core/auxiliaries/scheduler_builder.py b/federatedscope/core/auxiliaries/scheduler_builder.py index d71444a9e..28093cd73 100644 --- a/federatedscope/core/auxiliaries/scheduler_builder.py +++ b/federatedscope/core/auxiliaries/scheduler_builder.py @@ -16,13 +16,13 @@ f'available.') -def get_scheduler(optimizer, scheduler_type, **kwargs): +def get_scheduler(optimizer, type, **kwargs): """ This function builds an instance of scheduler. Args: optimizer: optimizer to be scheduled - scheduler_type: type of scheduler + type: type of scheduler **kwargs: kwargs dict Returns: @@ -33,16 +33,15 @@ def get_scheduler(optimizer, scheduler_type, **kwargs): scheduler. """ for func in register.scheduler_dict.values(): - scheduler = func(optimizer, scheduler_type) + scheduler = func(optimizer, type) if scheduler is not None: return scheduler - if torch is None or scheduler_type == '': + if torch is None or type == '': return None - if isinstance(scheduler_type, str): - if hasattr(torch.optim.lr_scheduler, scheduler_type): - return getattr(torch.optim.lr_scheduler, scheduler_type)(optimizer, - **kwargs) + if isinstance(type, str): + if hasattr(torch.optim.lr_scheduler, type): + return getattr(torch.optim.lr_scheduler, type)(optimizer, **kwargs) else: raise NotImplementedError( 'Scheduler {} not implement'.format(scheduler_type)) From 782549a2eac6c4544d2b387790bb1893fd5d5485 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Tue, 1 Nov 2022 15:12:48 +0800 Subject: [PATCH 11/21] update docstring for workers --- federatedscope/core/configs/cfg_data.py | 16 +-- federatedscope/core/fed_runner.py | 86 ++++++++----- federatedscope/core/workers/__init__.py | 4 +- federatedscope/core/workers/base_client.py | 54 ++++++--- federatedscope/core/workers/base_server.py | 51 +++++--- federatedscope/core/workers/base_worker.py | 16 ++- federatedscope/core/workers/client.py | 57 +++++---- federatedscope/core/workers/server.py | 133 ++++++++++++--------- 8 files changed, 267 insertions(+), 150 deletions(-) diff --git a/federatedscope/core/configs/cfg_data.py b/federatedscope/core/configs/cfg_data.py index b9eb0d48c..e8c64b381 100644 --- a/federatedscope/core/configs/cfg_data.py +++ b/federatedscope/core/configs/cfg_data.py @@ -88,35 +88,35 @@ def assert_data_cfg(cfg): # For compatibility with older versions of FS # TODO: delete this code block if cfg.data.loader != '': - logger.warning('config `cfg.data.loader` will be remove in the ' + logger.warning('config `cfg.data.loader` will be removed in the ' 'future, use `cfg.dataloader.type` instead.') cfg.dataloader.type = cfg.data.loader if cfg.data.batch_size != 64: - logger.warning('config `cfg.data.batch_size` will be remove in the ' + logger.warning('config `cfg.data.batch_size` will be removed in the ' 'future, use `cfg.dataloader.batch_size` instead.') cfg.dataloader.batch_size = cfg.data.batch_size if not cfg.data.shuffle: - logger.warning('config `cfg.data.shuffle` will be remove in the ' + logger.warning('config `cfg.data.shuffle` will be removed in the ' 'future, use `cfg.dataloader.shuffle` instead.') cfg.dataloader.shuffle = cfg.data.shuffle if cfg.data.num_workers != 0: - logger.warning('config `cfg.data.num_workers` will be remove in the ' + logger.warning('config `cfg.data.num_workers` will be removed in the ' 'future, use `cfg.dataloader.num_workers` instead.') cfg.dataloader.num_workers = cfg.data.num_workers if cfg.data.drop_last: - logger.warning('config `cfg.data.drop_last` will be remove in the ' + logger.warning('config `cfg.data.drop_last` will be removed in the ' 'future, use `cfg.dataloader.drop_last` instead.') cfg.dataloader.drop_last = cfg.data.drop_last if cfg.data.walk_length != 2: - logger.warning('config `cfg.data.walk_length` will be remove in the ' + logger.warning('config `cfg.data.walk_length` will be removed in the ' 'future, use `cfg.dataloader.walk_length` instead.') cfg.dataloader.walk_length = cfg.data.walk_length if cfg.data.num_steps != 30: - logger.warning('config `cfg.data.num_steps` will be remove in the ' + logger.warning('config `cfg.data.num_steps` will be removed in the ' 'future, use `cfg.dataloader.num_steps` instead.') cfg.dataloader.num_steps = cfg.data.num_steps if cfg.data.sizes != [10, 5]: - logger.warning('config `cfg.data.sizes` will be remove in the ' + logger.warning('config `cfg.data.sizes` will be removed in the ' 'future, use `cfg.dataloader.sizes` instead.') cfg.dataloader.sizes = cfg.data.sizes # -------------------------------------------------------------------- diff --git a/federatedscope/core/fed_runner.py b/federatedscope/core/fed_runner.py index c14fbd642..50d75ebc3 100644 --- a/federatedscope/core/fed_runner.py +++ b/federatedscope/core/fed_runner.py @@ -16,19 +16,31 @@ class BaseRunner(object): """ - This class is used to construct an FL course, which includes `_set_up` - and `run`. + This class is a base class to construct an FL course, which includes \ + ``_set_up()`` and ``run()``. - Arguments: - data: The data used in the FL courses, which are formatted as { - 'ID':data} for standalone mode. More details can be found in + Args: + data: The data used in the FL courses, which are formatted as \ + ``{'ID':data}`` for standalone mode. More details can be found in \ federatedscope.core.auxiliaries.data_builder . - server_class: The server class is used for instantiating a ( + server_class: The server class is used for instantiating a ( \ customized) server. - client_class: The client class is used for instantiating a ( + client_class: The client class is used for instantiating a ( \ customized) client. config: The configurations of the FL course. client_configs: The clients' configurations. + + Attributes: + data: The data used in the FL courses, which are formatted as \ + ``{'ID':data}`` for standalone mode. More details can be found in \ + federatedscope.core.auxiliaries.data_builder . + server: The instantiated server. + client: The instantiate client(s). + cfg : The configurations of the FL course. + client_cfgs: The clients' configurations. + mode: The run mode for FL, ``distributed`` or ``standalone`` + gpu_manager: manager of GPU resource + resource_info: information of resource """ def __init__(self, data, @@ -69,7 +81,7 @@ def __init__(self, @abc.abstractmethod def _set_up(self): """ - Set up client and/or server + Set up and instantiate the client/server. """ raise NotImplementedError @@ -83,9 +95,8 @@ def _get_server_args(self, resource_info, client_resource_info): client_resource_info: information of client's resource Returns: - server_data: None or data which server holds. - model: model to be aggregated. - kw: kwargs dict to instantiate the server. + (server_data, model, kw): None or data which server holds; model \ + to be aggregated; kwargs dict to instantiate the server. """ raise NotImplementedError @@ -99,24 +110,31 @@ def _get_client_args(self, client_id, resource_info): resource_info: information of resource Returns: - client_data: data which client holds. - kw: kwargs dict to instantiate the client. + (client_data, kw): data which client holds; kwargs dict to \ + instantiate the client. """ raise NotImplementedError @abc.abstractmethod def run(self): """ - Launch the worker + Launch the FL course Returns: - best_results: best results during the FL course + dict: best results during the FL course """ raise NotImplementedError def _setup_server(self, resource_info=None, client_resource_info=None): """ - Set up the server + Set up and instantiate the server. + + Args: + resource_info: information of resource + client_resource_info: information of client's resource + + Returns: + Instantiate server. """ assert self.server_class is not None, \ "`server_class` cannot be None." @@ -146,7 +164,15 @@ def _setup_client(self, client_model=None, resource_info=None): """ - Set up the Client + Set up and instantiate the client. + + Args: + client_id: ID of client + client_model: model of client + resource_info: information of resource + + Returns: + Instantiate client. """ assert self.client_class is not None, \ "`client_class` cannot be None" @@ -185,9 +211,6 @@ def _setup_client(self, class StandaloneRunner(BaseRunner): def _set_up(self): - """ - To set up server and client for standalone mode. - """ self.is_run_online = True if self.cfg.federate.online_aggr else False self.shared_comm_queue = deque() @@ -297,7 +320,7 @@ def run(self): def _handle_msg(self, msg, rcv=-1): """ - To simulate the message handling process (used only for the + To simulate the message handling process (used only for the \ standalone mode) """ if rcv != -1: @@ -321,9 +344,9 @@ def _handle_msg(self, msg, rcv=-1): def _run_simulation_online(self): """ Run for online aggregation. - Any broadcast operation would be executed client-by-clien to avoid - the existence of #clients messages at the same time. Currently, - only consider centralized topology + Any broadcast operation would be executed client-by-clien to avoid \ + the existence of #clients messages at the same time. Currently, \ + only consider centralized topology \ """ def is_broadcast(msg): return len(msg.receiver) >= 1 and msg.sender == 0 @@ -399,9 +422,6 @@ def _run_simulation(self): class DistributedRunner(BaseRunner): def _set_up(self): - """ - To set up server or client for distributed mode. - """ # sample resource information if self.resource_info is not None: sampled_index = np.random.choice(list(self.resource_info.keys())) @@ -457,15 +477,19 @@ class FedRunner(object): and `run`. Arguments: - data: The data used in the FL courses, which are formatted as { - 'ID':data} for standalone mode. More details can be found in + data: The data used in the FL courses, which are formatted as \ + ``{'ID':data}`` for standalone mode. More details can be found in \ federatedscope.core.auxiliaries.data_builder . - server_class: The server class is used for instantiating a ( + server_class: The server class is used for instantiating a ( \ customized) server. - client_class: The client class is used for instantiating a ( + client_class: The client class is used for instantiating a ( \ customized) client. config: The configurations of the FL course. client_configs: The clients' configurations. + + Warnings: + ``FedRunner`` will be removed in the future, consider \ + using ``StandaloneRunner`` or ``DistributedRunner`` instead! """ def __init__(self, data, diff --git a/federatedscope/core/workers/__init__.py b/federatedscope/core/workers/__init__.py index aaee5bfa9..f2ab836ec 100644 --- a/federatedscope/core/workers/__init__.py +++ b/federatedscope/core/workers/__init__.py @@ -1,5 +1,7 @@ from federatedscope.core.workers.base_worker import Worker +from federatedscope.core.workers.base_server import BaseServer +from federatedscope.core.workers.base_client import BaseClient from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client -__all__ = ['Worker', 'Server', 'Client'] +__all__ = ['Worker', 'BaseServer', 'BaseClient', 'Server', 'Client'] diff --git a/federatedscope/core/workers/base_client.py b/federatedscope/core/workers/base_client.py index 8ec2d3b83..5c983f741 100644 --- a/federatedscope/core/workers/base_client.py +++ b/federatedscope/core/workers/base_client.py @@ -5,20 +5,43 @@ class BaseClient(Worker): def __init__(self, ID, state, config, model, strategy): super(BaseClient, self).__init__(ID, state, config, model, strategy) + # TODO: move to worker self.msg_handlers = dict() + # TODO: move to worker def register_handlers(self, msg_type, callback_func): """ To bind a message type with a handling function. Arguments: msg_type (str): The defined message type - callback_func: The handling functions to handle the received - message + callback_func: The handling functions to handle the received \ + message """ self.msg_handlers[msg_type] = callback_func def _register_default_handlers(self): + """ + Register default handler dic to handle message, which includes \ + sender, receiver, state, and content. More detail can be found in \ + ``federatedscope.core.message``. + + Note: + the default handlers to handle messages and related callback \ + function are shown below: + ============================ ================================== + Message type Callback function + ============================ ================================== + ``assign_client_id`` ``callback_funcs_for_assign_id()`` + ``ask_for_join_in_info`` ``callback_funcs_for_join_in_info()`` + ``address`` ``callback_funcs_for_address()`` + ``model_para`` ``callback_funcs_for_model_para()`` + ``ss_model_para`` ``callback_funcs_for_model_para()`` + ``evaluate`` ``callback_funcs_for_evaluate()`` + ``finish`` ``callback_funcs_for_finish()`` + ``converged`` ``callback_funcs_for_converged()`` + ============================ ================================== + """ self.register_handlers('assign_client_id', self.callback_funcs_for_assign_id) self.register_handlers('ask_for_join_in_info', @@ -35,7 +58,7 @@ def _register_default_handlers(self): @abc.abstractmethod def run(self): """ - To listen to the message and handle them accordingly (used for + To listen to the message and handle them accordingly (used for \ distributed mode) """ raise NotImplementedError @@ -43,23 +66,21 @@ def run(self): @abc.abstractmethod def callback_funcs_for_model_para(self, message): """ - The handling function for receiving model parameters, - which triggers the local training process. + The handling function for receiving model parameters, \ + which triggers the local training process. \ This handling function is widely used in various FL courses. Arguments: - message: The received message, which includes sender, receiver, - state, and content. - More detail can be found in federatedscope.core.message + message: The received message """ raise NotImplementedError @abc.abstractmethod def callback_funcs_for_assign_id(self, message): """ - The handling function for receiving the client_ID assigned by the - server (during the joining process), - which is used in the distributed mode. + The handling function for receiving the client_ID assigned by the \ + server (during the joining process), which is used in the \ + distributed mode. Arguments: message: The received message @@ -69,8 +90,9 @@ def callback_funcs_for_assign_id(self, message): @abc.abstractmethod def callback_funcs_for_join_in_info(self, message): """ - The handling function for receiving the request of join in information - (such as batch_size, num_of_samples) during the joining process. + The handling function for receiving the request of join in \ + information (such as ``batch_size``, ``num_of_samples``) during \ + the joining process. Arguments: message: The received message @@ -80,7 +102,7 @@ def callback_funcs_for_join_in_info(self, message): @abc.abstractmethod def callback_funcs_for_address(self, message): """ - The handling function for receiving other clients' IP addresses, + The handling function for receiving other clients' IP addresses, \ which is used for constructing a complex topology Arguments: @@ -101,7 +123,7 @@ def callback_funcs_for_evaluate(self, message): @abc.abstractmethod def callback_funcs_for_finish(self, message): """ - The handling function for receiving the signal of finishing the FL + The handling function for receiving the signal of finishing the FL \ course. Arguments: @@ -112,7 +134,7 @@ def callback_funcs_for_finish(self, message): @abc.abstractmethod def callback_funcs_for_converged(self, message): """ - The handling function for receiving the signal that the FL course + The handling function for receiving the signal that the FL course \ converged Arguments: diff --git a/federatedscope/core/workers/base_server.py b/federatedscope/core/workers/base_server.py index 10788bf0b..b6cfadc69 100644 --- a/federatedscope/core/workers/base_server.py +++ b/federatedscope/core/workers/base_server.py @@ -5,20 +5,39 @@ class BaseServer(Worker): def __init__(self, ID, state, config, model, strategy): super(BaseServer, self).__init__(ID, state, config, model, strategy) + # TODO: move to worker self.msg_handlers = dict() + # TODO: move to worker def register_handlers(self, msg_type, callback_func): """ To bind a message type with a handling function. Arguments: msg_type (str): The defined message type - callback_func: The handling functions to handle the received - message + callback_func: The handling functions to handle the received \ + message """ self.msg_handlers[msg_type] = callback_func def _register_default_handlers(self): + """ + Register default handler dic to handle message, which includes \ + sender, receiver, state, and content. More detail can be found in \ + ``federatedscope.core.message``. + + Note: + the default handlers to handle messages and related callback \ + function are shown below: + ============================ ================================== + Message type Callback function + ============================ ================================== + ``join_in`` ``callback_funcs_for_join_in()`` + ``join_in_info`` ``callback_funcs_for_join_in()`` + ``model_para`` ``callback_funcs_model_para()`` + ``metrics`` ``callback_funcs_for_metrics`` + ============================ ================================== + """ self.register_handlers('join_in', self.callback_funcs_for_join_in) self.register_handlers('join_in_info', self.callback_funcs_for_join_in) self.register_handlers('model_para', self.callback_funcs_model_para) @@ -27,7 +46,7 @@ def _register_default_handlers(self): @abc.abstractmethod def run(self): """ - To start the FL course, listen and handle messages (for distributed + To start the FL course, listen and handle messages (for distributed \ mode). """ raise NotImplementedError @@ -35,25 +54,23 @@ def run(self): @abc.abstractmethod def callback_funcs_model_para(self, message): """ - The handling function for receiving model parameters, which triggers - check_and_move_on (perform aggregation when enough feedback has - been received). - This handling function is widely used in various FL courses. + The handling function for receiving model parameters, which triggers \ + ``check_and_move_on`` (perform aggregation when enough feedback has \ + been received). This handling function is widely used in various FL \ + courses. Arguments: - message: The received message, which includes sender, receiver, - state, and content. More detail can be found in - federatedscope.core.message + message: The received message. """ raise NotImplementedError @abc.abstractmethod def callback_funcs_for_join_in(self, message): """ - The handling function for receiving the join in information. The - server might request for some information (such as num_of_samples) - if necessary, assign IDs for the servers. - If all the clients have joined in, the training process will be + The handling function for receiving the join in information. The \ + server might request for some information (such as \ + ``num_of_samples``) if necessary, assign IDs for the servers. \ + If all the clients have joined in, the training process will be \ triggered. Arguments: @@ -64,9 +81,9 @@ def callback_funcs_for_join_in(self, message): @abc.abstractmethod def callback_funcs_for_metrics(self, message): """ - The handling function for receiving the evaluation results, - which triggers check_and_move_on - (perform aggregation when enough feedback has been received). + The handling function for receiving the evaluation results, \ + which triggers ``check_and_move_on`` (perform aggregation when \ + enough feedback has been received). Arguments: message: The received message diff --git a/federatedscope/core/workers/base_worker.py b/federatedscope/core/workers/base_worker.py index f7f064de9..a7fba51fd 100644 --- a/federatedscope/core/workers/base_worker.py +++ b/federatedscope/core/workers/base_worker.py @@ -3,7 +3,21 @@ class Worker(object): """ - The base worker class. + The base worker class, the parent of ``BaseClient`` and ``BaseServer`` + + Args: + ID: ID of worker + state: the training round index + config: the configuration of FL course + model: the model maintained locally + + Attributes: + ID: ID of worker + state: the training round index + model: the model maintained locally + cfg: the configuration of FL course + mode: the run mode for FL, ``distributed`` or ``standalone`` + monitor: monite FL course and record metrics """ def __init__(self, ID=-1, state=0, config=None, model=None, strategy=None): self._ID = ID diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index e08fff908..a6eabf310 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -7,7 +7,6 @@ from federatedscope.core.communication import StandaloneCommManager, \ gRPCCommManager from federatedscope.core.monitors.early_stopper import EarlyStopper -from federatedscope.core.workers import Worker from federatedscope.core.auxiliaries.trainer_builder import get_trainer from federatedscope.core.secret_sharing import AdditiveSecretSharing from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \ @@ -19,9 +18,9 @@ class Client(BaseClient): """ - The Client class, which describes the behaviors of client in an FL course. - The behaviors are described by the handling functions (named as - callback_funcs_for_xxx) + The Client class, which describes the behaviors of client in an FL \ + course. The behaviors are described by the handling functions (named as \ + ``callback_funcs_for_xxx``) Arguments: ID: The unique ID of the client, which is assigned by the server @@ -32,7 +31,25 @@ class Client(BaseClient): data: The data owned by the client model: The model maintained locally device: The device to run local training and evaluation - strategy: redundant attribute + + Attributes: + ID: ID of worker + state: the training round index + model: the model maintained locally + cfg: the configuration of FL course, \ + see ``federatedscope.core.configs`` + mode: the run mode for FL, ``distributed`` or ``standalone`` + monitor: monite FL course and record metrics, \ + see ``federatedscope.core.monitors.monitor.Monitor`` + trainer: instantiated trainer, see ``federatedscope.core.trainers`` + best_results: best results ever seen + history_results: all evaluation results + early_stopper: determine when to early stop, \ + see ``federatedscope.core.monitors.early_stopper.EarlyStopper`` + ss_manager: secret sharing manager + msg_buffer: dict buffer for storing message + comm_manager: manager for communication, \ + see ``federatedscope.core.communication`` """ def __init__(self, ID=-1, @@ -170,7 +187,7 @@ def _calculate_model_delta(self, init_model, updated_model): def join_in(self): """ - To send 'join_in' message to the server for joining in the FL course. + To send ``join_in`` message to the server for joining in the FL course. """ self.comm_manager.send( Message(msg_type='join_in', @@ -181,7 +198,7 @@ def join_in(self): def run(self): """ - To listen to the message and handle them accordingly (used for + To listen to the message and handle them accordingly (used for \ distributed mode) """ while True: @@ -194,14 +211,12 @@ def run(self): def callback_funcs_for_model_para(self, message: Message): """ - The handling function for receiving model parameters, - which triggers the local training process. + The handling function for receiving model parameters, \ + which triggers the local training process. \ This handling function is widely used in various FL courses. Arguments: - message: The received message, which includes sender, receiver, - state, and content. - More detail can be found in federatedscope.core.message + message: The received message """ if 'ss' in message.msg_type: # A fragment of the shared secret @@ -362,9 +377,9 @@ def callback_funcs_for_model_para(self, message: Message): def callback_funcs_for_assign_id(self, message: Message): """ - The handling function for receiving the client_ID assigned by the - server (during the joining process), - which is used in the distributed mode. + The handling function for receiving the client_ID assigned by the \ + server (during the joining process), which is used in the \ + distributed mode. Arguments: message: The received message @@ -376,8 +391,9 @@ def callback_funcs_for_assign_id(self, message: Message): def callback_funcs_for_join_in_info(self, message: Message): """ - The handling function for receiving the request of join in information - (such as batch_size, num_of_samples) during the joining process. + The handling function for receiving the request of join in \ + information (such as ``batch_size``, ``num_of_samples``) during \ + the joining process. Arguments: message: The received message @@ -416,7 +432,7 @@ def callback_funcs_for_join_in_info(self, message: Message): def callback_funcs_for_address(self, message: Message): """ - The handling function for receiving other clients' IP addresses, + The handling function for receiving other clients' IP addresses, \ which is used for constructing a complex topology Arguments: @@ -486,7 +502,7 @@ def callback_funcs_for_evaluate(self, message: Message): def callback_funcs_for_finish(self, message: Message): """ - The handling function for receiving the signal of finishing the FL + The handling function for receiving the signal of finishing the FL \ course. Arguments: @@ -504,11 +520,10 @@ def callback_funcs_for_finish(self, message: Message): def callback_funcs_for_converged(self, message: Message): """ - The handling function for receiving the signal that the FL course + The handling function for receiving the signal that the FL course \ converged Arguments: message: The received message """ - self._monitor.global_converged() diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 33e9c09a6..d2099423b 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -23,9 +23,9 @@ class Server(BaseServer): """ - The Server class, which describes the behaviors of server in an FL course. - The behaviors are described by the handled functions (named as - callback_funcs_for_xxx). + The Server class, which describes the behaviors of server in an FL \ + course. The behaviors are described by the handled functions (named as \ + ``callback_funcs_for_xxx``). Arguments: ID: The unique ID of the server, which is set to 0 by default @@ -36,7 +36,28 @@ class Server(BaseServer): client_num: The (expected) client num to start the FL course total_round_num: The total number of the training round device: The device to run local training and evaluation - strategy: redundant attribute + + Attributes: + ID: ID of worker + state: the training round index + model: the model maintained locally + cfg: the configuration of FL course, \ + see ``federatedscope.core.configs`` + mode: the run mode for FL, ``distributed`` or ``standalone`` + monitor: monite FL course and record metrics, \ + see ``federatedscope.core.monitors.monitor.Monitor`` + trainer: instantiated trainer, see ``federatedscope.core.trainers`` + best_results: best results ever seen + history_results: all evaluation results + early_stopper: determine when to early stop, \ + see ``federatedscope.core.monitors.early_stopper.EarlyStopper`` + aggregators: a protocol for aggregate all clients' model(s), see \ + ``federatedscope.core.aggregators`` + sample_client_num: number of client aggregated in each round + msg_buffer: dict buffer for storing message + staled_msg_buffer: list buffer for storing staled message + comm_manager: manager for communication, \ + see ``federatedscope.core.communication`` """ def __init__(self, ID=-1, @@ -207,7 +228,7 @@ def register_noise_injector(self, func): def run(self): """ - To start the FL course, listen and handle messages (for distributed + To start the FL course, listen and handle messages (for distributed \ mode). """ @@ -269,13 +290,16 @@ def check_and_move_on(self, check_eval_result=False, min_received_num=None): """ - To check the message_buffer. When enough messages are receiving, - some events (such as perform aggregation, evaluation, and move to + To check the message_buffer. When enough messages are receiving, \ + some events (such as perform aggregation, evaluation, and move to \ the next training round) would be triggered. Arguments: - check_eval_result (bool): If True, check the message buffer for - evaluation; and check the message buffer for training otherwise. + check_eval_result (bool): If True, check the message buffer for \ + evaluation; and check the message buffer for training \ + otherwise. + min_received_num: number of minimal received message, used for \ + async mode """ if min_received_num is None: if self._cfg.asyn.use: @@ -333,7 +357,8 @@ def check_and_move_on(self, def check_and_save(self): """ - To save the results and save model after each evaluation. + To save the results and save model after each evaluation, and check \ + whether to early stop. """ # early stopping @@ -500,13 +525,11 @@ def save_best_results(self): def save_client_eval_results(self): """ - save the evaluation results of each client when the fl course - early stopped or terminated - - :return: + save the evaluation results of each client when the fl course \ + early stopped or terminated """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + rnd = max(self.msg_buffer['eval'].keys()) + eval_msg_buffer = self.msg_buffer['eval'][rnd] with open(os.path.join(self._cfg.outdir, "eval_results.log"), "a") as outfile: @@ -521,12 +544,12 @@ def save_client_eval_results(self): def merge_eval_results_from_all_clients(self): """ - Merge evaluation results from all clients, update best, - log the merged results and save them into eval_results.log + Merge evaluation results from all clients, update best, \ + log the merged results and save them into eval_results.log - :returns: the formatted merged results + Returns: + the formatted merged results """ - round = max(self.msg_buffer['eval'].keys()) eval_msg_buffer = self.msg_buffer['eval'][round] eval_res_participated_clients = [] @@ -596,14 +619,14 @@ def broadcast_model_para(self, Arguments: msg_type: 'model_para' or other user defined msg_type - sample_client_num: the number of sampled clients in the broadcast - behavior. And sample_client_num = -1 denotes to broadcast to - all the clients. - filter_unseen_clients: whether filter out the unseen clients that - do not contribute to FL process by training on their local - data and uploading their local model update. The splitting is - useful to check participation generalization gap in [ICLR'22, - What Do We Mean by Generalization in Federated Learning?] + sample_client_num: the number of sampled clients in the broadcast \ + behavior. And ``sample_client_num = -1`` denotes to \ + broadcast to all the clients. + filter_unseen_clients: whether filter out the unseen clients that \ + do not contribute to FL process by training on their local \ + data and uploading their local model update. The splitting is \ + useful to check participation generalization gap in [ICLR'22, \ + What Do We Mean by Generalization in Federated Learning?] \ You may want to set it to be False when in evaluation stage """ if filter_unseen_clients: @@ -651,7 +674,7 @@ def broadcast_model_para(self, def broadcast_client_address(self): """ - To broadcast the communication addresses of clients (used for + To broadcast the communication addresses of clients (used for \ additive secret sharing) """ @@ -671,12 +694,14 @@ def check_buffer(self, To check the message buffer Arguments: - cur_round (int): The current round number - min_received_num (int): The minimal number of the receiving messages - check_eval_result (bool): To check training results for evaluation - results - :returns: Whether enough messages have been received or not - :rtype: bool + cur_round (int): The current round number + min_received_num (int): The minimal number of the receiving \ + messages + check_eval_result (bool): To check training results for \ + evaluation results + + Returns + bool: Whether enough messages have been received or not """ if check_eval_result: @@ -776,7 +801,7 @@ def trigger_for_start(self): def trigger_for_time_up(self, check_timestamp=None): """ - The handler for time up: modify the currency timestamp + The handler for time up: modify the currency timestamp \ and check the trigger condition """ if self.is_finish: @@ -812,7 +837,7 @@ def terminate(self, msg_type='finish'): def eval(self): """ - To conduct evaluation. When cfg.federate.make_global_eval=True, + To conduct evaluation. When ``cfg.federate.make_global_eval=True``, \ a global evaluation is conducted by the server. """ @@ -850,15 +875,13 @@ def eval(self): def callback_funcs_model_para(self, message: Message): """ - The handling function for receiving model parameters, which triggers - check_and_move_on (perform aggregation when enough feedback has - been received). - This handling function is widely used in various FL courses. + The handling function for receiving model parameters, which triggers \ + ``check_and_move_on`` (perform aggregation when enough feedback has \ + been received). This handling function is widely used in various FL \ + courses. Arguments: - message: The received message, which includes sender, receiver, - state, and content. More detail can be found in - federatedscope.core.message + message: The received message. """ if self.is_finish: return 'finish' @@ -899,10 +922,10 @@ def callback_funcs_model_para(self, message: Message): def callback_funcs_for_join_in(self, message: Message): """ - The handling function for receiving the join in information. The - server might request for some information (such as num_of_samples) - if necessary, assign IDs for the servers. - If all the clients have joined in, the training process will be + The handling function for receiving the join in information. The \ + server might request for some information (such as \ + ``num_of_samples``) if necessary, assign IDs for the servers. \ + If all the clients have joined in, the training process will be \ triggered. Arguments: @@ -946,21 +969,21 @@ def callback_funcs_for_join_in(self, message: Message): def callback_funcs_for_metrics(self, message: Message): """ - The handling function for receiving the evaluation results, - which triggers check_and_move_on - (perform aggregation when enough feedback has been received). + The handling function for receiving the evaluation results, \ + which triggers ``check_and_move_on`` (perform aggregation when \ + enough feedback has been received). Arguments: message: The received message """ - round = message.state + rnd = message.state sender = message.sender content = message.content - if round not in self.msg_buffer['eval'].keys(): - self.msg_buffer['eval'][round] = dict() + if rnd not in self.msg_buffer['eval'].keys(): + self.msg_buffer['eval'][rnd] = dict() - self.msg_buffer['eval'][round][sender] = content + self.msg_buffer['eval'][rnd][sender] = content return self.check_and_move_on(check_eval_result=True) From 6156ae2cdb404404487146e45c4ddb9ce25a84ba Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Tue, 1 Nov 2022 17:45:06 +0800 Subject: [PATCH 12/21] update docstring for workers --- doc/source/core.rst | 13 ++- federatedscope/core/aggregators/aggregator.py | 15 +++ .../asyn_clients_avg_aggregator.py | 23 ++-- .../aggregators/clients_avg_aggregator.py | 33 ++++-- .../core/aggregators/fedopt_aggregator.py | 10 +- .../server_clients_interpolate_aggregator.py | 9 +- federatedscope/core/configs/config.py | 107 ++++++++++-------- federatedscope/core/configs/yacs_config.py | 48 ++++---- federatedscope/core/data/base_data.py | 2 +- .../core/splitters/base_splitter.py | 13 ++- .../core/splitters/generic/iid_splitter.py | 6 + .../core/splitters/generic/lda_splitter.py | 9 ++ .../core/splitters/graph/louvain_splitter.py | 7 +- .../splitters/graph/randchunk_splitter.py | 15 +-- .../core/splitters/graph/random_splitter.py | 15 ++- .../core/splitters/graph/reltype_splitter.py | 9 +- .../splitters/graph/scaffold_lda_splitter.py | 16 +-- .../core/splitters/graph/scaffold_splitter.py | 16 ++- 18 files changed, 225 insertions(+), 141 deletions(-) diff --git a/doc/source/core.rst b/doc/source/core.rst index b1d3bbe42..92e01dd9f 100644 --- a/doc/source/core.rst +++ b/doc/source/core.rst @@ -31,10 +31,15 @@ federatedscope.core.data federatedscope.core.splitters ----------------------- - .. automodule:: federatedscope.core.splitters :members: :private-members: +.. automodule:: federatedscope.core.splitters.generic + :members: + :private-members: +.. automodule:: federatedscope.core.splitters.graph + :members: + :private-members: federatedscope.core.configs ----------------------- @@ -51,6 +56,12 @@ federatedscope.core.monitors :members: :private-members: +federatedscope.core.aggregators +----------------------- + +.. automodule:: federatedscope.core.aggregators + :members: + federatedscope.core.auxiliaries ----------------------- diff --git a/federatedscope/core/aggregators/aggregator.py b/federatedscope/core/aggregators/aggregator.py index 4f966b74b..367bcedeb 100644 --- a/federatedscope/core/aggregators/aggregator.py +++ b/federatedscope/core/aggregators/aggregator.py @@ -2,11 +2,20 @@ class Aggregator(ABC): + """ + Abstract class of Aggregator. + """ def __init__(self): pass @abstractmethod def aggregate(self, agg_info): + """ + Aggregation function. + + Args: + agg_info: information to be aggregated. + """ pass @@ -14,5 +23,11 @@ class NoCommunicationAggregator(Aggregator): """Clients do not communicate. Each client work locally """ def aggregate(self, agg_info): + """ + Aggregation function. + + Args: + agg_info: information to be aggregated. + """ # do nothing return {} diff --git a/federatedscope/core/aggregators/asyn_clients_avg_aggregator.py b/federatedscope/core/aggregators/asyn_clients_avg_aggregator.py index 39d33a737..7818e4fc0 100644 --- a/federatedscope/core/aggregators/asyn_clients_avg_aggregator.py +++ b/federatedscope/core/aggregators/asyn_clients_avg_aggregator.py @@ -4,8 +4,9 @@ class AsynClientsAvgAggregator(ClientsAvgAggregator): - """The aggregator used in asynchronous training, which discounts the - staled model updates + """ + The aggregator used in asynchronous training, which discounts the \ + staled model updates """ def __init__(self, model=None, device='cpu', config=None): super(AsynClientsAvgAggregator, self).__init__(model, device, config) @@ -15,9 +16,10 @@ def aggregate(self, agg_info): To preform aggregation Arguments: - agg_info (dict): the feedbacks from clients - :returns: the aggregated results - :rtype: dict + agg_info (dict): the feedbacks from clients + + Returns: + dict: the aggregated results """ models = agg_info["client_feedback"] @@ -39,16 +41,19 @@ def aggregate(self, agg_info): def discount_func(self, staleness): """ - Served as an example, we discount the model update with staleness \tau - as: (1.0/((1.0+\tau)**factor)), - which has been used in previous studies such as FedAsync (Asynchronous - Federated Optimization) and FedBuff + Served as an example, we discount the model update with staleness tau \ + as: ``(1.0/((1.0+\tau)**factor))``, \ + which has been used in previous studies such as FedAsync ( \ + Asynchronous Federated Optimization) and FedBuff \ (Federated Learning with Buffered Asynchronous Aggregation). """ return (1.0 / ((1.0 + staleness)**self.cfg.asyn.staleness_discount_factor)) def _para_weighted_avg(self, models, recover_fun=None, staleness=None): + """ + Calculates the weighted average of models. + """ training_set_size = 0 for i in range(len(models)): sample_size, _ = models[i] diff --git a/federatedscope/core/aggregators/clients_avg_aggregator.py b/federatedscope/core/aggregators/clients_avg_aggregator.py index 21ac60c1e..53aa953fa 100644 --- a/federatedscope/core/aggregators/clients_avg_aggregator.py +++ b/federatedscope/core/aggregators/clients_avg_aggregator.py @@ -5,9 +5,10 @@ class ClientsAvgAggregator(Aggregator): - """Implementation of vanilla FedAvg refer to `Communication-efficient - learning of deep networks from decentralized data` [McMahan et al., 2017] - (http://proceedings.mlr.press/v54/mcmahan17a.html) + """ + Implementation of vanilla FedAvg refer to 'Communication-efficient \ + learning of deep networks from decentralized data' [McMahan et al., 2017] \ + http://proceedings.mlr.press/v54/mcmahan17a.html """ def __init__(self, model=None, device='cpu', config=None): super(Aggregator, self).__init__() @@ -20,9 +21,10 @@ def aggregate(self, agg_info): To preform aggregation Arguments: - agg_info (dict): the feedbacks from clients - :returns: the aggregated results - :rtype: dict + agg_info (dict): the feedbacks from clients + + Returns: + dict: the aggregated results """ models = agg_info["client_feedback"] @@ -33,10 +35,10 @@ def aggregate(self, agg_info): return avg_model def update(self, model_parameters): - ''' + """ Arguments: model_parameters (dict): PyTorch Module object's state_dict. - ''' + """ self.model.load_state_dict(model_parameters, strict=False) def save_model(self, path, cur_round=-1): @@ -56,6 +58,9 @@ def load_model(self, path): raise ValueError("The file {} does NOT exist".format(path)) def _para_weighted_avg(self, models, recover_fun=None): + """ + Calculates the weighted average of models. + """ training_set_size = 0 for i in range(len(models)): sample_size, _ = models[i] @@ -93,6 +98,9 @@ def _para_weighted_avg(self, models, recover_fun=None): class OnlineClientsAvgAggregator(ClientsAvgAggregator): + """ + Implementation of online aggregation of FedAvg. + """ def __init__(self, model=None, device='cpu', @@ -102,6 +110,9 @@ def __init__(self, self.src_device = src_device def reset(self): + """ + Reset the state of the model to its initial state + """ self.maintained = self.model.state_dict() for key in self.maintained: self.maintained[key].data = torch.zeros_like( @@ -109,6 +120,9 @@ def reset(self): self.cnt = 0 def inc(self, content): + """ + Increment the model weight by the given content. + """ if isinstance(content, tuple): sample_size, model_params = content for key in self.maintained: @@ -123,4 +137,7 @@ def inc(self, content): "{} is not a tuple (sample_size, model_para)".format(content)) def aggregate(self, agg_info): + """ + Returns the aggregated value + """ return self.maintained diff --git a/federatedscope/core/aggregators/fedopt_aggregator.py b/federatedscope/core/aggregators/fedopt_aggregator.py index 47e1725a3..e6e63cab5 100644 --- a/federatedscope/core/aggregators/fedopt_aggregator.py +++ b/federatedscope/core/aggregators/fedopt_aggregator.py @@ -5,10 +5,9 @@ class FedOptAggregator(ClientsAvgAggregator): - """Implementation of FedOpt refer to `Adaptive Federated Optimization` [ - Reddi et al., 2021] - (https://openreview.net/forum?id=LkFG3lB13U5) - + """ + Implementation of FedOpt refer to `Adaptive Federated Optimization` \ + [Reddi et al., 2021](https://openreview.net/forum?id=LkFG3lB13U5) """ def __init__(self, config, model, device='cpu'): super(FedOptAggregator, self).__init__(model, device, config) @@ -16,6 +15,9 @@ def __init__(self, config, model, device='cpu'): **config.fedopt.optimizer) def aggregate(self, agg_info): + """ + To preform FedOpt aggregation. + """ new_model = super().aggregate(agg_info) model = self.model.cpu().state_dict() diff --git a/federatedscope/core/aggregators/server_clients_interpolate_aggregator.py b/federatedscope/core/aggregators/server_clients_interpolate_aggregator.py index 200de2543..327407c3b 100644 --- a/federatedscope/core/aggregators/server_clients_interpolate_aggregator.py +++ b/federatedscope/core/aggregators/server_clients_interpolate_aggregator.py @@ -2,9 +2,9 @@ class ServerClientsInterpolateAggregator(ClientsAvgAggregator): - """" - # conduct aggregation by interpolating global model from server and - local models from clients + """ + conduct aggregation by interpolating global model from server and \ + local models from clients """ def __init__(self, model=None, device='cpu', config=None, beta=1.0): super(ServerClientsInterpolateAggregator, @@ -12,6 +12,9 @@ def __init__(self, model=None, device='cpu', config=None, beta=1.0): self.beta = beta # the weight for local models used in interpolation def aggregate(self, agg_info): + """ + Returns the aggregated value + """ models = agg_info["client_feedback"] global_model = self.model elem_each_client = next(iter(models)) diff --git a/federatedscope/core/configs/config.py b/federatedscope/core/configs/config.py index 48b71b421..41b3b8195 100644 --- a/federatedscope/core/configs/config.py +++ b/federatedscope/core/configs/config.py @@ -23,11 +23,10 @@ def set_help_info(cn_node, help_info_dict, prefix=""): class CN(CfgNode): """ - An extended configuration system based on [yacs]( - https://github.com/rbgirshick/yacs). - The two-level tree structure consists of several internal dict-like - containers to allow simple key-value access and management. - + An extended configuration system based on [yacs]( \ + https://github.com/rbgirshick/yacs). \ + The two-level tree structure consists of several internal dict-like \ + containers to allow simple key-value access and management. """ def __init__(self, init_dict=None, key_list=None, new_allowed=False): init_dict = super().__init__(init_dict, key_list, new_allowed) @@ -59,6 +58,9 @@ def __delattr__(self, name): raise AttributeError(name) def clear_aux_info(self): + """ + Clears all the auxiliary information of the CN object. + """ if hasattr(self, "__cfg_check_funcs__"): delattr(self, "__cfg_check_funcs__") if hasattr(self, "__help_info__"): @@ -71,10 +73,11 @@ def clear_aux_info(self): def print_help(self, arg_name=""): """ - print help info for a specific given `arg_name` or - for all arguments if not given `arg_name` - :param arg_name: - :return: + print help info for a specific given ``arg_name`` or \ + for all arguments if not given ``arg_name`` + + Args: + arg_name: name of specific args """ if arg_name != "" and arg_name in self.__help_info__: print(f" --{arg_name} \t {self.__help_info__[arg_name]}") @@ -83,15 +86,22 @@ def print_help(self, arg_name=""): print(f" --{k} \t {v}") def register_cfg_check_fun(self, cfg_check_fun): + """ + Register a function that checks the configuration node. + + Args: + cfg_check_fun: function for validation the correctness of cfg. + """ self.__cfg_check_funcs__.append(cfg_check_fun) def merge_from_file(self, cfg_filename, check_cfg=True): """ - load configs from a yaml file, another cfg instance or a list - stores the keys and values. + load configs from a yaml file, another cfg instance or a list \ + stores the keys and values. - :param cfg_filename (string): - :return: + Args: + cfg_filename: file name of yaml file + check_cfg: whether enable ``assert_cfg()`` """ cfg_check_funcs = copy.copy(self.__cfg_check_funcs__) with open(cfg_filename, "r") as f: @@ -104,12 +114,12 @@ def merge_from_file(self, cfg_filename, check_cfg=True): def merge_from_other_cfg(self, cfg_other, check_cfg=True): """ - load configs from another cfg instance + load configs from another cfg instance - :param cfg_other (CN): - :return: + Args: + cfg_other: other cfg to be merged + check_cfg: whether enable ``assert_cfg()`` """ - cfg_check_funcs = copy.copy(self.__cfg_check_funcs__) _merge_a_into_b(cfg_other, self, self, []) self.__cfg_check_funcs__.clear() @@ -119,12 +129,13 @@ def merge_from_other_cfg(self, cfg_other, check_cfg=True): def merge_from_list(self, cfg_list, check_cfg=True): """ - load configs from a list stores the keys and values. - modified `merge_from_list` in `yacs.config.py` to allow adding - new keys if `is_new_allowed()` returns True + load configs from a list stores the keys and values. \ + modified ``merge_from_list`` in ``yacs.config.py`` to allow adding \ + new keys if ``is_new_allowed()`` returns True \ - :param cfg_list (list): - :return: + Args: + cfg_list: list of pairs of cfg name and value + check_cfg: whether enable ``assert_cfg()`` """ cfg_check_funcs = copy.copy(self.__cfg_check_funcs__) super().merge_from_list(cfg_list) @@ -135,9 +146,10 @@ def merge_from_list(self, cfg_list, check_cfg=True): def assert_cfg(self, check_cfg=True): """ - check the validness of the configuration instance + check the validness of the configuration instance - :return: + Args: + check_cfg: whether enable checks """ if check_cfg: for check_func in self.__cfg_check_funcs__: @@ -145,10 +157,8 @@ def assert_cfg(self, check_cfg=True): def clean_unused_sub_cfgs(self): """ - Clean the un-used secondary-level CfgNode, whose `.use` - attribute is `True` - - :return: + Clean the un-used secondary-level CfgNode, whose ``.use`` \ + attribute is ``True`` """ for v in self.values(): if isinstance(v, CfgNode) or isinstance(v, CN): @@ -162,6 +172,9 @@ def clean_unused_sub_cfgs(self): del v[k] def check_required_args(self): + """ + Check required arguments. + """ for k, v in self.items(): if isinstance(v, CN): v.check_required_args() @@ -170,11 +183,10 @@ def check_required_args(self): def de_arguments(self): """ - some config values are managed via `Argument` class, this function - is used to make these values clean without the `Argument` class, - such that the potential type-specific methods work correctly, - e.g., len(cfg.federate.method) for a string config - :return: + some config values are managed via ``Argument`` class, this function \ + is used to make these values clean without the ``Argument`` class, \ + such that the potential type-specific methods work correctly, \ + e.g., ``len(cfg.federate.method)`` for a string config """ for k, v in copy.deepcopy(self).items(): if isinstance(v, CN): @@ -183,6 +195,12 @@ def de_arguments(self): self[k] = v.value def ready_for_run(self, check_cfg=True): + """ + Check and cleans up the internal state of cfg and save cfg. + + Args: + check_cfg: whether enable ``assert_cfg()`` + """ self.assert_cfg(check_cfg) self.clean_unused_sub_cfgs() self.check_required_args() @@ -191,12 +209,10 @@ def ready_for_run(self, check_cfg=True): def freeze(self, inform=True, save=True, check_cfg=True): """ - 1) make the cfg attributes immutable; - 2) if save=True, save the frozen cfg_check_funcs into - "self.outdir/config.yaml" for better reproducibility; - 3) if self.wandb.use=True, update the frozen config - - :return: + (1) make the cfg attributes immutable; + (2) if ``save==True``, save the frozen cfg_check_funcs into \ + ``self.outdir/config.yaml`` for better reproducibility; + (3) if ``self.wandb.use==True``, update the frozen config """ self.ready_for_run(check_cfg) super(CN, self).freeze() @@ -243,15 +259,14 @@ def freeze(self, inform=True, save=True, check_cfg=True): def init_global_cfg(cfg): - r''' + """ This function sets the default config value. - 1) Note that for an experiment, only part of the arguments will be used - The remaining unused arguments won't affect anything. - So feel free to register any argument in graphgym.contrib.config - 2) We support more than one levels of configs, e.g., cfg.dataset.name - :return: configuration use by the experiment. - ''' + (1) Note that for an experiment, only part of the arguments will be used \ + The remaining unused arguments won't affect anything. \ + So feel free to register any argument in graphgym.contrib.config + (2) We support more than one levels of configs, e.g., cfg.dataset.name + """ # ---------------------------------------------------------------------- # # Basic options, first level configs diff --git a/federatedscope/core/configs/yacs_config.py b/federatedscope/core/configs/yacs_config.py index cd1c6ae00..8676986b7 100644 --- a/federatedscope/core/configs/yacs_config.py +++ b/federatedscope/core/configs/yacs_config.py @@ -98,7 +98,7 @@ def __repr__(self): class CfgNode(dict): """ - CfgNode represents an internal node in the configuration tree. It's a + CfgNode represents an internal node in the configuration tree. It's a \ simple dict-like container that allows for attribute-based access to keys. """ @@ -157,8 +157,8 @@ def _create_config_tree_from_dict(cls, dic, key_list): Any dict-like objects inside dict will be treated as a new CfgNode. Args: - dic (dict): - key_list (list[str]): a list of names which index this CfgNode + dic (dict): ``dict`` to be converted + key_list (list[str]): a list of names which index this CfgNode \ from the root. Currently only used for logging purposes. """ dic = copy.deepcopy(dic) @@ -256,8 +256,9 @@ def merge_from_other_cfg(self, cfg_other): _merge_a_into_b(cfg_other, self, self, []) def merge_from_list(self, cfg_list): - """Merge config (keys, values) in a list (e.g., from command line) into - this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. + """ + Merge config (keys, values) in a list (e.g., from command line) \ + into this CfgNode. For example, ``cfg_list = ['FOO.BAR', 0.5]``. """ _assert_with_logging( len(cfg_list) % 2 == 0, @@ -297,7 +298,8 @@ def is_frozen(self): return self.__dict__[CfgNode.IMMUTABLE] def _immutable(self, is_immutable): - """Set immutability to is_immutable and recursively apply the setting + """ + Set immutability to is_immutable and recursively apply the setting \ to all nested CfgNodes. """ self.__dict__[CfgNode.IMMUTABLE] = is_immutable @@ -314,8 +316,9 @@ def clone(self): return copy.deepcopy(self) def register_deprecated_key(self, key): - """Register key (e.g. `FOO.BAR`) a deprecated option. - When merging deprecated keys a warning is generated and the key is + """ + Register key (e.g. `FOO.BAR`) a deprecated option. \ + When merging deprecated keys a warning is generated and the key is \ ignored. """ _assert_with_logging( @@ -325,9 +328,10 @@ def register_deprecated_key(self, key): self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) def register_renamed_key(self, old_name, new_name, message=None): - """Register a key as having been renamed from `old_name` to `new_name`. - When merging a renamed key, an exception is thrown alerting to user to - the fact that the key has been renamed. + """ + Register a key as having been renamed from ``old_name`` \ + to `new_name`. When merging a renamed key, an exception is thrown \ + alerting to user to the fact that the key has been renamed. """ _assert_with_logging( old_name not in self.__dict__[CfgNode.RENAMED_KEYS], @@ -367,7 +371,7 @@ def is_new_allowed(self): def set_new_allowed(self, is_new_allowed): """ - Set this config (and recursively its subconfigs) to allow merging + Set this config (and recursively its subconfigs) to allow merging \ new keys from other configs. """ self.__dict__[CfgNode.NEW_ALLOWED] = is_new_allowed @@ -443,12 +447,12 @@ def _load_cfg_py_source(cls, filename): @classmethod def _decode_cfg_value(cls, value): """ - Decodes a raw config value (e.g., from a yaml config files or command + Decodes a raw config value (e.g., from a yaml config files or command \ line argument) into a Python object. - If the value is a dict, it will be interpreted as a new CfgNode. - If the value is a str, it will be evaluated as literals. - Otherwise it is returned as-is. + (1) If the value is a dict, it will be interpreted as a new CfgNode. + (2) If the value is a str, it will be evaluated as literals. + (3) Otherwise it is returned as-is. """ # Configs parsed from raw yaml will contain dictionary keys that need # to be converted to CfgNode objects @@ -491,9 +495,9 @@ def _valid_type(value, allow_cfg_node=False): def _merge_a_into_b(a, b, root, key_list): """ - [Modified from yacs, to allow int <-> float conversation] + [Modified from yacs, to allow int <-> float conversation] - Merge config dictionary a into config dictionary b, clobbering the + Merge config dictionary a into config dictionary b, clobbering the \ options in b whenever they are also specified in a. """ _assert_with_logging( @@ -537,11 +541,11 @@ def _merge_a_into_b(a, b, root, key_list): def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): """ - [Modified from yacs, to allow int <-> float conversation] + [Modified from yacs, to allow int <-> float conversation] - Checks that `replacement`, which is intended to replace `original` is of - the right type. The type is correct if it matches exactly or is one of a - few cases in which the type can be easily coerced. + Checks that ``replacement``, which is intended to replace \ + ``original`` is of the right type. The type is correct if it matches \ + exactly or is one of a few cases in which the type can be easily coerced. """ original_type = type(original) replacement_type = type(replacement) diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index ec5d67e87..b2691b8b4 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -7,7 +7,7 @@ class StandaloneDataDict(dict): """ - `StandaloneDataDict` maintain several `ClientData`. + `StandaloneDataDict` maintain several `ClientData`. """ def __init__(self, datadict, global_cfg): """ diff --git a/federatedscope/core/splitters/base_splitter.py b/federatedscope/core/splitters/base_splitter.py index 1e80bb986..c048c51c3 100644 --- a/federatedscope/core/splitters/base_splitter.py +++ b/federatedscope/core/splitters/base_splitter.py @@ -3,13 +3,14 @@ class BaseSplitter(abc.ABC): - def __init__(self, client_num): - """ - This is an abstract base class for all splitter. + """ + This is an abstract base class for all splitter, which is not \ + implemented with ``__call__()``. - Args: - client_num: Divide the dataset into `client_num` pieces. - """ + Attributes: + client_num: Divide the dataset into ``client_num`` pieces. + """ + def __init__(self, client_num): self.client_num = client_num @abc.abstractmethod diff --git a/federatedscope/core/splitters/generic/iid_splitter.py b/federatedscope/core/splitters/generic/iid_splitter.py index a8032bdae..c802b41db 100644 --- a/federatedscope/core/splitters/generic/iid_splitter.py +++ b/federatedscope/core/splitters/generic/iid_splitter.py @@ -3,6 +3,12 @@ class IIDSplitter(BaseSplitter): + """ + This splitter split dataset randomly . + + Args: + client_num: the dataset will be split into ``client_num`` pieces + """ def __init__(self, client_num): super(IIDSplitter, self).__init__(client_num) diff --git a/federatedscope/core/splitters/generic/lda_splitter.py b/federatedscope/core/splitters/generic/lda_splitter.py index b32d5e07f..8a6ba102f 100644 --- a/federatedscope/core/splitters/generic/lda_splitter.py +++ b/federatedscope/core/splitters/generic/lda_splitter.py @@ -5,6 +5,15 @@ class LDASplitter(BaseSplitter): + """ + This splitter split dataset with LDA. + + Args: + client_num: the dataset will be split into ``client_num`` pieces + alpha (float): Partition hyperparameter in LDA, smaller alpha \ + generates more extreme heterogeneous scenario see \ + ``np.random.dirichlet`` + """ def __init__(self, client_num, alpha=0.5): self.alpha = alpha super(LDASplitter, self).__init__(client_num) diff --git a/federatedscope/core/splitters/graph/louvain_splitter.py b/federatedscope/core/splitters/graph/louvain_splitter.py index 908ae9a43..3a170a1ac 100644 --- a/federatedscope/core/splitters/graph/louvain_splitter.py +++ b/federatedscope/core/splitters/graph/louvain_splitter.py @@ -10,13 +10,12 @@ class LouvainSplitter(BaseTransform, BaseSplitter): - r""" + """ Split Data into small data via louvain algorithm. Args: - client_num (int): Split data into client_num of pieces. - delta (int): The gap between the number of nodes on the each client. - + client_num (int): Split data into ``client_num`` of pieces. + delta (int): The gap between the number of nodes on each client. """ def __init__(self, client_num, delta=20): self.delta = delta diff --git a/federatedscope/core/splitters/graph/randchunk_splitter.py b/federatedscope/core/splitters/graph/randchunk_splitter.py index 07e2e93cb..97a58ddba 100644 --- a/federatedscope/core/splitters/graph/randchunk_splitter.py +++ b/federatedscope/core/splitters/graph/randchunk_splitter.py @@ -5,19 +5,16 @@ class RandChunkSplitter(BaseTransform, BaseSplitter): + """ + Split graph-level dataset via random chunk strategy. + + Arguments: + dataset (List or PyG.dataset): The graph-level datasets. + """ def __init__(self, client_num): BaseSplitter.__init__(self, client_num) def __call__(self, dataset, **kwargs): - r"""Split dataset via random chunk. - - Arguments: - dataset (List or PyG.dataset): The datasets. - - Returns: - data_list (List(List(PyG.data))): Splited dataset via random - chunk split. - """ data_list = [] dataset = [ds for ds in dataset] num_graph = len(dataset) diff --git a/federatedscope/core/splitters/graph/random_splitter.py b/federatedscope/core/splitters/graph/random_splitter.py index a3c12be1e..bf21f3559 100644 --- a/federatedscope/core/splitters/graph/random_splitter.py +++ b/federatedscope/core/splitters/graph/random_splitter.py @@ -12,18 +12,17 @@ class RandomSplitter(BaseTransform, BaseSplitter): - r""" + """ Split Data into small data via random sampling. Args: client_num (int): Split data into client_num of pieces. - sampling_rate (str): Samples of the unique nodes for each client, - eg. '0.2,0.2,0.2'. - overlapping_rate(float): Additional samples of overlapping data, - eg. '0.4' - drop_edge(float): Drop edges (drop_edge / client_num) for each - client whthin overlapping part. - + sampling_rate (str): Samples of the unique nodes for each client, \ + eg. ``'0.2,0.2,0.2'`` + overlapping_rate(float): Additional samples of overlapping data, \ + eg. ``'0.4'`` + drop_edge(float): Drop edges (drop_edge / client_num) for each \ + client within overlapping part. """ def __init__(self, client_num, diff --git a/federatedscope/core/splitters/graph/reltype_splitter.py b/federatedscope/core/splitters/graph/reltype_splitter.py index 2452addbd..142b1c239 100644 --- a/federatedscope/core/splitters/graph/reltype_splitter.py +++ b/federatedscope/core/splitters/graph/reltype_splitter.py @@ -10,14 +10,15 @@ class RelTypeSplitter(BaseTransform, BaseSplitter): - r""" - Split Data into small data via dirichlet distribution to + """ + Split Data into small data via dirichlet distribution to \ generate non-i.i.d data split. Arguments: client_num (int): Split data into client_num of pieces. - alpha (float): parameter controlling the identicalness among clients. - + alpha (float): Partition hyperparameter in LDA, smaller alpha \ + generates more extreme heterogeneous scenario see \ + ``np.random.dirichlet`` """ def __init__(self, client_num, alpha=0.5, realloc_mask=False): BaseSplitter.__init__(self, client_num) diff --git a/federatedscope/core/splitters/graph/scaffold_lda_splitter.py b/federatedscope/core/splitters/graph/scaffold_lda_splitter.py index 87b119a3b..f45d2fa59 100644 --- a/federatedscope/core/splitters/graph/scaffold_lda_splitter.py +++ b/federatedscope/core/splitters/graph/scaffold_lda_splitter.py @@ -16,15 +16,15 @@ class GenFeatures: - r"""Implementation of 'CanonicalAtomFeaturizer' and - 'CanonicalBondFeaturizer' in DGL. + r"""Implementation of ``CanonicalAtomFeaturizer`` and + ``CanonicalBondFeaturizer`` in DGL. \ Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html Arguments: data: PyG.data in PyG.dataset. Returns: - data: PyG.data, data passing featurizer. + PyG.data: data passing featurizer. """ def __init__(self): @@ -151,16 +151,18 @@ def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1): class ScaffoldLdaSplitter(BaseSplitter): - r"""First adopt scaffold splitting and then assign the samples to + """ + First adopt scaffold splitting and then assign the samples to \ clients according to Latent Dirichlet Allocation. Arguments: dataset (List or PyG.dataset): The molecular datasets. - alpha (float): Partition hyperparameter in LDA, smaller alpha - generates more extreme heterogeneous scenario. + alpha (float): Partition hyperparameter in LDA, smaller alpha \ + generates more extreme heterogeneous scenario see \ + ``np.random.dirichlet`` Returns: - data_list (List(List(PyG.data))): Splited dataset via scaffold split. + List(List(PyG.data)): data_list of split dataset via scaffold split. """ def __init__(self, client_num, alpha): diff --git a/federatedscope/core/splitters/graph/scaffold_splitter.py b/federatedscope/core/splitters/graph/scaffold_splitter.py index db41779f6..169cdb585 100644 --- a/federatedscope/core/splitters/graph/scaffold_splitter.py +++ b/federatedscope/core/splitters/graph/scaffold_splitter.py @@ -50,19 +50,17 @@ def gen_scaffold_split(dataset, client_num=5): class ScaffoldSplitter(BaseSplitter): + """ + Split molecular via scaffold. This splitter will sort all moleculars, and \ + split them into several parts. + + Arguments: + client_num (int): Split data into client_num of pieces. + """ def __init__(self, client_num): super(ScaffoldSplitter, self).__init__(client_num) def __call__(self, dataset, **kwargs): - r"""Split dataset with smiles string into scaffold split - - Arguments: - dataset (List or PyG.dataset): The molecular datasets. - - Returns: - data_list (List(List(PyG.data))): Splited dataset via scaffold - split. - """ dataset = [ds for ds in dataset] idx_slice = gen_scaffold_split(dataset) data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] From 4e760096cb60c5af16f70b2e6c64b315de62ab14 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Tue, 1 Nov 2022 19:22:40 +0800 Subject: [PATCH 13/21] add docstring for data --- doc/source/core.rst | 10 +- federatedscope/core/data/base_data.py | 54 ++++++---- federatedscope/core/data/base_translator.py | 23 ++-- federatedscope/core/data/dummy_translator.py | 12 ++- federatedscope/core/data/utils.py | 106 +++++++++++++------ 5 files changed, 135 insertions(+), 70 deletions(-) diff --git a/doc/source/core.rst b/doc/source/core.rst index 92e01dd9f..4c5bb6598 100644 --- a/doc/source/core.rst +++ b/doc/source/core.rst @@ -25,9 +25,17 @@ federatedscope.core.trainers federatedscope.core.data ----------------------- -.. automodule:: federatedscope.core.data +.. automodule:: federatedscope.core.data.base_data :members: :private-members: +.. automodule:: federatedscope.core.data.base_translator + :members: + :private-members: +.. automodule:: federatedscope.core.data.dummy_translator + :members: + :private-members: +.. automodule:: federatedscope.core.data.utils + :members: federatedscope.core.splitters ----------------------- diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index b2691b8b4..499cd344a 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -7,7 +7,14 @@ class StandaloneDataDict(dict): """ - `StandaloneDataDict` maintain several `ClientData`. + ``StandaloneDataDict`` maintain several ``ClientData``, only used in \ + ``Standalone`` mode to be passed to ``Runner``, which will conduct \ + several preprocess based on ``global_cfg``, see ``preprocess()`` \ + for details. + + Args: + datadict: ``Dict`` with ``client_id`` as key, ``ClientData`` as value. + global_cfg: global ``CfgNode`` """ def __init__(self, datadict, global_cfg): """ @@ -23,11 +30,12 @@ def __init__(self, datadict, global_cfg): def resetup(self, global_cfg, client_cfgs=None): """ - Resetup new configs for `ClientData`, which might be used in HPO. + Reset-up new configs for ``ClientData``, when the configs change \ + which might be used in HPO. Args: - global_cfg: enable new config for `ClientData` - client_cfgs: enable new client-specific config for `ClientData` + global_cfg: enable new config for ``ClientData`` + client_cfgs: enable new client-specific config for ``ClientData`` """ self.global_cfg, self.client_cfgs = global_cfg, client_cfgs for client_id, client_data in self.items(): @@ -46,9 +54,11 @@ def resetup(self, global_cfg, client_cfgs=None): def preprocess(self, datadict): """ - Preprocess for StandaloneDataDict for: - 1. Global evaluation (merge test data). - 2. Global mode (train with centralized setting, merge all data). + Preprocess for: + + (1) Global evaluation (merge test data). + (2) Global mode (train with centralized setting, merge all data). + (3) Apply data attack algorithms. Args: datadict: dict with `client_id` as key, `ClientData` as value. @@ -82,8 +92,7 @@ def preprocess(self, datadict): def attack(self, datadict): """ - Apply attack to `StandaloneDataDict`. - + Apply attack to ``StandaloneDataDict``. """ if 'backdoor' in self.global_cfg.attack.attack_method and 'edge' in \ self.global_cfg.attack.trigger_type: @@ -125,20 +134,20 @@ def attack(self, datadict): class ClientData(dict): """ - `ClientData` converts dataset to train/val/test DataLoader. - Key `data` in `ClientData` is the raw dataset. + ``ClientData`` converts split data to ``DataLoader``. + + Args: + loader: ``Dataloader`` class or data dict which have been built + client_cfg: client-specific ``CfgNode`` + data: raw dataset, which will stay raw + train: train dataset, which will be converted to ``Dataloader`` + val: valid dataset, which will be converted to ``Dataloader`` + test: test dataset, which will be converted to ``Dataloader`` + + Note: + Key ``data`` in ``ClientData`` is the raw dataset. """ def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs): - """ - - Args: - loader: Dataloader class or data dict which have been built - client_cfg: client-specific CfgNode - data: raw dataset, which will stay raw - train: train dataset, which will be converted to DataLoader - val: valid dataset, which will be converted to DataLoader - test: test dataset, which will be converted to DataLoader - """ self.client_cfg = None self.train = train self.val = val @@ -151,12 +160,13 @@ def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs): def setup(self, new_client_cfg=None): """ + Set up ``DataLoader`` in ``ClientData`` with new configurations. Args: new_client_cfg: new client-specific CfgNode Returns: - Status: indicate whether the client_cfg is updated + Bool: Status for indicating whether the client_cfg is updated """ # if `batch_size` or `shuffle` change, reinstantiate DataLoader if self.client_cfg is not None: diff --git a/federatedscope/core/data/base_translator.py b/federatedscope/core/data/base_translator.py index 8dd71f0c9..9c39b1fab 100644 --- a/federatedscope/core/data/base_translator.py +++ b/federatedscope/core/data/base_translator.py @@ -9,8 +9,14 @@ class BaseDataTranslator: """ - Perform process: - Dataset -> ML split -> FL split -> Data (passed to Runner) + Translator is a tool to convert a centralized dataset to \ + ``StandaloneDataDict``, which is the input data of runner. + + Notes: + Translator is consist of several stages: + + Dataset -> ML split (``split_train_val_test()``) -> \ + FL split (``split_to_client()``) -> ``StandaloneDataDict`` """ def __init__(self, global_cfg, client_cfgs=None): @@ -27,7 +33,6 @@ def __init__(self, global_cfg, client_cfgs=None): def __call__(self, dataset): """ - Args: dataset: `torch.utils.data.Dataset`, `List` of (feature, label) or split dataset tuple of (train, val, test) or Tuple of @@ -47,8 +52,8 @@ def split(self, dataset): Perform ML split and FL split. Returns: - dict of `ClientData` with client_idx as key. - + dict of ``ClientData`` with client_idx as key to build \ + ``StandaloneDataDict`` """ train, val, test = self.split_train_val_test(dataset) datadict = self.split_to_client(train, val, test) @@ -59,8 +64,7 @@ def split_train_val_test(self, dataset): Split dataset to train, val, test if not provided. Returns: - split_data (List): List of split dataset, [train, val, test] - + List: List of split dataset, like ``[train, val, test]`` """ splits = self.global_cfg.data.splits if isinstance(dataset, tuple): @@ -83,11 +87,10 @@ def split_train_val_test(self, dataset): def split_to_client(self, train, val, test): """ - Split dataset to clients and build `ClientData`. + Split dataset to clients and build ``ClientData``. Returns: - data_dict (dict): dict of `ClientData` with client_idx as key. - + dict: dict of ``ClientData`` with ``client_idx`` as key. """ # Initialization diff --git a/federatedscope/core/data/dummy_translator.py b/federatedscope/core/data/dummy_translator.py index 640a80ec3..464733c60 100644 --- a/federatedscope/core/data/dummy_translator.py +++ b/federatedscope/core/data/dummy_translator.py @@ -4,10 +4,18 @@ class DummyDataTranslator(BaseDataTranslator): """ - DummyDataTranslator convert FL dataset to DataLoader. - Do not perform FL split. + ``DummyDataTranslator`` convert datadict to ``StandaloneDataDict``. \ + Compared to ``core.data.base_translator.BaseDataTranslator``, it do not \ + perform FL split. """ def split(self, dataset): + """ + Perform ML split + + Returns: + dict of ``ClientData`` with client_idx as key to build \ + ``StandaloneDataDict`` + """ if not isinstance(dataset, dict): raise TypeError(f'Not support data type {type(dataset)}') datadict = {} diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index 763279db4..b5e32b6ca 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -34,6 +34,16 @@ def __repr__(self): def load_dataset(config): + """ + Loads the dataset for the given config from branches + + Args: + config: configurations for FL, see ``federatedscope.core.configs`` + + Notes: + See https://federatedscope.io/docs/datazoo/ for all available data. + """ + if config.data.type.lower() == 'toy': from federatedscope.tabular.dataloader.toy import load_toy_data dataset, modified_config = load_toy_data(config) @@ -86,26 +96,17 @@ def load_dataset(config): def load_external_data(config=None): - r""" Based on the configuration file, this function imports external - datasets and applies train/valid/test splits and split by some specific - `splitter` into the standard FederatedScope input data format. + """ + Based on the configuration file, this function imports external \ + datasets and applies train/valid/test. Args: config: `CN` from `federatedscope/core/configs/config.py` Returns: - data_local_dict: dict of split dataloader. - Format: - { - 'client_id': { - 'train': DataLoader(), - 'test': DataLoader(), - 'val': DataLoader() - } - } - modified_config: `CN` from `federatedscope/core/configs/config.py`, + (data, modified_config): tuple of ML split dataset, \ + and `CN` from `federatedscope/core/configs/config.py`, \ which might be modified in the function. - """ import torch @@ -506,6 +507,17 @@ def load_openml_data(tid, splits=None, config=None): def convert_data_mode(data, config): + """ + Convert ``StandaloneDataDict`` to ``ClientData`` in ``distributed`` mode. + + Args: + data: ``StandaloneDataDict`` + config: configuration of FL course, see `federatedscope.core.configs` + + Returns: + ``StandaloneDataDict`` in ``standalone`` mode, or ``ClientData`` in \ + ``distributed`` mode. + """ if config.federate.mode.lower() == 'standalone': return data else: @@ -524,12 +536,30 @@ def convert_data_mode(data, config): def get_func_args(func): + """ + Get the set of arguments that the function expects. + Args: + func: function + + Returns: + Arguments that the function expects + """ sign = inspect.signature(func).parameters.values() sign = set([val.name for val in sign]) return sign def filter_dict(func, kwarg): + """ + Filters out the common keys of kwarg that are not in kwarg. + + Args: + func: function to be filtered + kwarg: dict to filter + + Returns: + Filtered dict of arguments of the function. + """ sign = get_func_args(func) common_args = sign.intersection(kwarg.keys()) filtered_dict = {key: kwarg[key] for key in common_args} @@ -538,12 +568,16 @@ def filter_dict(func, kwarg): def merge_data(all_data, merged_max_data_id=None, specified_dataset_name=None): """ - Merge data from client 1 to `merged_max_data_id` contained in given - `all_data`. - :param all_data: - :param merged_max_data_id: - :param specified_dataset_name: - :return: + Merge data from client 1 to ``merged_max_data_id`` contained in given \ + ``all_data``. + + Args: + all_data: ``StandaloneDataDict`` + merged_max_data_id: max merged data index + specified_dataset_name: split name to be merged + + Returns: + Merged data. """ import torch.utils.data from federatedscope.core.data.wrap_dataset import WrapDataset @@ -631,18 +665,21 @@ def save_local_data(dir_path, val_data=None, val_targets=None): r""" + Save data to disk. Source: \ https://github.com/omarfoq/FedEM/blob/main/data/femnist/generate_data.py - save (`train_data`, `train_targets`) in {dir_path}/train.pt, - (`val_data`, `val_targets`) in {dir_path}/val.pt - and (`test_data`, `test_targets`) in {dir_path}/test.pt - :param dir_path: - :param train_data: - :param train_targets: - :param test_data: - :param test_targets: - :param val_data: - :param val_targets + Args: + train_data: x of train data + train_targets: y of train data + test_data: x of test data + test_targets: y of test data + val_data: x of validation data + val_targets:y of validation data + + Note: + save ``(`train_data`, `train_targets`)`` in ``{dir_path}/train.pt``, \ + ``(`val_data`, `val_targets`)`` in ``{dir_path}/val.pt`` \ + and ``(`test_data`, `test_targets`)`` in ``{dir_path}/test.pt`` """ import torch if (train_data is not None) and (train_targets is not None): @@ -656,17 +693,16 @@ def save_local_data(dir_path, def download_url(url: str, folder='folder'): - r"""Downloads the content of an url to a folder. - - Modified from `https://github.com/pyg-team/pytorch_geometric/blob/master - /torch_geometric/data/download.py` + """ + Downloads the content of an url to a folder. Modified from \ + https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric Args: url (string): The url of target file. folder (string): The target folder. Returns: - path (string): File path of downloaded files. + string: File path of downloaded files. """ file = url.rpartition('/')[2] From a761e289df60eb63f81332b2e875008a5eecb817 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Tue, 1 Nov 2022 19:23:59 +0800 Subject: [PATCH 14/21] fix typo --- federatedscope/core/data/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index b5e32b6ca..3fd13c245 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -538,8 +538,9 @@ def convert_data_mode(data, config): def get_func_args(func): """ Get the set of arguments that the function expects. + Args: - func: function + func: function to be analysis Returns: Arguments that the function expects From 2e7ee2dfd9a62b0f325f49ccd52cd77fe064ada7 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Tue, 1 Nov 2022 19:24:24 +0800 Subject: [PATCH 15/21] fix typo --- federatedscope/core/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/core/data/utils.py b/federatedscope/core/data/utils.py index 3fd13c245..1163575ac 100644 --- a/federatedscope/core/data/utils.py +++ b/federatedscope/core/data/utils.py @@ -40,7 +40,7 @@ def load_dataset(config): Args: config: configurations for FL, see ``federatedscope.core.configs`` - Notes: + Note: See https://federatedscope.io/docs/datazoo/ for all available data. """ From 68cd6d0f4404d6cfff745a47decd5dfa57b24ec3 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Tue, 1 Nov 2022 20:27:30 +0800 Subject: [PATCH 16/21] add doc string for monitor --- federatedscope/core/monitors/early_stopper.py | 55 ++++-- .../core/monitors/metric_calculator.py | 60 +++++- federatedscope/core/monitors/monitor.py | 173 ++++++++++++------ 3 files changed, 208 insertions(+), 80 deletions(-) diff --git a/federatedscope/core/monitors/early_stopper.py b/federatedscope/core/monitors/early_stopper.py index 48fb146d8..fe9688821 100644 --- a/federatedscope/core/monitors/early_stopper.py +++ b/federatedscope/core/monitors/early_stopper.py @@ -5,28 +5,24 @@ # TODO: make this as a sub-module of monitor class class EarlyStopper(object): """ - Track the history of metric (e.g., validation loss), - check whether should stop (training) process if the metric doesn't - improve after a given patience. + Track the history of metric (e.g., validation loss), \ + check whether should stop (training) process if the metric doesn't \ + improve after a given patience. + + Args: + patience (int): (Default: 5) How long to wait after last time the \ + monitored metric improved. Note that the \ + ``actual_checking_round = patience * cfg.eval.freq`` + delta (float): (Default: 0) Minimum change in the monitored metric to \ + indicate an improvement. + improve_indicator_mode (str): Early stop when no improve to \ + last ``patience`` round, in ``['mean', 'best']`` """ def __init__(self, patience=5, delta=0, improve_indicator_mode='best', the_larger_the_better=True): - """ - Args: - patience (int): How long to wait after last time the monitored - metric improved. - Note that the - actual_checking_round = patience * cfg.eval.freq - Default: 5 - delta (float): Minimum change in the monitored metric to - indicate an improvement. - Default: 0 - improve_indicator_mode (str): Early stop when no improve to - last `patience` round, in ['mean', 'best'] - """ assert 0 <= patience == int( patience ), "Please use a non-negtive integer to indicate the patience" @@ -47,10 +43,28 @@ def __init__(self, self.improvement_operator = operator.add def __track_and_check_dummy(self, new_result): + """ + Dummy stopper, always return false + + Args: + new_result: + + Returns: + False + """ self.early_stopped = False return self.early_stopped def __track_and_check_best(self, history_result): + """ + Tracks the best result and checks whether the patience is exceeded. + + Args: + history_result: results of all evaluation round + + Returns: + Bool: whether stop + """ new_result = history_result[-1] if self.best_metric is None: self.best_metric = new_result @@ -91,6 +105,15 @@ def __track_and_check_mean(self, history_result): return self.early_stopped def track_and_check(self, new_result): + """ + Checks the new result and if it improves it returns True. + + Args: + new_result: new evaluation result + + Returns: + Bool: whether stop + """ track_method = self.__track_and_check_dummy # do nothing if self.patience == 0: diff --git a/federatedscope/core/monitors/metric_calculator.py b/federatedscope/core/monitors/metric_calculator.py index ed3ec4b3b..fb8e89f3c 100644 --- a/federatedscope/core/monitors/metric_calculator.py +++ b/federatedscope/core/monitors/metric_calculator.py @@ -17,8 +17,14 @@ class MetricCalculator(object): - def __init__(self, eval_metric: Union[Set[str], List[str], str]): + """ + Initializes the metric functions for the monitor. Use ``eval(ctx)`` \ + to get evaluation results. + Args: + eval_metric: set of metric names + """ + def __init__(self, eval_metric: Union[Set[str], List[str], str]): # Add personalized metrics if isinstance(eval_metric, str): eval_metric = {eval_metric} @@ -29,6 +35,53 @@ def __init__(self, eval_metric: Union[Set[str], List[str], str]): self.eval_metric = self.get_metric_funcs(eval_metric) def get_metric_funcs(self, eval_metric): + """ + Build metrics for evaluation. + Args: + self: write your description + eval_metric: write your description + + Returns: + A metric calculator dict, such as \ + ``{'loss': (eval_loss, False), 'acc': (eval_acc, True), ...}`` + + Note: + The key-value pairs of built-in metric and related funcs and \ + ``the_larger_the_better`` sign is shown below: + ================= ============================================= = + Metric name Source \ + The larger the better + ================= ============================================= = + ``loss`` ``monitors.metric_calculator.eval_loss`` \ + False + ``avg_loss`` ``monitors.metric_calculator.eval_avg_loss`` \ + False + ``total`` ``monitors.metric_calculator.eval_total`` \ + False + ``correct`` ``monitors.metric_calculator.eval_correct`` \ + True + ``acc`` ``monitors.metric_calculator.eval_acc`` \ + True + ``ap`` ``monitors.metric_calculator.eval_ap`` \ + True + ``f1`` ``monitors.metric_calculator.eval_f1_score`` \ + True + ``roc_auc`` ``monitors.metric_calculator.eval_roc_auc`` \ + True + ``rmse`` ``monitors.metric_calculator.eval_rmse`` \ + False + ``mse`` ``monitors.metric_calculator.eval_mse`` \ + False + ``loss_regular`` ``monitors.metric_calculator.eval_regular`` \ + False + ``imp_ratio`` ``monitors.metric_calculator.eval_imp_ratio`` \ + True + ``std`` ``None`` \ + False + ``hits@{n}`` ``monitors.metric_calculator.eval_hits`` \ + True + ================= ============================================= = + """ metric_buildin = { metric: SUPPORT_METRICS[metric] for metric in {'loss', 'avg_loss', 'total'} | eval_metric @@ -50,10 +103,11 @@ def eval(self, ctx): return results def _check_and_parse(self, ctx): - """Check the format of the prediction and labels + """ + Check the format of the prediction and labels Args: - ctx: + ctx: context of trainer, see ``core.trainers.context`` Returns: y_true: The ground truth labels diff --git a/federatedscope/core/monitors/monitor.py b/federatedscope/core/monitors/monitor.py index 3050e8e56..a20bbbe46 100644 --- a/federatedscope/core/monitors/monitor.py +++ b/federatedscope/core/monitors/monitor.py @@ -25,10 +25,25 @@ class Monitor(object): """ - Provide the monitoring functionalities such as formatting the - evaluation results into diverse metrics. - Besides the prediction related performance, the monitor also can - track efficiency related metrics for a worker + Provide the monitoring functionalities such as formatting the \ + evaluation results into diverse metrics. \ + Besides the prediction related performance, the monitor also can \ + track efficiency related metrics for a worker + + Args: + cfg: a cfg node object + monitored_object: object to be monitored + + Attributes: + log_res_best: best ever seen results + outdir: output directory + use_wandb: whether use ``wandb`` + wandb_online_track: whether use ``wandb`` to track online + monitored_object: object to be monitored + metric_calculator: metric calculator, / + see ``core.monitors.metric_calculator`` + round_wise_update_key: key to decide which result of evaluation \ + round is better """ SUPPORTED_FORMS = ['weighted_avg', 'avg', 'fairness', 'raw'] @@ -93,20 +108,38 @@ def __init__(self, cfg, monitored_object=None): exit() def eval(self, ctx): + """ + Evaluates the given context with ``metric_calculator``. + + Args: + ctx: context of trainer, see ``core.trainers.context`` + + Returns: + Evaluation results. + """ results = self.metric_calculator.eval(ctx) return results def global_converged(self): + """ + Calculate wall time and round when global convergence has been reached. + """ self.global_convergence_wall_time = datetime.datetime.now( ) - self.fl_begin_wall_time self.global_convergence_round = self.monitored_object.state def local_converged(self): + """ + Calculate wall time and round when local convergence has been reached. + """ self.local_convergence_wall_time = datetime.datetime.now( ) - self.fl_begin_wall_time self.local_convergence_round = self.monitored_object.state def finish_fl(self): + """ + When FL finished, write system metrics to file. + """ self.fl_end_wall_time = datetime.datetime.now( ) - self.fl_begin_wall_time @@ -143,9 +176,8 @@ def merge_system_metrics_simulation_mode(self, file_io=True, from_global_monitors=False): """ - average the system metrics recorded in "system_metrics.json" by - all workers - :return: + Average the system metrics recorded in ``system_metrics.json`` by \ + all workers """ all_sys_metrics = defaultdict(list) @@ -236,6 +268,9 @@ def merge_system_metrics_simulation_mode(self, def save_formatted_results(self, formatted_res, save_file_name="eval_results.log"): + """ + Save formatted results to a file. + """ line = str(formatted_res) + "\n" if save_file_name != "": with open(os.path.join(self.outdir, save_file_name), @@ -254,6 +289,9 @@ def save_formatted_results(self, exit() def finish_fed_runner(self, fl_mode=None): + """ + Finish the Fed runner. + """ self.compress_raw_res_file() if fl_mode == "standalone": self.merge_system_metrics_simulation_mode() @@ -289,6 +327,9 @@ def finish_fed_runner(self, fl_mode=None): wandb.summary[k] = v def compress_raw_res_file(self): + """ + Compress the raw res file to be written to disk. + """ old_f_name = os.path.join(self.outdir, "eval_results.raw") if os.path.exists(old_f_name): logger.info( @@ -306,7 +347,7 @@ def format_eval_res(self, forms=None, return_raw=False): """ - format the evaluation results from trainer.ctx.eval_results + Format the evaluation results from ``trainer.ctx.eval_results`` Args: results (dict): a dict to store the evaluation results {metric: @@ -317,41 +358,45 @@ def format_eval_res(self, return_raw (bool): return either raw results, or other results Returns: - round_formatted_results (dict): a formatted results with - different forms and roles, - e.g., - { - 'Role': 'Server #', - 'Round': 200, - 'Results_weighted_avg': { - 'test_avg_loss': 0.58, 'test_acc': 0.67, 'test_correct': - 3356, 'test_loss': 2892, 'test_total': 5000 - }, - 'Results_avg': { - 'test_avg_loss': 0.57, 'test_acc': 0.67, 'test_correct': - 3356, 'test_loss': 2892, 'test_total': 5000 - }, - 'Results_fairness': { - 'test_total': 33.99, 'test_correct': 27.185, - 'test_avg_loss_std': 0.433551, - 'test_avg_loss_bottom_decile': 0.356503, - 'test_avg_loss_top_decile': 1.212492, - 'test_avg_loss_min': 0.198317, 'test_avg_loss_max': 3.603567, - 'test_avg_loss_bottom10%': 0.276681, 'test_avg_loss_top10%': - 1.686649, - 'test_avg_loss_cos1': 0.867932, 'test_avg_loss_entropy': 5.164172, - 'test_loss_std': 13.686828, 'test_loss_bottom_decile': 11.822035, - 'test_loss_top_decile': 39.727236, 'test_loss_min': 7.337724, - 'test_loss_max': 100.899873, 'test_loss_bottom10%': 9.618685, - 'test_loss_top10%': 54.96769, 'test_loss_cos1': 0.880356, - 'test_loss_entropy': 5.175803, 'test_acc_std': 0.123823, - 'test_acc_bottom_decile': 0.676471, 'test_acc_top_decile': - 0.916667, - 'test_acc_min': 0.071429, 'test_acc_max': 0.972973, - 'test_acc_bottom10%': 0.527482, 'test_acc_top10%': 0.94486, - 'test_acc_cos1': 0.988134, 'test_acc_entropy': 5.283755 - }, + dict: round_formatted_results, a formatted results with \ + different forms and roles + + Note: + Example of return value: + ``` + { \ + 'Role': 'Server #', \ + 'Round': 200, \ + 'Results_weighted_avg': { \ + 'test_avg_loss': 0.58, 'test_acc': 0.67, 'test_correct': \ + 3356, 'test_loss': 2892, 'test_total': 5000 \ + }, \ + 'Results_avg': { \ + 'test_avg_loss': 0.57, 'test_acc': 0.67, 'test_correct': \ + 3356, 'test_loss': 2892, 'test_total': 5000 \ + }, \ + 'Results_fairness': { \ + 'test_total': 33.99, 'test_correct': 27.185, \ + 'test_avg_loss_std': 0.433551, \ + 'test_avg_loss_bottom_decile': 0.356503, \ + 'test_avg_loss_top_decile': 1.212492, \ + 'test_avg_loss_min': 0.198317, 'test_avg_loss_max': 3.603567, \ + 'test_avg_loss_bottom10%': 0.276681, 'test_avg_loss_top10%': \ + 1.686649, \ + 'test_avg_loss_cos1': 0.8679, 'test_avg_loss_entropy': 5.1641, \ + 'test_loss_std': 13.686828, 'test_loss_bottom_decile': 11.8220, \ + 'test_loss_top_decile': 39.727236, 'test_loss_min': 7.337724, \ + 'test_loss_max': 100.899873, 'test_loss_bottom10%': 9.618685, \ + 'test_loss_top10%': 54.96769, 'test_loss_cos1': 0.880356, \ + 'test_loss_entropy': 5.175803, 'test_acc_std': 0.123823, \ + 'test_acc_bottom_decile': 0.676471, 'test_acc_top_decile': \ + 0.916667, \ + 'test_acc_min': 0.071429, 'test_acc_max': 0.972973, \ + 'test_acc_bottom10%': 0.527482, 'test_acc_top10%': 0.94486, \ + 'test_acc_cos1': 0.988134, 'test_acc_entropy': 5.283755 \ + }, \ } + ``` """ if forms is None: forms = ['weighted_avg', 'avg', 'fairness', 'raw'] @@ -433,15 +478,16 @@ def format_eval_res(self, round_formatted_results def calc_blocal_dissim(self, last_model, local_updated_models): - ''' + """ Arguments: last_model (dict): the state of last round. - local_updated_models (list): each element is ooxx. + local_updated_models (list): each element is model. + Returns: - b_local_dissimilarity (dict): the measurements proposed in - "Tian Li, Anit Kumar Sahu, Manzil Zaheer, and et al. Federated + dict: b_local_dissimilarity, the measurements proposed in \ + "Tian Li, Anit Kumar Sahu, Manzil Zaheer, and et al. Federated \ Optimization in Heterogeneous Networks". - ''' + """ # for k, v in last_model.items(): # print(k, v) # for i, elem in enumerate(local_updated_models): @@ -481,6 +527,9 @@ def calc_blocal_dissim(self, last_model, local_updated_models): return b_local_dissimilarity def convert_size(self, size_bytes): + """ + Convert bytes to human-readable size. + """ import math if size_bytes <= 0: return str(size_bytes) @@ -492,11 +541,11 @@ def convert_size(self, size_bytes): def track_model_size(self, models): """ - calculate the total model size given the models hold by the - worker/trainer + calculate the total model size given the models hold by the \ + worker/trainer - :param models: torch.nn.Module or list of torch.nn.Module - :return: + Args + models: torch.nn.Module or list of torch.nn.Module """ if self.total_model_size != 0: logger.warning( @@ -514,13 +563,9 @@ def track_model_size(self, models): def track_avg_flops(self, flops, sample_num=1): """ - update the average flops for forwarding each data sample, - for most models and tasks, - the averaging is not needed as the input shape is fixed - - :param flops: flops/ - :param sample_num: - :return: + update the average flops for forwarding each data sample, \ + for most models and tasks, \ + the averaging is not needed as the input shape is fixed """ self.flops_per_sample = (self.flops_per_sample * self.flop_count + @@ -528,16 +573,22 @@ def track_avg_flops(self, flops, sample_num=1): self.flop_count += 1 def track_upload_bytes(self, bytes): + """ + Track the number of bytes uploaded. + """ self.total_upload_bytes += bytes def track_download_bytes(self, bytes): + """ + Track the number of bytes downloaded. + """ self.total_download_bytes += bytes def update_best_result(self, best_results, new_results, results_type): """ - update best evaluation results. - by default, the update is based on validation loss with - `round_wise_update_key="val_loss" ` + Update best evaluation results. \ + by default, the update is based on validation loss with \ + ``round_wise_update_key="val_loss" `` """ update_best_this_round = False if not isinstance(new_results, dict): From ab345a56238db4c401398635fc1ce07b34d9e637 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Wed, 2 Nov 2022 11:53:21 +0800 Subject: [PATCH 17/21] fix docstring in app --- doc/source/cv.rst | 6 +++++- federatedscope/cv/dataloader/dataloader.py | 22 +++++++++++++------- federatedscope/cv/dataset/leaf.py | 5 +++-- federatedscope/cv/trainer/trainer.py | 3 +++ federatedscope/nlp/dataloader/dataloader.py | 23 ++++++++++++++------- federatedscope/nlp/trainer/trainer.py | 3 +++ 6 files changed, 43 insertions(+), 19 deletions(-) diff --git a/doc/source/cv.rst b/doc/source/cv.rst index 12d1c7e31..72db5e579 100644 --- a/doc/source/cv.rst +++ b/doc/source/cv.rst @@ -4,7 +4,11 @@ Federated Computer Vision Module References federatedscope.cv.dataset ----------------------- -.. automodule:: federatedscope.cv.dataset +.. automodule:: federatedscope.cv.dataset.leaf + :members: + :private-members: + +.. automodule:: federatedscope.cv.dataset.leaf_cv :members: :private-members: diff --git a/federatedscope/cv/dataloader/dataloader.py b/federatedscope/cv/dataloader/dataloader.py index 9ea27e492..edb736a20 100644 --- a/federatedscope/cv/dataloader/dataloader.py +++ b/federatedscope/cv/dataloader/dataloader.py @@ -3,14 +3,20 @@ def load_cv_dataset(config=None): - r""" - return { - 'client_id': { - 'train': DataLoader(), - 'test': DataLoader(), - 'val': DataLoader() - } - } + """ + Return the dataset of ``femnist`` or ``celeba``. + + Args: + config: configurations for FL, see ``federatedscope.core.configs`` + + Returns: + FL dataset dict, with ``client_id`` as key. + + Note: + ``load_cv_dataset()`` will return a dict as shown below: + ``` + {'client_id': {'train': dataset, 'test': dataset, 'val': dataset}} + ``` """ splits = config.data.splits diff --git a/federatedscope/cv/dataset/leaf.py b/federatedscope/cv/dataset/leaf.py index eb253e2d5..7b809da6a 100644 --- a/federatedscope/cv/dataset/leaf.py +++ b/federatedscope/cv/dataset/leaf.py @@ -18,7 +18,8 @@ def is_exists(path, names): class LEAF(Dataset): - """Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings" + """ + Base class for LEAF dataset from "LEAF: A Benchmark for Federated Settings" Arguments: root (str): root path. @@ -91,7 +92,7 @@ def process(self): class LocalDataset(Dataset): """ - Convert data list to torch Dataset to save memory usage. + Convert data list to torch Dataset to save memory usage. """ def __init__(self, Xs, diff --git a/federatedscope/cv/trainer/trainer.py b/federatedscope/cv/trainer/trainer.py index b00ebd66a..1ae8e021c 100644 --- a/federatedscope/cv/trainer/trainer.py +++ b/federatedscope/cv/trainer/trainer.py @@ -3,6 +3,9 @@ class CVTrainer(GeneralTorchTrainer): + """ + ``CVTrainer`` is the same as ``core.trainers.GeneralTorchTrainer``. + """ pass diff --git a/federatedscope/nlp/dataloader/dataloader.py b/federatedscope/nlp/dataloader/dataloader.py index d950a2a63..ac28666dc 100644 --- a/federatedscope/nlp/dataloader/dataloader.py +++ b/federatedscope/nlp/dataloader/dataloader.py @@ -5,14 +5,21 @@ def load_nlp_dataset(config=None): - r""" - return { - 'client_id': { - 'train': DataLoader(), - 'test': DataLoader(), - 'val': DataLoader() - } - } + """ + Return the dataset of ``shakespeare``, ``subreddit``, ``twitter``, \ + or ``synthetic``. + + Args: + config: configurations for FL, see ``federatedscope.core.configs`` + + Returns: + FL dataset dict, with ``client_id`` as key. + + Note: + ``load_nlp_dataset()`` will return a dict as shown below: + ``` + {'client_id': {'train': dataset, 'test': dataset, 'val': dataset}} + ``` """ splits = config.data.splits diff --git a/federatedscope/nlp/trainer/trainer.py b/federatedscope/nlp/trainer/trainer.py index 2e640e25d..6be8115ba 100644 --- a/federatedscope/nlp/trainer/trainer.py +++ b/federatedscope/nlp/trainer/trainer.py @@ -4,6 +4,9 @@ class NLPTrainer(GeneralTorchTrainer): + """ + ``NLPTrainer`` is used for text data. + """ def _hook_on_batch_forward(self, ctx): x, label = [move_to(_, ctx.device) for _ in ctx.data_batch] if isinstance(x, dict): From 3303734158c67a40b2a4560338d028b3fe1bd23b Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Fri, 4 Nov 2022 17:04:12 +0800 Subject: [PATCH 18/21] Add ut for trainer and fix some typos --- ...xtra_dependencies_torch1.10-application.sh | 1 - .../contrib/trainer/torch_example.py | 1 - federatedscope/core/data/base_translator.py | 2 +- .../core/splitters/generic/iid_splitter.py | 3 +- federatedscope/core/trainers/trainer.py | 1 + tests/test_trainer_property.py | 81 +++++++++++++++++++ 6 files changed, 85 insertions(+), 4 deletions(-) create mode 100644 tests/test_trainer_property.py diff --git a/environment/extra_dependencies_torch1.10-application.sh b/environment/extra_dependencies_torch1.10-application.sh index 3e7963f9b..f6a39f746 100644 --- a/environment/extra_dependencies_torch1.10-application.sh +++ b/environment/extra_dependencies_torch1.10-application.sh @@ -8,7 +8,6 @@ conda install -y nltk # Speech and NLP conda install -y sentencepiece textgrid typeguard -c conda-forge conda install -y transformers==4.16.2 tokenizers==0.10.3 datasets -c huggingface -c conda-forge -conda install -y torchtext==0.9.0 -c pytorch # Tabular conda install -y openml==0.12.2 diff --git a/federatedscope/contrib/trainer/torch_example.py b/federatedscope/contrib/trainer/torch_example.py index 09f21b997..dbdf938b7 100644 --- a/federatedscope/contrib/trainer/torch_example.py +++ b/federatedscope/contrib/trainer/torch_example.py @@ -83,7 +83,6 @@ def evaluate(self, target_data_split_name='test'): def update(self, model_parameters, strict=False): self.model.load_state_dict(model_parameters, strict) - return self.get_model_para() def get_model_para(self): return self.model.cpu().state_dict() diff --git a/federatedscope/core/data/base_translator.py b/federatedscope/core/data/base_translator.py index 9c39b1fab..6217a7c2f 100644 --- a/federatedscope/core/data/base_translator.py +++ b/federatedscope/core/data/base_translator.py @@ -13,7 +13,7 @@ class BaseDataTranslator: ``StandaloneDataDict``, which is the input data of runner. Notes: - Translator is consist of several stages: + The ``Translator`` is consist of several stages: Dataset -> ML split (``split_train_val_test()``) -> \ FL split (``split_to_client()``) -> ``StandaloneDataDict`` diff --git a/federatedscope/core/splitters/generic/iid_splitter.py b/federatedscope/core/splitters/generic/iid_splitter.py index c802b41db..1f28fa151 100644 --- a/federatedscope/core/splitters/generic/iid_splitter.py +++ b/federatedscope/core/splitters/generic/iid_splitter.py @@ -4,7 +4,8 @@ class IIDSplitter(BaseSplitter): """ - This splitter split dataset randomly . + This splitter splits dataset following the independent and identically \ + distribution. Args: client_num: the dataset will be split into ``client_num`` pieces diff --git a/federatedscope/core/trainers/trainer.py b/federatedscope/core/trainers/trainer.py index ad63c02aa..689f64abe 100644 --- a/federatedscope/core/trainers/trainer.py +++ b/federatedscope/core/trainers/trainer.py @@ -77,6 +77,7 @@ def cfg(self): @cfg.setter def cfg(self, new_cfg): self._cfg = new_cfg + self.ctx.cfg = new_cfg self._setup_data_related_var_in_ctx(self.ctx) def parse_data(self, data): diff --git a/tests/test_trainer_property.py b/tests/test_trainer_property.py new file mode 100644 index 000000000..881018856 --- /dev/null +++ b/tests/test_trainer_property.py @@ -0,0 +1,81 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.core.auxiliaries.logging import update_logger +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.auxiliaries.runner_builder import get_runner +from federatedscope.core.auxiliaries.worker_builder import get_server_cls, \ + get_client_cls + + +class TrainerCfgTest(unittest.TestCase): + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def set_config_trainer_cfg_test(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.eval.freq = 10 + cfg.eval.metrics = ['acc', 'loss_regular'] + + cfg.federate.mode = 'standalone' + cfg.data.root = 'test_data/' + cfg.data.type = 'femnist' + cfg.data.splits = [0.6, 0.2, 0.2] + cfg.dataloader.batch_size = 10 + cfg.data.subsample = 0.05 + cfg.data.transform = [['ToTensor'], + [ + 'Normalize', { + 'mean': [0.1307], + 'std': [0.3081] + } + ]] + + cfg.model.type = 'convnet2' + cfg.model.hidden = 2048 + cfg.model.out_channels = 62 + + cfg.train.optimizer.lr = 0.001 + cfg.train.optimizer.weight_decay = 0.0 + cfg.train.batch_or_epoch = 'epoch' + cfg.grad.grad_clip = 5.0 + + cfg.criterion.type = 'CrossEntropyLoss' + cfg.trainer.type = 'cvtrainer' + cfg.seed = 123 + + return backup_cfg + + def test_trainer_cfg(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_trainer_cfg_test(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + + num_train_batch = Fed_runner.client[1].trainer.ctx.num_train_batch + new_cfg = init_cfg.clone() + new_cfg.dataloader.batch_size = 64 + Fed_runner.client[1].trainer.cfg = new_cfg + new_num_train_batch = Fed_runner.client[1].trainer.ctx.num_train_batch + self.assertLess(new_num_train_batch, num_train_batch) + + init_cfg.merge_from_other_cfg(backup_cfg) + + +if __name__ == '__main__': + unittest.main() From b03d7202602cffba7704641e80b43b85fada28fb Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Wed, 9 Nov 2022 20:28:28 +0800 Subject: [PATCH 19/21] fix typo --- federatedscope/autotune/algos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/autotune/algos.py b/federatedscope/autotune/algos.py index 40e85fc21..a21f2ac95 100644 --- a/federatedscope/autotune/algos.py +++ b/federatedscope/autotune/algos.py @@ -57,7 +57,7 @@ def run(self): server_class=get_server_cls(self._trial_cfg), client_class=get_client_cls(self._trial_cfg), config=self._trial_cfg.clone(), - client_config=client_cfgs) + client_config=self._client_cfgs) results = Fed_runner.run() key1, key2 = self._trial_cfg.hpo.metric.split('.') self._returns['perf'] = results[key1][key2] From bc3a850d2df72c8bfe19364fa30623f078f7cb0d Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Thu, 10 Nov 2022 10:33:14 +0800 Subject: [PATCH 20/21] fix bug caused by merge --- federatedscope/cl/trainer/trainer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index 2232815ff..dbe2dcf55 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -1,15 +1,9 @@ -from federatedscope.core.auxiliaries.enums import MODE +import torch from federatedscope.register import register_trainer -from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer -from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler from federatedscope.core.trainers import GeneralTorchTrainer -from federatedscope.core.trainers.context import Context from federatedscope.core.trainers.context import CtxVar -from federatedscope.core.auxiliaries.enums import LIFECYCLE +from federatedscope.core.trainers.enums import LIFECYCLE, MODE from federatedscope.core.auxiliaries import utils -import torch -import numpy as np -import copy class CLTrainer(GeneralTorchTrainer): From 6a5ea837e865342a82e89045e6e516d4dc8b4897 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Thu, 10 Nov 2022 12:04:26 +0800 Subject: [PATCH 21/21] fix utils --- federatedscope/cl/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/federatedscope/cl/trainer/trainer.py b/federatedscope/cl/trainer/trainer.py index dbe2dcf55..ec3f1c557 100644 --- a/federatedscope/cl/trainer/trainer.py +++ b/federatedscope/cl/trainer/trainer.py @@ -3,7 +3,7 @@ from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.trainers.context import CtxVar from federatedscope.core.trainers.enums import LIFECYCLE, MODE -from federatedscope.core.auxiliaries import utils +from federatedscope.core.trainers.utils import move_to class CLTrainer(GeneralTorchTrainer): @@ -35,7 +35,7 @@ def get_train_pred_embedding(self): return [self.z1, self.z2] def _hook_on_batch_forward(self, ctx): - x, label = [utils.move_to(_, ctx.device) for _ in ctx.data_batch] + x, label = [move_to(_, ctx.device) for _ in ctx.data_batch] x1, x2 = x[0], x[1] if ctx.cur_mode in [MODE.TRAIN]: self.batches_aug_data_1 = x1