diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index bba25ddb0e..4049c58926 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -400,6 +400,8 @@ if(BUILD_CUML_CPP_LIBRARY) src/fil/fil.cu src/fil/infer.cu src/glm/glm.cu + src/genetic/genetic.cu + src/genetic/node.cu src/holtwinters/holtwinters.cu src/kmeans/kmeans.cu src/knn/knn.cu diff --git a/cpp/include/cuml/genetic/genetic.h b/cpp/include/cuml/genetic/genetic.h new file mode 100644 index 0000000000..a08b70d497 --- /dev/null +++ b/cpp/include/cuml/genetic/genetic.h @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "node.h" +#include "program.h" + +#include +#include +#include + +namespace cuml { +namespace genetic { + +/** Type of initialization of the member programs in the population */ +enum class init_method_t : uint32_t { + /** random nodes chosen, allowing shorter or asymmetrical trees */ + grow, + /** growing till a randomly chosen depth */ + full, + /** 50% of the population on `grow` and the rest with `full` */ + half_and_half, +}; // enum class init_method_t + +/** fitness metric types */ +enum class metric_t : uint32_t { + /** mean absolute error (regression-only) */ + mae, + /** mean squared error (regression-only) */ + mse, + /** root mean squared error (regression-only) */ + rmse, + /** pearson product-moment coefficient (regression and transformation) */ + pearson, + /** spearman's rank-order coefficient (regression and transformation) */ + spearman, + /** binary cross-entropy loss (classification-only) */ + logloss, +}; // enum class metric_t + +enum class transformer_t : uint32_t { + /** sigmoid function */ + sigmoid, +}; // enum class transformer_t + +/** + * @brief contains all the hyper-parameters for training + * + * @note Unless otherwise mentioned, all the parameters below are applicable to + * all of classification, regression and transformation. + */ +struct param { + /** number of programs in each generation */ + int population_size = 1000; + /** + * number of fittest programs to compare during correlation + * (transformation-only) + */ + int hall_of_fame = 100; + /** + * number of fittest programs to return from `hall_of_fame` top programs + * (transformation-only) + */ + int n_components = 10; + /** number of generations to evolve */ + int generations = 20; + /** + * number of programs that compete in the tournament to become part of next + * generation + */ + int tournament_size = 20; + /** metric threshold used for early stopping */ + float stopping_criteria = 0.0f; + /** minimum/maximum value for `constant` nodes */ + float const_range[2] = {-1.0f, 1.0f}; + /** minimum/maximum depth of programs after initialization */ + int init_depth[2] = {2, 6}; + /** initialization method */ + init_method_t init_method = init_method_t::half_and_half; + /** list of functions to choose from */ + std::vector function_set{node::type::add, node::type::mul, + node::type::div, node::type::sub}; + /** transformation function to class probabilities (classification-only) */ + transformer_t transformer = transformer_t::sigmoid; + /** fitness metric */ + metric_t metric = metric_t::mae; + /** penalization factor for large programs */ + float parsimony_coefficient = 0.001f; + /** crossover mutation probability of the tournament winner */ + float p_crossover = 0.9f; + /** subtree mutation probability of the tournament winner*/ + float p_subtree_mutation = 0.01f; + /** hoist mutation probability of the tournament winner */ + float p_hoist_mutation = 0.01f; + /** point mutation probabiilty of the tournament winner */ + float p_point_mutation = 0.01f; + /** point replace probabiility for point mutations */ + float p_point_replace = 0.05f; + /** subsampling factor */ + float max_samples = 1.0f; + /** list of feature names for generating syntax trees from the programs */ + std::vector feature_names; + ///@todo: feature_names + ///@todo: verbose + /** random seed used for RNG */ + uint64_t random_state = 0ull; + + /** Computes the probability of 'reproduction' */ + float p_reproduce() const; + + /** maximum possible number of programs */ + int max_programs() const; +}; // struct param + +} // namespace genetic +} // namespace cuml diff --git a/cpp/include/cuml/genetic/node.h b/cpp/include/cuml/genetic/node.h new file mode 100644 index 0000000000..6a657e86a4 --- /dev/null +++ b/cpp/include/cuml/genetic/node.h @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace cuml { +namespace genetic { + +/** + * @brief Represents a node in the syntax tree. + * + * @code{.cpp} + * // A non-terminal (aka function) node + * node func_node{node::type::sub}; + * // A constant node + * float const_value = 2.f; + * node const_node{const_value}; + * // A variable (aka feature) node + * node var_node{20}; + * @endcode + */ +struct node { + /** + * @brief All possible types of nodes. For simplicity, all the terminal and + * non-terminal types are clubbed together + */ + enum class type : uint32_t { + variable = 0, + constant, + + // note: keep the case statements in alphabetical order under each category + // of operators. + functions_begin, + // different binary function types follow + binary_begin = functions_begin, + add = binary_begin, + atan2, + div, + fdim, + max, + min, + mul, + pow, + sub, + binary_end = sub, // keep this to be the last binary function in the list + // different unary function types follow + unary_begin, + abs = unary_begin, + acos, + acosh, + asin, + asinh, + atan, + atanh, + cbrt, + cos, + cosh, + cube, + exp, + inv, + log, + neg, + rcbrt, + rsqrt, + sin, + sinh, + sq, + sqrt, + tan, + tanh, + unary_end = tanh, // keep this to be the last unary function in the list + functions_end = unary_end, + }; // enum type + + /** + * @brief Construct a function node + * + * @param[in] ft function type + */ + explicit node(type ft); + + /** + * @brief Construct a variable node + * + * @param[in] fid feature id that represents the variable + */ + explicit node(int fid); + + /** + * @brief Construct a constant node + * + * @param[in] val constant value + */ + explicit node(float val); + + /** + * @param[in] src source node to be copied + */ + explicit node(const node& src); + + /** + * @brief assignment operator + * + * @param[in] src source node to be copied + * + * @return current node reference + */ + node& operator=(const node& src); + + /** whether the current is either a variable or a constant */ + bool is_terminal() const; + + /** whether the current node is a function */ + bool is_nonterminal() const; + + /** Get the arity of the node. If it is a terminal, then a 0 is returned */ + int arity() const; + + /** + * @brief Helper method to get node type from input string + * + * @param[in] ntype node type in string. Possible strings correlate one-to-one + * with the enum values for `type` + * + * @return `type` + */ + static type from_str(const std::string& ntype); + + /** constant used to represent invalid feature id */ + static const int kInvalidFeatureId; + + /** node type */ + type t; + union { + /** + * if the node is `variable` type, then this is the column id to be used to + * fetch its value, from the input dataset + */ + int fid; + /** if the node is `constant` type, then this is the value of the node */ + float val; + } u; +}; // struct node + +} // namespace genetic +} // namespace cuml diff --git a/cpp/include/cuml/genetic/program.h b/cpp/include/cuml/genetic/program.h new file mode 100644 index 0000000000..ceec22a7e5 --- /dev/null +++ b/cpp/include/cuml/genetic/program.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "node.h" + +namespace cuml { +namespace genetic { + +/** + * @brief The main data structure to store the AST that represents a program + * in the current generation + */ +struct program { + /** + * the AST. It is stored in the reverse of DFS-right-child-first order. In + * other words, construct a regular AST in the form of depth-first, but + * instead of storing the left child first, store the right child and so on. + * Now take the resulting 1D array and reverse it. + * + * @note The pointed memory buffer is NOT owned by this class and further it + * is assumed to be a zero-copy (aka pinned memory) buffer, atleast in + * this initial version + */ + node* nodes; + /** total number of nodes in this AST */ + int len; + /** maximum depth of this AST */ + int depth; +}; // struct program + +} // namespace genetic +} // namespace cuml diff --git a/cpp/src/genetic/genetic.cu b/cpp/src/genetic/genetic.cu new file mode 100644 index 0000000000..fa9dba9987 --- /dev/null +++ b/cpp/src/genetic/genetic.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "genetic.cuh" +namespace cuml { +namespace genetic { + +float param::p_reproduce() const { return detail::p_reproduce(*this); } + +int param::max_programs() const { return detail::max_programs(*this); } + +} // namespace genetic +} // namespace cuml diff --git a/cpp/src/genetic/genetic.cuh b/cpp/src/genetic/genetic.cuh new file mode 100644 index 0000000000..67058c3677 --- /dev/null +++ b/cpp/src/genetic/genetic.cuh @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace cuml { +namespace genetic { +namespace detail { + +HDI float p_reproduce(const param& p) { + auto sum = p.p_crossover + p.p_subtree_mutation + p.p_hoist_mutation + + p.p_point_mutation; + auto ret = 1.f - sum; + return fmaxf(0.f, fminf(ret, 1.f)); +} + +HDI int max_programs(const param& p) { + // in the worst case every generation's top program ends up reproducing, + // thereby adding another program into the population + return p.population_size + p.generations; +} + +} // namespace detail +} // namespace genetic +} // namespace cuml diff --git a/cpp/src/genetic/node.cu b/cpp/src/genetic/node.cu new file mode 100644 index 0000000000..a1668998b7 --- /dev/null +++ b/cpp/src/genetic/node.cu @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "node.cuh" + +namespace cuml { +namespace genetic { + +const int node::kInvalidFeatureId = -1; + +node::node(node::type ft) : t(ft) { + ASSERT(is_nonterminal(), + "node: ctor with `type` argument expects functions type only!"); + u.fid = kInvalidFeatureId; +} + +node::node(int fid) : t(node::type::variable) { u.fid = fid; } + +node::node(float val) : t(node::type::constant) { u.val = val; } + +node::node(const node& src) : t(src.t), u(src.u) {} + +node& node::operator=(const node& src) { + t = src.t; + u = src.u; + return *this; +} + +bool node::is_terminal() const { return detail::is_terminal(t); } + +bool node::is_nonterminal() const { return detail::is_nonterminal(t); } + +int node::arity() const { return detail::arity(t); } + +#define CASE(str, val) \ + if (#val == str) return node::type::val +node::type node::from_str(const std::string& ntype) { + CASE(ntype, variable); + CASE(ntype, constant); + // note: keep the case statements in alphabetical order under each category of + // operators. + // binary operators + CASE(ntype, add); + CASE(ntype, atan2); + CASE(ntype, div); + CASE(ntype, fdim); + CASE(ntype, max); + CASE(ntype, min); + CASE(ntype, mul); + CASE(ntype, pow); + CASE(ntype, sub); + // unary operators + CASE(ntype, abs); + CASE(ntype, acos); + CASE(ntype, asin); + CASE(ntype, atan); + CASE(ntype, acosh); + CASE(ntype, asinh); + CASE(ntype, atanh); + CASE(ntype, cbrt); + CASE(ntype, cos); + CASE(ntype, cosh); + CASE(ntype, cube); + CASE(ntype, exp); + CASE(ntype, inv); + CASE(ntype, log); + CASE(ntype, neg); + CASE(ntype, rcbrt); + CASE(ntype, rsqrt); + CASE(ntype, sq); + CASE(ntype, sqrt); + CASE(ntype, sin); + CASE(ntype, sinh); + CASE(ntype, tan); + CASE(ntype, tanh); + ASSERT(false, "node::from_str: Bad type passed '%s'!", ntype.c_str()); +} +#undef CASE + +} // namespace genetic +} // namespace cuml diff --git a/cpp/src/genetic/node.cuh b/cpp/src/genetic/node.cuh new file mode 100644 index 0000000000..6763bb0c0f --- /dev/null +++ b/cpp/src/genetic/node.cuh @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace cuml { +namespace genetic { +namespace detail { + +static constexpr float MIN_VAL = 0.001f; + +HDI bool is_terminal(node::type t) { + return t == node::type::variable || t == node::type::constant; +} + +HDI bool is_nonterminal(node::type t) { return !is_terminal(t); } + +HDI int arity(node::type t) { + if (node::type::unary_begin <= t && t <= node::type::unary_end) { + return 1; + } + if (node::type::binary_begin <= t && t <= node::type::binary_end) { + return 2; + } + return 0; +} + +// `data` assumed to be stored in col-major format +DI float evaluate_node(const node& n, const float* data, size_t stride, + float inval, float inval1) { + if (n.t == node::type::constant) { + return n.u.val; + } else if (n.t == node::type::variable) { + return n.u.fid != node::kInvalidFeatureId ? data[n.u.fid * stride] : 0.f; + } else { + auto abs_inval = fabsf(inval), abs_inval1 = fabsf(inval1); + auto small = abs_inval < MIN_VAL; + // note: keep the case statements in alphabetical order under each category + // of operators. + switch (n.t) { + // binary operators + case node::type::add: + return inval + inval1; + case node::type::atan2: + return atan2f(inval, inval1); + case node::type::div: + return abs_inval1 < MIN_VAL ? 1.f : fdividef(inval, inval1); + case node::type::fdim: + return fdimf(inval, inval1); + case node::type::max: + return fmaxf(inval, inval1); + case node::type::min: + return fminf(inval, inval1); + case node::type::mul: + return inval * inval1; + case node::type::pow: + return powf(inval, inval1); + case node::type::sub: + return inval - inval1; + // unary operators + case node::type::abs: + return abs_inval; + case node::type::acos: + return acosf(inval); + case node::type::acosh: + return acoshf(inval); + case node::type::asin: + return asinf(inval); + case node::type::asinh: + return asinhf(inval); + case node::type::atan: + return atanf(inval); + case node::type::atanh: + return atanhf(inval); + case node::type::cbrt: + return cbrtf(inval); + case node::type::cos: + return cosf(inval); + case node::type::cosh: + return coshf(inval); + case node::type::cube: + return inval * inval * inval; + case node::type::exp: + return expf(inval); + case node::type::inv: + return abs_inval < MIN_VAL ? 0.f : 1.f / inval; + case node::type::log: + return abs_inval < MIN_VAL ? 0.f : logf(abs_inval); + case node::type::neg: + return -inval; + case node::type::rcbrt: + return rcbrtf(inval); + case node::type::rsqrt: + return rsqrtf(abs_inval); + case node::type::sin: + return sinf(inval); + case node::type::sinh: + return sinhf(inval); + case node::type::sq: + return inval * inval; + case node::type::sqrt: + return sqrtf(abs_inval); + case node::type::tan: + return tanf(inval); + case node::type::tanh: + return tanhf(inval); + // shouldn't reach here! + default: + return 0.f; + }; + } +} + +} // namespace detail +} // namespace genetic +} // namespace cuml diff --git a/cpp/src/genetic/reg_stack.cuh b/cpp/src/genetic/reg_stack.cuh new file mode 100644 index 0000000000..e0f6762c00 --- /dev/null +++ b/cpp/src/genetic/reg_stack.cuh @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cuml { +namespace genetic { + +/** + * @brief A fixed capacity stack on device currently used for AST evaluation + * + * The idea is to use only the registers to store the elements of the stack, + * thereby achieving the best performance. + * + * @tparam DataT data type of the stack elements + * @tparam MaxSize max capacity of the stack + */ +template +struct stack { + explicit HDI stack() : elements_(0) { +#pragma unroll + for (int i = 0; i < MaxSize; ++i) { + regs_[i] = DataT(0); + } + } + + /** Checks if the stack is empty */ + HDI bool empty() const { return elements_ == 0; } + + /** Current number of elements in the stack */ + HDI int size() const { return elements_; } + + /** Checks if the number of elements in the stack equal its capacity */ + HDI bool full() const { return elements_ == MaxSize; } + + /** + * @brief Pushes the input element to the top of the stack + * + * @param[in] val input element to be pushed + * + * @note If called when the stack is already full, then it is a no-op! To keep + * the device-side logic simpler, it has been designed this way. Trying + * to push more than `MaxSize` elements leads to all sorts of incorrect + * behavior. + */ + HDI void push(DataT val) { +#pragma unroll + for (int i = 0; i < MaxSize; ++i) { + if (elements_ == i) { + regs_[i] = val; + ++elements_; + } + } + } + + /** + * @brief Pops the top element from the stack + * + * @return pops the element and returns it, if already reached bottom, then it + * returns zero. + * + * @note If called when the stack is already empty, then it just returns a + * value of zero! To keep the device-side logic simpler, it has been + * designed this way. Trying to pop beyond the bottom of the stack leads + * to all sorts of incorrect behavior. + */ + HDI DataT pop() { +#pragma unroll + for (int i = 0; i < MaxSize; ++i) { + if (elements_ - 1 == i) { + --elements_; + return regs_[i]; + } + } + // shouldn't reach here! + return DataT(0); + } + + private: + int elements_; + DataT regs_[MaxSize]; +}; // struct stack + +} // namespace genetic +} // namespace cuml diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 9130c5e450..e3b8513178 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -51,13 +51,15 @@ if(BUILD_CUML_TESTS) sg/decisiontree_batchedlevel_algo.cu sg/decisiontree_batchedlevel_unittest.cu sg/fil_test.cu - sg/multi_sum_test.cu + sg/genetic/node_test.cpp + sg/genetic/param_test.cu sg/handle_test.cu sg/holtwinters_test.cu sg/kmeans_test.cu sg/knn_test.cu sg/lars_test.cu sg/logger.cpp + sg/multi_sum_test.cu sg/nvtx_test.cpp sg/ols.cu sg/pca_test.cu diff --git a/cpp/test/sg/genetic/node_test.cpp b/cpp/test/sg/genetic/node_test.cpp new file mode 100644 index 0000000000..127623ca6d --- /dev/null +++ b/cpp/test/sg/genetic/node_test.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace cuml { +namespace genetic { + +TEST(Genetic, node_test) { + node feature(1); + ASSERT_EQ(feature.t, node::type::variable); + ASSERT_TRUE(feature.is_terminal()); + ASSERT_FALSE(feature.is_nonterminal()); + ASSERT_EQ(feature.arity(), 0); + ASSERT_EQ(feature.u.fid, 1); + + node constval(0.1f); + ASSERT_EQ(constval.t, node::type::constant); + ASSERT_TRUE(constval.is_terminal()); + ASSERT_FALSE(constval.is_nonterminal()); + ASSERT_EQ(constval.arity(), 0); + ASSERT_EQ(constval.u.val, 0.1f); + + node func1(node::type::add); + ASSERT_EQ(func1.t, node::type::add); + ASSERT_FALSE(func1.is_terminal()); + ASSERT_TRUE(func1.is_nonterminal()); + ASSERT_EQ(func1.arity(), 2); + ASSERT_EQ(func1.u.fid, node::kInvalidFeatureId); + + node func2(node::type::cosh); + ASSERT_EQ(func2.t, node::type::cosh); + ASSERT_FALSE(func2.is_terminal()); + ASSERT_TRUE(func2.is_nonterminal()); + ASSERT_EQ(func2.arity(), 1); + ASSERT_EQ(func2.u.fid, node::kInvalidFeatureId); +} + +TEST(Genetic, node_from_str) { + ASSERT_EQ(node::from_str("add"), node::type::add); + ASSERT_EQ(node::from_str("tanh"), node::type::tanh); + ASSERT_THROW(node::from_str("bad_type"), raft::exception); +} + +TEST(Genetic, node_constants) { ASSERT_EQ(node::kInvalidFeatureId, -1); } + +} // namespace genetic +} // namespace cuml diff --git a/cpp/test/sg/genetic/param_test.cu b/cpp/test/sg/genetic/param_test.cu new file mode 100644 index 0000000000..9507e2bdb7 --- /dev/null +++ b/cpp/test/sg/genetic/param_test.cu @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "../../prims/test_utils.h" + +namespace cuml { +namespace genetic { + +TEST(Genetic, ParamTest) { + param p; + ASSERT_EQ(p.population_size, 1000); + ASSERT_EQ(p.hall_of_fame, 100); + ASSERT_EQ(p.n_components, 10); + ASSERT_EQ(p.generations, 20); + ASSERT_EQ(p.tournament_size, 20); + ASSERT_EQ(p.stopping_criteria, 0.0f); + ASSERT_EQ(p.const_range[0], -1.0f); + ASSERT_EQ(p.const_range[1], 1.0f); + ASSERT_EQ(p.init_depth[0], 2); + ASSERT_EQ(p.init_depth[1], 6); + ASSERT_EQ(p.init_method, init_method_t::half_and_half); + ASSERT_EQ(p.function_set.size(), 4u); + ASSERT_EQ(p.function_set[0], node::type::add); + ASSERT_EQ(p.function_set[1], node::type::mul); + ASSERT_EQ(p.function_set[2], node::type::div); + ASSERT_EQ(p.function_set[3], node::type::sub); + ASSERT_EQ(p.transformer, transformer_t::sigmoid); + ASSERT_EQ(p.metric, metric_t::mae); + ASSERT_EQ(p.parsimony_coefficient, 0.001f); + ASSERT_EQ(p.p_crossover, 0.9f); + ASSERT_EQ(p.p_subtree_mutation, 0.01f); + ASSERT_EQ(p.p_hoist_mutation, 0.01f); + ASSERT_EQ(p.p_point_mutation, 0.01f); + ASSERT_EQ(p.p_point_replace, 0.05f); + ASSERT_EQ(p.max_samples, 1.0f); + ASSERT_EQ(p.feature_names.size(), 0u); + ASSERT_EQ(p.random_state, 0ull); +} + +TEST(Genetic, p_reproduce) { + param p; + auto ret = p.p_reproduce(); + ASSERT_TRUE( + raft::match(p.p_reproduce(), 0.07f, raft::CompareApprox(0.0001f))); +} + +} // namespace genetic +} // namespace cuml