From b039520fd7f35799876ea8a621f3b651387817b8 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Sat, 6 Jul 2024 11:55:53 -0400 Subject: [PATCH] Decrease diff Signed-off-by: Adam Li --- sklearn/tree/_partitioner.pxd | 13 +- sklearn/tree/_partitioner.pyx | 28 ++-- sklearn/tree/_sort.pxd | 13 -- sklearn/tree/_sort.pyx | 123 ---------------- sklearn/tree/_splitter.pxd | 17 +++ sklearn/tree/_splitter.pyx | 265 +++++++++++++++++++++++++++++++++- sklearn/tree/meson.build | 3 - 7 files changed, 304 insertions(+), 158 deletions(-) delete mode 100644 sklearn/tree/_sort.pxd delete mode 100644 sklearn/tree/_sort.pyx diff --git a/sklearn/tree/_partitioner.pxd b/sklearn/tree/_partitioner.pxd index fd4e7c721424b..39cb006f551b7 100644 --- a/sklearn/tree/_partitioner.pxd +++ b/sklearn/tree/_partitioner.pxd @@ -10,7 +10,7 @@ cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 # functions. The alternative would have been to use inheritance-based polymorphism # but it would have resulted in a ~10% overall tree fitting performance # degradation caused by the overhead frequent virtual method lookups. -#ctypedef fused Partitioner: +# ctypedef fused Partitioner: # DensePartitioner # SparsePartitioner @@ -67,8 +67,15 @@ cdef class Partitioner: float32_t* min_feature_value_out, float32_t* max_feature_value_out, ) noexcept nogil - inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil - inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil + inline void next_p( + self, + intp_t* p_prev, + intp_t* p + ) noexcept nogil + inline intp_t partition_samples( + self, + float64_t current_threshold + ) noexcept nogil inline void partition_samples_final( self, intp_t best_pos, diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index 7f21e716272f4..9276e9eb0bab8 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -6,14 +6,14 @@ from scipy.sparse import issparse import numpy as np -from ._sort cimport sort, sparse_swap, swap, FEATURE_THRESHOLD +from ._splitter cimport sort, sparse_swap, FEATURE_THRESHOLD cdef class Partitioner: cdef: inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: self._init_node_split(self, start, end) - + inline void sort_samples_and_feature_values( self, intp_t current_feature @@ -33,7 +33,7 @@ cdef class Partitioner: inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil: return self._partition_samples(self, current_threshold) - + inline void partition_samples_final( self, intp_t best_pos, @@ -536,22 +536,22 @@ cdef inline void sparse_extract_nnz(SparsePartitioner self, intp_t feature) noex if ((1 - self.is_samples_sorted) * n_samples * log(n_samples) + n_samples * log(n_indices) < EXTRACT_NNZ_SWITCH * n_indices): extract_nnz_binary_search(X_indices, X_data, - indptr_start, indptr_end, - samples, self.start, self.end, - index_to_samples, - feature_values, - &self.end_negative, &self.start_positive, - sorted_samples, &self.is_samples_sorted) + indptr_start, indptr_end, + samples, self.start, self.end, + index_to_samples, + feature_values, + &self.end_negative, &self.start_positive, + sorted_samples, &self.is_samples_sorted) # Using an index to samples technique to extract non zero values # index_to_samples is a mapping from X_indices to samples else: extract_nnz_index_to_samples(X_indices, X_data, - indptr_start, indptr_end, - samples, self.start, self.end, - index_to_samples, - feature_values, - &self.end_negative, &self.start_positive) + indptr_start, indptr_end, + samples, self.start, self.end, + index_to_samples, + feature_values, + &self.end_negative, &self.start_positive) cdef int compare_SIZE_t(const void* a, const void* b) noexcept nogil: diff --git a/sklearn/tree/_sort.pxd b/sklearn/tree/_sort.pxd deleted file mode 100644 index 5a0b3d20d0f35..0000000000000 --- a/sklearn/tree/_sort.pxd +++ /dev/null @@ -1,13 +0,0 @@ -from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint32_t - - -# Mitigate precision differences between 32 bit and 64 bit -cdef float32_t FEATURE_THRESHOLD = 1e-7 - -# Sort n-element arrays pointed to by feature_values and samples, simultaneously, -# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil - -cdef void swap(float32_t* feature_values, intp_t* samples, intp_t i, intp_t j) noexcept nogil -cdef void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, - intp_t pos_1, intp_t pos_2) noexcept nogil diff --git a/sklearn/tree/_sort.pyx b/sklearn/tree/_sort.pyx deleted file mode 100644 index 9a9db6edf6e00..0000000000000 --- a/sklearn/tree/_sort.pyx +++ /dev/null @@ -1,123 +0,0 @@ -from ._utils cimport log - - -cdef inline void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, - intp_t pos_1, intp_t pos_2) noexcept nogil: - """Swap sample pos_1 and pos_2 preserving sparse invariant.""" - samples[pos_1], samples[pos_2] = samples[pos_2], samples[pos_1] - index_to_samples[samples[pos_1]] = pos_1 - index_to_samples[samples[pos_2]] = pos_2 - - -# Sort n-element arrays pointed to by feature_values and samples, simultaneously, -# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: - if n == 0: - return - cdef intp_t maxd = 2 * log(n) - introsort(feature_values, samples, n, maxd) - - -# Introsort with median of 3 pivot selection and 3-way partition function -# (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(float32_t* feature_values, intp_t *samples, - intp_t n, intp_t maxd) noexcept nogil: - cdef float32_t pivot - cdef intp_t i, l, r - - while n > 1: - if maxd <= 0: # max depth limit exceeded ("gone quadratic") - heapsort(feature_values, samples, n) - return - maxd -= 1 - - pivot = median3(feature_values, n) - - # Three-way partition. - i = l = 0 - r = n - while i < r: - if feature_values[i] < pivot: - swap(feature_values, samples, i, l) - i += 1 - l += 1 - elif feature_values[i] > pivot: - r -= 1 - swap(feature_values, samples, i, r) - else: - i += 1 - - introsort(feature_values, samples, l, maxd) - feature_values += r - samples += r - n -= r - - -cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: - cdef intp_t start, end - - # heapify - start = (n - 2) / 2 - end = n - while True: - sift_down(feature_values, samples, start, end) - if start == 0: - break - start -= 1 - - # sort by shrinking the heap, putting the max element immediately after it - end = n - 1 - while end > 0: - swap(feature_values, samples, 0, end) - sift_down(feature_values, samples, 0, end) - end = end - 1 - - -cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: - # Median of three pivot selection, after Bentley and McIlroy (1993). - # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] - if a < b: - if b < c: - return b - elif a < c: - return c - else: - return a - elif b < c: - if a < c: - return a - else: - return c - else: - return b - - -cdef inline void swap(float32_t* feature_values, intp_t* samples, - intp_t i, intp_t j) noexcept nogil: - # Helper for sort - feature_values[i], feature_values[j] = feature_values[j], feature_values[i] - samples[i], samples[j] = samples[j], samples[i] - - -cdef inline void sift_down(float32_t* feature_values, intp_t* samples, - intp_t start, intp_t end) noexcept nogil: - # Restore heap order in feature_values[start:end] by moving the max element to start. - cdef intp_t child, maxind, root - - root = start - while True: - child = root * 2 + 1 - - # find max of root, left child, right child - maxind = root - if child < end and feature_values[maxind] < feature_values[child]: - maxind = child - if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: - maxind = child + 1 - - if maxind == root: - break - else: - swap(feature_values, samples, root, maxind) - root = maxind diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index a55cf2786cbef..968d90e0dc2c9 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -10,6 +10,10 @@ from ._tree cimport ParentInfo from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint32_t +# Mitigate precision differences between 32 bit and 64 bit +cdef float32_t FEATURE_THRESHOLD = 1e-7 + + cdef struct SplitRecord: # Data to track sample split intp_t feature # Which feature to split on. @@ -132,3 +136,16 @@ cdef void shift_missing_values_to_left_if_required( intp_t[::1] samples, intp_t end, ) noexcept nogil + + +# Sort n-element arrays pointed to by feature_values and samples, simultaneously, +# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). +cdef void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil + +cdef void swap(float32_t* feature_values, intp_t* samples, intp_t i, intp_t j) noexcept nogil +cdef void sparse_swap( + intp_t[::1] index_to_samples, + intp_t[::1] samples, + intp_t pos_1, + intp_t pos_2 +) noexcept nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index eb08ec34ea2a2..48448978c061b 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,10 +1,11 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +from libc.stdlib cimport qsort from libc.string cimport memcpy from ._criterion cimport Criterion -from ._sort cimport FEATURE_THRESHOLD +from ._utils cimport log from ._utils cimport rand_int from ._utils cimport rand_uniform from ._utils cimport RAND_R_MAX @@ -15,7 +16,6 @@ import numpy as np cdef float64_t INFINITY = np.inf - cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil: self.impurity_left = INFINITY self.impurity_right = INFINITY @@ -663,6 +663,120 @@ cdef inline intp_t node_split_best( return 0 +# Sort n-element arrays pointed to by feature_values and samples, simultaneously, +# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). +cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + if n == 0: + return + cdef intp_t maxd = 2 * log(n) + introsort(feature_values, samples, n, maxd) + + +cdef inline void swap(float32_t* feature_values, intp_t* samples, + intp_t i, intp_t j) noexcept nogil: + # Helper for sort + feature_values[i], feature_values[j] = feature_values[j], feature_values[i] + samples[i], samples[j] = samples[j], samples[i] + + +cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: + # Median of three pivot selection, after Bentley and McIlroy (1993). + # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. + cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + if a < b: + if b < c: + return b + elif a < c: + return c + else: + return a + elif b < c: + if a < c: + return a + else: + return c + else: + return b + + +# Introsort with median of 3 pivot selection and 3-way partition function +# (robust to repeated elements, e.g. lots of zero features). +cdef void introsort(float32_t* feature_values, intp_t *samples, + intp_t n, intp_t maxd) noexcept nogil: + cdef float32_t pivot + cdef intp_t i, l, r + + while n > 1: + if maxd <= 0: # max depth limit exceeded ("gone quadratic") + heapsort(feature_values, samples, n) + return + maxd -= 1 + + pivot = median3(feature_values, n) + + # Three-way partition. + i = l = 0 + r = n + while i < r: + if feature_values[i] < pivot: + swap(feature_values, samples, i, l) + i += 1 + l += 1 + elif feature_values[i] > pivot: + r -= 1 + swap(feature_values, samples, i, r) + else: + i += 1 + + introsort(feature_values, samples, l, maxd) + feature_values += r + samples += r + n -= r + + +cdef inline void sift_down(float32_t* feature_values, intp_t* samples, + intp_t start, intp_t end) noexcept nogil: + # Restore heap order in feature_values[start:end] by moving the max element to start. + cdef intp_t child, maxind, root + + root = start + while True: + child = root * 2 + 1 + + # find max of root, left child, right child + maxind = root + if child < end and feature_values[maxind] < feature_values[child]: + maxind = child + if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: + maxind = child + 1 + + if maxind == root: + break + else: + swap(feature_values, samples, root, maxind) + root = maxind + + +cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + cdef intp_t start, end + + # heapify + start = (n - 2) / 2 + end = n + while True: + sift_down(feature_values, samples, start, end) + if start == 0: + break + start -= 1 + + # sort by shrinking the heap, putting the max element immediately after it + end = n - 1 + while end > 0: + swap(feature_values, samples, 0, end) + sift_down(feature_values, samples, 0, end) + end = end - 1 + + cdef inline int node_split_random( Splitter splitter, Partitioner partitioner, @@ -960,3 +1074,150 @@ cdef class RandomSparseSplitter(Splitter): split, parent_record, ) + + +cdef int compare_SIZE_t(const void* a, const void* b) noexcept nogil: + """Comparison function for sort. + This must return an `int` as it is used by stdlib's qsort, which expects + an `int` return value. + """ + return ((a)[0] - (b)[0]) + + +cdef inline void binary_search(const int32_t[::1] sorted_array, + int32_t start, int32_t end, + intp_t value, intp_t* index, + int32_t* new_start) noexcept nogil: + """Return the index of value in the sorted array. + If not found, return -1. new_start is the last pivot + 1 + """ + cdef int32_t pivot + index[0] = -1 + while start < end: + pivot = start + (end - start) / 2 + + if sorted_array[pivot] == value: + index[0] = pivot + start = pivot + 1 + break + + if sorted_array[pivot] < value: + start = pivot + 1 + else: + end = pivot + new_start[0] = start + + +cdef inline void extract_nnz_index_to_samples(const int32_t[::1] X_indices, + const float32_t[::1] X_data, + int32_t indptr_start, + int32_t indptr_end, + intp_t[::1] samples, + intp_t start, + intp_t end, + intp_t[::1] index_to_samples, + float32_t[::1] feature_values, + intp_t* end_negative, + intp_t* start_positive) noexcept nogil: + """Extract and partition values for a feature using index_to_samples. + Complexity is O(indptr_end - indptr_start). + """ + cdef int32_t k + cdef intp_t index + cdef intp_t end_negative_ = start + cdef intp_t start_positive_ = end + + for k in range(indptr_start, indptr_end): + if start <= index_to_samples[X_indices[k]] < end: + if X_data[k] > 0: + start_positive_ -= 1 + feature_values[start_positive_] = X_data[k] + index = index_to_samples[X_indices[k]] + sparse_swap(index_to_samples, samples, index, start_positive_) + + elif X_data[k] < 0: + feature_values[end_negative_] = X_data[k] + index = index_to_samples[X_indices[k]] + sparse_swap(index_to_samples, samples, index, end_negative_) + end_negative_ += 1 + + # Returned values + end_negative[0] = end_negative_ + start_positive[0] = start_positive_ + + +cdef inline void extract_nnz_binary_search(const int32_t[::1] X_indices, + const float32_t[::1] X_data, + int32_t indptr_start, + int32_t indptr_end, + intp_t[::1] samples, + intp_t start, + intp_t end, + intp_t[::1] index_to_samples, + float32_t[::1] feature_values, + intp_t* end_negative, + intp_t* start_positive, + intp_t[::1] sorted_samples, + bint* is_samples_sorted) noexcept nogil: + """Extract and partition values for a given feature using binary search. + If n_samples = end - start and n_indices = indptr_end - indptr_start, + the complexity is + O((1 - is_samples_sorted[0]) * n_samples * log(n_samples) + + n_samples * log(n_indices)). + """ + cdef intp_t n_samples + + if not is_samples_sorted[0]: + n_samples = end - start + memcpy(&sorted_samples[start], &samples[start], + n_samples * sizeof(intp_t)) + qsort(&sorted_samples[start], n_samples, sizeof(intp_t), + compare_SIZE_t) + is_samples_sorted[0] = 1 + + while (indptr_start < indptr_end and + sorted_samples[start] > X_indices[indptr_start]): + indptr_start += 1 + + while (indptr_start < indptr_end and + sorted_samples[end - 1] < X_indices[indptr_end - 1]): + indptr_end -= 1 + + cdef intp_t p = start + cdef intp_t index + cdef intp_t k + cdef intp_t end_negative_ = start + cdef intp_t start_positive_ = end + + while (p < end and indptr_start < indptr_end): + # Find index of sorted_samples[p] in X_indices + binary_search(X_indices, indptr_start, indptr_end, + sorted_samples[p], &k, &indptr_start) + + if k != -1: + # If k != -1, we have found a non zero value + + if X_data[k] > 0: + start_positive_ -= 1 + feature_values[start_positive_] = X_data[k] + index = index_to_samples[X_indices[k]] + sparse_swap(index_to_samples, samples, index, start_positive_) + + elif X_data[k] < 0: + feature_values[end_negative_] = X_data[k] + index = index_to_samples[X_indices[k]] + sparse_swap(index_to_samples, samples, index, end_negative_) + end_negative_ += 1 + p += 1 + + # Returned values + end_negative[0] = end_negative_ + start_positive[0] = start_positive_ + + +cdef inline void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, + intp_t pos_1, intp_t pos_2) noexcept nogil: + """Swap sample pos_1 and pos_2 preserving sparse invariant.""" + samples[pos_1], samples[pos_2] = samples[pos_2], samples[pos_1] + index_to_samples[samples[pos_1]] = pos_1 + index_to_samples[samples[pos_2]] = pos_2 diff --git a/sklearn/tree/meson.build b/sklearn/tree/meson.build index 8ed696cd2481e..04d1d5f353d02 100644 --- a/sklearn/tree/meson.build +++ b/sklearn/tree/meson.build @@ -2,9 +2,6 @@ tree_extension_metadata = { '_tree': {'sources': ['_tree.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, - '_sort': - {'sources': ['_sort.pyx'], - 'override_options': ['cython_language=cpp', 'optimization=3']}, '_splitter': {'sources': ['_splitter.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']},