From 6c19980396944b70716ea8c7e9035b52a5782b54 Mon Sep 17 00:00:00 2001 From: qbc Date: Fri, 9 Dec 2022 12:11:23 +0800 Subject: [PATCH 1/3] distribute mode for xgb --- .../vertical_fl/xgb_base/worker/Test_base.py | 10 ++-- .../vertical_fl/xgb_base/worker/XGBClient.py | 16 ++++++- .../vertical_fl/xgb_base/worker/XGBServer.py | 9 ++++ .../distributed_xgb_client_1.yaml | 46 +++++++++++++++++++ .../distributed_xgb_client_2.yaml | 46 +++++++++++++++++++ .../distributed_xgb_server.yaml | 44 ++++++++++++++++++ .../run_distributed_xgb.sh | 18 ++++++++ 7 files changed, 183 insertions(+), 6 deletions(-) create mode 100644 scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml create mode 100644 scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml create mode 100644 scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml create mode 100644 scripts/distributed_scripts/run_distributed_xgb.sh diff --git a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py index 66734a5f5..6a34bbc24 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py +++ b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py @@ -43,10 +43,9 @@ def test_for_node(self, tree_num, node_num): if node_num >= 2**self.client.max_tree_depth - 1: if tree_num + 1 < self.client.num_of_trees: # TODO: add feedback during training - self.client.state += 1 - logger.info( - f'----------- Starting a new training round (Round ' - f'#{self.client.state}) -------------') + # self.client.state += 1 + logger.info(f'----------- Building a new tree (Tree ' + f'#{tree_num + 1}) -------------') # build the next tree self.client.fs.compute_for_root(tree_num + 1) @@ -67,8 +66,9 @@ def test_for_node(self, tree_num, node_num): each for each in list( self.client.comm_manager.neighbors.keys()) if each != self.client.server_id + and each != self.client.ID ], - content=None)) + content='None')) self.client.comm_manager.send( Message(msg_type='feature_importance', sender=self.client.ID, diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index a62d57774..377538dcd 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -43,6 +43,8 @@ def __init__(self, self.num_of_trees = None self.max_tree_depth = None + self.federate_mode = config.federate.mode + self.bin_num = config.train.optimizer.bin_num self.batch_size = config.data.batch_size @@ -103,6 +105,8 @@ def __init__(self, self.register_handlers('send_feature_importance', self.callback_func_for_send_feature_importance) + self.register_handlers('finish', self.callback_func_for_finish) + # save the order of values in each feature def order_feature(self, data): for j in range(data.shape[1]): @@ -125,6 +129,13 @@ def sample_data(self, index=None): def callback_func_for_model_para(self, message: Message): self.lambda_, self.gamma, self.num_of_trees, self.max_tree_depth \ = message.content + + if self.federate_mode == 'distributed': + self.comm_manager.add_neighbors(neighbor_id=self.ID, + address={ + 'host': self.comm_manager.host, + 'port': self.comm_manager.port + }) self.tree_list = [ Tree(self.max_tree_depth).tree for _ in range(self.num_of_trees) ] @@ -133,7 +144,7 @@ def callback_func_for_model_para(self, message: Message): # init y_hat self.y_hat = np.random.uniform(low=0.0, high=1.0, size=len(self.y)) # self.y_hat = np.zeros(len(self.y)) - logger.info(f'----------- Starting a new training round (Round ' + logger.info(f'---------- Building a new tree (Tree ' f'#{self.state}) -------------') self.comm_manager.send( Message( @@ -217,3 +228,6 @@ def callback_func_for_send_feature_importance(self, message: Message): state=self.state, receiver=self.server_id, content=self.feature_importance)) + + def callback_func_for_finish(self, message: Message): + pass diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py index 4b99302de..b2374ae89 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py @@ -86,7 +86,16 @@ def callback_func_for_feature_importance(self, message: Message): rnd=self.tree_num, role='Server #', forms=self._cfg.eval.report) + formatted_logs['feature_importance'] = self.feature_importance_dict logger.info(formatted_logs) + self.comm_manager.send( + Message(msg_type='finish', + sender=self.ID, + receiver=list( + self.comm_manager.get_neighbors().keys()), + state=self.state, + content='None')) + self.state = self.total_round_num + 1 def callback_func_for_test_result(self, message: Message): self.tree_num, self.metrics = message.content diff --git a/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml new file mode 100644 index 000000000..7558ce683 --- /dev/null +++ b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml @@ -0,0 +1,46 @@ +use_gpu: False +device: 0 +early_stop: + patience: 5 +seed: 12345 +federate: + client_num: 2 + mode: 'distributed' + make_global_eval: False + online_aggr: False + total_round_num: 20 +distribute: + use: True + server_host: '127.0.0.1' + server_port: 50051 + client_host: '127.0.0.1' + client_port: 50052 + role: 'client' + data_idx: 1 +data: + root: data/ + type: credit + splits: [0.8, 0.2] +dataloader: + type: raw + batch_size: 2000 +model: + type: lr +train: + optimizer: + bin_num: 100 + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +xgb_base: + use: True + use_bin: True + dims: [5, 10] +criterion: + type: CrossEntropyLoss +trainer: + type: none +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml new file mode 100644 index 000000000..30e4f3bee --- /dev/null +++ b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml @@ -0,0 +1,46 @@ +use_gpu: False +device: 0 +early_stop: + patience: 5 +seed: 12345 +federate: + client_num: 2 + mode: 'distributed' + make_global_eval: False + online_aggr: False + total_round_num: 20 +distribute: + use: True + server_host: '127.0.0.1' + server_port: 50051 + client_host: '127.0.0.1' + client_port: 50053 + role: 'client' + data_idx: 2 +data: + root: data/ + type: credit + splits: [0.8, 0.2] +dataloader: + type: raw + batch_size: 2000 +model: + type: lr +train: + optimizer: + bin_num: 100 + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +xgb_base: + use: True + use_bin: True + dims: [5, 10] +criterion: + type: CrossEntropyLoss +trainer: + type: none +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml b/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml new file mode 100644 index 000000000..02c83d82d --- /dev/null +++ b/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml @@ -0,0 +1,44 @@ +use_gpu: False +device: 0 +early_stop: + patience: 5 +seed: 12345 +federate: + client_num: 2 + mode: 'distributed' + make_global_eval: False + online_aggr: False + total_round_num: 20 +distribute: + use: True + server_host: '127.0.0.1' + server_port: 50051 + role: 'server' + data_idx: 0 +data: + root: data/ + type: credit + splits: [0.8, 0.2] +dataloader: + type: raw + batch_size: 2000 +model: + type: lr +train: + optimizer: + bin_num: 100 + lambda_: 0.1 + gamma: 0 + num_of_trees: 10 + max_tree_depth: 3 +xgb_base: + use: True + use_bin: True + dims: [5, 10] +criterion: + type: CrossEntropyLoss +trainer: + type: none +eval: + freq: 3 + best_res_update_round_wise_key: test_loss \ No newline at end of file diff --git a/scripts/distributed_scripts/run_distributed_xgb.sh b/scripts/distributed_scripts/run_distributed_xgb.sh new file mode 100644 index 000000000..87f531968 --- /dev/null +++ b/scripts/distributed_scripts/run_distributed_xgb.sh @@ -0,0 +1,18 @@ +set -e + +cd .. + +echo "Test distributed mode with XGB..." + +### server owns global test data +# python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server.yaml & +### server doesn't own data +python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml & +sleep 2 + +# clients +python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml & +sleep 2 +python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml & +sleep 2 + From 091d207c2d5289c2d181078fe82bb059dde515eb Mon Sep 17 00:00:00 2001 From: qbc Date: Fri, 9 Dec 2022 13:57:09 +0800 Subject: [PATCH 2/3] minor changes --- federatedscope/vertical_fl/xgb_base/worker/Test_base.py | 1 - federatedscope/vertical_fl/xgb_base/worker/XGBClient.py | 2 -- federatedscope/vertical_fl/xgb_base/worker/XGBServer.py | 1 + 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py index 6a34bbc24..8c4e2c6b5 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py +++ b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py @@ -43,7 +43,6 @@ def test_for_node(self, tree_num, node_num): if node_num >= 2**self.client.max_tree_depth - 1: if tree_num + 1 < self.client.num_of_trees: # TODO: add feedback during training - # self.client.state += 1 logger.info(f'----------- Building a new tree (Tree ' f'#{tree_num + 1}) -------------') # build the next tree diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index 377538dcd..4af1ac21b 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -104,7 +104,6 @@ def __init__(self, self.callback_func_for_compute_next_node) self.register_handlers('send_feature_importance', self.callback_func_for_send_feature_importance) - self.register_handlers('finish', self.callback_func_for_finish) # save the order of values in each feature @@ -143,7 +142,6 @@ def callback_func_for_model_para(self, message: Message): self.batch_index, self.x, self.y = self.sample_data() # init y_hat self.y_hat = np.random.uniform(low=0.0, high=1.0, size=len(self.y)) - # self.y_hat = np.zeros(len(self.y)) logger.info(f'---------- Building a new tree (Tree ' f'#{self.state}) -------------') self.comm_manager.send( diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py index b2374ae89..80eb3f275 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py @@ -95,6 +95,7 @@ def callback_func_for_feature_importance(self, message: Message): self.comm_manager.get_neighbors().keys()), state=self.state, content='None')) + # jump out running self.state = self.total_round_num + 1 def callback_func_for_test_result(self, message: Message): From 03933eedf5280f9aaa50de9d2c7b782cc8e0df95 Mon Sep 17 00:00:00 2001 From: qbc Date: Fri, 9 Dec 2022 16:58:02 +0800 Subject: [PATCH 3/3] modified according to comments --- federatedscope/vertical_fl/xgb_base/worker/XGBClient.py | 2 ++ .../distributed_configs/distributed_xgb_client_1.yaml | 2 +- .../distributed_configs/distributed_xgb_client_2.yaml | 2 +- .../distributed_configs/distributed_xgb_server.yaml | 2 +- scripts/distributed_scripts/run_distributed_xgb.sh | 4 +--- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index 4af1ac21b..ce6b717aa 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -129,6 +129,8 @@ def callback_func_for_model_para(self, message: Message): self.lambda_, self.gamma, self.num_of_trees, self.max_tree_depth \ = message.content + # client adds his own ID and address in his comm_manager.neighbors + # to send and receive messages from himself if self.federate_mode == 'distributed': self.comm_manager.add_neighbors(neighbor_id=self.ID, address={ diff --git a/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml index 7558ce683..8bdc38166 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_1.yaml @@ -1,4 +1,4 @@ -use_gpu: False +use_gpu: True device: 0 early_stop: patience: 5 diff --git a/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml index 30e4f3bee..0ad914743 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_xgb_client_2.yaml @@ -1,4 +1,4 @@ -use_gpu: False +use_gpu: True device: 0 early_stop: patience: 5 diff --git a/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml b/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml index 02c83d82d..5a4342697 100644 --- a/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml +++ b/scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml @@ -1,4 +1,4 @@ -use_gpu: False +use_gpu: True device: 0 early_stop: patience: 5 diff --git a/scripts/distributed_scripts/run_distributed_xgb.sh b/scripts/distributed_scripts/run_distributed_xgb.sh index 87f531968..a85be5bc7 100644 --- a/scripts/distributed_scripts/run_distributed_xgb.sh +++ b/scripts/distributed_scripts/run_distributed_xgb.sh @@ -4,9 +4,7 @@ cd .. echo "Test distributed mode with XGB..." -### server owns global test data -# python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server.yaml & -### server doesn't own data +### server python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_xgb_server.yaml & sleep 2