diff --git a/README.md b/README.md index 6737b54a06..23445afb1a 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ repo](https://github.com/rapidsai/notebooks-contrib). | Category | Algorithm | Notes | | --- | --- | --- | | **Clustering** | Density-Based Spatial Clustering of Applications with Noise (DBSCAN) | Multi-node multi-GPU via Dask | +| | Hierarchical Density-Based Spatial Clustering of Applications with Noise (HDBSCAN) | Experimental | | | K-Means | Multi-node multi-GPU via Dask | | | Single-Linkage Agglomerative Clustering | | | **Dimensionality Reduction** | Principal Components Analysis (PCA) | Multi-node multi-GPU via Dask| diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 1da22c3fa6..55f0d9a985 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -225,6 +225,8 @@ if(BUILD_CUML_CPP_LIBRARY) src/glm/glm.cu src/genetic/genetic.cu src/genetic/node.cu + src/hdbscan/hdbscan.cu + src/hdbscan/condensed_hierarchy.cu src/holtwinters/holtwinters.cu src/kmeans/kmeans.cu src/knn/knn.cu diff --git a/cpp/include/cuml/cluster/hdbscan.hpp b/cpp/include/cuml/cluster/hdbscan.hpp new file mode 100644 index 0000000000..965cf58bb2 --- /dev/null +++ b/cpp/include/cuml/cluster/hdbscan.hpp @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include + +namespace ML { +namespace HDBSCAN { +namespace Common { + +/** + * The Condensed hierarchicy is represented by an edge list with + * parents as the source vertices, children as the destination, + * with attributes for the cluster size and lambda value. + * + * @tparam value_idx + * @tparam value_t + */ +template +class CondensedHierarchy { + public: + /** + * Constructs an empty condensed hierarchy object which requires + * condense() to be called in order to populate the state. + * @param handle_ + * @param n_leaves_ + */ + CondensedHierarchy(const raft::handle_t &handle_, size_t n_leaves_); + + /** + * Constructs a condensed hierarchy object with existing arrays + * which already contain a condensed hierarchy. + * @param handle_ + * @param n_leaves_ + * @param n_edges_ + * @param parents_ + * @param children_ + * @param lambdas_ + * @param sizes_ + */ + CondensedHierarchy(const raft::handle_t &handle_, size_t n_leaves_, + int n_edges_, value_idx *parents_, value_idx *children_, + value_t *lambdas_, value_idx *sizes_); + + /** + * Constructs a condensed hierarchy object by moving + * rmm::device_uvector. Used to construct cluster trees + * @param handle_ + * @param n_leaves_ + * @param n_edges_ + * @param n_clusters_ + * @param parents_ + * @param children_ + * @param lambdas_ + * @param sizes_ + */ + CondensedHierarchy(const raft::handle_t &handle_, size_t n_leaves_, + int n_edges_, int n_clusters_, + rmm::device_uvector &&parents_, + rmm::device_uvector &&children_, + rmm::device_uvector &&lambdas_, + rmm::device_uvector &&sizes_); + /** + * To maintain a high level of parallelism, the output from + * Condense::build_condensed_hierarchy() is sparse (the cluster + * nodes inside any collapsed subtrees will be 0). + * + * This function converts the sparse form to a dense form and renumbers + * the cluster nodes into a topological sort order. The renumbering + * reverses the values in the parent array since root has the largest value + * in the single-linkage tree. Then, it makes the combined parent and + * children arrays monotonic. Finally all of the arrays of the dendrogram + * are sorted by parent->children->sizes (e.g. topological). The root node + * will always have an id of 0 and the largest cluster size. + * + * Ths single-linkage tree dendrogram is a binary tree and parents/children + * can be found with simple indexing arithmetic but the condensed tree no + * longer has this property and so the tree now relies on either + * special indexing or the topological ordering for efficient traversal. + */ + void condense(value_idx *full_parents, value_idx *full_children, + value_t *full_lambdas, value_idx *full_sizes, + value_idx size = -1); + + value_idx get_cluster_tree_edges(); + + value_idx *get_parents() { return parents.data(); } + value_idx *get_children() { return children.data(); } + value_t *get_lambdas() { return lambdas.data(); } + value_idx *get_sizes() { return sizes.data(); } + value_idx get_n_edges() { return n_edges; } + int get_n_clusters() { return n_clusters; } + value_idx get_n_leaves() const { return n_leaves; } + + private: + const raft::handle_t &handle; + + rmm::device_uvector parents; + rmm::device_uvector children; + rmm::device_uvector lambdas; + rmm::device_uvector sizes; + + size_t n_edges; + size_t n_leaves; + int n_clusters; + value_idx root_cluster; +}; + +enum CLUSTER_SELECTION_METHOD { EOM = 0, LEAF = 1 }; + +class RobustSingleLinkageParams { + public: + int k = 5; + int min_samples = 5; + int min_cluster_size = 5; + int max_cluster_size = 0; + + float cluster_selection_epsilon = 0.0; + + bool allow_single_cluster = false; + + float alpha = 1.0; +}; + +class HDBSCANParams : public RobustSingleLinkageParams { + public: + CLUSTER_SELECTION_METHOD cluster_selection_method = + CLUSTER_SELECTION_METHOD::EOM; +}; + +/** + * Container object for output information common between + * robust single linkage variants. + * @tparam value_idx + * @tparam value_t + */ +template +class robust_single_linkage_output { + public: + /** + * Construct output object with empty device arrays of + * known size. + * @param handle_ raft handle for ordering cuda operations + * @param n_leaves_ number of data points + * @param labels_ labels array on device (size n_leaves) + * @param children_ dendrogram src/dst array (size n_leaves - 1, 2) + * @param sizes_ dendrogram cluster sizes array (size n_leaves - 1) + * @param deltas_ dendrogram distances array (size n_leaves - 1) + * @param mst_src_ min spanning tree source array (size n_leaves - 1) + * @param mst_dst_ min spanning tree destination array (size n_leaves - 1) + * @param mst_weights_ min spanninng tree distances array (size n_leaves - 1) + */ + robust_single_linkage_output(const raft::handle_t &handle_, int n_leaves_, + value_idx *labels_, value_idx *children_, + value_idx *sizes_, value_t *deltas_, + value_idx *mst_src_, value_idx *mst_dst_, + value_t *mst_weights_) + : handle(handle_), + n_leaves(n_leaves_), + n_clusters(0), + labels(labels_), + children(children_), + sizes(sizes_), + deltas(deltas_), + mst_src(mst_src_), + mst_dst(mst_dst_), + mst_weights(mst_weights_) {} + + int get_n_leaves() const { return n_leaves; } + int get_n_clusters() const { return n_clusters; } + value_idx *get_labels() { return labels; } + value_idx *get_children() { return children; } + value_idx *get_sizes() { return sizes; } + value_t *get_deltas() { return deltas; } + value_idx *get_mst_src() { return mst_src; } + value_idx *get_mst_dst() { return mst_dst; } + value_t *get_mst_weights() { return mst_weights; } + + /** + * The number of clusters is set by the algorithm once it is known. + * @param n_clusters_ number of resulting clusters + */ + void set_n_clusters(int n_clusters_) { n_clusters = n_clusters_; } + + protected: + const raft::handle_t &get_handle() { return handle; } + + const raft::handle_t &handle; + + int n_leaves; + int n_clusters; + + value_idx *labels; // size n_leaves + + // Dendrogram + value_idx *children; // size n_leaves * 2 + value_idx *sizes; // size n_leaves + value_t *deltas; // size n_leaves + + // MST (size n_leaves - 1). + value_idx *mst_src; + value_idx *mst_dst; + value_t *mst_weights; +}; + +/** + * Plain old container object to consolidate output + * arrays. This object is intentionally kept simple + * and straightforward in order to ease its use + * in the Python layer. For this reason, the MST + * arrays and renumbered dendrogram array, as well + * as its aggregated distances/cluster sizes, are + * kept separate. The condensed hierarchy is computed + * and populated in a separate object because its size + * is not known ahead of time. An RMM device vector is + * held privately and stabilities initialized explicitly + * since that size is also not known ahead of time. + * @tparam value_idx + * @tparam value_t + */ +template +class hdbscan_output : public robust_single_linkage_output { + public: + hdbscan_output(const raft::handle_t &handle_, int n_leaves_, + value_idx *labels_, value_t *probabilities_, + value_idx *children_, value_idx *sizes_, value_t *deltas_, + value_idx *mst_src_, value_idx *mst_dst_, + value_t *mst_weights_) + : robust_single_linkage_output( + handle_, n_leaves_, labels_, children_, sizes_, deltas_, mst_src_, + mst_dst_, mst_weights_), + probabilities(probabilities_), + stabilities(0, handle_.get_stream()), + condensed_tree(handle_, n_leaves_) {} + + // Using getters here, making the members private and forcing + // consistent state with the constructor. This should make + // it much easier to use / debug. + value_t *get_probabilities() { return probabilities; } + value_t *get_stabilities() { + ASSERT(stabilities.size() > 0, "stabilities needs to be initialized"); + return stabilities.data(); + } + + /** + * Once n_clusters is known, the stabilities array + * can be initialized. + * @param n_clusters_ + */ + void set_n_clusters(int n_clusters_) { + robust_single_linkage_output::set_n_clusters( + n_clusters_); + stabilities.resize( + n_clusters_, + robust_single_linkage_output::get_handle() + .get_stream()); + } + + CondensedHierarchy &get_condensed_tree() { + return condensed_tree; + } + + private: + value_t *probabilities; // size n_leaves + + // Size not known ahead of time. Initialize + // with `initialize_stabilities()` method. + rmm::device_uvector stabilities; + + // Use condensed hierarchy to wrap + // condensed tree outputs since we do not + // know the size ahead of time. + CondensedHierarchy condensed_tree; +}; + +template class CondensedHierarchy; + +}; // namespace Common +}; // namespace HDBSCAN + +/** + * Executes HDBSCAN clustering on an mxn-dimensional input array, X. + * @param[in] handle raft handle for resource reuse + * @param[in] X array (size m, n) on device in row-major format + * @param m number of rows in X + * @param n number of columns in X + * @param metric distance metric to use + * @param params struct of configuration hyper-parameters + * @param out struct of output data and arrays on device + */ +void hdbscan(const raft::handle_t &handle, const float *X, size_t m, size_t n, + raft::distance::DistanceType metric, + HDBSCAN::Common::HDBSCANParams ¶ms, + HDBSCAN::Common::hdbscan_output &out); +} // END namespace ML \ No newline at end of file diff --git a/cpp/include/cuml/cluster/linkage.hpp b/cpp/include/cuml/cluster/linkage.hpp index 64284be729..7a778ea5da 100644 --- a/cpp/include/cuml/cluster/linkage.hpp +++ b/cpp/include/cuml/cluster/linkage.hpp @@ -18,6 +18,7 @@ #include #include +#include namespace raft { class handle_t; diff --git a/cpp/src/hdbscan/condensed_hierarchy.cu b/cpp/src/hdbscan/condensed_hierarchy.cu new file mode 100644 index 0000000000..c262828afb --- /dev/null +++ b/cpp/src/hdbscan/condensed_hierarchy.cu @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace ML { +namespace HDBSCAN { +namespace Common { + +struct TupleComp { + template + __host__ __device__ bool operator()(const one &t1, const two &t2) { + // sort first by each parent, + if (thrust::get<0>(t1) < thrust::get<0>(t2)) return true; + if (thrust::get<0>(t1) > thrust::get<0>(t2)) return false; + + // within each parent, sort by each child, + if (thrust::get<1>(t1) < thrust::get<1>(t2)) return true; + if (thrust::get<1>(t1) > thrust::get<1>(t2)) return false; + + // then sort by value in descending order + return thrust::get<2>(t1) < thrust::get<2>(t2); + } +}; + +template +CondensedHierarchy::CondensedHierarchy( + const raft::handle_t &handle_, size_t n_leaves_) + : handle(handle_), + n_leaves(n_leaves_), + parents(0, handle.get_stream()), + children(0, handle.get_stream()), + lambdas(0, handle.get_stream()), + sizes(0, handle.get_stream()) {} + +template +CondensedHierarchy::CondensedHierarchy( + const raft::handle_t &handle_, size_t n_leaves_, int n_edges_, + value_idx *parents_, value_idx *children_, value_t *lambdas_, + value_idx *sizes_) + : handle(handle_), + n_leaves(n_leaves_), + n_edges(n_edges_), + parents(0, handle.get_stream()), + children(0, handle.get_stream()), + lambdas(0, handle.get_stream()), + sizes(0, handle.get_stream()) { + parents.resize(n_edges_, handle.get_stream()); + children.resize(n_edges_, handle.get_stream()); + lambdas.resize(n_edges_, handle.get_stream()); + sizes.resize(n_edges_, handle.get_stream()); + + raft::copy(parents.begin(), parents_, n_edges_, handle.get_stream()); + raft::copy(children.begin(), children_, n_edges_, handle.get_stream()); + raft::copy(lambdas.begin(), lambdas_, n_edges_, handle.get_stream()); + raft::copy(sizes.begin(), sizes_, n_edges_, handle.get_stream()); + + auto parents_ptr = thrust::device_pointer_cast(parents.data()); + + auto parents_min_max = + thrust::minmax_element(thrust::cuda::par.on(handle.get_stream()), + parents_ptr, parents_ptr + n_edges); + auto min_cluster = *parents_min_max.first; + auto max_cluster = *parents_min_max.second; + + n_clusters = max_cluster - min_cluster + 1; + + auto sort_keys = thrust::make_zip_iterator( + thrust::make_tuple(parents.begin(), children.begin(), sizes.begin())); + auto sort_values = + thrust::make_zip_iterator(thrust::make_tuple(lambdas.begin())); + + thrust::sort_by_key(thrust::cuda::par.on(handle.get_stream()), sort_keys, + sort_keys + n_edges, sort_values, TupleComp()); +} + +template +CondensedHierarchy::CondensedHierarchy( + const raft::handle_t &handle_, size_t n_leaves_, int n_edges_, + int n_clusters_, rmm::device_uvector &&parents_, + rmm::device_uvector &&children_, + rmm::device_uvector &&lambdas_, + rmm::device_uvector &&sizes_) + : handle(handle_), + n_leaves(n_leaves_), + n_edges(n_edges_), + n_clusters(n_clusters_), + parents(std::move(parents_)), + children(std::move(children_)), + lambdas(std::move(lambdas_)), + sizes(std::move(sizes_)) {} + +/** + * Populates the condensed hierarchy object with the output + * from Condense::condense_hierarchy + * @param full_parents + * @param full_children + * @param full_lambdas + * @param full_sizes + */ +template +void CondensedHierarchy::condense(value_idx *full_parents, + value_idx *full_children, + value_t *full_lambdas, + value_idx *full_sizes, + value_idx size) { + auto stream = handle.get_stream(); + + if (size == -1) size = 4 * (n_leaves - 1) + 2; + + n_edges = thrust::transform_reduce( + thrust::cuda::par.on(stream), full_sizes, full_sizes + size, + [=] __device__(value_idx a) { return a != -1; }, 0, + thrust::plus()); + + parents.resize(n_edges, stream); + children.resize(n_edges, stream); + lambdas.resize(n_edges, stream); + sizes.resize(n_edges, stream); + + auto in = thrust::make_zip_iterator( + thrust::make_tuple(full_parents, full_children, full_lambdas, full_sizes)); + + auto out = thrust::make_zip_iterator(thrust::make_tuple( + parents.data(), children.data(), lambdas.data(), sizes.data())); + + thrust::copy_if( + thrust::cuda::par.on(stream), in, in + size, out, + [=] __device__( + thrust::tuple tup) { + return thrust::get<3>(tup) != -1; + }); + + // TODO: Avoid the copies here by updating kernel + rmm::device_uvector parent_child(n_edges * 2, stream); + raft::copy_async(parent_child.begin(), children.begin(), n_edges, stream); + raft::copy_async(parent_child.begin() + n_edges, parents.begin(), n_edges, + stream); + + // find n_clusters + auto parents_ptr = thrust::device_pointer_cast(parents.data()); + auto max_parent = *(thrust::max_element(thrust::cuda::par.on(stream), + parents_ptr, parents_ptr + n_edges)); + + // now invert labels + auto invert_op = [max_parent, n_leaves = n_leaves] __device__(auto &x) { + return x >= n_leaves ? max_parent - x + n_leaves : x; + }; + + thrust::transform(thrust::cuda::par.on(stream), parent_child.begin(), + parent_child.end(), parent_child.begin(), invert_op); + + raft::label::make_monotonic(parent_child.data(), parent_child.data(), + parent_child.size(), stream, + handle.get_device_allocator(), true); + + raft::copy_async(children.begin(), parent_child.begin(), n_edges, stream); + raft::copy_async(parents.begin(), parent_child.begin() + n_edges, n_edges, + stream); + + auto parents_min_max = thrust::minmax_element( + thrust::cuda::par.on(stream), parents_ptr, parents_ptr + n_edges); + auto min_cluster = *parents_min_max.first; + auto max_cluster = *parents_min_max.second; + + n_clusters = max_cluster - min_cluster + 1; + + auto sort_keys = thrust::make_zip_iterator( + thrust::make_tuple(parents.begin(), children.begin(), sizes.begin())); + auto sort_values = + thrust::make_zip_iterator(thrust::make_tuple(lambdas.begin())); + + thrust::sort_by_key(thrust::cuda::par.on(stream), sort_keys, + sort_keys + n_edges, sort_values, TupleComp()); +} + +}; // namespace Common +}; // namespace HDBSCAN +}; // namespace ML diff --git a/cpp/src/hdbscan/detail/condense.cuh b/cpp/src/hdbscan/detail/condense.cuh new file mode 100644 index 0000000000..e5aa447c98 --- /dev/null +++ b/cpp/src/hdbscan/detail/condense.cuh @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "kernels/condense.cuh" + +#include + +#include + +#include +#include + +#include +#include + +#include + +#include +#include + +namespace ML { +namespace HDBSCAN { +namespace detail { +namespace Condense { + +/** + * Condenses a binary single-linkage tree dendrogram in the Scipy hierarchy + * format by collapsing subtrees that fall below a minimum cluster size. + * + * For increased parallelism, the output array sizes are held fixed but + * the result will be sparse (e.g. zeros in place of parents who have been + * removed / collapsed). This function accepts an empty instance of + * `CondensedHierarchy` and invokes the `condense()` function on it to + * convert the sparse output arrays into their dense form. + * + * @tparam value_idx + * @tparam value_t + * @tparam tpb + * @param handle + * @param[in] children parents/children from single-linkage dendrogram + * @param[in] delta distances from single-linkage dendrogram + * @param[in] sizes sizes from single-linkage dendrogram + * @param[in] min_cluster_size any subtrees less than this size will be + * collapsed. + * @param[in] n_leaves number of actual data samples in the dendrogram + * @param[out] condensed_tree output dendrogram. will likely no longer be + * a binary tree. + */ +template +void build_condensed_hierarchy( + const raft::handle_t &handle, const value_idx *children, const value_t *delta, + const value_idx *sizes, int min_cluster_size, int n_leaves, + Common::CondensedHierarchy &condensed_tree) { + cudaStream_t stream = handle.get_stream(); + auto exec_policy = rmm::exec_policy(stream); + + // Root is the last edge in the dendrogram + int root = 2 * (n_leaves - 1); + + auto d_ptr = thrust::device_pointer_cast(children); + value_idx n_vertices = + *(thrust::max_element(exec_policy, d_ptr, d_ptr + root)) + 1; + + // Prevent potential infinite loop from labeling disconnected + // connectivities graph. + RAFT_EXPECTS(n_vertices == (n_leaves - 1) * 2, + "Multiple components found in MST or MST is invalid. " + "Cannot find single-linkage solution."); + + rmm::device_uvector frontier(root + 1, stream); + + thrust::fill(exec_policy, frontier.begin(), frontier.end(), false); + + // Array to propagate the lambda of subtrees actively being collapsed + // through multiple bfs iterations. + rmm::device_uvector ignore(root + 1, stream); + + // Propagate labels from root + rmm::device_uvector relabel(root + 1, handle.get_stream()); + thrust::fill(exec_policy, relabel.begin(), relabel.end(), -1); + + raft::update_device(relabel.data() + root, &root, 1, handle.get_stream()); + + // Flip frontier for root + constexpr bool start = true; + raft::update_device(frontier.data() + root, &start, 1, handle.get_stream()); + + rmm::device_uvector out_parent((root + 1) * 2, stream); + rmm::device_uvector out_child((root + 1) * 2, stream); + rmm::device_uvector out_lambda((root + 1) * 2, stream); + rmm::device_uvector out_size((root + 1) * 2, stream); + + thrust::fill(exec_policy, out_parent.begin(), out_parent.end(), -1); + thrust::fill(exec_policy, out_child.begin(), out_child.end(), -1); + thrust::fill(exec_policy, out_lambda.begin(), out_lambda.end(), -1); + thrust::fill(exec_policy, out_size.begin(), out_size.end(), -1); + thrust::fill(exec_policy, ignore.begin(), ignore.end(), -1); + + // While frontier is not empty, perform single bfs through tree + size_t grid = raft::ceildiv(root + 1, (int)tpb); + + value_idx n_elements_to_traverse = + thrust::reduce(exec_policy, frontier.data(), frontier.data() + root + 1, 0); + + while (n_elements_to_traverse > 0) { + // TODO: Investigate whether it would be worth performing a gather/argmatch in order + // to schedule only the number of threads needed. (it might not be worth it) + condense_hierarchy_kernel<<>>( + frontier.data(), ignore.data(), relabel.data(), children, delta, sizes, + n_leaves, min_cluster_size, out_parent.data(), out_child.data(), + out_lambda.data(), out_size.data()); + + n_elements_to_traverse = thrust::reduce(exec_policy, frontier.data(), + frontier.data() + root + 1, 0); + + CUDA_CHECK(cudaStreamSynchronize(stream)); + } + + condensed_tree.condense(out_parent.data(), out_child.data(), + out_lambda.data(), out_size.data()); +} + +}; // end namespace Condense +}; // end namespace detail +}; // end namespace HDBSCAN +}; // end namespace ML \ No newline at end of file diff --git a/cpp/src/hdbscan/detail/extract.cuh b/cpp/src/hdbscan/detail/extract.cuh new file mode 100644 index 0000000000..6838f6f12e --- /dev/null +++ b/cpp/src/hdbscan/detail/extract.cuh @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include