Skip to content

Commit

Permalink
Changed decision_tree_new to decision_tree_optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
sandy9999 committed Jul 16, 2024
1 parent f7ed22e commit 1788401
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
121 changes: 121 additions & 0 deletions Compiler/decision_tree_new.py → Compiler/decision_tree_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,28 @@ def _(i):
def _(k):
self.train_layer(k)
return self.get_tree(h, self.label)

def train_with_testing(self, *test_set, output=False):
""" Train decision tree and test against test data.
:param y: binary labels (list or sint vector)
:param x: sample data (by attribute, list or
:py:obj:`~Compiler.types.Matrix`)
:param output: output tree after every level
:returns: tree
"""
for k in range(len(self.nids)):
self.train_layer(k)
tree = self.get_tree(k + 1, self.label)
if output:
output_decision_tree(tree)
test_decision_tree('train', tree, self.y, self.x,
n_threads=self.n_threads)
if test_set:
test_decision_tree('test', tree, *test_set,
n_threads=self.n_threads)
return tree

def get_tree(self, h, Label):
Layer = [None] * (h + 1)
Expand All @@ -449,6 +471,69 @@ def output_decision_tree(layers):
for j, x in enumerate(('NID', 'result')):
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))

def pick(bits, x):
if len(bits) == 1:
return bits[0] * x[0]
else:
try:
return x[0].dot_product(bits, x)
except:
return sum(aa * bb for aa, bb in zip(bits, x))

def run_decision_tree(layers, data):
""" Run decision tree against sample data.
:param layers: tree output by :py:class:`TreeTrainer`
:param data: sample data (:py:class:`~Compiler.types.Array`)
:returns: binary label
"""
h = len(layers) - 1
index = 1
for k, layer in enumerate(layers[:-1]):
assert len(layer) == 3
for x in layer:
assert len(x) <= 2 ** k
bits = layer[0].equal(index, k)
threshold = pick(bits, layer[2])
key_index = pick(bits, layer[1])
if key_index.is_clear:
key = data[key_index]
else:
key = pick(
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
child = 2 * key < threshold
index += child * 2 ** k
bits = layers[h][0].equal(index, h)
return pick(bits, layers[h][1])

def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
if time:
start_timer(100)
n = len(y)
x = x.transpose().reveal()
y = y.reveal()
guess = regint.Array(n)
truth = regint.Array(n)
correct = regint.Array(2)
parts = regint.Array(2)
layers = [[Array.create_from(util.reveal(x)) for x in layer]
for layer in layers]
@for_range_multithread(n_threads, 1, n)
def _(i):
guess[i] = run_decision_tree([[part[:] for part in layer]
for layer in layers], x[i]).reveal()
truth[i] = y[i].reveal()
@for_range(n)
def _(i):
parts[truth[i]] += 1
c = (guess[i].bit_xor(truth[i]).bit_not())
correct[truth[i]] += c
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
if time:
stop_timer(100)

class TreeClassifier:
""" Tree classification that uses
:py:class:`TreeTrainer` internally.
Expand Down Expand Up @@ -482,3 +567,39 @@ def fit(self, X, y, attr_types=None):

def output(self):
output_decision_tree(self.tree)

def fit_with_testing(self, X_train, y_train, X_test, y_test,
attr_types=None, output_trees=False, debug=False):
""" Train tree with accuracy output after every level.
:param X_train: training data with row-wise samples (sint/sfix matrix)
:param y_train: training binary labels (sint list/array)
:param X_test: testing data with row-wise samples (sint/sfix matrix)
:param y_test: testing binary labels (sint list/array)
:param attr_types: attributes types (list of 'b'/'c' for
binary/continuous; default is all continuous)
:param output_trees: output tree after every level
:param debug: output debugging information
"""
trainer = TreeTrainer(X_train.transpose(), y_train, self.max_depth,
attr_lengths=self.get_attr_lengths(attr_types),
n_threads=self.n_threads)
trainer.debug = debug
trainer.debug_gini = debug
trainer.debug_threading = debug > 1
self.tree = trainer.train_with_testing(y_test, X_test.transpose(),
output=output_trees)

def predict(self, X):
""" Use tree for prediction.
:param X: sample data with row-wise samples (sint/sfix matrix)
:returns: sint array
"""
res = sint.Array(len(X))
@for_range(len(X))
def _(i):
res[i] = run_decision_tree(self.tree, X[i])
return res
2 changes: 1 addition & 1 deletion Programs/Source/breast_tree.mpc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ y_test = sint.input_tensor_via(0, y_test)

sfix.set_precision_from_args(program)

from Compiler.decision_tree import TreeClassifier
from Compiler.decision_tree_optimized import TreeClassifier

tree = TreeClassifier(max_depth=5, n_threads=2)

Expand Down
2 changes: 1 addition & 1 deletion Programs/Source/custom_data_dt.mpc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ df_y = Array.create_from(df_y[0])
program.set_bit_length(32)
sfix.set_precision(16, 31)

from Compiler.decision_tree_new import TreeClassifier
from Compiler.decision_tree_optimized import TreeClassifier

tree = TreeClassifier(max_depth=int(program.args[3]), n_threads=4)

Expand Down

0 comments on commit 1788401

Please sign in to comment.