diff --git a/Compiler/decision_tree_optimized.py b/Compiler/decision_tree_optimized.py index 4e6fdbf2f..4ac9e9f39 100644 --- a/Compiler/decision_tree_optimized.py +++ b/Compiler/decision_tree_optimized.py @@ -1,7 +1,7 @@ from Compiler.types import * from Compiler.sorting import * from Compiler.library import * -from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne +from Compiler.decision_tree import get_type, PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne, output_decision_tree, pick, run_decision_tree, test_decision_tree from Compiler import util, oram from itertools import accumulate @@ -11,18 +11,6 @@ debug_split = False max_leaves = None -def get_type(x): - if isinstance(x, (Array, SubMultiArray)): - return x.value_type - elif isinstance(x, (tuple, list)): - x = x[0] + x[-1] - if util.is_constant(x): - return cint - else: - return type(x) - else: - return type(x) - def GetSortPerm(keys, *to_sort, n_bits=None, time=False): """ Compute and return secret shared permutation that stably sorts :param keys. @@ -36,7 +24,7 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False): res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x for x in to_sort) res = res.transpose() - return radix_sort_permutation_from_matrix(bs, res) + return radix_sort_from_matrix(bs, res) def ApplyPermutation(perm, x): res = Array.create_from(x) @@ -374,81 +362,6 @@ def get_tree(self, h, Label): def DecisionTreeTraining(x, y, h, binary=False): return TreeTrainer(x, y, h, binary=binary).train() -def output_decision_tree(layers): - """ Print decision tree output by :py:class:`TreeTrainer`. """ - - print_ln('full model %s', util.reveal(layers)) - for i, layer in enumerate(layers[:-1]): - print_ln('level %s:', i) - for j, x in enumerate(('NID', 'AID', 'Thr')): - print_ln(' %s: %s', x, util.reveal(layer[j])) - print_ln('leaves:') - for j, x in enumerate(('NID', 'result')): - print_ln(' %s: %s', x, util.reveal(layers[-1][j])) - -def pick(bits, x): - if len(bits) == 1: - return bits[0] * x[0] - else: - try: - return x[0].dot_product(bits, x) - except: - return sum(aa * bb for aa, bb in zip(bits, x)) - -def run_decision_tree(layers, data): - """ Run decision tree against sample data. - - :param layers: tree output by :py:class:`TreeTrainer` - :param data: sample data (:py:class:`~Compiler.types.Array`) - :returns: binary label - - """ - h = len(layers) - 1 - index = 1 - for k, layer in enumerate(layers[:-1]): - assert len(layer) == 3 - for x in layer: - assert len(x) <= 2 ** k - bits = layer[0].equal(index, k) - threshold = pick(bits, layer[2]) - key_index = pick(bits, layer[1]) - if key_index.is_clear: - key = data[key_index] - else: - key = pick( - oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) - child = 2 * key < threshold - index += child * 2 ** k - bits = layers[h][0].equal(index, h) - return pick(bits, layers[h][1]) - -def test_decision_tree(name, layers, y, x, n_threads=None, time=False): - if time: - start_timer(100) - n = len(y) - x = x.transpose().reveal() - y = y.reveal() - guess = regint.Array(n) - truth = regint.Array(n) - correct = regint.Array(2) - parts = regint.Array(2) - layers = [[Array.create_from(util.reveal(x)) for x in layer] - for layer in layers] - @for_range_multithread(n_threads, 1, n) - def _(i): - guess[i] = run_decision_tree([[part[:] for part in layer] - for layer in layers], x[i]).reveal() - truth[i] = y[i].reveal() - @for_range(n) - def _(i): - parts[truth[i]] += 1 - c = (guess[i].bit_xor(truth[i]).bit_not()) - correct[truth[i]] += c - print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, - sum(correct), n, correct[0], parts[0], correct[1], parts[1]) - if time: - stop_timer(100) - class TreeClassifier: """ Tree classification that uses :py:class:`TreeTrainer` internally. diff --git a/Compiler/sorting.py b/Compiler/sorting.py index e146baea4..39e512681 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -73,24 +73,4 @@ def _(): @library.else_ def _(): reveal_sort(h, D, reverse=True) - -def radix_sort_permutation_from_matrix(bs, D): - n = len(D) - for b in bs: - assert(len(b) == n) - B = types.sint.Matrix(n, 2) - h = types.Array.create_from(types.sint(types.regint.inc(n))) - @library.for_range(len(bs)) - def _(i): - b = bs[i] - B.set_column(0, 1 - b.get_vector()) - B.set_column(1, b.get_vector()) - c = types.Array.create_from(dest_comp(B)) - reveal_sort(c, h, reverse=False) - @library.if_e(i < len(bs) - 1) - def _(): - reveal_sort(h, bs[i + 1], reverse=True) - @library.else_ - def _(): - reveal_sort(h, D, reverse=True) return h