Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename models and add readme file #562

Merged
merged 7 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def extend_fl_setting_cfg(cfg):
# ---------------------------------------------------------------------- #
cfg.vertical = CN()
cfg.vertical.use = False
cfg.vertical.mode = 'order_based' # ['order_based', 'label_based']
cfg.vertical.mode = 'feature_gathering'
# ['feature_gathering', 'label_scattering']
cfg.vertical.dims = [5, 10] # TODO: we need to explain dims
cfg.vertical.encryption = 'paillier'
cfg.vertical.key_size = 3072
Expand Down
20 changes: 20 additions & 0 deletions federatedscope/vertical_fl/tree_based_models/baseline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Here we give some hand-on examples, and roughly show their characteristics.

For more details, please see the specific configuration in each file.

| Yaml | Dataset | Model _type | Algo | Protect_method | Eval_protection |
| --------------------------------------------------- | -------------- | --------------------- | -------- | -------------- | --------------- |
| `gbdt_feature_gathering_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'gbdt'` | None | None |
| `gbdt_feature_gathering_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'gbdt'` | None | None |
| `gbdt_label_scattering_on_adult.yaml` | Adult (clas.) | `'label_scattering'` | `'gbdt'` | `'he'` | None |
| `rf_feature_gathering_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'rf'` | None | None |
| `rf_feature_gathering_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'rf'` | None | None |
| `rf_label_scattering_on_adult.yaml` | Adult (clas.) | `'label_scattering'` | `'rf'` | `'he'` | None |
| `xgb_feature_gathering_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'xgb'` | None | None |
| `xgb_feature_gathering_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | None | None |
| `xgb_feature_gathering_dp_on_abalone.yaml` | Abalone (reg.) | `'feature_gathering'` | `'xgb'` | None | None |
| `xgb_feature_gathering_dp_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | `'dp'` | None |
| `xgb_feature_gathering_op_boost_on_adult.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | `'op_boost'` | None |
| `xgb_label_scattering_on_abalone.yaml` | Abalone (reg.) | `'label_scattering'` | `'xgb'` | `'he'` | None |
| `xgb_label_scattering_on_adult.yaml` | Adult (clas.) | `'label_scattering'` | `'xgb'` | `'he'` | None |
| `xgb_feature_gathering_dp_on_adult_by_he_eval.yaml` | Adult (clas.) | `'feature_gathering'` | `'xgb'` | `'he'` | `'he'` |
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ vertical:
use: True
dims: [4, 8]
algo: 'rf'
data_size_for_debug: 1500
data_size_for_debug: 2000
feature_subsample_ratio: 0.5
eval:
freq: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ vertical:
use: True
dims: [7, 14]
algo: 'rf'
data_size_for_debug: 1500
data_size_for_debug: 2000
feature_subsample_ratio: 0.8
eval:
freq: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ vertical:
use: True
dims: [ 7, 14 ]
algo: 'rf'
mode: 'label_based'
mode: 'label_scattering'
data_size_for_debug: 2000
feature_subsample_ratio: 0.4
protect_object: 'grad_and_hess'
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ train:
eta: 0.19
vertical:
use: True
mode: 'label_based'
mode: 'label_scattering'
dims: [6, 10]
algo: 'xgb'
data_size_for_debug: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ train:
eta: 0.5
vertical:
use: True
mode: 'label_based'
mode: 'label_scattering'
dims: [7, 14]
algo: 'xgb'
data_size_for_debug: 2000
Expand Down
5 changes: 4 additions & 1 deletion federatedscope/vertical_fl/tree_based_models/model/Tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ def cal_gain(self, split_idx, y, indicator):
else:
raise ValueError(f'Task type error: {self.task_type}')

def cal_gain_for_rf_label_base(self, node_num, split_idx, y, indicator):
def cal_gain_for_rf_label_scattering(self, node_num, split_idx, y,
indicator):

y_left_children_label_sum = np.sum(y[:split_idx])
y_right_children_label_sum = np.sum(y[split_idx:])
left_children_num = np.sum(indicator[:split_idx])
right_children_num = np.sum(indicator[split_idx:])

if self.task_type == 'classification':
if np.sum(indicator) == np.sum(y) or np.sum(y) == 0:
return 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def get_best_gain_from_msg(self, msg, tree_num=None, node_num=None):
continue

gain = self.model[
tree_num].cal_gain_for_rf_label_base(
tree_num].cal_gain_for_rf_label_scattering(
node_num, value_idx, label, indicator)
if gain < best_gain:
client_has_max_gain = client_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def train(self, training_info=None, tree_num=0, node_num=None):
# Start to build a tree
if node_num is None:
if training_info is not None and \
self.cfg.vertical.mode == 'order_based':
self.cfg.vertical.mode == 'feature_gathering':
self.merged_feature_order, self.extra_info = \
self._parse_training_info(training_info)
return self._compute_for_root(tree_num=tree_num)
Expand Down Expand Up @@ -274,7 +274,7 @@ def _compute_for_node(self, tree_num, node_num):
return self._compute_for_node(tree_num, node_num + 1)
# Calculate best gain
else:
if self.cfg.vertical.mode == 'order_based':
if self.cfg.vertical.mode == 'feature_gathering':
improved_flag, split_ref, _ = self._get_best_gain(
tree_num, node_num)
if improved_flag:
Expand All @@ -290,7 +290,7 @@ def _compute_for_node(self, tree_num, node_num):
else:
self._set_weight_and_status(tree_num, node_num)
return self._compute_for_node(tree_num, node_num + 1)
elif self.cfg.vertical.mode == 'label_based':
elif self.cfg.vertical.mode == 'label_scattering':
results = (self.model[tree_num][node_num].grad,
self.model[tree_num][node_num].hess,
self.model[tree_num][node_num].indicator, tree_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ def callback_func_for_data_sample(self, message: Message):
index=batch_index)
self.feature_order = feature_order_info['feature_order']

if self._cfg.vertical.mode == 'order_based':
if self._cfg.vertical.mode == 'feature_gathering':
training_info = feature_order_info
elif self._cfg.vertical.mode == 'label_based':
elif self._cfg.vertical.mode == 'label_scattering':
training_info = 'dummy_info'
else:
raise TypeError(f'The expected types of vertical.mode include '
f'["label_based", "order_based"], but got '
f'{self._cfg.vertical.mode}.')
raise TypeError(
f'The expected types of vertical.mode include '
f'["label_scattering", "feature_gathering"], but got '
f'{self._cfg.vertical.mode}.')

self.comm_manager.send(
Message(msg_type='training_info',
Expand All @@ -122,7 +123,7 @@ def start_a_new_training_round(self,
self.msg_buffer['train'].clear()
self.feature_order = feature_order_info['feature_order']
self.msg_buffer['train'][self.ID] = feature_order_info \
if self._cfg.vertical.mode == 'order_based' else 'dummy_info'
if self._cfg.vertical.mode == 'feature_gathering' else 'dummy_info'
self.state = tree_num
receiver = [
each for each in list(self.comm_manager.neighbors.keys())
Expand Down
12 changes: 6 additions & 6 deletions tests/test_tree_based_model_for_vfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def set_config_for_rf_base(self, cfg):

return backup_cfg

def set_config_for_rf_label_base(self, cfg):
def set_config_for_rf_label_scattering(self, cfg):
backup_cfg = cfg.clone()

import torch
Expand Down Expand Up @@ -270,7 +270,7 @@ def set_config_for_xgb_dp_too_large_noise(self, cfg):

return backup_cfg

def set_config_for_label_based_xgb(self, cfg):
def set_config_for_label_scattering_xgb(self, cfg):
backup_cfg = cfg.clone()

import torch
Expand Down Expand Up @@ -298,7 +298,7 @@ def set_config_for_label_based_xgb(self, cfg):
cfg.vertical.use = True
cfg.vertical.dims = [7, 14]
cfg.vertical.algo = 'xgb'
cfg.vertical.mode = 'label_based'
cfg.vertical.mode = 'label_scattering'
cfg.vertical.protect_object = 'grad_and_hess'
cfg.vertical.protect_method = 'he'
cfg.vertical.data_size_for_debug = 2000
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_RF_Base(self):

def test_RF_lable_Base(self):
init_cfg = global_cfg.clone()
backup_cfg = self.set_config_for_rf_label_base(init_cfg)
backup_cfg = self.set_config_for_rf_label_scattering(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)

Expand Down Expand Up @@ -533,9 +533,9 @@ def test_XGB_use_dp_too_large_noise(self):
print(test_results)
self.assertLess(test_results['server_global_eval']['test_acc'], 0.76)

def test_label_based_XGB(self):
def test_label_scattering_XGB(self):
init_cfg = global_cfg.clone()
backup_cfg = self.set_config_for_label_based_xgb(init_cfg)
backup_cfg = self.set_config_for_label_scattering_xgb(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)

Expand Down