diff --git a/federatedscope/vertical_fl/model/Tree.py b/federatedscope/vertical_fl/model/Tree.py index 9f9a18789..bf8ccc65c 100644 --- a/federatedscope/vertical_fl/model/Tree.py +++ b/federatedscope/vertical_fl/model/Tree.py @@ -41,6 +41,14 @@ def split_childern(self, data, feature_value): def set_status(self, node_num, status='off'): self.tree[node_num].status = status + def check_empty_child(self, node_num, split_idx, order): + indicator = self.tree[node_num].indicator[order] + if np.sum(indicator[:split_idx]) == 0 or np.sum( + indicator[split_idx:]) == 0: + return True + + return False + class XGBTree(Tree): def __init__(self, max_depth, lambda_, gamma): diff --git a/federatedscope/vertical_fl/trainer/label_protected_trainer.py b/federatedscope/vertical_fl/trainer/label_protected_trainer.py index 2ba968ea2..569c5bfc1 100644 --- a/federatedscope/vertical_fl/trainer/label_protected_trainer.py +++ b/federatedscope/vertical_fl/trainer/label_protected_trainer.py @@ -95,8 +95,8 @@ def get_feature_value(self, feature_idx, value_idx): def get_abs_value_idx(self, feature_idx, value_idx): if self.extra_info is not None and self.extra_info.get( 'split_position', None) is not None: - return self.extra_info['split_position'][feature_idx][ - value_idx] + return self.extra_info['split_position'][feature_idx][value_idx + - 1] else: return value_idx diff --git a/federatedscope/vertical_fl/trainer/random_forest_trainer.py b/federatedscope/vertical_fl/trainer/random_forest_trainer.py index be60a12d5..558b34d65 100644 --- a/federatedscope/vertical_fl/trainer/random_forest_trainer.py +++ b/federatedscope/vertical_fl/trainer/random_forest_trainer.py @@ -57,14 +57,6 @@ def _get_best_gain(self, tree_num, node_num, grad=None, hess=None): if split_position is None: # The left/right sub-tree cannot be empty split_position = activate_idx[:, 1:] - else: - active_split_position = list() - for idx, each_split_position in enumerate(split_position): - active_split_position.append([ - x for x in each_split_position - if x in activate_idx[idx, 1:] - ]) - split_position = active_split_position for feature_idx in range(feature_num): if len(split_position[feature_idx]) == 0: @@ -72,7 +64,11 @@ def _get_best_gain(self, tree_num, node_num, grad=None, hess=None): ordered_indicator, ordered_label = \ self._get_ordered_indicator_and_label( tree_num, node_num, feature_idx) + order = self.merged_feature_order[feature_idx] for value_idx in split_position[feature_idx]: + if self.model[tree_num].check_empty_child( + node_num, value_idx, order): + continue gain = self.model[tree_num].cal_gain(value_idx, ordered_label, ordered_indicator) if gain < best_gain: diff --git a/federatedscope/vertical_fl/trainer/trainer.py b/federatedscope/vertical_fl/trainer/trainer.py index 5bc263999..8eaa7a620 100644 --- a/federatedscope/vertical_fl/trainer/trainer.py +++ b/federatedscope/vertical_fl/trainer/trainer.py @@ -194,19 +194,15 @@ def _get_best_gain(self, tree_num, node_num, grad=None, hess=None): if split_position is None: # The left/right sub-tree cannot be empty split_position = activate_idx[:, 1:] - else: - active_split_position = list() - for idx, each_split_position in enumerate(split_position): - active_split_position.append([ - x for x in each_split_position - if x in activate_idx[idx, 1:] - ]) - split_position = active_split_position for feature_idx in range(feature_num): ordered_g, ordered_h = self._get_ordered_gh( tree_num, node_num, feature_idx, grad, hess) + order = self.merged_feature_order[feature_idx] for value_idx in split_position[feature_idx]: + if self.model[tree_num].check_empty_child( + node_num, value_idx, order): + continue gain = self.model[tree_num].cal_gain(ordered_g, ordered_h, value_idx, node_num)