Skip to content

Commit

Permalink
Add docs for scripts of tree-based model (#562)
Browse files Browse the repository at this point in the history
  • Loading branch information
qbc2016 authored Mar 31, 2023
1 parent 05a0eac commit fde3548
Show file tree
Hide file tree
Showing 23 changed files with 48 additions and 93 deletions.
3 changes: 2 additions & 1 deletion federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def extend_fl_setting_cfg(cfg):
# ---------------------------------------------------------------------- #
cfg.vertical = CN()
cfg.vertical.use = False
cfg.vertical.mode = 'order_based' # ['order_based', 'label_based']
cfg.vertical.mode = 'feature_gathering'
# ['feature_gathering', 'label_scattering']
cfg.vertical.dims = [5, 10] # TODO: we need to explain dims
cfg.vertical.encryption = 'paillier'
cfg.vertical.key_size = 3072
Expand Down
20 changes: 20 additions & 0 deletions federatedscope/vertical_fl/tree_based_models/baseline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Here we give some hand-on examples, and roughly show their characteristics.

For more details, please see the specific configuration in each file.

| Yaml | Dataset | Model _type | Algo | Protect_method | Eval_protection |
| --------------------------------------------------- | -------------- | --------------------- | -------- | -------------- | --------------- |
| `gbdt_feature_gathering_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'gbdt'` | None | None |
| `gbdt_feature_gathering_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'gbdt'` | None | None |
| `gbdt_label_scattering_on_adult.yaml` | Adult (clas.) | `'label_scattering'` | `'gbdt'` | `'he'` | None |
| `rf_feature_gathering_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'rf'` | None | None |
| `rf_feature_gathering_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'rf'` | None | None |
| `rf_label_scattering_on_adult.yaml` | Adult (clas.) | `'label_scattering'` | `'rf'` | `'he'` | None |
| `xgb_feature_gathering_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'xgb'` | None | None |
| `xgb_feature_gathering_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | None | None |
| `xgb_feature_gathering_dp_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'xgb'` | None | None |
| `xgb_feature_gathering_dp_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | `'dp'` | None |
| `xgb_feature_gathering_op_boost_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | `'op_boost'` | None |
| `xgb_label_scattering_on_abalone.yaml` | Abalone (reg.) | `'label_scattering'` | `'xgb'` | `'he'` | None |
| `xgb_label_scattering_on_adult.yaml` | Adult (clas.) | `'label_scattering'` | `'xgb'` | `'he'` | None |
| `xgb_feature_gathering_dp_on_adult_by_he_eval.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | `'he'` | `'he'` |
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ vertical:
use: True
dims: [4, 8]
algo: 'rf'
data_size_for_debug: 1500
data_size_for_debug: 2000
feature_subsample_ratio: 0.5
eval:
freq: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ vertical:
use: True
dims: [7, 14]
algo: 'rf'
data_size_for_debug: 1500
data_size_for_debug: 2000
feature_subsample_ratio: 0.8
eval:
freq: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ vertical:
use: True
dims: [ 7, 14 ]
algo: 'rf'
mode: 'label_based'
mode: 'label_scattering'
data_size_for_debug: 2000
feature_subsample_ratio: 0.4
protect_object: 'grad_and_hess'
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ train:
eta: 0.19
vertical:
use: True
mode: 'label_based'
mode: 'label_scattering'
dims: [6, 10]
algo: 'xgb'
data_size_for_debug: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ train:
eta: 0.5
vertical:
use: True
mode: 'label_based'
mode: 'label_scattering'
dims: [7, 14]
algo: 'xgb'
data_size_for_debug: 2000
Expand Down
5 changes: 4 additions & 1 deletion federatedscope/vertical_fl/tree_based_models/model/Tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ def cal_gain(self, split_idx, y, indicator):
else:
raise ValueError(f'Task type error: {self.task_type}')

def cal_gain_for_rf_label_base(self, node_num, split_idx, y, indicator):
def cal_gain_for_rf_label_scattering(self, node_num, split_idx, y,
indicator):

y_left_children_label_sum = np.sum(y[:split_idx])
y_right_children_label_sum = np.sum(y[split_idx:])
left_children_num = np.sum(indicator[:split_idx])
right_children_num = np.sum(indicator[split_idx:])

if self.task_type == 'classification':
if np.sum(indicator) == np.sum(y) or np.sum(y) == 0:
return 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def get_best_gain_from_msg(self, msg, tree_num=None, node_num=None):
continue

gain = self.model[
tree_num].cal_gain_for_rf_label_base(
tree_num].cal_gain_for_rf_label_scattering(
node_num, value_idx, label, indicator)
if gain < best_gain:
client_has_max_gain = client_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def train(self, training_info=None, tree_num=0, node_num=None):
# Start to build a tree
if node_num is None:
if training_info is not None and \
self.cfg.vertical.mode == 'order_based':
self.cfg.vertical.mode == 'feature_gathering':
self.merged_feature_order, self.extra_info = \
self._parse_training_info(training_info)
return self._compute_for_root(tree_num=tree_num)
Expand Down Expand Up @@ -274,7 +274,7 @@ def _compute_for_node(self, tree_num, node_num):
return self._compute_for_node(tree_num, node_num + 1)
# Calculate best gain
else:
if self.cfg.vertical.mode == 'order_based':
if self.cfg.vertical.mode == 'feature_gathering':
improved_flag, split_ref, _ = self._get_best_gain(
tree_num, node_num)
if improved_flag:
Expand All @@ -290,7 +290,7 @@ def _compute_for_node(self, tree_num, node_num):
else:
self._set_weight_and_status(tree_num, node_num)
return self._compute_for_node(tree_num, node_num + 1)
elif self.cfg.vertical.mode == 'label_based':
elif self.cfg.vertical.mode == 'label_scattering':
results = (self.model[tree_num][node_num].grad,
self.model[tree_num][node_num].hess,
self.model[tree_num][node_num].indicator, tree_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ def callback_func_for_data_sample(self, message: Message):
index=batch_index)
self.feature_order = feature_order_info['feature_order']

if self._cfg.vertical.mode == 'order_based':
if self._cfg.vertical.mode == 'feature_gathering':
training_info = feature_order_info
elif self._cfg.vertical.mode == 'label_based':
elif self._cfg.vertical.mode == 'label_scattering':
training_info = 'dummy_info'
else:
raise TypeError(f'The expected types of vertical.mode include '
f'["label_based", "order_based"], but got '
f'{self._cfg.vertical.mode}.')
raise TypeError(
f'The expected types of vertical.mode include '
f'["label_scattering", "feature_gathering"], but got '
f'{self._cfg.vertical.mode}.')

self.comm_manager.send(
Message(msg_type='training_info',
Expand All @@ -122,7 +123,7 @@ def start_a_new_training_round(self,
self.msg_buffer['train'].clear()
self.feature_order = feature_order_info['feature_order']
self.msg_buffer['train'][self.ID] = feature_order_info \
if self._cfg.vertical.mode == 'order_based' else 'dummy_info'
if self._cfg.vertical.mode == 'feature_gathering' else 'dummy_info'
self.state = tree_num
receiver = [
each for each in list(self.comm_manager.neighbors.keys())
Expand Down
12 changes: 6 additions & 6 deletions tests/test_tree_based_model_for_vfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def set_config_for_rf_base(self, cfg):

return backup_cfg

def set_config_for_rf_label_base(self, cfg):
def set_config_for_rf_label_scattering(self, cfg):
backup_cfg = cfg.clone()

import torch
Expand Down Expand Up @@ -270,7 +270,7 @@ def set_config_for_xgb_dp_too_large_noise(self, cfg):

return backup_cfg

def set_config_for_label_based_xgb(self, cfg):
def set_config_for_label_scattering_xgb(self, cfg):
backup_cfg = cfg.clone()

import torch
Expand Down Expand Up @@ -298,7 +298,7 @@ def set_config_for_label_based_xgb(self, cfg):
cfg.vertical.use = True
cfg.vertical.dims = [7, 14]
cfg.vertical.algo = 'xgb'
cfg.vertical.mode = 'label_based'
cfg.vertical.mode = 'label_scattering'
cfg.vertical.protect_object = 'grad_and_hess'
cfg.vertical.protect_method = 'he'
cfg.vertical.data_size_for_debug = 2000
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_RF_Base(self):

def test_RF_lable_Base(self):
init_cfg = global_cfg.clone()
backup_cfg = self.set_config_for_rf_label_base(init_cfg)
backup_cfg = self.set_config_for_rf_label_scattering(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)

Expand Down Expand Up @@ -533,9 +533,9 @@ def test_XGB_use_dp_too_large_noise(self):
print(test_results)
self.assertLess(test_results['server_global_eval']['test_acc'], 0.76)

def test_label_based_XGB(self):
def test_label_scattering_XGB(self):
init_cfg = global_cfg.clone()
backup_cfg = self.set_config_for_label_based_xgb(init_cfg)
backup_cfg = self.set_config_for_label_scattering_xgb(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)

Expand Down

0 comments on commit fde3548

Please sign in to comment.