diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c8a8f22599..1d6cb6f03b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -268,6 +268,7 @@ if(BUILD_CUML_CPP_LIBRARY) src/fil/infer.cu src/glm/glm.cu src/genetic/genetic.cu + src/genetic/program.cu src/genetic/node.cu src/hdbscan/hdbscan.cu src/hdbscan/condensed_hierarchy.cu diff --git a/cpp/examples/CMakeLists.txt b/cpp/examples/CMakeLists.txt index 6b9ab42b8d..526500a649 100644 --- a/cpp/examples/CMakeLists.txt +++ b/cpp/examples/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright (c) 2019, NVIDIA CORPORATION. +# Copyright (c) 2019-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +16,4 @@ add_subdirectory(kmeans) add_subdirectory(dbscan) +add_subdirectory(symreg) diff --git a/cpp/examples/symreg/CMakeLists.txt b/cpp/examples/symreg/CMakeLists.txt new file mode 100644 index 0000000000..66dd39a49c --- /dev/null +++ b/cpp/examples/symreg/CMakeLists.txt @@ -0,0 +1,19 @@ +#============================================================================= +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#============================================================================= + +add_executable(symreg_example symreg_example.cpp) +target_include_directories(symreg_example PRIVATE ${CUML_INCLUDE_DIRECTORIES}) +target_link_libraries(symreg_example cuml++) diff --git a/cpp/examples/symreg/CMakeLists_standalone.txt b/cpp/examples/symreg/CMakeLists_standalone.txt new file mode 100644 index 0000000000..e79a215cca --- /dev/null +++ b/cpp/examples/symreg/CMakeLists_standalone.txt @@ -0,0 +1,33 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +cmake_minimum_required(VERSION 3.8 FATAL_ERROR) +include(ExternalProject) + +project(symreg_example VERSION 0.1.0 LANGUAGES CXX CUDA ) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(CUDAToolkit) +find_package(cuml) + +add_executable(symreg_example symreg_example.cpp) + +# Need to set linker language to CUDA to link the CUDA Runtime +set_target_properties(symreg_example PROPERTIES LINKER_LANGUAGE "CUDA") + +# Link cuml and cudart +target_link_libraries(symreg_example cuml::cuml++ CUDA::cudart) \ No newline at end of file diff --git a/cpp/examples/symreg/README.md b/cpp/examples/symreg/README.md new file mode 100644 index 0000000000..52581eb627 --- /dev/null +++ b/cpp/examples/symreg/README.md @@ -0,0 +1,87 @@ +# symbolic regression +This subfolder contains an example on how perform symbolic regression in cuML (from C++) +There are two `CMakeLists.txt` in this folder: +1. `CMakeLists.txt` (default) which is included when building cuML +2. `CMakeLists_standalone.txt` as an example for a stand alone project linking to `libcuml.so` + +## Build +`symreg_example` is built as a part of cuML. To build it as a standalone executable, do +```bash +$ cmake .. -DCUML_LIBRARY_DIR=/path/to/directory/with/libcuml.so -DCUML_INCLUDE_DIR=/path/to/cuml/headers +``` +Then build with `make` or `ninja` +``` +$ make +Scanning dependencies of target raft +[ 10%] Creating directories for 'raft' +[ 20%] Performing download step (git clone) for 'raft' +Cloning into 'raft'... +[ 30%] Performing update step for 'raft' +[ 40%] No patch step for 'raft' +[ 50%] No configure step for 'raft' +[ 60%] No build step for 'raft' +[ 70%] No install step for 'raft' +[ 80%] Completed 'raft' +[ 80%] Built target raft +Scanning dependencies of target symreg_example +[ 90%] Building CXX object CMakeFiles/symreg_example.dir/symreg_example.cpp.o +[100%] Linking CUDA executable symreg_example +[100%] Built target symreg_example +``` +`CMakeLists_standalone.txt` also loads a minimal set of header dependencies(namely [raft](https://github.com/rapidsai/raft) and [cub](https://github.com/NVIDIA/cub)) if they are not detected in the system. +## Run + +1. Generate a toy training and test dataset +``` +$ python prepare_input.py +Training set has n_rows=250 n_cols=2 +Test set has n_rows=50 n_cols=2 +Wrote 500 values to train_data.txt +Wrote 100 values to test_data.txt +Wrote 250 values to train_labels.txt +Wrote 50 values to test_labels.txt +``` + +2. Run the symbolic regressor using the 4 files as inputs. An example query is given below +```bash +$ ./symreg_example -n_cols 2 \ + -n_train_rows 250 \ + -n_test_rows 50 \ + -random_state 21 \ + -population_size 4000 \ + -generations 20 \ + -stopping_criteria 0.01 \ + -p_crossover 0.7 \ + -p_subtree 0.1 \ + -p_hoist 0.05 \ + -p_point 0.1 \ + -parsimony_coefficient 0.01 +``` + +3. The corresponding output for the above query is given below : + +``` +Reading input with 250 rows and 2 columns from train_data.txt. +Reading input with 250 rows from train_labels.txt. +Reading input with 50 rows and 2 columns from test_data.txt. +Reading input with 50 rows from test_labels.txt. +*************************************** +Allocating device memory... +Allocation time = 0.259072ms +*************************************** +Beginning training on given dataset... +Finished training for 4 generations. + Best AST index : 1855 + Best AST depth : 3 + Best AST length : 13 + Best AST equation :( add( sub( mult( X0, X0) , div( X1, X1) ) , sub( X1, mult( X1, X1) ) ) ) +Training time = 626.658ms +*************************************** +Beginning Inference on Test dataset... +Inference score on test set = 5.29271e-08 +Inference time = 0.35248ms +Some Predicted test values: +-1.65061;-1.64081;-0.91711;-2.28976;-0.280688; +Corresponding Actual test values: +-1.65061;-1.64081;-0.91711;-2.28976;-0.280688; +``` \ No newline at end of file diff --git a/cpp/examples/symreg/prepare_input.py b/cpp/examples/symreg/prepare_input.py new file mode 100644 index 0000000000..4e53e131e9 --- /dev/null +++ b/cpp/examples/symreg/prepare_input.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +from sklearn.model_selection import train_test_split + +rng = np.random.RandomState(seed=2021) + +# Training samples +X_train = rng.uniform(-1, 1, 500).reshape(250, 2) +y_train = X_train[:, 0]**2 - X_train[:, 1]**2 + X_train[:, 1] - 1 + +# Testing samples +X_test = rng.uniform(-1, 1, 100).reshape(50, 2) +y_test = X_test[:, 0]**2 - X_test[:, 1]**2 + X_test[:, 1] - 1 + +print("Training set has n_rows=%d n_cols=%d" %(X_train.shape)) +print("Test set has n_rows=%d n_cols=%d" %(X_test.shape)) + +train_data = "train_data.txt" +test_data = "test_data.txt" +train_labels = "train_labels.txt" +test_labels = "test_labels.txt" + +# Save all datasets in col-major format +np.savetxt(train_data, X_train.T,fmt='%.7f') +np.savetxt(test_data, X_test.T,fmt='%.7f') +np.savetxt(train_labels, y_train,fmt='%.7f') +np.savetxt(test_labels, y_test,fmt='%.7f') + +print("Wrote %d values to %s"%(X_train.size,train_data)) +print("Wrote %d values to %s"%(X_test.size,test_data)) +print("Wrote %d values to %s"%(y_train.size,train_labels)) +print("Wrote %d values to %s"%(y_test.size,test_labels)) diff --git a/cpp/examples/symreg/symreg_example.cpp b/cpp/examples/symreg/symreg_example.cpp new file mode 100644 index 0000000000..a33d9af8bf --- /dev/null +++ b/cpp/examples/symreg/symreg_example.cpp @@ -0,0 +1,347 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +// Namspace alias +namespace cg = cuml::genetic; + +#ifndef CUDA_RT_CALL +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed with " \ + "%s (%d).\n", \ + #call, \ + __LINE__, \ + __FILE__, \ + cudaGetErrorString(cudaStatus), \ + cudaStatus); \ + } +#endif // CUDA_RT_CALL + +template +T get_argval(char** begin, char** end, const std::string& arg, const T default_val) +{ + T argval = default_val; + char** itr = std::find(begin, end, arg); + if (itr != end && ++itr != end) { + std::istringstream inbuf(*itr); + inbuf >> argval; + } + return argval; +} + +template +int parse_col_major(const std::string fname, + std::vector& vec, + const int n_rows, + const int n_cols) +{ + std::ifstream is(fname); + if (!is.is_open()) { + std::cerr << "ERROR: Could not open file " << fname << std::endl; + return 1; + } + + std::istream_iterator start(is), end; + vec.reserve(n_rows * n_cols); + vec.assign(start, end); + return 0; +} + +int main(int argc, char* argv[]) +{ + // Training hyper parameters(contains default vals) + cg::param params; + + // Cuda Events to track execution time for various components + cudaEvent_t start, stop; + CUDA_RT_CALL(cudaEventCreate(&start)); + CUDA_RT_CALL(cudaEventCreate(&stop)); + + // Process training arguments + const int population_size = + get_argval(argv, argv + argc, "-population_size", params.population_size); + const uint64_t random_state = get_argval(argv, argv + argc, "-random_state", params.random_state); + const int num_generations = get_argval(argv, argv + argc, "-generations", params.generations); + const float stop_criterion = + get_argval(argv, argv + argc, "-stopping_criteria", params.stopping_criteria); + const float p_crossover = get_argval(argv, argv + argc, "-p_crossover", params.p_crossover); + const float p_subtree = get_argval(argv, argv + argc, "-p_subtree", params.p_subtree_mutation); + const float p_hoist = get_argval(argv, argv + argc, "-p_hoist", params.p_hoist_mutation); + const float p_point = get_argval(argv, argv + argc, "-p_point", params.p_point_mutation); + const float p_point_replace = + get_argval(argv, argv + argc, "-p_point_replace", params.p_point_replace); + const float parsimony_coeff = + get_argval(argv, argv + argc, "-parsimony_coeff", params.parsimony_coefficient); + const std::string metric = get_argval(argv, + argv + argc, + "-metric", + std::string("mae")); // mean absolute error is default + + // Process dataset specific arguments + const int n_cols = get_argval(argv, argv + argc, "-n_cols", 0); + const int n_train_rows = get_argval(argv, argv + argc, "-n_train_rows", 0); + const int n_test_rows = get_argval(argv, argv + argc, "-n_test_rows", 0); + + const std::string fX_train = + get_argval(argv, argv + argc, "-train_data", std::string("train_data.txt")); + const std::string fy_train = + get_argval(argv, argv + argc, "-train_labels", std::string("train_labels.txt")); + const std::string fX_test = + get_argval(argv, argv + argc, "-test_data", std::string("test_data.txt")); + const std::string fy_test = + get_argval(argv, argv + argc, "-test_labels", std::string("test_labels.txt")); + + // Optionally accept files containing sample weights - if none are specified, then we consider a + // uniform distribution + const std::string fw_train = + get_argval(argv, argv + argc, "-train_weights", std::string("train_weights.txt")); + const std::string fw_test = + get_argval(argv, argv + argc, "-test_weights", std::string("test_weights.txt")); + + std::vector X_train; + std::vector X_test; + std::vector y_train; + std::vector y_test; + std::vector w_train; + std::vector w_test; + + // Read input + if (parse_col_major(fX_train, X_train, n_train_rows, n_cols)) return 1; + if (parse_col_major(fX_test, X_test, n_test_rows, n_cols)) return 1; + if (parse_col_major(fy_train, y_train, n_train_rows, 1)) return 1; + if (parse_col_major(fy_test, y_test, n_test_rows, 1)) return 1; + if (parse_col_major(fw_train, w_train, n_train_rows, 1)) { + std::cerr << "Defaulting to uniform training weights" << std::endl; + w_train.resize(n_train_rows, 1.0f); + } + if (parse_col_major(fw_test, w_test, n_test_rows, 1)) { + std::cerr << "Defaulting to uniform test weights" << std::endl; + w_test.resize(n_test_rows, 1.0f); + } + + // Check for valid mutation probability distribution + float p_sum = p_crossover + p_hoist + p_point + p_subtree; + if (p_sum >= 1.0f || p_sum <= 0.0f) { + std::cerr << "ERROR: Invalid mutation probabilities provided" << std::endl + << "Probability sum for crossover, subtree, host and point mutations is " << p_sum + << std::endl; + return 1; + } + + // Check if point_replace < 1.0f + if (p_point_replace > 1.0f || p_point_replace < 0.0f) { + std::cerr << "ERROR: Invalid value for point replacement probability" << std::endl; + return 1; + } + + // Set all training parameters + params.num_features = n_cols; + params.population_size = population_size; + params.random_state = random_state; + params.generations = num_generations; + params.stopping_criteria = stop_criterion; + params.p_crossover = p_crossover; + params.p_subtree_mutation = p_subtree; + params.p_hoist_mutation = p_hoist; + params.p_point_mutation = p_point; + params.p_point_replace = p_point_replace; + params.parsimony_coefficient = parsimony_coeff; + + // Set training metric + if (metric == "mae") { + params.metric = cg::metric_t::mae; + } else if (metric == "mse") { + params.metric = cg::metric_t::mse; + } else if (metric == "rmse") { + params.metric = cg::metric_t::rmse; + } else { + std::cerr << "ERROR: Invalid metric specified for regression (can only be " + "mae, mse or rmse) " + << std::endl; + return 1; + } + + /* ======================= Begin GPU memory allocation ======================= */ + std::cout << "***************************************" << std::endl; + raft::handle_t handle; + std::shared_ptr allocator(new raft::mr::device::default_allocator()); + + cudaStream_t stream; + CUDA_RT_CALL(cudaStreamCreate(&stream)); + handle.set_stream(stream); + + // Begin recording time + cudaEventRecord(start, stream); + + rmm::device_uvector dX_train(n_cols * n_train_rows, stream); + rmm::device_uvector dy_train(n_train_rows, stream); + rmm::device_uvector dw_train(n_train_rows, stream); + rmm::device_uvector dX_test(n_cols * n_test_rows, stream); + rmm::device_uvector dy_test(n_test_rows, stream); + rmm::device_uvector dw_test(n_test_rows, stream); + rmm::device_uvector dy_pred(n_test_rows, stream); + rmm::device_scalar d_score{stream}; + + cg::program_t d_finalprogs; // pointer to last generation ASTs on device + + CUDA_RT_CALL(cudaMemcpyAsync(dX_train.data(), + X_train.data(), + sizeof(float) * dX_train.size(), + cudaMemcpyHostToDevice, + stream)); + + CUDA_RT_CALL(cudaMemcpyAsync(dy_train.data(), + y_train.data(), + sizeof(float) * dy_train.size(), + cudaMemcpyHostToDevice, + stream)); + + CUDA_RT_CALL(cudaMemcpyAsync(dw_train.data(), + w_train.data(), + sizeof(float) * dw_train.size(), + cudaMemcpyHostToDevice, + stream)); + + CUDA_RT_CALL(cudaMemcpyAsync( + dX_test.data(), X_test.data(), sizeof(float) * dX_test.size(), cudaMemcpyHostToDevice, stream)); + + CUDA_RT_CALL(cudaMemcpyAsync( + dy_test.data(), y_test.data(), sizeof(float) * dy_test.size(), cudaMemcpyHostToDevice, stream)); + + CUDA_RT_CALL(cudaMemcpyAsync( + dw_test.data(), w_test.data(), sizeof(float) * n_test_rows, cudaMemcpyHostToDevice, stream)); + + // Initialize AST + raft::allocate(d_finalprogs, params.population_size, stream); + + std::vector> history; + history.reserve(params.generations); + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + float alloc_time; + cudaEventElapsedTime(&alloc_time, start, stop); + + std::cout << "Allocated device memory in " << std::setw(10) << alloc_time << "ms" << std::endl; + + /* ======================= Begin training ======================= */ + + std::cout << "***************************************" << std::endl; + std::cout << std::setw(30) << "Beginning training for " << std::setw(15) << params.generations + << " generations" << std::endl; + cudaEventRecord(start, stream); + + cg::symFit(handle, + dX_train.data(), + dy_train.data(), + dw_train.data(), + n_train_rows, + n_cols, + params, + d_finalprogs, + history); + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + float training_time; + cudaEventElapsedTime(&training_time, start, stop); + + int n_gen = params.num_epochs; + std::cout << std::setw(30) << "Convergence achieved in " << std::setw(15) << n_gen + << " generations." << std::endl; + + // Find index of best program + int best_idx = 0; + float opt_fitness = history.back()[0].raw_fitness_; + + // For all 3 loss functions - min is better + for (int i = 1; i < params.population_size; ++i) { + if (history.back()[i].raw_fitness_ < opt_fitness) { + best_idx = i; + opt_fitness = history.back()[i].raw_fitness_; + } + } + + std::string eqn = cg::stringify(history.back()[best_idx]); + std::cout << std::setw(30) << "Best AST depth " << std::setw(15) << history.back()[best_idx].depth + << std::endl; + std::cout << std::setw(30) << "Best AST length " << std::setw(15) << history.back()[best_idx].len + << std::endl; + std::cout << std::setw(30) << "Best AST equation " << std::setw(15) << eqn << std::endl; + std::cout << "Training time = " << training_time << "ms" << std::endl; + + /* ======================= Begin testing ======================= */ + + std::cout << "***************************************" << std::endl; + std::cout << "Beginning Inference on test dataset " << std::endl; + cudaEventRecord(start, stream); + cuml::genetic::symRegPredict( + handle, dX_test.data(), n_test_rows, d_finalprogs + best_idx, dy_pred.data()); + + std::vector hy_pred(n_test_rows, 0.0f); + CUDA_RT_CALL(cudaMemcpy( + hy_pred.data(), dy_pred.data(), n_test_rows * sizeof(float), cudaMemcpyDeviceToHost)); + + cuml::genetic::compute_metric( + handle, n_test_rows, 1, dy_test.data(), dy_pred.data(), dw_test.data(), d_score.data(), params); + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + float inference_time; + cudaEventElapsedTime(&inference_time, start, stop); + + // Output fitness score + std::cout << "Inference score = " << d_score.value(stream) << std::endl; + std::cout << "Inference time = " << inference_time << "ms" << std::endl; + + std::cout << "Some Predicted test values:" << std::endl; + std::copy(hy_pred.begin(), hy_pred.begin() + 5, std::ostream_iterator(std::cout, ";")); + std::cout << std::endl; + + std::cout << "Corresponding Actual test values:" << std::endl; + std::copy(y_test.begin(), y_test.begin() + 5, std::ostream_iterator(std::cout, ";")); + std::cout << std::endl; + + /* ======================= Reset data ======================= */ + + raft::deallocate(d_finalprogs, stream); + CUDA_RT_CALL(cudaEventDestroy(start)); + CUDA_RT_CALL(cudaEventDestroy(stop)); + CUDA_RT_CALL(cudaStreamDestroy(stream)); + return 0; +} diff --git a/cpp/include/cuml/genetic/common.h b/cpp/include/cuml/genetic/common.h new file mode 100644 index 0000000000..c305ff75cb --- /dev/null +++ b/cpp/include/cuml/genetic/common.h @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "node.h" + +#include +#include +#include +#include + +namespace cuml { +namespace genetic { + +/** fitness metric types */ +enum class metric_t : uint32_t { + /** mean absolute error (regression-only) */ + mae, + + /** mean squared error (regression-only) */ + mse, + + /** root mean squared error (regression-only) */ + rmse, + + /** pearson product-moment coefficient (regression and transformation) */ + pearson, + + /** spearman's rank-order coefficient (regression and transformation) */ + spearman, + + /** binary cross-entropy loss (classification-only) */ + logloss, +}; // enum class metric_t + +/** Type of initialization of the member programs in the population */ +enum class init_method_t : uint32_t { + /** random nodes chosen, allowing shorter or asymmetrical trees */ + grow, + + /** growing till a randomly chosen depth */ + full, + + /** 50% of the population on `grow` and the rest with `full` */ + half_and_half, +}; // enum class init_method_t + +enum class transformer_t : uint32_t { + /** sigmoid function */ + sigmoid, +}; // enum class transformer_t + +/** Mutation types for a program */ +enum class mutation_t : uint32_t { + /** Placeholder for first generation programs */ + none, + + /** Crossover mutations */ + crossover, + + /** Subtree mutations */ + subtree, + + /** Hoise mutations */ + hoist, + + /** Point mutations */ + point, + + /** Program reproduction */ + reproduce +}; // enum class mutation_t + +/** + * @brief contains all the hyper-parameters for training + * + * @note Unless otherwise mentioned, all the parameters below are applicable to + * all of classification, regression and transformation. + */ +struct param { + /** number of programs in each generation */ + int population_size = 1000; + + /** + * number of fittest programs to compare during correlation + * (transformation-only) + */ + int hall_of_fame = 100; + + /** + * number of fittest programs to return from `hall_of_fame` top programs + * (transformation-only) + */ + int n_components = 10; + + /** number of generations to evolve */ + int generations = 20; + + /** + * number of programs that compete in the tournament to become part of next + * generation + */ + int tournament_size = 20; + + /** metric threshold used for early stopping */ + float stopping_criteria = 0.0f; + + /** minimum/maximum value for `constant` nodes */ + float const_range[2] = {-1.0f, 1.0f}; + + /** minimum/maximum depth of programs after initialization */ + int init_depth[2] = {2, 6}; + + /** initialization method */ + init_method_t init_method = init_method_t::half_and_half; + + /** list of functions to choose from */ + std::vector function_set{ + node::type::add, node::type::mul, node::type::div, node::type::sub}; + + /** map of functions ordered by their arity */ + std::map> arity_set{ + {2, {node::type::add, node::type::mul, node::type::div, node::type::sub}}}; + /** transformation function to class probabilities (classification-only) */ + transformer_t transformer = transformer_t::sigmoid; + /** fitness metric */ + metric_t metric = metric_t::mae; + /** penalization factor for large programs */ + float parsimony_coefficient = 0.001f; + /** crossover mutation probability of the tournament winner */ + float p_crossover = 0.9f; + /** subtree mutation probability of the tournament winner*/ + float p_subtree_mutation = 0.01f; + /** hoist mutation probability of the tournament winner */ + float p_hoist_mutation = 0.01f; + + /** point mutation probabiilty of the tournament winner */ + float p_point_mutation = 0.01f; + + /** point replace probabiility for point mutations */ + float p_point_replace = 0.05f; + + /** subsampling factor */ + float max_samples = 1.0f; + + /** Terminal ratio for node selection during grow initialization. 0 -> auto-selection*/ + float terminalRatio = 0.0f; + + /** list of feature names for generating syntax trees from the programs */ + std::vector feature_names; + + /** number of features in current dataset */ + int num_features; + ///@todo: feature_names + ///@todo: verbose + + /** random seed used for RNG */ + uint64_t random_state = 0UL; + + /** Number of epochs for which the algorithm ran */ + int num_epochs = 0; + + /** Low memory flag for program history */ + bool low_memory = false; + + /** Computes the probability of 'reproduction' */ + float p_reproduce() const; + + /** maximum possible number of programs */ + int max_programs() const; + + /** criterion for scoring based on metric used */ + int criterion() const; +}; // struct param + +} // namespace genetic +} // namespace cuml diff --git a/cpp/include/cuml/genetic/genetic.h b/cpp/include/cuml/genetic/genetic.h index 64349986d2..03c229ffe1 100644 --- a/cpp/include/cuml/genetic/genetic.h +++ b/cpp/include/cuml/genetic/genetic.h @@ -16,115 +16,122 @@ #pragma once -#include "node.h" +#include +#include "common.h" #include "program.h" -#include -#include -#include - namespace cuml { namespace genetic { -/** Type of initialization of the member programs in the population */ -enum class init_method_t : uint32_t { - /** random nodes chosen, allowing shorter or asymmetrical trees */ - grow, - /** growing till a randomly chosen depth */ - full, - /** 50% of the population on `grow` and the rest with `full` */ - half_and_half, -}; // enum class init_method_t +/** + * @brief Visualize an AST + * + * @param prog host object containing the AST + * @return String representation of the AST + */ +std::string stringify(const program& prog); -/** fitness metric types */ -enum class metric_t : uint32_t { - /** mean absolute error (regression-only) */ - mae, - /** mean squared error (regression-only) */ - mse, - /** root mean squared error (regression-only) */ - rmse, - /** pearson product-moment coefficient (regression and transformation) */ - pearson, - /** spearman's rank-order coefficient (regression and transformation) */ - spearman, - /** binary cross-entropy loss (classification-only) */ - logloss, -}; // enum class metric_t +/** + * @brief Fit either a regressor, classifier or a transformer to the given dataset + * + * @param handle cuML handle + * @param input device pointer to the feature matrix + * @param labels device pointer to the label vector of length n_rows + * @param sample_weights device pointer to the sample weights of length n_rows + * @param n_rows number of rows of the feature matrix + * @param n_cols number of columns of the feature matrix + * @param params host struct containing hyperparameters needed for training + * @param final_progs device pointer to the final generation of programs(sorted by decreasing + * fitness) + * @param history host vector containing the list of all programs in every generation + * (sorted by decreasing fitness) + * + * @note This module allocates extra device memory for the nodes of the last generation that is + * pointed by `final_progs[i].nodes` for each program `i` in `final_progs`. The amount of memory + * allocated is found at runtime, and is `final_progs[i].len * sizeof(node)` for each program `i`. + * The reason this isn't deallocated within the function is because the resulting memory is needed + * for executing predictions in `symRegPredict`, `symClfPredict`, `symClfPredictProbs` and + * `symTransform` functions. The above device memory is expected to be explicitly deallocated by the + * caller AFTER calling the predict function. + */ +void symFit(const raft::handle_t& handle, + const float* input, + const float* labels, + const float* sample_weights, + const int n_rows, + const int n_cols, + param& params, + program_t& final_progs, + std::vector>& history); -enum class transformer_t : uint32_t { - /** sigmoid function */ - sigmoid, -}; // enum class transformer_t +/** + * @brief Make predictions for a symbolic regressor + * + * @param handle cuML handle + * @param input device pointer to feature matrix + * @param n_rows number of rows of the feature matrix + * @param best_prog device pointer to best AST fit during training + * @param output device pointer to output values + */ +void symRegPredict(const raft::handle_t& handle, + const float* input, + const int n_rows, + const program_t& best_prog, + float* output); /** - * @brief contains all the hyper-parameters for training + * @brief Probability prediction for a symbolic classifier. If a transformer(like sigmoid) is + * specified, then it is applied on the output before returning it. * - * @note Unless otherwise mentioned, all the parameters below are applicable to - * all of classification, regression and transformation. + * @param handle cuML handle + * @param input device pointer to feature matrix + * @param n_rows number of rows of the feature matrix + * @param params host struct containg training hyperparameters + * @param best_prog The best program obtained during training. Inferences are made using this + * @param output device pointer to output probability(in col major format) */ -struct param { - /** number of programs in each generation */ - int population_size = 1000; - /** - * number of fittest programs to compare during correlation - * (transformation-only) - */ - int hall_of_fame = 100; - /** - * number of fittest programs to return from `hall_of_fame` top programs - * (transformation-only) - */ - int n_components = 10; - /** number of generations to evolve */ - int generations = 20; - /** - * number of programs that compete in the tournament to become part of next - * generation - */ - int tournament_size = 20; - /** metric threshold used for early stopping */ - float stopping_criteria = 0.0f; - /** minimum/maximum value for `constant` nodes */ - float const_range[2] = {-1.0f, 1.0f}; - /** minimum/maximum depth of programs after initialization */ - int init_depth[2] = {2, 6}; - /** initialization method */ - init_method_t init_method = init_method_t::half_and_half; - /** list of functions to choose from */ - std::vector function_set{ - node::type::add, node::type::mul, node::type::div, node::type::sub}; - /** transformation function to class probabilities (classification-only) */ - transformer_t transformer = transformer_t::sigmoid; - /** fitness metric */ - metric_t metric = metric_t::mae; - /** penalization factor for large programs */ - float parsimony_coefficient = 0.001f; - /** crossover mutation probability of the tournament winner */ - float p_crossover = 0.9f; - /** subtree mutation probability of the tournament winner*/ - float p_subtree_mutation = 0.01f; - /** hoist mutation probability of the tournament winner */ - float p_hoist_mutation = 0.01f; - /** point mutation probabiilty of the tournament winner */ - float p_point_mutation = 0.01f; - /** point replace probabiility for point mutations */ - float p_point_replace = 0.05f; - /** subsampling factor */ - float max_samples = 1.0f; - /** list of feature names for generating syntax trees from the programs */ - std::vector feature_names; - ///@todo: feature_names - ///@todo: verbose - /** random seed used for RNG */ - uint64_t random_state = 0ull; +void symClfPredictProbs(const raft::handle_t& handle, + const float* input, + const int n_rows, + const param& params, + const program_t& best_prog, + float* output); - /** Computes the probability of 'reproduction' */ - float p_reproduce() const; +/** + * @brief Return predictions for a binary classification program defining the decision boundary + * + * @param handle cuML handle + * @param input device pointer to feature matrix + * @param n_rows number of rows of the feature matrix + * @param params host struct containg training hyperparameters + * @param best_prog Best program obtained after training + * @param output Device pointer to output predictions + */ +void symClfPredict(const raft::handle_t& handle, + const float* input, + const int n_rows, + const param& params, + const program_t& best_prog, + float* output); - /** maximum possible number of programs */ - int max_programs() const; -}; // struct param +/** + * @brief Transform the values in the input feature matrix according to the supplied programs + * + * @param handle cuML handle + * @param input device pointer to feature matrix + * @param params Hyperparameters used during training + * @param final_progs List of ASTs used for generating new features + * @param n_rows number of rows of the feature matrix + * @param n_cols number of columns of the feature matrix + * @param output device pointer to transformed input + */ +void symTransform(const raft::handle_t& handle, + const float* input, + const param& params, + const program_t& final_progs, + const int n_rows, + const int n_cols, + float* output); } // namespace genetic } // namespace cuml diff --git a/cpp/include/cuml/genetic/node.h b/cpp/include/cuml/genetic/node.h index e599360d5c..b7dc4da0d4 100644 --- a/cpp/include/cuml/genetic/node.h +++ b/cpp/include/cuml/genetic/node.h @@ -88,6 +88,11 @@ struct node { functions_end = unary_end, }; // enum type + /** + * @brief Default constructor for node + */ + explicit node(); + /** * @brief Construct a function node * diff --git a/cpp/include/cuml/genetic/program.h b/cpp/include/cuml/genetic/program.h index a24cd68c5d..e782194cae 100644 --- a/cpp/include/cuml/genetic/program.h +++ b/cpp/include/cuml/genetic/program.h @@ -16,11 +16,13 @@ #pragma once +#include +#include +#include "common.h" + namespace cuml { namespace genetic { -struct node; - /** * @brief The main data structure to store the AST that represents a program * in the current generation @@ -36,12 +38,251 @@ struct program { * is assumed to be a zero-copy (aka pinned memory) buffer, atleast in * this initial version */ + + /** + * Default constructor + */ + explicit program(); + + /** + * @brief Destroy the program object + * + */ + ~program(); + + /** + * @brief Copy constructor for a new program object + * + * @param src + */ + explicit program(const program& src); + + /** + * @brief assignment operator + * + * @param[in] src source program to be copied + * + * @return current program reference + */ + program& operator=(const program& src); + node* nodes; /** total number of nodes in this AST */ int len; /** maximum depth of this AST */ int depth; + /** fitness score of current AST */ + float raw_fitness_; + /** fitness metric used for current AST*/ + metric_t metric; + /** mutation type responsible for production */ + mutation_t mut_type; }; // struct program +/** program_t is a shorthand for device programs */ +typedef program* program_t; + +/** + * @brief Calls the execution kernel to evaluate all programs on the given dataset + * + * @param h cuML handle + * @param d_progs Device pointer to programs + * @param n_rows Number of rows in the input dataset + * @param n_progs Total number of programs being evaluated + * @param data Device pointer to input dataset (in col-major format) + * @param y_pred Device pointer to output of program evaluation + */ +void execute(const raft::handle_t& h, + const program_t& d_progs, + const int n_rows, + const int n_progs, + const float* data, + float* y_pred); + +/** + * @brief Compute the loss based on the metric specified in the training hyperparameters. + * It performs a batched computation for all programs in one shot. + * + * @param h cuML handle + * @param n_rows The number of labels/rows in the expected output + * @param n_progs The number of programs being batched + * @param y Device pointer to the expected output (SIZE = n_samples) + * @param y_pred Device pointer to the predicted output (SIZE = n_samples * n_progs) + * @param w Device pointer to sample weights (SIZE = n_samples) + * @param score Device pointer to final score (SIZE = n_progs) + * @param params Training hyperparameters + */ +void compute_metric(const raft::handle_t& h, + int n_rows, + int n_progs, + const float* y, + const float* y_pred, + const float* w, + float* score, + const param& params); + +/** + * @brief Computes the fitness scores for a sngle program on the given dataset + * + * @param h cuML handle + * @param d_prog Device pointer to program + * @param score Device pointer to fitness vals + * @param params Training hyperparameters + * @param n_rows Number of rows in the input dataset + * @param data Device pointer to input dataset + * @param y Device pointer to input labels + * @param sample_weights Device pointer to sample weights + */ +void find_fitness(const raft::handle_t& h, + program_t& d_prog, + float* score, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights); + +/** + * @brief Computes the fitness scores for all programs on the given dataset + * + * @param h cuML handle + * @param n_progs Batch size(Number of programs) + * @param d_progs Device pointer to list of programs + * @param score Device pointer to fitness vals computed for all programs + * @param params Training hyperparameters + * @param n_rows Number of rows in the input dataset + * @param data Device pointer to input dataset + * @param y Device pointer to input labels + * @param sample_weights Device pointer to sample weights + */ +void find_batched_fitness(const raft::handle_t& h, + int n_progs, + program_t& d_progs, + float* score, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights); + +/** + * @brief Computes and sets the fitness scores for a single program on the given dataset + * + * @param h cuML handle + * @param d_prog Device pointer to program + * @param h_prog Host program object + * @param params Training hyperparameters + * @param n_rows Number of rows in the input dataset + * @param data Device pointer to input dataset + * @param y Device pointer to input labels + * @param sample_weights Device pointer to sample weights + */ +void set_fitness(const raft::handle_t& h, + program_t& d_prog, + program& h_prog, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights); + +/** + * @brief Computes and sets the fitness scores for all programs on the given dataset + * + * @param h cuML handle + * @param n_progs Batch size + * @param d_progs Device pointer to list of programs + * @param h_progs Host vector of programs corresponding to d_progs + * @param params Training hyperparameters + * @param n_rows Number of rows in the input dataset + * @param data Device pointer to input dataset + * @param y Device pointer to input labels + * @param sample_weights Device pointer to sample weights + */ +void set_batched_fitness(const raft::handle_t& h, + int n_progs, + program_t& d_progs, + std::vector& h_progs, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights); + +/** + * @brief Returns precomputed fitness score of program on the host, + * after accounting for parsimony + * + * @param prog The host program + * @param params Training hyperparameters + * @return Fitness score corresponding to trained program + */ +float get_fitness(const program& prog, const param& params); + +/** + * @brief Evaluates and returns the depth of the current program. + * + * @param p_out The given program + * @return The depth of the current program + */ +int get_depth(const program& p_out); + +/** + * @brief Build a random program with depth atmost 10 + * + * @param p_out The output program + * @param params Training hyperparameters + * @param rng RNG to decide nodes to add + */ +void build_program(program& p_out, const param& params, std::mt19937& rng); + +/** + * @brief Perform a point mutation on the given program(AST) + * + * @param prog The input program + * @param p_out The result program + * @param params Training hyperparameters + * @param rng RNG to decide nodes to mutate + */ +void point_mutation(const program& prog, program& p_out, const param& params, std::mt19937& rng); + +/** + * @brief Perform a 'hoisted' crossover mutation using the parent and donor programs. + * The donor subtree selected is hoisted to ensure our constrains on total depth + * + * @param prog The input program + * @param donor The donor program + * @param p_out The result program + * @param params Training hyperparameters + * @param rng RNG for subtree selection + */ +void crossover(const program& prog, + const program& donor, + program& p_out, + const param& params, + std::mt19937& rng); + +/** + * @brief Performs a crossover mutation with a randomly built new program. + * Since crossover is 'hoisted', this will ensure that depth constrains + * are not violated. + * + * @param prog The input program + * @param p_out The result mutated program + * @param params Training hyperparameters + * @param rng RNG to control subtree selection and temporary program addition + */ +void subtree_mutation(const program& prog, program& p_out, const param& params, std::mt19937& rng); + +/** + * @brief Perform a hoist mutation on a random subtree of the given program + * (replace a subtree with a subtree of a subtree) + * + * @param prog The input program + * @param p_out The output program + * @param params Training hyperparameters + * @param rng RNG to control subtree selection + */ +void hoist_mutation(const program& prog, program& p_out, const param& params, std::mt19937& rng); } // namespace genetic } // namespace cuml diff --git a/cpp/src/genetic/genetic.cuh b/cpp/src/genetic/constants.h similarity index 52% rename from cpp/src/genetic/genetic.cuh rename to cpp/src/genetic/constants.h index 6423dc2d62..5e793a6604 100644 --- a/cpp/src/genetic/genetic.cuh +++ b/cpp/src/genetic/constants.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,29 +14,18 @@ * limitations under the License. */ -#pragma once +/** @file constants.h Common GPU functionality + constants for all operations */ -#include -#include +#pragma once namespace cuml { namespace genetic { -namespace detail { -HDI float p_reproduce(const param& p) -{ - auto sum = p.p_crossover + p.p_subtree_mutation + p.p_hoist_mutation + p.p_point_mutation; - auto ret = 1.f - sum; - return fmaxf(0.f, fminf(ret, 1.f)); -} +// Max number of threads per block to use with tournament and evaluation kernels +const int GENE_TPB = 256; -HDI int max_programs(const param& p) -{ - // in the worst case every generation's top program ends up reproducing, - // thereby adding another program into the population - return p.population_size + p.generations; -} +// Max size of stack used for AST evaluation +const int MAX_STACK_SIZE = 20; -} // namespace detail } // namespace genetic -} // namespace cuml +} // namespace cuml \ No newline at end of file diff --git a/cpp/src/genetic/fitness.cuh b/cpp/src/genetic/fitness.cuh new file mode 100644 index 0000000000..fa32e198c1 --- /dev/null +++ b/cpp/src/genetic/fitness.cuh @@ -0,0 +1,389 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cuml { +namespace genetic { + +template +void weightedPearson(const raft::handle_t& h, + const uint64_t n_samples, + const uint64_t n_progs, + const math_t* Y, + const math_t* X, + const math_t* W, + math_t* out) +{ + // Find Pearson's correlation coefficient + + cudaStream_t stream = h.get_stream(); + + rmm::device_uvector corr(n_samples * n_progs, stream); + + rmm::device_uvector y_tmp(n_samples, stream); + rmm::device_uvector x_tmp(n_samples * n_progs, stream); + + rmm::device_scalar y_mu(stream); // output mean + rmm::device_uvector x_mu(n_progs, stream); // predicted output mean + + rmm::device_uvector y_diff(n_samples, stream); // normalized output + rmm::device_uvector x_diff(n_samples * n_progs, + stream); // normalized predicted output + + rmm::device_uvector y_std(1, stream); // output stddev + rmm::device_uvector x_std(n_progs, + stream); // predicted output stddev + + rmm::device_scalar dWS(stream); // sample weight sum + math_t N = (math_t)n_samples; + + // Sum of weights + raft::stats::sum(dWS.data(), W, (uint64_t)1, n_samples, false, stream); + math_t WS = dWS.value(stream); + + // Find y_mu + raft::linalg::matrixVectorOp( + y_tmp.data(), + Y, + W, + (uint64_t)1, + n_samples, + false, + false, + [N, WS] __device__(math_t y, math_t w) { return N * w * y / WS; }, + stream); + + raft::stats::mean(y_mu.data(), y_tmp.data(), (uint64_t)1, n_samples, false, false, stream); + + // Find x_mu + raft::linalg::matrixVectorOp( + x_tmp.data(), + X, + W, + n_progs, + n_samples, + false, + true, + [N, WS] __device__(math_t x, math_t w) { return N * w * x / WS; }, + stream); + + raft::stats::mean(x_mu.data(), x_tmp.data(), n_progs, n_samples, false, false, stream); + + // Find y_diff + raft::stats::meanCenter( + y_diff.data(), Y, y_mu.data(), (uint64_t)1, n_samples, false, true, stream); + + // Find x_diff + raft::stats::meanCenter(x_diff.data(), X, x_mu.data(), n_progs, n_samples, false, true, stream); + + // Find y_std + raft::linalg::stridedReduction( + y_std.data(), + y_diff.data(), + (uint64_t)1, + n_samples, + (math_t)0, + stream, + false, + [W] __device__(math_t v, int i) { return v * v * W[i]; }, + raft::Sum(), + [] __device__(math_t in) { return raft::mySqrt(in); }); + math_t HYstd = y_std.element(0, stream); + + // Find x_std + raft::linalg::stridedReduction( + x_std.data(), + x_diff.data(), + n_progs, + n_samples, + (math_t)0, + stream, + false, + [W] __device__(math_t v, int i) { return v * v * W[i]; }, + raft::Sum(), + [] __device__(math_t in) { return raft::mySqrt(in); }); + + // Cross covariance + raft::linalg::matrixVectorOp( + corr.data(), + x_diff.data(), + y_diff.data(), + W, + n_progs, + n_samples, + false, + false, + [N, HYstd] __device__(math_t xd, math_t yd, math_t w) { return N * w * xd * yd / HYstd; }, + stream); + + // Find Correlation coeff + raft::linalg::matrixVectorOp( + corr.data(), + corr.data(), + x_std.data(), + n_progs, + n_samples, + false, + true, + [] __device__(math_t c, math_t xd) { return c / xd; }, + stream); + + raft::stats::mean(out, corr.data(), n_progs, n_samples, false, false, stream); +} + +struct rank_functor { + template + __host__ __device__ math_t operator()(math_t data) + { + if (data == 0) + return 0; + else + return 1; + } +}; + +template +void weightedSpearman(const raft::handle_t& h, + const uint64_t n_samples, + const uint64_t n_progs, + const math_t* Y, + const math_t* Y_pred, + const math_t* W, + math_t* out) +{ + cudaStream_t stream = h.get_stream(); + + // Get ranks for Y + thrust::device_vector Ycopy(Y, Y + n_samples); + thrust::device_vector rank_idx(n_samples, 0); + thrust::device_vector rank_diff(n_samples, 0); + thrust::device_vector Yrank(n_samples, 0); + + auto exec_policy = rmm::exec_policy(stream); + + thrust::sequence(exec_policy, rank_idx.begin(), rank_idx.end(), 0); + thrust::sort_by_key(exec_policy, Ycopy.begin(), Ycopy.end(), rank_idx.begin()); + thrust::adjacent_difference(exec_policy, Ycopy.begin(), Ycopy.end(), rank_diff.begin()); + thrust::transform( + exec_policy, rank_diff.begin(), rank_diff.end(), rank_diff.begin(), rank_functor()); + rank_diff[0] = 1; + thrust::inclusive_scan(exec_policy, rank_diff.begin(), rank_diff.end(), rank_diff.begin()); + thrust::copy(rank_diff.begin(), + rank_diff.end(), + thrust::make_permutation_iterator(Yrank.begin(), rank_idx.begin())); + + // Get ranks for Y_pred + // TODO: Find a better way to batch this + thrust::device_vector Ypredcopy(Y_pred, Y_pred + n_samples * n_progs); + thrust::device_vector Ypredrank(n_samples * n_progs, 0); + thrust::device_ptr Ypredptr = thrust::device_pointer_cast(Ypredcopy.data()); + thrust::device_ptr Ypredrankptr = thrust::device_pointer_cast(Ypredrank.data()); + + for (std::size_t i = 0; i < n_progs; ++i) { + thrust::sequence(exec_policy, rank_idx.begin(), rank_idx.end(), 0); + thrust::sort_by_key( + exec_policy, Ypredptr + (i * n_samples), Ypredptr + ((i + 1) * n_samples), rank_idx.begin()); + thrust::adjacent_difference( + exec_policy, Ypredptr + (i * n_samples), Ypredptr + ((i + 1) * n_samples), rank_diff.begin()); + thrust::transform( + exec_policy, rank_diff.begin(), rank_diff.end(), rank_diff.begin(), rank_functor()); + rank_diff[0] = 1; + thrust::inclusive_scan(exec_policy, rank_diff.begin(), rank_diff.end(), rank_diff.begin()); + thrust::copy( + rank_diff.begin(), + rank_diff.end(), + thrust::make_permutation_iterator(Ypredrankptr + (i * n_samples), rank_idx.begin())); + } + + // Compute pearson's coefficient + weightedPearson(h, + n_samples, + n_progs, + thrust::raw_pointer_cast(Yrank.data()), + thrust::raw_pointer_cast(Ypredrank.data()), + W, + out); +} + +template +void meanAbsoluteError(const raft::handle_t& h, + const uint64_t n_samples, + const uint64_t n_progs, + const math_t* Y, + const math_t* Y_pred, + const math_t* W, + math_t* out) +{ + cudaStream_t stream = h.get_stream(); + rmm::device_uvector error(n_samples * n_progs, stream); + rmm::device_scalar dWS(stream); + math_t N = (math_t)n_samples; + + // Weight Sum + raft::stats::sum(dWS.data(), W, (uint64_t)1, n_samples, false, stream); + math_t WS = dWS.value(stream); + + // Compute absolute differences + raft::linalg::matrixVectorOp( + error.data(), + Y_pred, + Y, + W, + n_progs, + n_samples, + false, + false, + [N, WS] __device__(math_t y_p, math_t y, math_t w) { + return N * w * raft::myAbs(y - y_p) / WS; + }, + stream); + + // Average along rows + raft::stats::mean(out, error.data(), n_progs, n_samples, false, false, stream); +} + +template +void meanSquareError(const raft::handle_t& h, + const uint64_t n_samples, + const uint64_t n_progs, + const math_t* Y, + const math_t* Y_pred, + const math_t* W, + math_t* out) +{ + cudaStream_t stream = h.get_stream(); + rmm::device_uvector error(n_samples * n_progs, stream); + rmm::device_scalar dWS(stream); + math_t N = (math_t)n_samples; + + // Weight Sum + raft::stats::sum(dWS.data(), W, (uint64_t)1, n_samples, false, stream); + math_t WS = dWS.value(stream); + + // Compute square differences + raft::linalg::matrixVectorOp( + error.data(), + Y_pred, + Y, + W, + n_progs, + n_samples, + false, + false, + [N, WS] __device__(math_t y_p, math_t y, math_t w) { + return N * w * (y_p - y) * (y_p - y) / WS; + }, + stream); + + // Add up row values per column + raft::stats::mean(out, error.data(), n_progs, n_samples, false, false, stream); +} + +template +void rootMeanSquareError(const raft::handle_t& h, + const uint64_t n_samples, + const uint64_t n_progs, + const math_t* Y, + const math_t* Y_pred, + const math_t* W, + math_t* out) +{ + cudaStream_t stream = h.get_stream(); + + // Find MSE + meanSquareError(h, n_samples, n_progs, Y, Y_pred, W, out); + + // Take sqrt on all entries + raft::matrix::seqRoot(out, n_progs, stream); +} + +template +void logLoss(const raft::handle_t& h, + const uint64_t n_samples, + const uint64_t n_progs, + const math_t* Y, + const math_t* Y_pred, + const math_t* W, + math_t* out) +{ + cudaStream_t stream = h.get_stream(); + // Logistic error per sample + rmm::device_uvector error(n_samples * n_progs, stream); + rmm::device_scalar dWS(stream); + math_t N = (math_t)n_samples; + + // Weight Sum + raft::stats::sum(dWS.data(), W, (uint64_t)1, n_samples, false, stream); + math_t WS = dWS.value(stream); + + // Compute logistic loss as described in + // http://fa.bianp.net/blog/2019/evaluate_logistic/ + // in an attempt to avoid encountering nan values. Modified for weighted logistic regression. + raft::linalg::matrixVectorOp( + error.data(), + Y_pred, + Y, + W, + n_progs, + n_samples, + false, + false, + [N, WS] __device__(math_t yp, math_t y, math_t w) { + math_t logsig; + if (yp < -33.3) + logsig = yp; + else if (yp <= -18) + logsig = yp - expf(yp); + else if (yp <= 37) + logsig = -log1pf(expf(-yp)); + else + logsig = -expf(-yp); + + return ((1 - y) * yp - logsig) * (N * w / WS); + }, + stream); + + // Take average along rows + raft::stats::mean(out, error.data(), n_progs, n_samples, false, false, stream); +} + +} // namespace genetic +} // namespace cuml diff --git a/cpp/src/genetic/genetic.cu b/cpp/src/genetic/genetic.cu index fa9dba9987..3c977d244c 100644 --- a/cpp/src/genetic/genetic.cu +++ b/cpp/src/genetic/genetic.cu @@ -14,13 +14,559 @@ * limitations under the License. */ -#include "genetic.cuh" +#include +#include +#include +#include +#include "constants.h" +#include "node.cuh" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + namespace cuml { namespace genetic { -float param::p_reproduce() const { return detail::p_reproduce(*this); } +/** + * @brief Simultaneously execute tournaments for all programs. + * The fitness values being compared are adjusted for bloat (program length), + * using the given parsimony coefficient. + * + * @param progs Device pointer to programs + * @param win_indices Winning indices for every tournament + * @param seeds Init seeds for choice selection + * @param n_progs Number of programs + * @param n_tours No of tournaments to be conducted + * @param tour_size No of programs considered per tournament(@c <=n_progs><) + * @param criterion Selection criterion for choices(min/max) + * @param parsimony Parsimony coefficient to account for bloat + */ +__global__ void batched_tournament_kernel(const program_t progs, + int* win_indices, + const int* seeds, + const int n_progs, + const int n_tours, + const int tour_size, + const int criterion, + const float parsimony) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n_tours) return; + + raft::random::detail::PhiloxGenerator rng(seeds[idx], idx, 0); + + int r; + rng.next(r); + + // Define optima values + int opt = r % n_progs; + float opt_penalty = parsimony * progs[opt].len * (2 * criterion - 1); + float opt_score = progs[opt].raw_fitness_ - opt_penalty; + + for (int s = 1; s < tour_size; ++s) { + rng.next(r); + int curr = r % n_progs; + float curr_penalty = parsimony * progs[curr].len * (2 * criterion - 1); + float curr_score = progs[curr].raw_fitness_ - curr_penalty; + + // Eliminate thread divergence - b takes values in {0,1} + // All threads have same criterion but mostly have different 'b' + int b = (opt_score < curr_score); + if (criterion) { + opt = (1 - b) * opt + b * curr; + opt_penalty = (1 - b) * opt_penalty + b * curr_penalty; + opt_score = (1 - b) * opt_score + b * curr_score; + } else { + opt = b * opt + (1 - b) * curr; + opt_penalty = b * opt_penalty + (1 - b) * curr_penalty; + opt_score = b * opt_score + (1 - b) * curr_score; + } + } + + // Set win index + win_indices[idx] = opt; +} + +/** + * @brief Driver function for evolving a generation of programs + * + * @param h cuML handle + * @param h_oldprogs previous generation host programs + * @param d_oldprogs previous generation device programs + * @param h_nextprogs next generation host programs + * @param d_nextprogs next generation device programs + * @param n_samples No of samples in input dataset + * @param data Device pointer to input dataset + * @param y Device pointer to input predictions + * @param sample_weights Device pointer to input weights + * @param params Training hyperparameters + * @param generation Current generation id + * @param seed Random seed for generators + */ +void parallel_evolve(const raft::handle_t& h, + const std::vector& h_oldprogs, + const program_t& d_oldprogs, + std::vector& h_nextprogs, + program_t& d_nextprogs, + const int n_samples, + const float* data, + const float* y, + const float* sample_weights, + const param& params, + const int generation, + const int seed) +{ + cudaStream_t stream = h.get_stream(); + auto n_progs = params.population_size; + auto tour_size = params.tournament_size; + auto n_tours = n_progs; // at least num_progs tournaments + + // Seed engines + std::mt19937 h_gen(seed); // CPU rng + raft::random::Rng d_gen(seed); // GPU rng + + std::uniform_real_distribution dist_U(0.0f, 1.0f); + + // Build, Mutate and Run Tournaments + + if (generation == 1) { + // Build random programs for the first generation + for (auto i = 0; i < n_progs; ++i) { + build_program(h_nextprogs[i], params, h_gen); + } + + } else { + // Set mutation type + float mut_probs[4]; + mut_probs[0] = params.p_crossover; + mut_probs[1] = params.p_subtree_mutation; + mut_probs[2] = params.p_hoist_mutation; + mut_probs[3] = params.p_point_mutation; + std::partial_sum(mut_probs, mut_probs + 4, mut_probs); + + for (auto i = 0; i < n_progs; ++i) { + float prob = dist_U(h_gen); + + if (prob < mut_probs[0]) { + h_nextprogs[i].mut_type = mutation_t::crossover; + n_tours++; + } else if (prob < mut_probs[1]) { + h_nextprogs[i].mut_type = mutation_t::subtree; + } else if (prob < mut_probs[2]) { + h_nextprogs[i].mut_type = mutation_t::hoist; + } else if (prob < mut_probs[3]) { + h_nextprogs[i].mut_type = mutation_t::point; + } else { + h_nextprogs[i].mut_type = mutation_t::reproduce; + } + } + + // Run tournaments + rmm::device_uvector tour_seeds(n_tours, stream); + rmm::device_uvector d_win_indices(n_tours, stream); + d_gen.uniformInt(tour_seeds.data(), n_tours, 1, INT_MAX, stream); + + auto criterion = params.criterion(); + dim3 nblks(raft::ceildiv(n_tours, GENE_TPB), 1, 1); + batched_tournament_kernel<<>>(d_oldprogs, + d_win_indices.data(), + tour_seeds.data(), + n_progs, + n_tours, + tour_size, + criterion, + params.parsimony_coefficient); + + CUDA_CHECK(cudaPeekAtLastError()); + + // Make sure tournaments have finished running before copying win indices + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Perform host mutations + + auto donor_pos = n_progs; + for (auto pos = 0; pos < n_progs; ++pos) { + auto parent_index = d_win_indices.element(pos, stream); + + if (h_nextprogs[pos].mut_type == mutation_t::crossover) { + // Get secondary index + auto donor_index = d_win_indices.element(donor_pos, stream); + donor_pos++; + crossover( + h_oldprogs[parent_index], h_oldprogs[donor_index], h_nextprogs[pos], params, h_gen); + } else if (h_nextprogs[pos].mut_type == mutation_t::subtree) { + subtree_mutation(h_oldprogs[parent_index], h_nextprogs[pos], params, h_gen); + } else if (h_nextprogs[pos].mut_type == mutation_t::hoist) { + hoist_mutation(h_oldprogs[parent_index], h_nextprogs[pos], params, h_gen); + } else if (h_nextprogs[pos].mut_type == mutation_t::point) { + point_mutation(h_oldprogs[parent_index], h_nextprogs[pos], params, h_gen); + } else if (h_nextprogs[pos].mut_type == mutation_t::reproduce) { + h_nextprogs[pos] = h_oldprogs[parent_index]; + } else { + // Should not come here + } + } + } + + /* Memcpy individual host nodes to device and destroy previous generation device nodes + TODO: Find a better way to do this. */ + for (auto i = 0; i < n_progs; ++i) { + program tmp(h_nextprogs[i]); + delete[] tmp.nodes; + + // Set current generation device nodes + tmp.nodes = (node*)rmm::mr::get_current_device_resource()->allocate( + h_nextprogs[i].len * sizeof(node), stream); + raft::copy(tmp.nodes, h_nextprogs[i].nodes, h_nextprogs[i].len, stream); + raft::copy(d_nextprogs + i, &tmp, 1, stream); + + if (generation > 1) { + // Free device memory allocated to program nodes in previous generation + raft::copy(&tmp, d_oldprogs + i, 1, stream); + rmm::mr::get_current_device_resource()->deallocate( + tmp.nodes, h_nextprogs[i].len * sizeof(node), stream); + } + + tmp.nodes = nullptr; + } + + // Make sure all copying is done + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Update raw fitness for all programs + set_batched_fitness( + h, n_progs, d_nextprogs, h_nextprogs, params, n_samples, data, y, sample_weights); +} + +float param::p_reproduce() const +{ + auto sum = + this->p_crossover + this->p_subtree_mutation + this->p_hoist_mutation + this->p_point_mutation; + auto ret = 1.f - sum; + return fmaxf(0.f, fminf(ret, 1.f)); +} + +int param::max_programs() const +{ + // in the worst case every generation's top program ends up reproducing, + // thereby adding another program into the population + return this->population_size + this->generations; +} + +int param::criterion() const +{ + // Returns 0 if a smaller value is preferred and 1 for the opposite + switch (this->metric) { + case metric_t::mse: return 0; + case metric_t::logloss: return 0; + case metric_t::mae: return 0; + case metric_t::rmse: return 0; + case metric_t::pearson: return 1; + case metric_t::spearman: return 1; + default: return -1; + } +} + +std::string stringify(const program& prog) +{ + std::string eqn = "( "; + std::string delim = ""; + std::stack ar_stack; + ar_stack.push(0); + + for (int i = 0; i < prog.len; ++i) { + if (prog.nodes[i].is_terminal()) { + eqn += delim; + if (prog.nodes[i].t == node::type::variable) { + // variable + eqn += "X"; + eqn += std::to_string(prog.nodes[i].u.fid); + } else { + // const + eqn += std::to_string(prog.nodes[i].u.val); + } + + int end_elem = ar_stack.top(); + ar_stack.pop(); + ar_stack.push(end_elem - 1); + while (ar_stack.top() == 0) { + ar_stack.pop(); + eqn += ") "; + if (ar_stack.empty()) { break; } + end_elem = ar_stack.top(); + ar_stack.pop(); + ar_stack.push(end_elem - 1); + } + delim = ", "; + } else { + ar_stack.push(prog.nodes[i].arity()); + eqn += delim; + switch (prog.nodes[i].t) { + // binary operators + case node::type::add: eqn += "add("; break; + case node::type::atan2: eqn += "atan2("; break; + case node::type::div: eqn += "div("; break; + case node::type::fdim: eqn += "fdim("; break; + case node::type::max: eqn += "max("; break; + case node::type::min: eqn += "min("; break; + case node::type::mul: eqn += "mult("; break; + case node::type::pow: eqn += "pow("; break; + case node::type::sub: eqn += "sub("; break; + // unary operators + case node::type::abs: eqn += "abs("; break; + case node::type::acos: eqn += "acos("; break; + case node::type::acosh: eqn += "acosh("; break; + case node::type::asin: eqn += "asin("; break; + case node::type::asinh: eqn += "asinh("; break; + case node::type::atan: eqn += "atan("; break; + case node::type::atanh: eqn += "atanh("; break; + case node::type::cbrt: eqn += "cbrt("; break; + case node::type::cos: eqn += "cos("; break; + case node::type::cosh: eqn += "cosh("; break; + case node::type::cube: eqn += "cube("; break; + case node::type::exp: eqn += "exp("; break; + case node::type::inv: eqn += "inv("; break; + case node::type::log: eqn += "log("; break; + case node::type::neg: eqn += "neg("; break; + case node::type::rcbrt: eqn += "rcbrt("; break; + case node::type::rsqrt: eqn += "rsqrt("; break; + case node::type::sin: eqn += "sin("; break; + case node::type::sinh: eqn += "sinh("; break; + case node::type::sq: eqn += "sq("; break; + case node::type::sqrt: eqn += "sqrt("; break; + case node::type::tan: eqn += "tan("; break; + case node::type::tanh: eqn += "tanh("; break; + default: break; + } + eqn += " "; + delim = ""; + } + } + + eqn += ")"; + return eqn; +} + +void symFit(const raft::handle_t& handle, + const float* input, + const float* labels, + const float* sample_weights, + const int n_rows, + const int n_cols, + param& params, + program_t& final_progs, + std::vector>& history) +{ + cudaStream_t stream = handle.get_stream(); + + // Update arity map in params - Need to do this only here, as all operations will call Fit atleast + // once + for (auto f : params.function_set) { + int ar = 1; + if (node::type::binary_begin <= f && f <= node::type::binary_end) { ar = 2; } + + if (params.arity_set.find(ar) == params.arity_set.end()) { + // Create map entry for current arity + std::vector vec_f(1, f); + params.arity_set.insert(std::make_pair(ar, vec_f)); + } else { + // Insert into map + std::vector vec_f = params.arity_set.at(ar); + if (std::find(vec_f.begin(), vec_f.end(), f) == vec_f.end()) { + params.arity_set.at(ar).push_back(f); + } + } + } + + // Check terminalRatio to dynamically set it + bool growAuto = (params.terminalRatio == 0.0f); + if (growAuto) { + params.terminalRatio = + 1.0f * params.num_features / (params.num_features + params.function_set.size()); + } + + /* Initializations */ + + std::vector h_currprogs(params.population_size); + std::vector h_nextprogs(params.population_size); + + std::vector h_fitness(params.population_size, 0.0f); + + program_t d_currprogs; // pointer to current programs + d_currprogs = (program_t)rmm::mr::get_current_device_resource()->allocate( + params.population_size * sizeof(program), stream); + program_t d_nextprogs = final_progs; // Reuse memory already allocated for final_progs + final_progs = nullptr; + + std::mt19937_64 h_gen_engine(params.random_state); + std::uniform_int_distribution seed_dist; + + /* Begin training */ + auto gen = 0; + params.num_epochs = 0; + + while (gen < params.generations) { + // Generate an init seed + auto init_seed = seed_dist(h_gen_engine); + + // Evolve current generation + parallel_evolve(handle, + h_currprogs, + d_currprogs, + h_nextprogs, + d_nextprogs, + n_rows, + input, + labels, + sample_weights, + params, + (gen + 1), + init_seed); + + // Update epochs + ++params.num_epochs; + + // Update h_currprogs (deepcopy) + h_currprogs = h_nextprogs; + + // Update evolution history, depending on the low memory flag + if (!params.low_memory || gen == 0) { + history.push_back(h_currprogs); + } else { + history.back() = h_currprogs; + } + + // Swap d_currprogs(to preserve device memory) + program_t d_tmp = d_currprogs; + d_currprogs = d_nextprogs; + d_nextprogs = d_tmp; + + // Update fitness array [host] and compute stopping criterion + auto crit = params.criterion(); + h_fitness[0] = h_currprogs[0].raw_fitness_; + auto opt_fit = h_fitness[0]; + + for (auto i = 1; i < params.population_size; ++i) { + h_fitness[i] = h_currprogs[i].raw_fitness_; + + if (crit == 0) { + opt_fit = std::min(opt_fit, h_fitness[i]); + } else { + opt_fit = std::max(opt_fit, h_fitness[i]); + } + } + + // Check for stop criterion + if ((crit == 0 && opt_fit <= params.stopping_criteria) || + (crit == 1 && opt_fit >= params.stopping_criteria)) { + CUML_LOG_DEBUG( + "Early stopping criterion reached in Generation #%d, fitness=%f", (gen + 1), opt_fit); + break; + } + + // Update generation + ++gen; + } + + // Set final generation programs + final_progs = d_currprogs; + + // Reset automatic growth parameter + if (growAuto) { params.terminalRatio = 0.0f; } + + // Deallocate the previous generation device memory + rmm::mr::get_current_device_resource()->deallocate( + d_nextprogs, params.population_size * sizeof(program), stream); + d_currprogs = nullptr; + d_nextprogs = nullptr; +} + +void symRegPredict(const raft::handle_t& handle, + const float* input, + const int n_rows, + const program_t& best_prog, + float* output) +{ + // Assume best_prog is on device + execute(handle, best_prog, n_rows, 1, input, output); +} + +void symClfPredictProbs(const raft::handle_t& handle, + const float* input, + const int n_rows, + const param& params, + const program_t& best_prog, + float* output) +{ + cudaStream_t stream = handle.get_stream(); + + // Assume output is of shape [n_rows, 2] in colMajor format + execute(handle, best_prog, n_rows, 1, input, output); + + // Apply 2 map operations to get probabilities! + // TODO: Modification needed for n_classes + if (params.transformer == transformer_t::sigmoid) { + raft::linalg::unaryOp( + output + n_rows, + output, + n_rows, + [] __device__(float in) { return 1.0f / (1.0f + expf(-in)); }, + stream); + raft::linalg::unaryOp( + output, output + n_rows, n_rows, [] __device__(float in) { return 1.0f - in; }, stream); + } else { + // Only sigmoid supported for now + } +} + +void symClfPredict(const raft::handle_t& handle, + const float* input, + const int n_rows, + const param& params, + const program_t& best_prog, + float* output) +{ + cudaStream_t stream = handle.get_stream(); + + // Memory for probabilities + rmm::device_uvector probs(2 * n_rows, stream); + symClfPredictProbs(handle, input, n_rows, params, best_prog, probs.data()); + + // Take argmax along columns + // TODO: Further modification needed for n_classes + raft::linalg::binaryOp( + output, + probs.data(), + probs.data() + n_rows, + n_rows, + [] __device__(float p0, float p1) { return 1.0f * (p0 <= p1); }, + stream); +} -int param::max_programs() const { return detail::max_programs(*this); } +void symTransform(const raft::handle_t& handle, + const float* input, + const param& params, + const program_t& final_progs, + const int n_rows, + const int n_cols, + float* output) +{ + cudaStream_t stream = handle.get_stream(); + // Execute final_progs(ordered by fitness) on input + // output of size [n_rows,hall_of_fame] + execute(handle, final_progs, n_rows, params.n_components, input, output); +} } // namespace genetic } // namespace cuml diff --git a/cpp/src/genetic/node.cu b/cpp/src/genetic/node.cu index 8884763f81..fb7f79020f 100644 --- a/cpp/src/genetic/node.cu +++ b/cpp/src/genetic/node.cu @@ -22,6 +22,8 @@ namespace genetic { const int node::kInvalidFeatureId = -1; +node::node() {} + node::node(node::type ft) : t(ft) { ASSERT(is_nonterminal(), "node: ctor with `type` argument expects functions type only!"); diff --git a/cpp/src/genetic/node.cuh b/cpp/src/genetic/node.cuh index ac4f49101d..b999940b41 100644 --- a/cpp/src/genetic/node.cuh +++ b/cpp/src/genetic/node.cuh @@ -40,52 +40,52 @@ HDI int arity(node::type t) } // `data` assumed to be stored in col-major format -DI float evaluate_node(const node& n, const float* data, size_t stride, float inval, float inval1) +DI float evaluate_node( + const node& n, const float* data, const uint64_t stride, const uint64_t idx, const float* in) { if (n.t == node::type::constant) { return n.u.val; } else if (n.t == node::type::variable) { - return n.u.fid != node::kInvalidFeatureId ? data[n.u.fid * stride] : 0.f; + return data[(stride * n.u.fid) + idx]; } else { - auto abs_inval = fabsf(inval), abs_inval1 = fabsf(inval1); - auto small = abs_inval < MIN_VAL; + auto abs_inval = fabsf(in[0]), abs_inval1 = fabsf(in[1]); // note: keep the case statements in alphabetical order under each category // of operators. switch (n.t) { // binary operators - case node::type::add: return inval + inval1; - case node::type::atan2: return atan2f(inval, inval1); - case node::type::div: return abs_inval1 < MIN_VAL ? 1.f : fdividef(inval, inval1); - case node::type::fdim: return fdimf(inval, inval1); - case node::type::max: return fmaxf(inval, inval1); - case node::type::min: return fminf(inval, inval1); - case node::type::mul: return inval * inval1; - case node::type::pow: return powf(inval, inval1); - case node::type::sub: return inval - inval1; + case node::type::add: return in[0] + in[1]; + case node::type::atan2: return atan2f(in[0], in[1]); + case node::type::div: return abs_inval1 < MIN_VAL ? 1.0f : fdividef(in[0], in[1]); + case node::type::fdim: return fdimf(in[0], in[1]); + case node::type::max: return fmaxf(in[0], in[1]); + case node::type::min: return fminf(in[0], in[1]); + case node::type::mul: return in[0] * in[1]; + case node::type::pow: return powf(in[0], in[1]); + case node::type::sub: return in[0] - in[1]; // unary operators case node::type::abs: return abs_inval; - case node::type::acos: return acosf(inval); - case node::type::acosh: return acoshf(inval); - case node::type::asin: return asinf(inval); - case node::type::asinh: return asinhf(inval); - case node::type::atan: return atanf(inval); - case node::type::atanh: return atanhf(inval); - case node::type::cbrt: return cbrtf(inval); - case node::type::cos: return cosf(inval); - case node::type::cosh: return coshf(inval); - case node::type::cube: return inval * inval * inval; - case node::type::exp: return expf(inval); - case node::type::inv: return abs_inval < MIN_VAL ? 0.f : 1.f / inval; + case node::type::acos: return acosf(in[0]); + case node::type::acosh: return acoshf(in[0]); + case node::type::asin: return asinf(in[0]); + case node::type::asinh: return asinhf(in[0]); + case node::type::atan: return atanf(in[0]); + case node::type::atanh: return atanhf(in[0]); + case node::type::cbrt: return cbrtf(in[0]); + case node::type::cos: return cosf(in[0]); + case node::type::cosh: return coshf(in[0]); + case node::type::cube: return in[0] * in[0] * in[0]; + case node::type::exp: return expf(in[0]); + case node::type::inv: return abs_inval < MIN_VAL ? 0.f : 1.f / in[0]; case node::type::log: return abs_inval < MIN_VAL ? 0.f : logf(abs_inval); - case node::type::neg: return -inval; - case node::type::rcbrt: return rcbrtf(inval); + case node::type::neg: return -in[0]; + case node::type::rcbrt: return rcbrtf(in[0]); case node::type::rsqrt: return rsqrtf(abs_inval); - case node::type::sin: return sinf(inval); - case node::type::sinh: return sinhf(inval); - case node::type::sq: return inval * inval; + case node::type::sin: return sinf(in[0]); + case node::type::sinh: return sinhf(in[0]); + case node::type::sq: return in[0] * in[0]; case node::type::sqrt: return sqrtf(abs_inval); - case node::type::tan: return tanf(inval); - case node::type::tanh: return tanhf(inval); + case node::type::tan: return tanf(in[0]); + case node::type::tanh: return tanhf(in[0]); // shouldn't reach here! default: return 0.f; }; diff --git a/cpp/src/genetic/program.cu b/cpp/src/genetic/program.cu new file mode 100644 index 0000000000..9f2a104430 --- /dev/null +++ b/cpp/src/genetic/program.cu @@ -0,0 +1,563 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "constants.h" +#include "fitness.cuh" +#include "node.cuh" +#include "reg_stack.cuh" + +namespace cuml { +namespace genetic { + +/** + * Execution kernel for a single program. We assume that the input data + * is stored in column major format. + */ +template +__global__ void execute_kernel(const program_t d_progs, + const float* data, + float* y_pred, + const uint64_t n_rows) +{ + uint64_t pid = blockIdx.y; // current program + uint64_t row_id = blockIdx.x * blockDim.x + threadIdx.x; // current dataset row + + if (row_id >= n_rows) { return; } + + stack eval_stack; // Maintain stack only for remaining threads + program_t curr_p = d_progs + pid; // Current program + + int end = curr_p->len - 1; + node* curr_node = curr_p->nodes + end; + + float res = 0.0f; + float in[2] = {0.0f, 0.0f}; + + while (end >= 0) { + if (detail::is_nonterminal(curr_node->t)) { + int ar = detail::arity(curr_node->t); + in[0] = eval_stack.pop(); // Min arity of function is 1 + if (ar > 1) in[1] = eval_stack.pop(); + } + res = detail::evaluate_node(*curr_node, data, n_rows, row_id, in); + eval_stack.push(res); + curr_node--; + end--; + } + + // Outputs stored in col-major format + y_pred[pid * n_rows + row_id] = eval_stack.pop(); +} + +program::program() + : len(0), + depth(0), + raw_fitness_(0.0f), + metric(metric_t::mse), + mut_type(mutation_t::none), + nodes(nullptr) +{ +} + +program::~program() { delete[] nodes; } + +program::program(const program& src) + : len(src.len), + depth(src.depth), + raw_fitness_(src.raw_fitness_), + metric(src.metric), + mut_type(src.mut_type) +{ + nodes = new node[len]; + std::copy(src.nodes, src.nodes + src.len, nodes); +} + +program& program::operator=(const program& src) +{ + len = src.len; + depth = src.depth; + raw_fitness_ = src.raw_fitness_; + metric = src.metric; + mut_type = src.mut_type; + + // Copy nodes + delete[] nodes; + nodes = new node[len]; + std::copy(src.nodes, src.nodes + src.len, nodes); + + return *this; +} + +void compute_metric(const raft::handle_t& h, + int n_rows, + int n_progs, + const float* y, + const float* y_pred, + const float* w, + float* score, + const param& params) +{ + // Call appropriate metric function based on metric defined in params + if (params.metric == metric_t::pearson) { + weightedPearson(h, n_rows, n_progs, y, y_pred, w, score); + } else if (params.metric == metric_t::spearman) { + weightedSpearman(h, n_rows, n_progs, y, y_pred, w, score); + } else if (params.metric == metric_t::mae) { + meanAbsoluteError(h, n_rows, n_progs, y, y_pred, w, score); + } else if (params.metric == metric_t::mse) { + meanSquareError(h, n_rows, n_progs, y, y_pred, w, score); + } else if (params.metric == metric_t::rmse) { + rootMeanSquareError(h, n_rows, n_progs, y, y_pred, w, score); + } else if (params.metric == metric_t::logloss) { + logLoss(h, n_rows, n_progs, y, y_pred, w, score); + } else { + // This should not be reachable + } +} + +void execute(const raft::handle_t& h, + const program_t& d_progs, + const int n_rows, + const int n_progs, + const float* data, + float* y_pred) +{ + cudaStream_t stream = h.get_stream(); + + dim3 blks(raft::ceildiv(n_rows, GENE_TPB), n_progs, 1); + execute_kernel<<>>(d_progs, data, y_pred, (uint64_t)n_rows); + CUDA_CHECK(cudaPeekAtLastError()); +} + +void find_fitness(const raft::handle_t& h, + program_t& d_prog, + float* score, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights) +{ + cudaStream_t stream = h.get_stream(); + + // Compute predicted values + rmm::device_uvector y_pred(n_rows, stream); + execute(h, d_prog, n_rows, 1, data, y_pred.data()); + + // Compute error + compute_metric(h, n_rows, 1, y, y_pred.data(), sample_weights, score, params); +} + +void find_batched_fitness(const raft::handle_t& h, + int n_progs, + program_t& d_progs, + float* score, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights) +{ + cudaStream_t stream = h.get_stream(); + + rmm::device_uvector y_pred((uint64_t)n_rows * (uint64_t)n_progs, stream); + execute(h, d_progs, n_rows, n_progs, data, y_pred.data()); + + // Compute error + compute_metric(h, n_rows, n_progs, y, y_pred.data(), sample_weights, score, params); +} + +void set_fitness(const raft::handle_t& h, + program_t& d_prog, + program& h_prog, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights) +{ + cudaStream_t stream = h.get_stream(); + + rmm::device_uvector score(1, stream); + + find_fitness(h, d_prog, score.data(), params, n_rows, data, y, sample_weights); + + // Update host and device score for program + CUDA_CHECK(cudaMemcpyAsync( + &d_prog[0].raw_fitness_, score.data(), sizeof(float), cudaMemcpyDeviceToDevice, stream)); + h_prog.raw_fitness_ = score.front_element(stream); +} + +void set_batched_fitness(const raft::handle_t& h, + int n_progs, + program_t& d_progs, + std::vector& h_progs, + const param& params, + const int n_rows, + const float* data, + const float* y, + const float* sample_weights) +{ + cudaStream_t stream = h.get_stream(); + + rmm::device_uvector score(n_progs, stream); + + find_batched_fitness(h, n_progs, d_progs, score.data(), params, n_rows, data, y, sample_weights); + + // Update scores on host and device + // TODO: Find a way to reduce the number of implicit memory transfers + for (auto i = 0; i < n_progs; ++i) { + CUDA_CHECK(cudaMemcpyAsync(&d_progs[i].raw_fitness_, + score.element_ptr(i), + sizeof(float), + cudaMemcpyDeviceToDevice, + stream)); + h_progs[i].raw_fitness_ = score.element(i, stream); + } +} + +float get_fitness(const program& prog, const param& params) +{ + int crit = params.criterion(); + float penalty = params.parsimony_coefficient * prog.len * (2 * crit - 1); + return (prog.raw_fitness_ - penalty); +} + +/** + * @brief Get a random subtree of the current program nodes (on CPU) + * + * @param pnodes AST represented as a list of nodes + * @param len The total number of nodes in the AST + * @param rng Random number generator for subtree selection + * @return A tuple [first,last) which contains the required subtree + */ +std::pair get_subtree(node* pnodes, int len, std::mt19937& rng) +{ + int start, end; + start = end = 0; + + // Specify RNG + std::uniform_real_distribution dist_uniform(0.0f, 1.0f); + float bound = dist_uniform(rng); + + // Specify subtree start probs acc to Koza's selection approach + std::vector node_probs(len, 0.1); + float sum = 0.1 * len; + + for (int i = 0; i < len; ++i) { + if (pnodes[i].is_nonterminal()) { + node_probs[i] = 0.9; + sum += 0.8; + } + } + + // Normalize vector + for (int i = 0; i < len; ++i) { + node_probs[i] /= sum; + } + + // Compute cumulative sum + std::partial_sum(node_probs.begin(), node_probs.end(), node_probs.begin()); + + start = std::lower_bound(node_probs.begin(), node_probs.end(), bound) - node_probs.begin(); + end = start; + + // Iterate until all function arguments are satisfied in current subtree + int num_args = 1; + while (num_args > end - start) { + node curr; + curr = pnodes[end]; + if (curr.is_nonterminal()) num_args += curr.arity(); + ++end; + } + + return std::make_pair(start, end); +} + +int get_depth(const program& p_out) +{ + int depth = 0; + std::stack arity_stack; + for (auto i = 0; i < p_out.len; ++i) { + node curr(p_out.nodes[i]); + + // Update depth + int sz = arity_stack.size(); + depth = std::max(depth, sz); + + // Update stack + if (curr.is_nonterminal()) { + arity_stack.push(curr.arity()); + } else { + // Only triggered for a depth 0 node + if (arity_stack.empty()) break; + + int e = arity_stack.top(); + arity_stack.pop(); + arity_stack.push(e - 1); + + while (arity_stack.top() == 0) { + arity_stack.pop(); + if (arity_stack.empty()) break; + + e = arity_stack.top(); + arity_stack.pop(); + arity_stack.push(e - 1); + } + } + } + + return depth; +} + +void build_program(program& p_out, const param& params, std::mt19937& rng) +{ + // Define data structures needed for tree + std::stack arity_stack; + std::vector nodelist; + nodelist.reserve(1 << (MAX_STACK_SIZE)); + + // Specify Distributions with parameters + std::uniform_int_distribution dist_function(0, params.function_set.size() - 1); + std::uniform_int_distribution dist_initDepth(params.init_depth[0], params.init_depth[1]); + std::uniform_int_distribution dist_terminalChoice(0, params.num_features); + std::uniform_real_distribution dist_constVal(params.const_range[0], params.const_range[1]); + std::bernoulli_distribution dist_nodeChoice(params.terminalRatio); + std::bernoulli_distribution dist_coinToss(0.5); + + // Initialize nodes + int max_depth = dist_initDepth(rng); + node::type func = params.function_set[dist_function(rng)]; + node curr_node(func); + nodelist.push_back(curr_node); + arity_stack.push(curr_node.arity()); + + init_method_t method = params.init_method; + if (method == init_method_t::half_and_half) { + // Choose either grow or full for this tree + bool choice = dist_coinToss(rng); + method = choice ? init_method_t::grow : init_method_t::full; + } + + // Fill tree + while (!arity_stack.empty()) { + int depth = arity_stack.size(); + p_out.depth = std::max(depth, p_out.depth); + bool node_choice = dist_nodeChoice(rng); + + if ((node_choice == false || method == init_method_t::full) && depth < max_depth) { + // Add a function to node list + curr_node = node(params.function_set[dist_function(rng)]); + nodelist.push_back(curr_node); + arity_stack.push(curr_node.arity()); + } else { + // Add terminal + int terminal_choice = dist_terminalChoice(rng); + if (terminal_choice == params.num_features) { + // Add constant + float val = dist_constVal(rng); + curr_node = node(val); + } else { + // Add variable + int fid = terminal_choice; + curr_node = node(fid); + } + + // Modify nodelist + nodelist.push_back(curr_node); + + // Modify stack + int e = arity_stack.top(); + arity_stack.pop(); + arity_stack.push(e - 1); + while (arity_stack.top() == 0) { + arity_stack.pop(); + if (arity_stack.empty()) { break; } + + e = arity_stack.top(); + arity_stack.pop(); + arity_stack.push(e - 1); + } + } + } + + // Set new program parameters - need to do a copy as + // nodelist will be deleted using RAII semantics + p_out.nodes = new node[nodelist.size()]; + std::copy(nodelist.begin(), nodelist.end(), p_out.nodes); + + p_out.len = nodelist.size(); + p_out.metric = params.metric; + p_out.raw_fitness_ = 0.0f; +} + +void point_mutation(const program& prog, program& p_out, const param& params, std::mt19937& rng) +{ + // deep-copy program + p_out = prog; + + // Specify RNGs + std::uniform_real_distribution dist_uniform(0.0f, 1.0f); + std::uniform_int_distribution dist_terminalChoice(0, params.num_features); + std::uniform_real_distribution dist_constantVal(params.const_range[0], + params.const_range[1]); + + // Fill with uniform numbers + std::vector node_probs(p_out.len); + std::generate( + node_probs.begin(), node_probs.end(), [&dist_uniform, &rng] { return dist_uniform(rng); }); + + // Mutate nodes + int len = p_out.len; + for (int i = 0; i < len; ++i) { + node curr(prog.nodes[i]); + + if (node_probs[i] < params.p_point_replace) { + if (curr.is_terminal()) { + int choice = dist_terminalChoice(rng); + + if (choice == params.num_features) { + // Add a randomly generated constant + curr = node(dist_constantVal(rng)); + } else { + // Add a variable with fid=choice + curr = node(choice); + } + } else if (curr.is_nonterminal()) { + // Replace current function with another function of the same arity + int ar = curr.arity(); + // CUML_LOG_DEBUG("Arity is %d, curr function is + // %d",ar,static_cast::type>(curr.t)); + std::vector fset = params.arity_set.at(ar); + std::uniform_int_distribution<> dist_fset(0, fset.size() - 1); + int choice = dist_fset(rng); + curr = node(fset[choice]); + } + + // Update p_out with updated value + p_out.nodes[i] = curr; + } + } +} + +void crossover( + const program& prog, const program& donor, program& p_out, const param& params, std::mt19937& rng) +{ + // Get a random subtree of prog to replace + std::pair prog_slice = get_subtree(prog.nodes, prog.len, rng); + int prog_start = prog_slice.first; + int prog_end = prog_slice.second; + + // Set metric of output program + p_out.metric = prog.metric; + + // MAX_STACK_SIZE can only handle tree of depth MAX_STACK_SIZE - max(func_arity=2) + 1 + // Thus we continuously hoist the donor subtree. + // Actual indices in donor + int donor_start = 0; + int donor_end = donor.len; + int output_depth = 0; + int iter = 0; + do { + ++iter; + // Get donor subtree + std::pair donor_slice = + get_subtree(donor.nodes + donor_start, donor_end - donor_start, rng); + + // Get indices w.r.t current subspace [donor_start,donor_end) + int donor_substart = donor_slice.first; + int donor_subend = donor_slice.second; + + // Update relative indices to global indices + donor_substart += donor_start; + donor_subend += donor_start; + + // Update to new subspace + donor_start = donor_substart; + donor_end = donor_subend; + + // Evolve on current subspace + p_out.len = (prog_start) + (donor_end - donor_start) + (prog.len - prog_end); + delete[] p_out.nodes; + p_out.nodes = new node[p_out.len]; + + // Copy slices using std::copy + std::copy(prog.nodes, prog.nodes + prog_start, p_out.nodes); + std::copy(donor.nodes + donor_start, donor.nodes + donor_end, p_out.nodes + prog_start); + std::copy(prog.nodes + prog_end, + prog.nodes + prog.len, + p_out.nodes + (prog_start) + (donor_end - donor_start)); + + output_depth = get_depth(p_out); + } while (output_depth >= MAX_STACK_SIZE); + + // Set the depth of the final program + p_out.depth = output_depth; +} + +void subtree_mutation(const program& prog, program& p_out, const param& params, std::mt19937& rng) +{ + // Generate a random program and perform crossover + program new_program; + build_program(new_program, params, rng); + crossover(prog, new_program, p_out, params, rng); +} + +void hoist_mutation(const program& prog, program& p_out, const param& params, std::mt19937& rng) +{ + // Replace program subtree with a random sub-subtree + + std::pair prog_slice = get_subtree(prog.nodes, prog.len, rng); + int prog_start = prog_slice.first; + int prog_end = prog_slice.second; + + std::pair sub_slice = get_subtree(prog.nodes + prog_start, prog_end - prog_start, rng); + int sub_start = sub_slice.first; + int sub_end = sub_slice.second; + + // Update subtree indices to global indices + sub_start += prog_start; + sub_end += prog_start; + + p_out.len = (prog_start) + (sub_end - sub_start) + (prog.len - prog_end); + p_out.nodes = new node[p_out.len]; + p_out.metric = prog.metric; + + // Copy node slices using std::copy + std::copy(prog.nodes, prog.nodes + prog_start, p_out.nodes); + std::copy(prog.nodes + sub_start, prog.nodes + sub_end, p_out.nodes + prog_start); + std::copy(prog.nodes + prog_end, + prog.nodes + prog.len, + p_out.nodes + (prog_start) + (sub_end - sub_start)); + + // Update depth + p_out.depth = get_depth(p_out); +} + +} // namespace genetic +} // namespace cuml \ No newline at end of file diff --git a/cpp/src/genetic/reg_stack.cuh b/cpp/src/genetic/reg_stack.cuh index 1c3bb34cb3..4696f6e975 100644 --- a/cpp/src/genetic/reg_stack.cuh +++ b/cpp/src/genetic/reg_stack.cuh @@ -18,6 +18,13 @@ #include +#ifndef CUDA_PRAGMA_UNROLL +#ifdef __CUDA_ARCH__ +#define CUDA_PRAGMA_UNROLL _Pragma("unroll") +#else +#define CUDA_PRAGMA_UNROLL +#endif // __CUDA_ARCH__ +#endif // CUDA_PRAGMA_UNROLL namespace cuml { namespace genetic { @@ -34,7 +41,7 @@ template struct stack { explicit HDI stack() : elements_(0) { -#pragma unroll + CUDA_PRAGMA_UNROLL for (int i = 0; i < MaxSize; ++i) { regs_[i] = DataT(0); } @@ -61,17 +68,17 @@ struct stack { */ HDI void push(DataT val) { -#pragma unroll - for (int i = 0; i < MaxSize; ++i) { + CUDA_PRAGMA_UNROLL + for (int i = MaxSize - 1; i >= 0; --i) { if (elements_ == i) { - regs_[i] = val; ++elements_; + regs_[i] = val; } } } /** - * @brief Pops the top element from the stack + * @brief Lazily pops the top element from the stack * * @return pops the element and returns it, if already reached bottom, then it * returns zero. @@ -83,14 +90,14 @@ struct stack { */ HDI DataT pop() { -#pragma unroll + CUDA_PRAGMA_UNROLL for (int i = 0; i < MaxSize; ++i) { - if (elements_ - 1 == i) { - --elements_; + if (elements_ == (i + 1)) { + elements_--; return regs_[i]; } } - // shouldn't reach here! + return DataT(0); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6823d6d6e7..e689bfd1e4 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -53,6 +53,8 @@ if(BUILD_CUML_TESTS) sg/fnv_hash_test.cpp sg/genetic/node_test.cpp sg/genetic/param_test.cu + sg/genetic/program_test.cu + sg/genetic/evolution_test.cu sg/hdbscan_test.cu sg/holtwinters_test.cu sg/kmeans_test.cu diff --git a/cpp/test/sg/genetic/evolution_test.cu b/cpp/test/sg/genetic/evolution_test.cu new file mode 100644 index 0000000000..0f718120c8 --- /dev/null +++ b/cpp/test/sg/genetic/evolution_test.cu @@ -0,0 +1,361 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuml { +namespace genetic { + +/** + * @brief Tests the training and inference of the symbolic regressor, classifier and transformer + * on y = 0.5X[0] + 0.4 X[1] + * + */ +class GeneticEvolutionTest : public ::testing::Test { + public: + GeneticEvolutionTest() + : d_train(0, cudaStream_t(0)), + d_trainlab(0, cudaStream_t(0)), + d_test(0, cudaStream_t(0)), + d_testlab(0, cudaStream_t(0)), + d_trainwts(0, cudaStream_t(0)), + d_testwts(0, cudaStream_t(0)) + { + } + + protected: + void SetUp() override + { + ML::Logger::get().setLevel(CUML_LEVEL_INFO); + CUDA_CHECK(cudaStreamCreate(&stream)); + handle.set_stream(stream); + + // Set training param vals + hyper_params.population_size = 5000; + hyper_params.num_features = n_cols; + hyper_params.random_state = 11; + hyper_params.generations = 20; + hyper_params.stopping_criteria = 0.01; + hyper_params.p_crossover = 0.7; + hyper_params.p_subtree_mutation = 0.1; + hyper_params.p_hoist_mutation = 0.05; + hyper_params.p_point_mutation = 0.1; + hyper_params.parsimony_coefficient = 0.01; + + // Initialize weights + h_trainwts.resize(n_tr_rows, 1.0f); + h_testwts.resize(n_tst_rows, 1.0f); + + // resize device memory + d_train.resize(n_cols * n_tr_rows, stream); + d_trainlab.resize(n_tr_rows, stream); + d_test.resize(n_cols * n_tst_rows, stream); + d_testlab.resize(n_tst_rows, stream); + d_trainwts.resize(n_tr_rows, stream); + d_testwts.resize(n_tst_rows, stream); + + // Memcpy HtoD + CUDA_CHECK(cudaMemcpyAsync(d_train.data(), + h_train.data(), + n_cols * n_tr_rows * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaMemcpyAsync(d_trainlab.data(), + h_trainlab.data(), + n_tr_rows * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaMemcpyAsync(d_test.data(), + h_test.data(), + n_cols * n_tst_rows * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaMemcpyAsync(d_testlab.data(), + h_testlab.data(), + n_tst_rows * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaMemcpyAsync(d_trainwts.data(), + h_trainwts.data(), + n_tr_rows * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + CUDA_CHECK(cudaMemcpyAsync(d_testwts.data(), + h_testwts.data(), + n_tst_rows * sizeof(float), + cudaMemcpyHostToDevice, + stream)); + } + + void TearDown() override { CUDA_CHECK(cudaStreamDestroy(stream)); } + + raft::handle_t handle; + cudaStream_t stream; + param hyper_params; + + // Some mini-dataset constants + const int n_tr_rows = 250; + const int n_tst_rows = 50; + const int n_cols = 2; + const float tolerance = 0.025f; // assuming upto 2.5% tolerance for results(for now) + + // Contains synthetic Data + // y = + std::vector h_train = { + 0.2119566, -0.7221057, 0.9944866, -0.6420138, 0.3243210, -0.8062112, 0.9247920, -0.8267401, + 0.2330494, 0.1486086, -0.0957095, 0.1386102, 0.1674080, 0.0356288, 0.4644501, 0.3442579, + 0.6560287, 0.2349779, -0.3978628, 0.1793082, -0.1155355, 0.0176618, 0.8318791, 0.7813108, + 0.2736598, 0.6475824, -0.3849131, -0.4696701, -0.6907704, 0.2952283, -0.8723270, -0.3355115, + -0.0523054, -0.8182662, 0.5539537, -0.8737933, 0.5849895, -0.2579604, 0.3574578, -0.1654855, + -0.2554073, 0.3591112, 0.9403976, -0.3390219, 0.6517981, 0.6465558, 0.4370021, -0.0079799, + 0.2970910, 0.2452746, -0.7523201, -0.0951637, 0.6400041, -0.5386036, 0.4352954, -0.2126355, + 0.6203773, 0.7159789, -0.6823127, 0.4670905, -0.4666402, 0.0071169, 0.5038485, -0.5780727, + 0.7944591, 0.6328644, 0.1813934, 0.2653100, -0.1671608, 0.8108285, 0.3609906, -0.5820257, + 0.0447571, 0.7247062, 0.3546630, 0.5908147, -0.1850210, 0.8889677, 0.4725176, 0.2190818, + 0.1944676, -0.1650774, 0.5239485, 0.4871244, 0.8803309, 0.3119077, -0.1502819, 0.2140640, + -0.3925484, 0.1745171, -0.0332719, 0.9880465, 0.5828160, 0.3987538, 0.4770127, -0.4151363, + -0.9899210, 0.7880531, -0.3253276, -0.4564783, -0.9825586, -0.0729553, 0.7512086, 0.3045725, + -0.5038860, -0.9412159, -0.8188231, -0.3728235, 0.2280060, -0.4212141, -0.2424457, -0.5574245, + -0.5845115, 0.7049432, -0.5244312, -0.0405502, -0.2238990, 0.6347900, 0.9998363, 0.3580613, + 0.0199144, -0.1971139, 0.8036406, 0.7131155, 0.5613965, 0.3835140, 0.0717551, 0.0463067, + 0.5255786, 0.0928743, 0.1386557, -0.7212757, 0.3051646, 0.2635859, -0.5229289, -0.8547997, + 0.6653103, -0.1116264, 0.2930650, 0.5135837, 0.7412015, -0.3735900, -0.9826624, -0.6185324, + -0.8464018, -0.4180478, 0.7254488, -0.5188612, -0.3333993, 0.8999060, -0.6015426, -0.6545046, + 0.6795465, -0.5157862, 0.4536161, -0.7564244, -0.0614987, 0.9840064, 0.3975551, 0.8684530, + 0.6091788, 0.2544823, -0.9745569, -0.1815226, -0.1521985, 0.8436312, -0.9446849, -0.2546227, + 0.9108996, -0.2374187, -0.8820541, -0.2937101, 0.2558129, 0.7706293, 0.1066034, -0.7223888, + -0.6807924, -0.5187497, -0.3461997, 0.3319379, -0.5073046, 0.0713026, 0.4598049, -0.9708425, + -0.2323956, 0.3963093, -0.9132538, -0.2047350, 0.1162403, -0.6301352, -0.1114944, -0.4411873, + -0.7517651, 0.9942231, 0.6387486, -0.3516690, 0.2925287, 0.8415794, -0.2203800, 0.1182607, + -0.5032156, 0.4939238, 0.9852490, -0.8617036, -0.8945347, 0.1789286, -0.1909516, 0.2587640, + -0.2992706, 0.6049703, -0.1238372, 0.8297717, -0.3196876, 0.9792059, 0.7898732, 0.8210509, + -0.5545098, -0.5691904, -0.7678227, -0.9643255, -0.1002291, -0.4273028, -0.6697328, -0.3049299, + -0.0368014, 0.4804423, -0.6646156, 0.5903011, -0.1700153, -0.6397213, 0.9845422, -0.5159376, + 0.1589690, -0.3279489, -0.1498093, -0.9002322, 0.1960990, 0.3850992, 0.4812583, -0.1506606, + -0.0863564, -0.4061224, -0.3599582, -0.2919797, -0.5094189, 0.7824159, 0.3322580, -0.3275573, + -0.9909980, -0.5806390, 0.4667387, -0.3746538, -0.7436752, 0.5058509, 0.5686203, -0.8828574, + 0.2331149, 0.1225447, 0.9276860, -0.2576783, -0.5962995, -0.6098081, -0.0473731, 0.6461973, + -0.8618875, 0.2869696, -0.5910612, 0.2354020, 0.7434812, 0.9635402, -0.7473646, -0.1364276, + 0.4180313, 0.1777712, -0.3155821, -0.3896985, -0.5973547, 0.3018475, -0.2226010, 0.6965982, + -0.1711176, 0.4426420, 0.5972827, 0.7491136, 0.5431328, 0.1888770, -0.4517326, 0.7062291, + 0.5087549, -0.3582025, -0.4492956, 0.1632529, -0.1689859, 0.9334283, -0.3891996, 0.1138209, + 0.7598738, 0.0241726, -0.3133468, -0.0708007, 0.9602417, -0.7650007, -0.6497396, 0.4096349, + -0.7035034, 0.6052362, 0.5920056, -0.4065195, 0.3722862, -0.7039886, -0.2351859, 0.3143256, + -0.8650362, 0.3481469, 0.5242298, 0.2190642, 0.7090682, 0.7368234, 0.3148258, -0.8396302, + -0.8332214, 0.6766308, 0.4428585, 0.5376374, 0.1104256, -0.9560977, 0.8913012, 0.2302127, + -0.7445556, -0.8753514, -0.1434969, 0.7423451, -0.9627953, 0.7919458, -0.8590292, -0.2405730, + 0.0733800, -0.1964383, 0.3429065, -0.5199867, -0.6148949, -0.4645573, -0.1036227, 0.1915514, + 0.4981042, -0.3142545, -0.1360139, 0.5123143, -0.8319357, 0.2593685, -0.6637208, 0.8695423, + -0.4745009, -0.4598881, 0.2561057, 0.8682946, 0.7572707, -0.2405597, -0.6909520, -0.2329739, + -0.3544887, 0.5916605, -0.5483196, 0.3634111, 0.0485800, 0.1492287, -0.0361141, 0.6510856, + 0.9754849, -0.1871928, 0.7787021, -0.6019276, 0.2416331, -0.1160285, 0.8894659, 0.9423820, + -0.7052383, -0.8790381, -0.7129928, 0.5332075, -0.5728216, -0.9184565, 0.0437820, 0.3580015, + -0.7459742, -0.6401960, -0.7465842, -0.0257084, 0.7586666, 0.3472861, 0.3226733, -0.8356623, + 0.9038333, 0.9519323, 0.6794367, -0.4118270, -0.1475553, 0.1638173, 0.7039975, 0.0782125, + -0.6468386, -0.4905404, -0.0657285, -0.9094056, -0.1691999, 0.9545628, 0.5260556, 0.0704832, + 0.9559255, 0.4109315, 0.0437353, 0.1975988, -0.2173066, 0.4840004, -0.9305912, 0.6281645, + -0.2873839, -0.0092089, -0.7423917, -0.5064726, 0.2959957, 0.3744118, -0.2324660, 0.6419766, + 0.0482254, 0.0711853, -0.0668010, -0.6056250, -0.6424942, 0.5091138, -0.7920839, -0.3631541, + 0.2925649, 0.8553973, -0.5368195, -0.8043768, 0.6299060, -0.7402435, 0.7831608, -0.4979353, + -0.7786197, 0.1855255, -0.7243119, 0.7581270, 0.7850708, -0.6414960, -0.4423507, -0.4211898, + 0.8494025, 0.3603602, -0.3777632, 0.3322407, -0.0483915, -0.8515641, -0.9453503, -0.4536391, + -0.1080792, 0.5246211, 0.2128397, -0.0146389, -0.7508293, -0.0058518, 0.5420505, 0.1439000, + 0.1900943, 0.0454271, 0.3117409, 0.1234926, -0.1166942, 0.2856016, 0.8390452, 0.8877837, + 0.0886838, -0.7009126, -0.5130350, -0.0999212, 0.3338176, -0.3013774, 0.3526511, 0.9518843, + 0.5853393, -0.1422507, -0.9768327, -0.5915277, 0.9691055, 0.4186211, 0.7512146, 0.5220292, + -0.1700221, 0.5423641, 0.5864487, -0.7437551, -0.5076052, -0.8304062, 0.4895252, 0.7349310, + 0.7687441, 0.6319372, 0.7462888, 0.2358095}; + + std::vector h_trainlab = { + -0.7061807, -0.9935827, -1.3077246, -0.3378525, -0.6495246, -2.0123182, 0.0340125, -0.2089733, + -0.8786033, -1.3019919, -1.9427123, -1.9624611, -1.0215918, -0.7701042, -2.3890236, -0.6768685, + -1.5100409, -0.7647975, -0.6509883, -0.9327181, -2.2925701, -1.1547282, -0.0646960, -0.2433849, + -1.3402845, -1.1222004, -1.8060292, -0.5686744, -0.7949885, -0.7014911, -0.4394445, -0.6407220, + -0.7567281, -0.1424980, -0.4449957, -0.0832827, -1.3135824, -0.7259869, -0.6223005, -1.4591261, + -1.5859294, -0.7344378, -0.3131946, -0.8229243, -1.1158352, -0.4810999, -0.6265636, -0.9763480, + -1.3232699, -1.0156538, -0.3958369, -2.3411706, -1.6622960, -0.4680720, -2.0089384, -0.7158608, + -0.3735971, -1.0591518, -0.3007601, -1.9814152, -1.0727452, -0.7844243, -2.3594606, -0.4388914, + -0.1194218, -0.4284076, -0.7608060, -0.7356959, -0.7563467, -1.8871661, -2.3971652, -0.4424445, + -0.7512620, -0.2262175, -0.7759824, -2.5211585, -0.8688839, -0.0325217, -2.0756457, -2.5935947, + -1.1262706, -0.7814806, -2.6152479, -0.5979422, -1.8219779, -1.2011619, -0.9094200, -1.1892029, + -0.6205842, -1.7599165, -1.9918835, -0.7041349, -0.7746859, -0.6861359, -0.5224625, -1.2406723, + -0.1745701, -0.1291239, -2.4182146, -0.5995310, -1.1388247, -0.8812391, -1.1353377, -1.5786207, + -0.5555833, 0.0002464, -0.1457169, -1.1594313, -2.1163798, -1.1098294, -1.4213709, -0.4476795, + -1.5073204, -0.2717116, -0.6787519, -0.8713962, -0.9872876, -0.3698685, 0.0235867, -1.0940261, + -0.8272783, -1.9253905, -0.1709152, -0.6209573, -0.5865176, -0.7986188, -2.1974506, -2.6496017, + -1.9451187, -0.7424771, -1.8817208, -2.2417800, -0.8650095, -0.7006861, -2.0289972, -1.3193644, + -1.8613344, -1.0139089, -0.7310213, -0.5095533, -0.2320652, -2.3944243, 0.0525441, -0.5716605, + -0.0658016, -1.4066644, -0.6430519, -0.5938018, -0.6804599, -0.1180739, -1.7033852, -1.3027941, + -0.6082652, -2.4703887, -0.9920609, -0.3844494, -0.7468968, 0.0337840, -0.7998180, -0.0037226, + -0.5870786, -0.7766853, -0.3147676, -0.7173055, -2.7734269, -0.0547125, -0.4775438, -0.9444610, + -1.4637991, -1.7066195, -0.0135983, -0.6795068, -1.2210661, -0.1762879, -0.9427360, -0.4120364, + -0.6077851, -1.7033054, -1.9354388, -0.6399003, -2.1621227, -1.4899510, -0.5816087, 0.0662278, + -1.7709871, -2.2943379, 0.0671570, -2.2462875, -0.8166682, -1.3488045, -2.3724372, -0.6542480, + -1.6837887, 0.1718501, -0.4232655, -1.9293420, -1.5524519, -0.8903348, -0.8235148, -0.7555137, + -1.2672423, -0.5341824, -0.0800176, -1.8341924, -2.0388451, -1.6274120, -1.0832978, -0.6836474, + -0.7428981, -0.6488642, -2.2992384, -0.3173651, -0.6495681, 0.0820371, -0.2221419, -0.2825119, + -0.4779604, -0.5677801, -0.5407600, 0.1339569, -0.8549058, -0.7177885, -0.4706391, -2.0992089, + -1.7748856, -0.8790807, -0.3359026, -1.0437502, -0.7428065, -0.5449560, 0.2120406, -0.8962944, + -2.9057635, -1.8338823, -0.9476171, 0.0537955, -0.7746540, -0.6021839, -0.9673201, -0.7290961, + -0.7500160, -2.1319913, -1.6356984, -2.4347284, -0.4906021, -0.1930180, -0.7118280, -0.6601136, + 0.1714188, -0.4826550}; + + std::vector h_test = { + 0.6506153, -0.2861214, -0.4207479, -0.0879224, 0.6963105, 0.7591472, -0.9145728, 0.3606104, + 0.5918564, -0.5548665, -0.4487113, 0.0824032, 0.4425484, -0.9139633, -0.7823172, 0.0768981, + 0.0922035, -0.0138858, 0.9646097, 0.2624208, -0.7190498, -0.6117298, -0.8807327, 0.2868101, + -0.8899322, 0.9853774, -0.5898669, 0.6281458, 0.5219784, -0.5437135, -0.2806136, -0.0927834, + -0.2291698, 0.0450774, 0.4253027, 0.6545525, 0.7031374, -0.3601150, 0.0715214, -0.9844534, + -0.8571354, -0.8157709, -0.6361769, -0.5510336, 0.4286138, 0.8863587, -0.7481151, -0.6144726, + -0.7920206, -0.2917536, -0.6506116, -0.4862449, -0.0866336, -0.7439836, 0.3753550, 0.2632956, + -0.2270555, 0.1109649, -0.6320683, 0.0280535, 0.6881603, 0.8163167, 0.1781434, -0.8063828, + 0.8032009, -0.6779581, -0.8654890, -0.5322430, 0.3786414, 0.0546245, -0.5542659, 0.6897840, + -0.1039676, -0.0343101, 0.4219748, -0.4535081, 0.7228620, 0.3873561, 0.1427819, -0.2881901, + 0.5431166, -0.0090170, -0.8354108, -0.0099369, -0.5904349, 0.2928394, 0.3634137, -0.7485119, + -0.5442900, 0.4072478, -0.4909732, 0.0737537, -0.0973075, -0.0848911, 0.7041450, 0.3288523, + -0.5264588, -0.5135713, 0.5130192, -0.0708379}; + + std::vector h_testlab = { + -1.6506068, -1.6408135, -0.9171102, -2.2897648, -0.2806881, -0.2297245, -0.4421663, -0.7713085, + -1.6812845, -0.6648566, -0.5840624, -0.8432659, -0.6577426, -1.6213072, -0.2299105, -2.1316719, + -2.6060586, -1.8153329, 0.1657440, -0.8794947, -1.3444440, -0.4118046, -0.3390867, -0.9532273, + 0.0358915, -0.6882091, -0.4517245, -0.3681215, -0.6051433, -1.0756192, -0.6731151, -1.0004896, + -2.4808031, -1.0080036, -1.7581659, -0.3644765, -0.2742536, -2.1790992, -1.8354263, 0.2105456, + -0.9973469, -0.2662037, -0.7020552, -0.7884595, -0.6079654, 0.0063403, -1.2439414, -1.3997503, + -0.1228729, -0.9907357 + + }; + + std::vector h_trainwts; + std::vector h_testwts; + + rmm::device_uvector d_train; + rmm::device_uvector d_trainlab; + rmm::device_uvector d_test; + rmm::device_uvector d_testlab; + rmm::device_uvector d_trainwts; + rmm::device_uvector d_testwts; +}; + +TEST_F(GeneticEvolutionTest, SymReg) +{ + raft::CompareApprox compApprox(tolerance); + program_t final_progs; + final_progs = (program_t)rmm::mr::get_current_device_resource()->allocate( + hyper_params.population_size * sizeof(program), stream); + std::vector> history; + history.reserve(hyper_params.generations); + + cudaEvent_t start, stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + cudaEventRecord(start, stream); + + symFit(handle, + d_train.data(), + d_trainlab.data(), + d_trainwts.data(), + n_tr_rows, + n_cols, + hyper_params, + final_progs, + history); + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + float training_time; + cudaEventElapsedTime(&training_time, start, stop); + + int n_gen = history.size(); + std::cout << "Finished training for " << n_gen << " generations." << std::endl; + + // Find index of best program + int best_idx = 0; + float opt_fitness = history[n_gen - 1][0].raw_fitness_; + + // For all 3 loss functions - min is better + for (int i = 1; i < hyper_params.population_size; ++i) { + if (history[n_gen - 1][i].raw_fitness_ < opt_fitness) { + best_idx = i; + opt_fitness = history[n_gen - 1][i].raw_fitness_; + } + } + + std::string eqn = stringify(history[n_gen - 1][best_idx]); + CUML_LOG_DEBUG("Best Index = %d", best_idx); + std::cout << "Raw fitness score on train set is " << history[n_gen - 1][best_idx].raw_fitness_ + << std::endl; + std::cout << "Best AST equation is : " << eqn << std::endl; + + // Predict values for test dataset + rmm::device_uvector d_predlabels(n_tst_rows, stream); + + cudaEventRecord(start, stream); + + cuml::genetic::symRegPredict( + handle, d_test.data(), n_tst_rows, final_progs + best_idx, d_predlabels.data()); + + std::vector h_predlabels(n_tst_rows, 0.0f); + CUDA_CHECK(cudaMemcpy( + h_predlabels.data(), d_predlabels.data(), n_tst_rows * sizeof(float), cudaMemcpyDeviceToHost)); + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + float inference_time; + cudaEventElapsedTime(&inference_time, start, stop); + + // deallocate the nodes allocated for the last generation inside SymFit + for (auto i = 0; i < hyper_params.population_size; ++i) { + program tmp = program(); + raft::copy(&tmp, final_progs + i, 1, stream); + rmm::mr::get_current_device_resource()->deallocate(tmp.nodes, tmp.len * sizeof(node), stream); + tmp.nodes = nullptr; + } + // deallocate the final programs from device memory + rmm::mr::get_current_device_resource()->deallocate( + final_progs, hyper_params.population_size * sizeof(program), stream); + + ASSERT_TRUE(compApprox(history[n_gen - 1][best_idx].raw_fitness_, 0.0036f)); + std::cout << "Some Predicted test values:" << std::endl; + std::copy( + h_predlabels.begin(), h_predlabels.begin() + 10, std::ostream_iterator(std::cout, ";")); + std::cout << std::endl; + + std::cout << "Some Actual test values:" << std::endl; + std::copy( + h_testlab.begin(), h_testlab.begin() + 10, std::ostream_iterator(std::cout, ";")); + std::cout << std::endl; + + std::cout << "Training time = " << training_time << " ms" << std::endl; + std::cout << "Inference time = " << inference_time << " ms" << std::endl; +} + +} // namespace genetic +} // namespace cuml diff --git a/cpp/test/sg/genetic/param_test.cu b/cpp/test/sg/genetic/param_test.cu index 3941ba869a..32a1b5c6f3 100644 --- a/cpp/test/sg/genetic/param_test.cu +++ b/cpp/test/sg/genetic/param_test.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include #include "../../prims/test_utils.h" @@ -41,6 +41,8 @@ TEST(Genetic, ParamTest) ASSERT_EQ(p.function_set[2], node::type::div); ASSERT_EQ(p.function_set[3], node::type::sub); ASSERT_EQ(p.transformer, transformer_t::sigmoid); + ASSERT_EQ(p.arity_set[2][0], node::type::add); + ASSERT_EQ(p.arity_set[2].size(), 4); ASSERT_EQ(p.metric, metric_t::mae); ASSERT_EQ(p.parsimony_coefficient, 0.001f); ASSERT_EQ(p.p_crossover, 0.9f); diff --git a/cpp/test/sg/genetic/program_test.cu b/cpp/test/sg/genetic/program_test.cu new file mode 100644 index 0000000000..1e2d3225df --- /dev/null +++ b/cpp/test/sg/genetic/program_test.cu @@ -0,0 +1,712 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuml { +namespace genetic { + +class GeneticProgramTest : public ::testing::Test { + public: + GeneticProgramTest() + : d_data(0, cudaStream_t(0)), + d_y(0, cudaStream_t(0)), + d_lYpred(0, cudaStream_t(0)), + d_lY(0, cudaStream_t(0)), + d_lunitW(0, cudaStream_t(0)), + d_lW(0, cudaStream_t(0)), + dx2(0, cudaStream_t(0)), + dy2(0, cudaStream_t(0)), + dw2(0, cudaStream_t(0)), + dyp2(0, cudaStream_t(0)) + { + } + + protected: + void SetUp() override + { + CUDA_CHECK(cudaStreamCreate(&stream)); + handle.set_stream(stream); + + // Params + hyper_params.population_size = 2; + hyper_params.random_state = 123; + hyper_params.num_features = 3; + + // X[0] * X[1] + X[2] + 0.5 + h_nodes1.push_back(node(node::type::add)); + h_nodes1.push_back(node(node::type::add)); + h_nodes1.push_back(node(node::type::mul)); + h_nodes1.push_back(node(0)); + h_nodes1.push_back(node(1)); + h_nodes1.push_back(node(2)); + h_nodes1.push_back(node(0.5f)); + + // 0.5*X[1] - 0.4*X[2] + h_nodes2.push_back(node(node::type::sub)); + h_nodes2.push_back(node(node::type::mul)); + h_nodes2.push_back(node(0.5f)); + h_nodes2.push_back(node(1)); + h_nodes2.push_back(node(node::type::mul)); + h_nodes2.push_back(node(0.4f)); + h_nodes2.push_back(node(2)); + + // Programs + h_progs.resize(2); + h_progs[0].len = h_nodes1.size(); + h_progs[0].nodes = new node[h_progs[0].len]; + std::copy(h_nodes1.data(), h_nodes1.data() + h_nodes1.size(), h_progs[0].nodes); + + h_progs[1].len = h_nodes2.size(); + h_progs[1].nodes = new node[h_progs[1].len]; + std::copy(h_nodes2.data(), h_nodes2.data() + h_nodes2.size(), h_progs[1].nodes); + + // Loss weights + h_lunitW.resize(250, 1.0f); + + // Smaller input + hw2.resize(5, 1.0f); + + // Device memory + d_data.resize(75, stream); + d_y.resize(25, stream); + d_lYpred.resize(500, stream); + d_lY.resize(250, stream); + d_lunitW.resize(250, stream); + d_lW.resize(250, stream); + d_nodes1 = (node*)rmm::mr::get_current_device_resource()->allocate(7 * sizeof(node), stream); + d_nodes2 = (node*)rmm::mr::get_current_device_resource()->allocate(7 * sizeof(node), stream); + d_progs = + (program_t)rmm::mr::get_current_device_resource()->allocate(2 * sizeof(program), stream); + + CUDA_CHECK(cudaMemcpyAsync( + d_lYpred.data(), h_lYpred.data(), 500 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync( + d_lY.data(), h_lY.data(), 250 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync( + d_lunitW.data(), h_lunitW.data(), 250 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync( + d_lW.data(), h_lW.data(), 250 * sizeof(float), cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK(cudaMemcpyAsync( + d_data.data(), h_data.data(), 75 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(d_y.data(), h_y.data(), 25 * sizeof(float), cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK( + cudaMemcpyAsync(d_nodes1, h_nodes1.data(), 7 * sizeof(node), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(d_nodes2, h_nodes2.data(), 7 * sizeof(node), cudaMemcpyHostToDevice, stream)); + + program tmp(h_progs[0]); + delete[] tmp.nodes; + tmp.nodes = d_nodes1; + CUDA_CHECK(cudaMemcpyAsync(&d_progs[0], &tmp, sizeof(program), cudaMemcpyHostToDevice, stream)); + tmp.nodes = nullptr; + + tmp = program(h_progs[1]); + delete[] tmp.nodes; + tmp.nodes = d_nodes2; + CUDA_CHECK(cudaMemcpyAsync(&d_progs[1], &tmp, sizeof(program), cudaMemcpyHostToDevice, stream)); + tmp.nodes = nullptr; + + // Small input + dx2.resize(15, stream); + dy2.resize(5, stream); + dw2.resize(5, stream); + dyp2.resize(10, stream); + + CUDA_CHECK( + cudaMemcpyAsync(dx2.data(), hx2.data(), 15 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(dy2.data(), hy2.data(), 5 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(dw2.data(), hw2.data(), 5 * sizeof(float), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync( + dyp2.data(), hyp2.data(), 10 * sizeof(float), cudaMemcpyHostToDevice, stream)); + } + + void TearDown() override + { + rmm::mr::get_current_device_resource()->deallocate(d_nodes1, 7 * sizeof(node), stream); + rmm::mr::get_current_device_resource()->deallocate(d_nodes2, 7 * sizeof(node), stream); + rmm::mr::get_current_device_resource()->deallocate(d_progs, 2 * sizeof(program), stream); + CUDA_CHECK(cudaStreamDestroy(stream)); + } + + raft::handle_t handle; + cudaStream_t stream; + const int n_cols = 3; + const int n_progs = 2; + const int n_samples = 25; + const int n_samples2 = 5; + const float tolerance = 0.025f; // assuming upto 2.5% tolerance for results(for now) + + // 25*3 datapoints generated using numpy + // y = X[0] * X[1] + X[2] + 0.5 + std::vector h_data{ + -0.50446586, -2.06014071, 0.88514116, -2.3015387, 0.83898341, 1.65980218, -0.87785842, + 0.31563495, 0.3190391, 0.53035547, 0.30017032, -0.12289023, -1.10061918, -0.0126646, + 2.10025514, 1.13376944, -0.88762896, 0.05080775, -0.34934272, 2.18557541, 0.50249434, + -0.07557171, -0.52817175, -0.6871727, 0.51292982, + + -1.44411381, 1.46210794, 0.28558733, 0.86540763, 0.58662319, 0.2344157, -0.17242821, + 0.87616892, -0.7612069, -0.26788808, 0.61720311, -0.68372786, 0.58281521, -0.67124613, + 0.19091548, -0.38405435, -0.19183555, 1.6924546, -1.1425182, 1.51981682, 0.90159072, + 0.48851815, -0.61175641, -0.39675353, 1.25286816, + + -1.39649634, -0.24937038, 0.93110208, -1.07296862, -0.20889423, -1.11731035, -1.09989127, + 0.16003707, 1.74481176, -0.93576943, 0.12015895, 0.90085595, 0.04221375, -0.84520564, + -0.63699565, -0.3224172, 0.74204416, -0.74715829, -0.35224985, 1.13162939, 1.14472371, + -0.29809284, 1.62434536, -0.69166075, -0.75439794}; + + std::vector h_y{-0.16799022, -2.76151846, 1.68388718, -2.56473777, 0.78327289, + -0.22822666, -0.44852371, 0.9365866, 2.001957, -0.57784534, + 0.80542501, 1.48487942, -0.09924385, -0.33670458, 0.26397558, + -0.2578463, 1.41232295, -0.16116848, 0.54688057, 4.95330364, + 2.09776794, 0.16498901, 2.44745782, 0.08097744, 0.3882355}; + + // Values for loss function tests (250 values each) + std::vector h_lYpred{ + 0.06298f, 0.81894f, 0.12176f, 0.17104f, 0.12851f, 0.28721f, 0.85043f, 0.68120f, 0.57074f, + 0.21796f, 0.96626f, 0.32337f, 0.21887f, 0.80867f, 0.96438f, 0.20052f, 0.28668f, 0.86931f, + 0.71421f, 0.85405f, 0.13916f, 0.00316f, 0.59440f, 0.86299f, 0.67019f, 0.54309f, 0.82629f, + 0.94563f, 0.01481f, 0.13665f, 0.77081f, 0.58024f, 0.02538f, 0.36610f, 0.13948f, 0.75034f, + 0.80435f, 0.27488f, 0.74165f, 0.02921f, 0.51479f, 0.66415f, 0.27380f, 0.85304f, 0.95767f, + 0.22758f, 0.38602f, 0.41555f, 0.53783f, 0.48663f, 0.11103f, 0.69397f, 0.21749f, 0.71930f, + 0.28976f, 0.50971f, 0.68532f, 0.97518f, 0.71299f, 0.37629f, 0.56444f, 0.42280f, 0.51921f, + 0.84366f, 0.30778f, 0.39493f, 0.74007f, 0.18280f, 0.22621f, 0.63083f, 0.46085f, 0.47259f, + 0.65442f, 0.25453f, 0.23058f, 0.17460f, 0.30702f, 0.22421f, 0.37237f, 0.36660f, 0.29702f, + 0.65276f, 0.30222f, 0.63844f, 0.99909f, 0.55084f, 0.05066f, 0.18914f, 0.36652f, 0.36765f, + 0.93901f, 0.13575f, 0.72582f, 0.20223f, 0.06375f, 0.52581f, 0.77119f, 0.12127f, 0.27800f, + 0.04008f, 0.01752f, 0.00394f, 0.68973f, 0.91931f, 0.48011f, 0.48363f, 0.09770f, 0.84381f, + 0.80244f, 0.42710f, 0.82164f, 0.63239f, 0.08117f, 0.46195f, 0.49832f, 0.05717f, 0.16886f, + 0.22311f, 0.45326f, 0.50748f, 0.19089f, 0.78211f, 0.34272f, 0.38456f, 0.64874f, 0.18216f, + 0.64757f, 0.26900f, 0.20780f, 0.87067f, 0.16903f, 0.77285f, 0.70580f, 0.54404f, 0.97395f, + 0.52550f, 0.81364f, 0.30085f, 0.36754f, 0.42492f, 0.79470f, 0.31590f, 0.26322f, 0.68332f, + 0.96523f, 0.31110f, 0.97029f, 0.80217f, 0.77125f, 0.36302f, 0.13444f, 0.28420f, 0.20442f, + 0.89692f, 0.50515f, 0.61952f, 0.48237f, 0.35080f, 0.75606f, 0.85438f, 0.70647f, 0.91793f, + 0.24037f, 0.72867f, 0.84713f, 0.39838f, 0.49553f, 0.32876f, 0.22610f, 0.86573f, 0.99232f, + 0.71321f, 0.30179f, 0.01941f, 0.84838f, 0.58587f, 0.43339f, 0.29490f, 0.07191f, 0.88531f, + 0.26896f, 0.36085f, 0.96043f, 0.70679f, 0.39593f, 0.37642f, 0.76078f, 0.63827f, 0.36346f, + 0.12755f, 0.07074f, 0.67744f, 0.35042f, 0.30773f, 0.15577f, 0.64096f, 0.05035f, 0.32882f, + 0.33640f, 0.54106f, 0.76279f, 0.00414f, 0.17373f, 0.83551f, 0.18176f, 0.91190f, 0.03559f, + 0.31992f, 0.86311f, 0.04054f, 0.49714f, 0.53551f, 0.65316f, 0.15681f, 0.80268f, 0.44978f, + 0.26365f, 0.37162f, 0.97630f, 0.82863f, 0.73267f, 0.93207f, 0.47129f, 0.70817f, 0.57300f, + 0.34240f, 0.89749f, 0.79844f, 0.67992f, 0.72523f, 0.43319f, 0.07310f, 0.61074f, 0.93830f, + 0.90822f, 0.08077f, 0.28048f, 0.04549f, 0.44870f, 0.10337f, 0.93911f, 0.13464f, 0.16080f, + 0.94620f, 0.15276f, 0.56239f, 0.38684f, 0.12437f, 0.98149f, 0.80650f, 0.44040f, 0.59698f, + 0.82197f, 0.91634f, 0.89667f, 0.96333f, 0.21204f, 0.47457f, 0.95737f, 0.08697f, 0.50921f, + 0.58647f, 0.71985f, 0.39455f, 0.73240f, 0.04227f, 0.74879f, 0.34403f, 0.94240f, 0.45158f, + 0.83860f, 0.51819f, 0.87374f, 0.70416f, 0.52987f, 0.72727f, 0.53649f, 0.74878f, 0.13247f, + 0.91358f, 0.61871f, 0.50048f, 0.04681f, 0.56370f, 0.68393f, 0.51947f, 0.85044f, 0.24416f, + 0.39354f, 0.33526f, 0.66574f, 0.65638f, 0.15506f, 0.84167f, 0.84663f, 0.92094f, 0.14140f, + 0.69364f, 0.40575f, 0.63543f, 0.35074f, 0.68887f, 0.70662f, 0.90424f, 0.09042f, 0.57486f, + 0.52239f, 0.40711f, 0.82103f, 0.08674f, 0.14005f, 0.44922f, 0.81244f, 0.99037f, 0.26577f, + 0.64744f, 0.25391f, 0.47913f, 0.09676f, 0.26023f, 0.86098f, 0.24472f, 0.15364f, 0.38980f, + 0.02943f, 0.59390f, 0.25683f, 0.38976f, 0.90195f, 0.27418f, 0.45255f, 0.74992f, 0.07155f, + 0.95425f, 0.77560f, 0.41618f, 0.27963f, 0.32602f, 0.75690f, 0.09356f, 0.73795f, 0.59604f, + 0.97534f, 0.27677f, 0.06770f, 0.59517f, 0.64286f, 0.36224f, 0.22017f, 0.83546f, 0.21461f, + 0.24793f, 0.08248f, 0.16668f, 0.74429f, 0.66674f, 0.68034f, 0.34710f, 0.82358f, 0.47555f, + 0.50109f, 0.09328f, 0.98566f, 0.99481f, 0.41391f, 0.86833f, 0.38645f, 0.49203f, 0.44547f, + 0.55391f, 0.87598f, 0.85542f, 0.56283f, 0.61385f, 0.70564f, 0.29067f, 0.91150f, 0.64787f, + 0.18255f, 0.03792f, 0.69633f, 0.29029f, 0.31412f, 0.49111f, 0.34615f, 0.43144f, 0.31616f, + 0.15405f, 0.44915f, 0.12777f, 0.09491f, 0.26003f, 0.71537f, 0.19450f, 0.91570f, 0.28420f, + 0.77892f, 0.53199f, 0.66034f, 0.01978f, 0.35415f, 0.03664f, 0.42675f, 0.41304f, 0.33804f, + 0.11290f, 0.89985f, 0.75959f, 0.59417f, 0.53113f, 0.38898f, 0.76259f, 0.83973f, 0.75809f, + 0.65900f, 0.55141f, 0.14175f, 0.44740f, 0.95823f, 0.77612f, 0.48749f, 0.74491f, 0.57491f, + 0.59119f, 0.26665f, 0.48599f, 0.85947f, 0.46245f, 0.08129f, 0.00825f, 0.29669f, 0.43499f, + 0.47998f, 0.60173f, 0.26611f, 0.01223f, 0.81734f, 0.77892f, 0.79022f, 0.01394f, 0.45596f, + 0.45259f, 0.32536f, 0.84229f, 0.43612f, 0.30531f, 0.10670f, 0.57758f, 0.65956f, 0.42007f, + 0.32166f, 0.10552f, 0.63558f, 0.17990f, 0.50732f, 0.34599f, 0.16603f, 0.26309f, 0.04098f, + 0.15997f, 0.79728f, 0.00528f, 0.35510f, 0.24344f, 0.07018f, 0.22062f, 0.92927f, 0.13373f, + 0.50955f, 0.11199f, 0.75728f, 0.62117f, 0.18153f, 0.84993f, 0.04677f, 0.13013f, 0.92211f, + 0.95474f, 0.88898f, 0.55561f, 0.22625f, 0.78700f, 0.73659f, 0.97613f, 0.02299f, 0.07724f, + 0.78942f, 0.02193f, 0.05320f, 0.92053f, 0.35103f, 0.39305f, 0.24208f, 0.08225f, 0.78460f, + 0.52144f, 0.32927f, 0.84725f, 0.36106f, 0.80349f}; + + std::vector h_lY{ + 0.60960f, 0.61090f, 0.41418f, 0.90827f, 0.76181f, 0.31777f, 0.04096f, 0.27290f, 0.56879f, + 0.75461f, 0.73555f, 0.41598f, 0.59506f, 0.08768f, 0.99554f, 0.20613f, 0.13546f, 0.32044f, + 0.41057f, 0.38501f, 0.27894f, 0.24027f, 0.91171f, 0.26811f, 0.55595f, 0.71153f, 0.69739f, + 0.53411f, 0.78365f, 0.60914f, 0.41856f, 0.61688f, 0.28741f, 0.28708f, 0.37029f, 0.47945f, + 0.40612f, 0.75762f, 0.91728f, 0.70406f, 0.26717f, 0.71175f, 0.39243f, 0.35904f, 0.38469f, + 0.08664f, 0.38611f, 0.35606f, 0.52801f, 0.96986f, 0.84780f, 0.56942f, 0.41712f, 0.17005f, + 0.79105f, 0.74347f, 0.83473f, 0.06303f, 0.37864f, 0.66666f, 0.78153f, 0.11061f, 0.33880f, + 0.82412f, 0.47141f, 0.53043f, 0.51184f, 0.34172f, 0.57087f, 0.88349f, 0.32870f, 0.11501f, + 0.35460f, 0.23630f, 0.37728f, 0.96120f, 0.19871f, 0.78119f, 0.23860f, 0.70615f, 0.46745f, + 0.43392f, 0.49967f, 0.39721f, 0.53185f, 0.27827f, 0.14435f, 0.82008f, 0.43275f, 0.82113f, + 0.06428f, 0.53528f, 0.21594f, 0.86172f, 0.41172f, 0.96051f, 0.54487f, 0.01971f, 0.71222f, + 0.04258f, 0.36715f, 0.24844f, 0.12494f, 0.34132f, 0.87059f, 0.70216f, 0.33533f, 0.10020f, + 0.79337f, 0.26059f, 0.81314f, 0.54342f, 0.79115f, 0.71730f, 0.70860f, 0.00998f, 0.64761f, + 0.01206f, 0.53463f, 0.94436f, 0.19639f, 0.23296f, 0.55945f, 0.14070f, 0.57765f, 0.50908f, + 0.95720f, 0.95611f, 0.12311f, 0.95382f, 0.23116f, 0.36939f, 0.66395f, 0.76282f, 0.16314f, + 0.00186f, 0.77662f, 0.58799f, 0.18155f, 0.10355f, 0.45982f, 0.34359f, 0.59476f, 0.72759f, + 0.77310f, 0.50736f, 0.43720f, 0.63624f, 0.84569f, 0.73073f, 0.04179f, 0.64806f, 0.19924f, + 0.96082f, 0.06270f, 0.27744f, 0.59384f, 0.07317f, 0.10979f, 0.47857f, 0.60274f, 0.54937f, + 0.58563f, 0.45247f, 0.84396f, 0.43945f, 0.47719f, 0.40808f, 0.81152f, 0.48558f, 0.21577f, + 0.93935f, 0.08222f, 0.43114f, 0.68239f, 0.78870f, 0.24300f, 0.84829f, 0.44764f, 0.57347f, + 0.78353f, 0.30614f, 0.39493f, 0.40320f, 0.72849f, 0.39406f, 0.89363f, 0.33323f, 0.38395f, + 0.94783f, 0.46082f, 0.30498f, 0.17110f, 0.14083f, 0.48474f, 0.45024f, 0.92586f, 0.77450f, + 0.43503f, 0.45188f, 0.80866f, 0.24937f, 0.34205f, 0.35942f, 0.79689f, 0.77224f, 0.14354f, + 0.54387f, 0.50787f, 0.31753f, 0.98414f, 0.03261f, 0.89748f, 0.82350f, 0.60235f, 0.00041f, + 0.99696f, 0.39894f, 0.52078f, 0.54421f, 0.33405f, 0.81143f, 0.49764f, 0.44993f, 0.37257f, + 0.16238f, 0.81337f, 0.51335f, 0.96118f, 0.98901f, 0.95259f, 0.36557f, 0.24654f, 0.99554f, + 0.33408f, 0.01734f, 0.85852f, 0.41286f, 0.67371f, 0.93781f, 0.04977f, 0.17298f, 0.91502f, + 0.70144f, 0.97356f, 0.12571f, 0.64375f, 0.10033f, 0.36798f, 0.90001f}; + + // Unitary weights + std::vector h_lunitW; + + // Non-unitary weights + std::vector h_lW{ + 0.38674f, 0.59870f, 0.36761f, 0.59731f, 0.99057f, 0.24131f, 0.29727f, 0.94112f, 0.78962f, + 0.71998f, 0.10983f, 0.33620f, 0.37988f, 0.14344f, 0.37377f, 0.06403f, 0.22877f, 0.21993f, + 0.11340f, 0.28554f, 0.45453f, 0.14344f, 0.11715f, 0.23184f, 0.08622f, 0.26746f, 0.49058f, + 0.06981f, 0.41885f, 0.04422f, 0.99925f, 0.71709f, 0.11910f, 0.49944f, 0.98116f, 0.66316f, + 0.11646f, 0.25202f, 0.93223f, 0.81414f, 0.20446f, 0.23813f, 0.45380f, 0.83618f, 0.95958f, + 0.72684f, 0.86808f, 0.96348f, 0.76092f, 0.86071f, 0.44155f, 0.85212f, 0.76185f, 0.51460f, + 0.65627f, 0.38269f, 0.08251f, 0.07506f, 0.22281f, 0.05325f, 0.71190f, 0.62834f, 0.19348f, + 0.44271f, 0.23677f, 0.81817f, 0.73055f, 0.48816f, 0.57524f, 0.45278f, 0.27998f, 0.35699f, + 0.26875f, 0.63546f, 0.50990f, 0.21046f, 0.76892f, 0.74433f, 0.39302f, 0.55071f, 0.24554f, + 0.56793f, 0.67852f, 0.43290f, 0.97266f, 0.52475f, 0.88402f, 0.79439f, 0.01496f, 0.46426f, + 0.15537f, 0.35364f, 0.42962f, 0.47999f, 0.06357f, 0.78531f, 0.62165f, 0.45226f, 0.84973f, + 0.63747f, 0.00593f, 0.31520f, 0.13150f, 0.47776f, 0.56420f, 0.21679f, 0.32107f, 0.62491f, + 0.33747f, 0.86599f, 0.82573f, 0.26970f, 0.50087f, 0.86947f, 0.47433f, 0.91848f, 0.19534f, + 0.45760f, 0.38407f, 0.18953f, 0.30000f, 0.37964f, 0.42509f, 0.55408f, 0.74500f, 0.44484f, + 0.67679f, 0.12214f, 0.68380f, 0.74917f, 0.87429f, 0.04355f, 0.98426f, 0.88845f, 0.88318f, + 0.64393f, 0.90849f, 0.87948f, 0.22915f, 0.86887f, 0.58676f, 0.51575f, 0.56549f, 0.41412f, + 0.06593f, 0.40484f, 0.72931f, 0.02289f, 0.96391f, 0.61075f, 0.91701f, 0.29698f, 0.37095f, + 0.42087f, 0.73251f, 0.93271f, 0.32687f, 0.48981f, 0.01081f, 0.11985f, 0.46962f, 0.02569f, + 0.83989f, 0.21767f, 0.82370f, 0.35174f, 0.94939f, 0.46032f, 0.81569f, 0.66635f, 0.07019f, + 0.68926f, 0.65628f, 0.19914f, 0.17936f, 0.64540f, 0.09031f, 0.05875f, 0.88790f, 0.83687f, + 0.46605f, 0.08537f, 0.49514f, 0.44504f, 0.67687f, 0.28943f, 0.74668f, 0.43207f, 0.70990f, + 0.62513f, 0.56137f, 0.94399f, 0.75806f, 0.41840f, 0.38428f, 0.30754f, 0.62633f, 0.23173f, + 0.40750f, 0.49968f, 0.05536f, 0.11405f, 0.34185f, 0.36367f, 0.06341f, 0.66834f, 0.42899f, + 0.08343f, 0.72266f, 0.33155f, 0.74943f, 0.15387f, 0.02475f, 0.35741f, 0.15806f, 0.35406f, + 0.18226f, 0.31042f, 0.36047f, 0.62366f, 0.30036f, 0.66625f, 0.99695f, 0.99472f, 0.06743f, + 0.56804f, 0.28185f, 0.77387f, 0.58763f, 0.77824f, 0.03720f, 0.99490f, 0.73720f, 0.93635f, + 0.85669f, 0.91634f, 0.26065f, 0.97469f, 0.03867f, 0.52306f, 0.99167f, 0.90332f, 0.88546f, + 0.07109f, 0.94168f, 0.10211f, 0.95949f, 0.86314f, 0.59917f, 0.41948f}; + + // Setup smaller input + std::vector hx2 = {0.06298, + 0.96626, + 0.13916, + 0.77081, + 0.51479, + 0.81894, + 0.32337, + 0.00316, + 0.58024, + 0.66415, + 0.12176, + 0.21887, + 0.59440, + 0.02538, + 0.27380}; + + std::vector hy2 = {0.11103, 0.69397, 0.21749, 0.71930, 0.28976}; + std::vector hyp2 = { + 0.67334, 1.03133, 1.09484, 0.97263, 1.1157, 0.36077, 0.07413, -0.23618, 0.27997, 0.22255}; + std::vector hw2; + + // Nodes and programs + std::vector h_nodes1; + std::vector h_nodes2; + std::vector h_progs; + + // Device ptrs + node* d_nodes1; + node* d_nodes2; + program_t d_progs; + rmm::device_uvector d_data; + rmm::device_uvector d_y; + rmm::device_uvector d_lYpred; + rmm::device_uvector d_lY; + rmm::device_uvector d_lunitW; + rmm::device_uvector d_lW; + rmm::device_uvector dx2; + rmm::device_uvector dy2; + rmm::device_uvector dw2; + rmm::device_uvector dyp2; + + param hyper_params; +}; + +TEST_F(GeneticProgramTest, PearsonCoeff) +{ + raft::CompareApproxAbs compApprox(tolerance); + float h_expected_score[2] = {0.09528403f, 0.08269963f}; + float h_score[2] = {0.0f, 0.0f}; + rmm::device_uvector d_score(2, stream); + hyper_params.metric = metric_t::pearson; + + // Unitary weights + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lunitW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Unitary weights - small + h_expected_score[0] = 0.3247632f; + h_expected_score[1] = 0.0796348f; + compute_metric( + handle, n_samples2, n_progs, dy2.data(), dyp2.data(), dw2.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Non-unitary weights + h_expected_score[0] = 0.14329584f; + h_expected_score[1] = 0.09064283f; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } +} + +TEST_F(GeneticProgramTest, SpearmanCoeff) +{ + raft::CompareApproxAbs compApprox(tolerance); + float h_score[2] = {0.0f, 0.0f}; + rmm::device_uvector d_score(2, stream); + hyper_params.metric = metric_t::spearman; + + // Unitary weights + float h_expected_score[2] = {0.09268333f, 0.07529861f}; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lunitW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Unitary weights - small + h_expected_score[0] = 0.10000f; + h_expected_score[1] = 0.10000f; + compute_metric( + handle, n_samples2, n_progs, dy2.data(), dyp2.data(), dw2.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Non-unitary weights + h_expected_score[0] = 0.14072408f; + h_expected_score[1] = 0.08157397f; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } +} + +TEST_F(GeneticProgramTest, MeanSquareLoss) +{ + raft::CompareApprox compApprox(tolerance); + float h_score[2] = {0.0f, 0.0f}; + rmm::device_uvector d_score(2, stream); + hyper_params.metric = metric_t::mse; + + // Unitary weights + float h_expected_score[2] = {0.14297023, 0.14242104}; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lunitW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Unitary weights - small + h_expected_score[0] = 0.3892163f; + h_expected_score[1] = 0.1699830f; + compute_metric( + handle, n_samples2, n_progs, dy2.data(), dyp2.data(), dw2.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Non-unitary weights + h_expected_score[0] = 0.13842479f; + h_expected_score[1] = 0.14538825f; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } +} + +TEST_F(GeneticProgramTest, MeanAbsoluteLoss) +{ + raft::CompareApprox compApprox(tolerance); + float h_score[2] = {0.0f, 0.0f}; + rmm::device_uvector d_score(2, stream); + hyper_params.metric = metric_t::mae; + + // Unitary weights - big + float h_expected_score[2] = {0.30614017, 0.31275677}; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lunitW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Unitary weights - small + h_expected_score[0] = 0.571255f; + h_expected_score[1] = 0.365957f; + compute_metric( + handle, n_samples2, n_progs, dy2.data(), dyp2.data(), dw2.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Non-unitary weights -big + h_expected_score[0] = 0.29643119f; + h_expected_score[1] = 0.31756123f; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } +} + +TEST_F(GeneticProgramTest, RMSLoss) +{ + raft::CompareApprox compApprox(tolerance); + float h_score[2] = {0.0f, 0.0f}; + rmm::device_uvector d_score(2, stream); + hyper_params.metric = metric_t::rmse; + + // Unitary weights + float h_expected_score[2] = {0.37811404, 0.37738713}; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lunitW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Unitary weights - small + h_expected_score[0] = 0.6238720f; + h_expected_score[1] = 0.4122899f; + compute_metric( + handle, n_samples2, n_progs, dy2.data(), dyp2.data(), dw2.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Non-unitary weights + h_expected_score[0] = 0.37205482f; + h_expected_score[1] = 0.38129811f; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } +} + +TEST_F(GeneticProgramTest, LogLoss) +{ + raft::CompareApprox compApprox(tolerance); + float h_score[2] = {0.0f, 0.0f}; + rmm::device_uvector d_score(2, stream); + hyper_params.metric = metric_t::logloss; + + // Unitary weights + float h_expected_score[2] = {0.72276, 0.724011}; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lunitW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } + + // Non-unitary weights + h_expected_score[0] = 0.715887f; + h_expected_score[1] = 0.721293f; + compute_metric( + handle, 250, 2, d_lY.data(), d_lYpred.data(), d_lW.data(), d_score.data(), hyper_params); + CUDA_CHECK( + cudaMemcpyAsync(h_score, d_score.data(), 2 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + std::copy(h_score, h_score + 2, std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 2; ++i) { + ASSERT_TRUE(compApprox(h_score[i], h_expected_score[i])); + } +} + +TEST_F(GeneticProgramTest, ProgramExecution) +{ + raft::CompareApprox compApprox(tolerance); + + // Enable debug logging + ML::Logger::get().setLevel(CUML_LEVEL_INFO); + + // Allocate memory + std::vector h_ypred(n_progs * n_samples, 0.0f); + rmm::device_uvector d_ypred(n_progs * n_samples, stream); + + // Execute programs + execute(handle, d_progs, n_samples, n_progs, d_data.data(), d_ypred.data()); + CUDA_CHECK(cudaMemcpyAsync(h_ypred.data(), + d_ypred.data(), + n_progs * n_samples * sizeof(float), + cudaMemcpyDeviceToHost, + stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Check results + + for (int i = 0; i < n_samples; ++i) { + ASSERT_TRUE(compApprox(h_ypred[i], h_y[i])); + } + + for (int i = 0; i < n_samples; ++i) { + ASSERT_TRUE(compApprox(h_ypred[n_samples + i], + 0.5 * h_data[n_samples + i] - 0.4 * h_data[2 * n_samples + i])); + } +} + +TEST_F(GeneticProgramTest, ProgramFitnessScore) +{ + raft::CompareApprox compApprox(tolerance); + + std::vector all_metrics = { + metric_t::mae, metric_t::mse, metric_t::rmse, metric_t::pearson, metric_t::spearman}; + + std::vector hexpscores = { + 0.57126, 0.36596, 0.38922, 0.16998, 0.62387, 0.41229, 0.32476, 0.07963, 0.10000, 0.10000}; + + std::vector hactualscores(10); + + rmm::device_uvector dactualscores(10, stream); + + // Start execution for all metrics + for (int i = 0; i < 5; ++i) { + hyper_params.metric = all_metrics[i]; + find_batched_fitness(handle, + n_progs, + d_progs, + dactualscores.data() + 2 * i, + hyper_params, + n_samples2, + dx2.data(), + dy2.data(), + dw2.data()); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + CUDA_CHECK(cudaMemcpyAsync(hactualscores.data(), + dactualscores.data(), + 10 * sizeof(float), + cudaMemcpyDeviceToHost, + stream)); + std::copy( + hactualscores.begin(), hactualscores.end(), std::ostream_iterator(std::cerr, ";")); + std::cerr << std::endl; + + for (int i = 0; i < 10; ++i) { + ASSERT_TRUE(compApprox(std::abs(hactualscores[i]), hexpscores[i])); + } +} + +} // namespace genetic +} // namespace cuml \ No newline at end of file