Skip to content

Commit

Permalink
label-based Tree model for vfl (#528)
Browse files Browse the repository at this point in the history
- add label based xgb
- add label protection method he
- add unittest and baseline yaml
- remove self-loop in comm_manager.neighbors
- distinguish msg_buffer for train and eval
  • Loading branch information
xieyxclack authored Feb 22, 2023
1 parent 4b9d70e commit 2efb172
Show file tree
Hide file tree
Showing 12 changed files with 553 additions and 125 deletions.
7 changes: 5 additions & 2 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,16 @@ def extend_fl_setting_cfg(cfg):
# ---------------------------------------------------------------------- #
cfg.vertical = CN()
cfg.vertical.use = False
cfg.vertical.mode = 'order_based' # ['order_based', 'label_based']
cfg.vertical.dims = [5, 10] # TODO: we need to explain dims
cfg.vertical.encryption = 'paillier'
cfg.vertical.key_size = 3072
cfg.vertical.algo = 'lr' # ['lr', 'xgb', 'gbdt', 'rf']
cfg.vertical.feature_subsample_ratio = 1.0
cfg.vertical.protect_object = '' # feature_order, TODO: add more
cfg.vertical.protect_method = '' # dp, op_boost
cfg.vertical.protect_object = '' # [feature_order, grad_and_hess]
cfg.vertical.protect_method = ''
# [dp, op_boost] for protect_object = feature_order
# [he] for protect_object = grad_and_hess
cfg.vertical.protect_args = []
# Default values for 'dp': {'bucket_num':100, 'epsilon':None}
# Default values for 'op_boost': {'algo':'global', 'lower_bound':1,
Expand Down
10 changes: 6 additions & 4 deletions federatedscope/vertical_fl/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from federatedscope.vertical_fl.trainer.trainer import VerticalTrainer
from federatedscope.vertical_fl.trainer.feature_order_protected_trainer \
import createFeatureOrderProtectedTrainer
from federatedscope.vertical_fl.trainer.random_forest_trainer import \
RandomForestTrainer
from federatedscope.vertical_fl.trainer.feature_order_protected_trainer \
import createFeatureOrderProtectedTrainer
from federatedscope.vertical_fl.trainer.label_protected_trainer import \
createLabelProtectedTrainer

__all__ = [
'VerticalTrainer', 'createFeatureOrderProtectedTrainer',
'RandomForestTrainer'
'VerticalTrainer', 'RandomForestTrainer',
'createFeatureOrderProtectedTrainer', 'createLabelProtectedTrainer'
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from federatedscope.vertical_fl.trainer.utils import bucketize


def createFeatureOrderProtectedTrainer(cls, model, data, device, config,
Expand Down Expand Up @@ -42,14 +43,6 @@ def get_feature_value(self, feature_idx, value_idx):

return self.split_value[feature_idx][value_idx]

def _bucketize(self, feature_order, bucket_size, bucket_num):
bucketized_feature_order = list()
for bucket_idx in range(bucket_num):
start = bucket_idx * bucket_size
end = min((bucket_idx + 1) * bucket_size, len(feature_order))
bucketized_feature_order.append(feature_order[start:end])
return bucketized_feature_order

def _processed_data(self, data):
min_value = np.min(data, axis=0)
max_value = np.max(data, axis=0)
Expand Down Expand Up @@ -168,7 +161,7 @@ def _protect_via_dp(self, raw_feature_order, data):
self.split_value = []

for feature_idx in range(len(raw_feature_order)):
bucketized_feature_order = self._bucketize(
bucketized_feature_order = bucketize(
raw_feature_order[feature_idx], bucket_size,
self.bucket_num)
noisy_bucketizd_feature_order = [
Expand Down
193 changes: 193 additions & 0 deletions federatedscope/vertical_fl/trainer/label_protected_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import numpy as np
from federatedscope.vertical_fl.trainer.utils import bucketize


def createLabelProtectedTrainer(cls, model, data, device, config, monitor):
class LabelProtectedTrainer(cls):
def __init__(self, model, data, device, config, monitor):
super(LabelProtectedTrainer,
self).__init__(model, data, device, config, monitor)

assert config.vertical.protect_method != '', \
"Please specify the method for protecting label"
args = config.vertical.protect_args[0] if len(
config.vertical.protect_args) > 0 else {}

if config.vertical.protect_method == 'he':
self.bucket_num = args.get('bucket_num', 100)
self.split_value = None
from federatedscope.vertical_fl.Paillier import \
abstract_paillier
keys = abstract_paillier.generate_paillier_keypair(
n_length=self.cfg.vertical.key_size)
self.public_key, self.private_key = keys
else:
raise ValueError(
f"The method {args['method']} is not provided")

def _bucketize(self, raw_feature_order, data):
bucket_size = int(
np.ceil(self.cfg.dataloader.batch_size / self.bucket_num))
split_position = list()
self.split_value = list()

for feature_idx in range(len(raw_feature_order)):
bucketized_feature_order = bucketize(
raw_feature_order[feature_idx], bucket_size,
self.bucket_num)

# Save split positions (instance number within buckets)
# We exclude the endpoints to avoid empty sub-trees
_split_position = list()
_split_value = dict()
accumu_num = 0
for bucket_idx, each_bucket in enumerate(
bucketized_feature_order):
instance_num = len(each_bucket)
# Skip the empty bucket
if instance_num != 0:
# Skip the endpoints
if bucket_idx != self.bucket_num - 1:
_split_position.append(accumu_num + instance_num)

# Save split values: average of min value of (j-1)-th
# bucket and max value of j-th bucket
max_value = data[bucketized_feature_order[bucket_idx]
[-1]][feature_idx]
min_value = data[bucketized_feature_order[bucket_idx]
[0]][feature_idx]
if accumu_num == 0:
_split_value[accumu_num +
instance_num] = max_value / 2.0
elif bucket_idx == self.bucket_num - 1:
_split_value[accumu_num] += min_value / 2.0
else:
_split_value[accumu_num] += min_value / 2.0
_split_value[accumu_num +
instance_num] = max_value / 2.0

accumu_num += instance_num

split_position.append(_split_position)
self.split_value.append(_split_value)

extra_info = {'split_position': split_position}

return {
'feature_order': raw_feature_order,
'extra_info': extra_info
}

def _get_feature_order_info(self, data):
num_of_feature = data.shape[1]
feature_order = [0] * num_of_feature
for i in range(num_of_feature):
feature_order[i] = data[:, i].argsort()
return self._bucketize(feature_order, data)

def get_feature_value(self, feature_idx, value_idx):
if not hasattr(self, 'split_value') or self.split_value is None:
return super().get_feature_value(feature_idx=feature_idx,
value_idx=value_idx)

return self.split_value[feature_idx][value_idx]

def get_abs_value_idx(self, feature_idx, value_idx):
if self.extra_info is not None and self.extra_info.get(
'split_position', None) is not None:
return self.extra_info['split_position'][feature_idx][
value_idx]
else:
return value_idx

def _compute_for_node(self, tree_num, node_num):

# All the nodes have been traversed
if node_num >= 2**self.model.max_depth - 1:
self._predict(tree_num)
return 'train_finish', None
elif self.model[tree_num][node_num].status == 'off':
return self._compute_for_node(tree_num, node_num + 1)
# The leaf node
elif node_num >= 2**(self.model.max_depth - 1) - 1:
self._set_weight_and_status(tree_num, node_num)
return self._compute_for_node(tree_num, node_num + 1)
# Calculate sum of grad and hess based on the encrypted results
else:
en_grad = [
self.public_key.encrypt(x)
for x in self.model[tree_num][node_num].grad
]
if self.model[tree_num][node_num].hess is not None:
en_hess = [
self.public_key.encrypt(x)
for x in self.model[tree_num][node_num].hess
]
else:
en_hess = None
results = (en_grad, en_hess, tree_num, node_num)

return 'call_for_local_gain', results

def _get_best_gain(self, tree_num, node_num, grad, hess):
# We can only get partial sum since the grad/hess is encrypted

if self.merged_feature_order is None:
self.merged_feature_order = self.client_feature_order
if self.extra_info is None:
self.extra_info = self.client_extra_info

feature_num = len(self.merged_feature_order)
split_position = self.extra_info.get('split_position')
sum_of_grad = list()
sum_of_hess = list()

for feature_idx in range(feature_num):
ordered_g, ordered_h = self._get_ordered_gh(
tree_num, node_num, feature_idx, grad, hess)
start_idx = 0
_sum_of_grad = list()
_sum_of_hess = list()
for value_idx in split_position[feature_idx]:
_sum_of_grad.append(np.sum(ordered_g[start_idx:value_idx]))
_sum_of_hess.append(np.sum(ordered_h[start_idx:value_idx]))
start_idx = value_idx
_sum_of_grad.append(np.sum(ordered_g[start_idx:]))
_sum_of_hess.append(np.sum(ordered_h[start_idx:]))
sum_of_grad.append(_sum_of_grad)
sum_of_hess.append(_sum_of_hess)

results = {'sum_of_grad': sum_of_grad, 'sum_of_hess': sum_of_hess}
return False, results, None

def get_best_gain_from_msg(self, msg, tree_num=None, node_num=None):
client_has_max_gain = None
best_gain = None
split_ref = {}
for client_id, local_gain in msg.items():
_, _, split_info = local_gain
sum_of_grad = split_info['sum_of_grad']
sum_of_hess = split_info['sum_of_hess']
for feature_idx in range(len(sum_of_grad)):
grad = [
self.private_key.decrypt(x)
for x in sum_of_grad[feature_idx]
]
hess = [
self.private_key.decrypt(x)
for x in sum_of_hess[feature_idx]
]

for value_idx in range(1, len(grad)):
gain = self.model[tree_num].cal_gain(
grad, hess, value_idx, node_num)

if best_gain is None or gain > best_gain:
client_has_max_gain = client_id
best_gain = gain
split_ref['feature_idx'] = feature_idx
split_ref['value_idx'] = value_idx

return best_gain, client_has_max_gain, split_ref

return LabelProtectedTrainer(model, data, device, config, monitor)
4 changes: 2 additions & 2 deletions federatedscope/vertical_fl/trainer/random_forest_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _get_ordered_indicator_and_label(self, tree_num, node_num,
ordered_label = self.model[tree_num][node_num].label[order]
return ordered_indicator, ordered_label

def _get_best_gain(self, tree_num, node_num):
def _get_best_gain(self, tree_num, node_num, grad=None, hess=None):
if self.cfg.criterion.type == 'CrossEntropyLoss':
default_gain = 1
elif 'Regression' in self.cfg.criterion.type:
Expand Down Expand Up @@ -80,4 +80,4 @@ def _get_best_gain(self, tree_num, node_num):
split_ref['feature_idx'] = feature_idx
split_ref['value_idx'] = value_idx

return best_gain < default_gain, split_ref
return best_gain < default_gain, split_ref, best_gain
Loading

0 comments on commit 2efb172

Please sign in to comment.