From 05d899b36b76545d2439dbe47e4659d644ced227 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Tue, 21 Mar 2023 16:23:16 -0400 Subject: [PATCH 01/13] Stop setting package version attribute in wheels (#1359) This PR removes modification of the `__init__.py::version` attribute that occurs during the wheel build process. See https://github.com/rapidsai/ops/issues/2592 for more information. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Sevag H (https://github.com/sevagh) URL: https://github.com/rapidsai/raft/pull/1359 --- ci/release/apply_wheel_modifications.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ci/release/apply_wheel_modifications.sh b/ci/release/apply_wheel_modifications.sh index ed3d2a15fd..efc8f0c77c 100755 --- a/ci/release/apply_wheel_modifications.sh +++ b/ci/release/apply_wheel_modifications.sh @@ -6,10 +6,6 @@ VERSION=${1} CUDA_SUFFIX=${2} -# __init__.py versions -sed -i "s/__version__ = .*/__version__ = \"${VERSION}\"/g" python/pylibraft/pylibraft/__init__.py -sed -i "s/__version__ = .*/__version__ = \"${VERSION}\"/g" python/raft-dask/raft_dask/__init__.py - # pyproject.toml versions sed -i "s/^version = .*/version = \"${VERSION}\"/g" python/pylibraft/pyproject.toml sed -i "s/^version = .*/version = \"${VERSION}\"/g" python/raft-dask/pyproject.toml From a7e619cfec8b17a467122e0fd123aedad1bc5e06 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 22:53:40 +0100 Subject: [PATCH 02/13] Remove usage of Dask's `get_worker` (#1365) In dask/distributed#7580 get_worker was modified to return the worker of a task, thus it cannot be used by client.run, and we must now use dask_worker as the first argument to client.run to obtain the worker. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/raft/pull/1365 --- ci/wheel_smoke_test_raft_dask.py | 21 ++- python/raft-dask/raft_dask/common/comms.py | 159 ++++++++++++------ python/raft-dask/raft_dask/test/test_comms.py | 24 ++- 3 files changed, 133 insertions(+), 71 deletions(-) diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py index 32c13e61ca..5709ac901c 100644 --- a/ci/wheel_smoke_test_raft_dask.py +++ b/ci/wheel_smoke_test_raft_dask.py @@ -1,4 +1,19 @@ -from dask.distributed import Client, wait +# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dask.distributed import Client, get_worker, wait from dask_cuda import LocalCUDACluster, initialize from raft_dask.common import ( @@ -23,12 +38,12 @@ def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_send_recv(handle, n_trials) def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return func(handle, root) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 56e40b98da..ebe9a8dc4f 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -19,7 +19,7 @@ import warnings from collections import OrderedDict -from dask.distributed import default_client, get_worker +from dask.distributed import default_client from pylibraft.common.handle import Handle @@ -242,7 +242,7 @@ def destroy(self): self.ucx_initialized = False -def local_handle(sessionId): +def local_handle(sessionId, dask_worker=None): """ Simple helper function for retrieving the local handle_t instance for a comms session on a worker. @@ -251,16 +251,19 @@ def local_handle(sessionId): ---------- sessionId : str session identifier from an initialized comms instance + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- handle : raft.Handle or None """ - state = get_raft_comm_state(sessionId, get_worker()) + state = get_raft_comm_state(sessionId, dask_worker) return state["handle"] if "handle" in state else None -def get_raft_comm_state(sessionId, state_object=None): +def get_raft_comm_state(sessionId, state_object=None, dask_worker=None): """ Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, @@ -271,13 +274,16 @@ def get_raft_comm_state(sessionId, state_object=None): sessionId : SessionId value to retrieve from the dask_scheduler instances state_object : Object (either Worker, or Scheduler) on which the raft comm state will retrieved (or created) + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- session state : str session state associated with sessionId """ - state_object = state_object if state_object is not None else get_worker() + state_object = state_object if state_object is not None else dask_worker if not hasattr(state_object, "_raft_comm_state"): state_object._raft_comm_state = {} @@ -308,13 +314,19 @@ def set_nccl_root(sessionId, state_object): return raft_comm_state["nccl_uid"] -def get_ucx(): +def get_ucx(dask_worker=None): """ A simple convenience wrapper to make sure UCP listener and endpoints are only ever assigned once per worker. + + Parameters + ---------- + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ raft_comm_state = get_raft_comm_state( - sessionId="ucp", state_object=get_worker() + sessionId="ucp", state_object=dask_worker ) if "ucx" not in raft_comm_state: raft_comm_state["ucx"] = UCX.get() @@ -371,7 +383,7 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): return nccl_uid -def _func_set_worker_as_nccl_root(sessionId, verbose): +def _func_set_worker_as_nccl_root(sessionId, verbose, dask_worker=None): """ Creates a persistent nccl uniqueId on the scheduler node. @@ -380,63 +392,74 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): ---------- sessionId : Associated session to attach the unique ID to. verbose : Indicates whether or not to emit additional information + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Return ------ uniqueId : byte str NCCL uniqueId, associating this DASK worker as its root node. """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Setting worker as NCCL root for session, '{sessionId}'", ) - nccl_uid = set_nccl_root(sessionId=sessionId, state_object=worker) + nccl_uid = set_nccl_root(sessionId=sessionId, state_object=dask_worker) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Done setting scheduler as NCCL root." ) return nccl_uid -def _func_ucp_listener_port(): - return get_ucx().listener_port() +def _func_ucp_listener_port(dask_worker=None): + return get_ucx(dask_worker=dask_worker).listener_port() async def _func_init_all( - sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle + sessionId, + uniqueId, + comms_p2p, + worker_info, + verbose, + streams_per_handle, + dask_worker=None, ): - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId - raft_comm_state["wid"] = worker_info[get_worker().address]["rank"] + raft_comm_state["wid"] = worker_info[dask_worker.address]["rank"] raft_comm_state["nworkers"] = len(worker_info) if verbose: - worker.log_event(topic="info", msg="Initializing NCCL.") + dask_worker.log_event(topic="info", msg="Initializing NCCL.") start = time.time() - _func_init_nccl(sessionId, uniqueId) + _func_init_nccl(sessionId, uniqueId, dask_worker=dask_worker) if verbose: elapsed = time.time() - start - worker.log_event( + dask_worker.log_event( topic="info", msg=f"NCCL Initialization took: {elapsed} seconds." ) if comms_p2p: if verbose: - worker.log_event(topic="info", msg="Initializing UCX Endpoints") + dask_worker.log_event( + topic="info", msg="Initializing UCX Endpoints" + ) if verbose: start = time.time() - await _func_ucp_create_endpoints(sessionId, worker_info) + await _func_ucp_create_endpoints( + sessionId, worker_info, dask_worker=dask_worker + ) if verbose: elapsed = time.time() - start @@ -444,18 +467,22 @@ async def _func_init_all( f"Done initializing UCX endpoints." f"Took: {elapsed} seconds.\nBuilding handle." ) - worker.log_event(topic="info", msg=msg) + dask_worker.log_event(topic="info", msg=msg) - _func_build_handle_p2p(sessionId, streams_per_handle, verbose) + _func_build_handle_p2p( + sessionId, streams_per_handle, verbose, dask_worker=dask_worker + ) if verbose: - worker.log_event(topic="info", msg="Done building handle.") + dask_worker.log_event(topic="info", msg="Done building handle.") else: - _func_build_handle(sessionId, streams_per_handle, verbose) + _func_build_handle( + sessionId, streams_per_handle, verbose, dask_worker=dask_worker + ) -def _func_init_nccl(sessionId, uniqueId): +def _func_init_nccl(sessionId, uniqueId, dask_worker=None): """ Initialize ncclComm_t on worker @@ -466,11 +493,13 @@ def _func_init_nccl(sessionId, uniqueId): uniqueId : array[byte] The NCCL unique Id generated from the client. + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker, dask_worker=dask_worker ) wid = raft_comm_state["wid"] nWorkers = raft_comm_state["nworkers"] @@ -480,13 +509,15 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) raft_comm_state["nccl"] = n except Exception as e: - worker.log_event( + dask_worker.log_event( topic="error", msg=f"An error occurred initializing NCCL: {e}." ) raise -def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): +def _func_build_handle_p2p( + sessionId, streams_per_handle, verbose, dask_worker=None +): """ Builds a handle_t on the current worker given the initialized comms @@ -495,14 +526,16 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event(topic="info", msg="Building p2p handle.") + dask_worker.log_event(topic="info", msg="Building p2p handle.") - ucp_worker = get_ucx().get_worker() + ucp_worker = get_ucx(dask_worker).get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) handle = Handle(n_streams=streams_per_handle) @@ -512,21 +545,23 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): workerId = raft_comm_state["wid"] if verbose: - worker.log_event(topic="info", msg="Injecting comms on handle.") + dask_worker.log_event(topic="info", msg="Injecting comms on handle.") inject_comms_on_handle( handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) raft_comm_state["handle"] = handle -def _func_build_handle(sessionId, streams_per_handle, verbose): +def _func_build_handle( + sessionId, streams_per_handle, verbose, dask_worker=None +): """ Builds a handle_t on the current worker given the initialized comms @@ -535,17 +570,19 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) handle = Handle(n_streams=streams_per_handle) raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) workerId = raft_comm_state["wid"] @@ -558,16 +595,18 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): raft_comm_state["handle"] = handle -def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): +def _func_store_initial_state( + nworkers, sessionId, uniqueId, wid, dask_worker=None +): raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId raft_comm_state["wid"] = wid raft_comm_state["nworkers"] = nworkers -async def _func_ucp_create_endpoints(sessionId, worker_info): +async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): """ Runs on each worker to create ucp endpoints to all other workers @@ -577,6 +616,9 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): uuid unique id for this instance worker_info : dict Maps worker addresses to NCCL ranks & UCX ports + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ eps = [None] * len(worker_info) count = 1 @@ -584,40 +626,47 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): for k in worker_info: ip, port = parse_host_port(k) - ep = await get_ucx().get_endpoint(ip, worker_info[k]["port"]) + ep = await get_ucx(dask_worker=dask_worker).get_endpoint( + ip, worker_info[k]["port"] + ) eps[worker_info[k]["rank"]] = ep count += 1 raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["ucp_eps"] = eps -async def _func_destroy_all(sessionId, comms_p2p, verbose=False): - worker = get_worker() +async def _func_destroy_all( + sessionId, comms_p2p, verbose=False, dask_worker=None +): if verbose: - worker.log_event(topic="info", msg="Destroying NCCL session state.") + dask_worker.log_event( + topic="info", msg="Destroying NCCL session state." + ) raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) if "nccl" in raft_comm_state: raft_comm_state["nccl"].destroy() del raft_comm_state["nccl"] if verbose: - worker.log_event(topic="info", msg="NCCL session state destroyed.") + dask_worker.log_event( + topic="info", msg="NCCL session state destroyed." + ) else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'nccl' element", ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Destroying CUDA handle for sessionId, '{sessionId}.'", ) @@ -626,7 +675,7 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False): del raft_comm_state["handle"] else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'handle' element", diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 74ec446e94..3a430f9270 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import pytest -from dask.distributed import Client, wait +from dask.distributed import Client, get_worker, wait try: from raft_dask.common import ( @@ -60,32 +60,32 @@ def test_comms_init_no_p2p(cluster): def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return func(handle, root) def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_send_recv(handle, n_trials) def func_test_device_send_or_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_send_or_recv(handle, n_trials) def func_test_device_sendrecv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_sendrecv(handle, n_trials) def func_test_device_multicast_sendrecv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_multicast_sendrecv(handle, n_trials) def func_test_comm_split(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comm_split(handle, n_trials) @@ -114,11 +114,9 @@ def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): ) -def func_check_uid_on_worker(sessionId, uniqueId): - from dask.distributed import get_worker - +def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): return func_check_uid( - sessionId=sessionId, uniqueId=uniqueId, state_object=get_worker() + sessionId=sessionId, uniqueId=uniqueId, state_object=dask_worker ) @@ -127,7 +125,7 @@ def test_handles(cluster): client = Client(cluster) def _has_handle(sessionId): - return local_handle(sessionId) is not None + return local_handle(sessionId, dask_worker=get_worker()) is not None try: cb = Comms(verbose=True) From 08e7012bc00140f77732fd73b134f388edf119dd Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 23 Mar 2023 03:24:56 +0100 Subject: [PATCH 03/13] Reduce compile times of distance specializations (#1307) Following the findings in https://github.com/ahendriksen/raft/tree/investigate-compile-time-reduction-strategies#investigation-of-compile-times, this PR reduces the compile times of the pairwise distance specializations. This is achieved by: 1. Reducing the number of included files in the translation units where kernels are instantiated, specifically `spdlog` and `rmm` are avoided. 2. Limiting loop unrolling in kernels with expensive operations in the inner loop. Additional improvements geared towards iterative development: 1. The tests do not have to be recompiled when the internals of a pairwise distance kernel change. Before, a rebuilt was triggered due an include of `raft/distance/distance.cuh`. 2. Addition of a fine tuning benchmark for the pairwise distance kernels that separates building the kernel from the benchmark code. This dramatically speeds up development. Compiling an empty benchmark takes roughly 18 seconds on my machine. Whereas recompiling a kernel takes ~3.8 seconds. Without this addition, a commit like 35a2ad437 would require substantially more time to make sure that performance is not degraded. ![image](https://user-images.githubusercontent.com/4172822/225383120-5f8a82f9-0b46-4c39-bc1d-7b2a0551e881.png) ``` Parallel build time before: 270 seconds (6 cores, SMT, 12 jobs) Parallel build time before: 147 seconds (6 cores, SMT, 12 jobs) Sum of compile times before: 3022.6 seconds Sum of compile times after: 1816.2 seconds Comparison of compile times between headers and compiled: path before (s) after (s) change (s) change (%) pairwise_test None 0.486 None None ance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o 101.1 10.3 -90.8 -89.8% src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o 52.9 6.3 -46.6 -88.0% /distance/distance/specializations/detail/canberra_double_double_double_int.cu.o 48.5 6.4 -42.1 -86.8% stance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o 65.3 10.4 -55.0 -84.1% istance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o 70.2 12.6 -57.6 -82.0% stance/distance/specializations/detail/correlation_double_double_double_int.cu.o 46.7 8.9 -37.8 -80.9% distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o 41.6 8.1 -33.5 -80.6% nce/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o 74.6 15.1 -59.5 -79.7% ir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o 40.9 8.4 -32.5 -79.4% ance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o 40.7 8.6 -32.1 -78.8% distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o 40.8 9.0 -31.7 -77.8% istance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o 45.9 10.2 -35.7 -77.8% src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o 41.2 9.5 -31.8 -77.0% istance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o 29.5 7.2 -22.3 -75.6% t.dir/src/distance/distance/specializations/detail/l1_float_float_float_int.cu.o 47.3 13.2 -34.1 -72.2% ce/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o 47.0 13.3 -33.7 -71.6% /distance/distance/specializations/detail/correlation_float_float_float_int.cu.o 49.4 14.1 -35.3 -71.5% ce/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o 43.6 12.5 -31.1 -71.4% c/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o 28.5 8.2 -20.3 -71.2% ance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o 75.8 21.9 -53.9 -71.1% istance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o 46.2 13.5 -32.7 -70.7% ir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o 43.1 12.7 -30.4 -70.6% stance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o 52.3 24.9 -27.3 -52.3% /distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o 75.8 40.3 -35.5 -46.8% rc/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o 53.5 28.7 -24.8 -46.4% r/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o 83.9 50.1 -33.8 -40.3% CMakeFiles/pairwise_test.dir/test/distance/fused_l2_nn.cu.o 85.1 64.1 -21.1 -24.7% wise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int64.cu.o 56.2 42.9 -13.3 -23.6% irwise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int.cu.o 52.5 40.2 -12.3 -23.5% CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o 56.3 43.3 -13.0 -23.1% CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o 55.7 44.0 -11.7 -21.0% rwise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int.cu.o 45.3 36.4 -9.0 -19.8% CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o 54.6 44.1 -10.6 -19.3% CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o 51.6 42.1 -9.6 -18.6% CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o 53.1 43.4 -9.6 -18.2% CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o 53.2 43.9 -9.3 -17.5% CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o 53.1 44.0 -9.0 -17.0% CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o 52.3 43.4 -8.9 -17.0% CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o 54.0 45.6 -8.4 -15.6% CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o 52.6 44.5 -8.1 -15.4% CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o 52.4 44.7 -7.7 -14.8% ise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int64.cu.o 43.5 37.2 -6.4 -14.7% CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o 52.4 44.8 -7.6 -14.5% CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o 53.2 45.7 -7.6 -14.2% CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o 51.1 44.8 -6.3 -12.4% istance/distance/specializations/detail/inner_product_float_float_float_int.cu.o 39.5 35.1 -4.5 -11.3% CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o 51.7 46.8 -4.9 -9.5% ance/distance/specializations/detail/inner_product_double_double_double_int.cu.o 37.1 33.9 -3.1 -8.5% src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o 45.3 41.7 -3.6 -8.0% rc/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o 42.5 39.6 -2.9 -6.8% stance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o 40.4 38.5 -1.9 -4.8% CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o 123.3 117.8 -5.4 -4.4% CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o 55.3 53.4 -1.9 -3.5% build.ninja 4.0 4.0 +0.0 +0.1% istance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o 45.2 45.6 +0.4 +0.8% .dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o 45.2 46.0 +0.8 +1.7% dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o 39.0 39.8 +0.8 +2.1% CMakeFiles/pairwise_test.dir/src/distance/distance/pairwise_distance.cu.o 39.6 50.1 +10.5 +26.6% ``` Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/1307 --- cpp/CMakeLists.txt | 4 - cpp/bench/CMakeLists.txt | 5 + cpp/bench/distance/tune_pairwise/bench.cu | 151 +++++ cpp/bench/distance/tune_pairwise/kernel.cu | 88 +++ cpp/bench/distance/tune_pairwise/kernel.cuh | 44 ++ cpp/include/raft/core/kvp.hpp | 2 +- cpp/include/raft/distance/detail/distance.cuh | 107 ++-- .../distance/detail/distance_ops/canberra.cuh | 5 +- .../detail/distance_ops/correlation.cuh | 4 +- .../distance/detail/distance_ops/cosine.cuh | 11 +- .../distance/detail/distance_ops/cutlass.cuh | 6 +- .../distance/detail/distance_ops/hamming.cuh | 4 +- .../detail/distance_ops/hellinger.cuh | 4 +- .../detail/distance_ops/jensen_shannon.cuh | 5 +- .../detail/distance_ops/kl_divergence.cuh | 5 +- .../raft/distance/detail/distance_ops/l1.cuh | 4 +- .../distance/detail/distance_ops/l2_exp.cuh | 11 +- .../distance/detail/distance_ops/l2_unexp.cuh | 4 +- .../distance/detail/distance_ops/l_inf.cuh | 4 +- .../distance/detail/distance_ops/lp_unexp.cuh | 5 +- .../detail/distance_ops/russel_rao.cuh | 4 +- .../distance/detail/distance_ops/template.cuh | 10 +- .../raft/distance/detail/fused_l2_nn.cuh | 133 ++--- .../detail/pairwise_distance_base.cuh | 83 ++- .../detail/pairwise_distance_cutlass_base.cuh | 29 +- .../detail/pairwise_matrix/dispatch.cuh | 114 ++-- .../pairwise_matrix/dispatch_layout.cuh | 21 +- .../detail/pairwise_matrix/dispatch_sm60.cuh | 26 +- .../detail/pairwise_matrix/dispatch_sm80.cuh | 16 +- .../detail/pairwise_matrix/kernel_sm60.cuh | 69 +-- .../detail/00_write_template.py | 148 +++++ .../specializations/detail/canberra.cuh | 50 +- .../specializations/detail/correlation.cuh | 51 +- .../specializations/detail/cosine.cuh | 51 +- .../detail/hamming_unexpanded.cuh | 51 +- .../detail/hellinger_expanded.cuh | 50 +- .../specializations/detail/jensen_shannon.cuh | 50 +- .../specializations/detail/kl_divergence.cuh | 49 +- .../distance/specializations/detail/l1.cuh | 48 +- .../specializations/detail/l2_expanded.cuh | 49 +- .../detail/l2_sqrt_expanded.cuh | 54 -- .../detail/l2_sqrt_unexpanded.cuh | 54 -- .../specializations/detail/l2_unexpanded.cuh | 49 +- .../distance/specializations/detail/l_inf.cuh | 48 +- .../specializations/detail/lp_unexpanded.cuh | 49 +- .../specializations/detail/russel_rao.cuh | 50 +- .../distance/specializations/distance.cuh | 2 - .../raft/spatial/knn/detail/fused_l2_knn.cuh | 524 +++++++++--------- cpp/include/raft/util/arch.cuh | 23 +- cpp/include/raft/util/cuda_dev_essentials.cuh | 91 +++ cpp/include/raft/util/cuda_rt_essentials.hpp | 60 ++ cpp/include/raft/util/cuda_utils.cuh | 105 +--- cpp/include/raft/util/cudart_utils.hpp | 38 +- cpp/include/raft/util/device_loads_stores.cuh | 5 +- .../detail/00_write_template.py | 159 ++++++ .../canberra_double_double_double_int.cu | 36 +- .../detail/canberra_float_float_float_int.cu | 35 +- .../correlation_double_double_double_int.cu | 35 +- .../correlation_float_float_float_int.cu | 35 +- .../detail/cosine_double_double_double_int.cu | 35 +- .../detail/cosine_float_float_float_int.cu | 35 +- ...ing_unexpanded_double_double_double_int.cu | 35 +- ...amming_unexpanded_float_float_float_int.cu | 35 +- ...inger_expanded_double_double_double_int.cu | 35 +- ...ellinger_expanded_float_float_float_int.cu | 34 +- ...jensen_shannon_double_double_double_int.cu | 36 +- .../jensen_shannon_float_float_float_int.cu | 36 +- .../kl_divergence_double_double_double_int.cu | 35 +- .../kl_divergence_float_float_float_int.cu | 35 +- .../detail/l1_double_double_double_int.cu | 35 +- .../detail/l1_float_float_float_int.cu | 35 +- .../l2_expanded_double_double_double_int.cu | 37 +- .../l2_expanded_float_float_float_int.cu | 36 +- ..._sqrt_expanded_double_double_double_int.cu | 38 -- .../l2_sqrt_expanded_float_float_float_int.cu | 38 -- ...qrt_unexpanded_double_double_double_int.cu | 38 -- ...2_sqrt_unexpanded_float_float_float_int.cu | 38 -- .../l2_unexpanded_double_double_double_int.cu | 35 +- .../l2_unexpanded_float_float_float_int.cu | 35 +- .../detail/l_inf_double_double_double_int.cu | 34 +- .../detail/l_inf_float_float_float_int.cu | 35 +- .../lp_unexpanded_double_double_double_int.cu | 35 +- .../lp_unexpanded_float_float_float_int.cu | 35 +- .../russel_rao_double_double_double_int.cu | 36 +- .../russel_rao_float_float_float_int.cu | 35 +- cpp/test/distance/distance_base.cuh | 74 ++- cpp/test/distance/fused_l2_nn.cu | 30 +- 87 files changed, 2057 insertions(+), 2000 deletions(-) create mode 100644 cpp/bench/distance/tune_pairwise/bench.cu create mode 100644 cpp/bench/distance/tune_pairwise/kernel.cu create mode 100644 cpp/bench/distance/tune_pairwise/kernel.cuh create mode 100644 cpp/include/raft/distance/specializations/detail/00_write_template.py delete mode 100644 cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh create mode 100644 cpp/include/raft/util/cuda_dev_essentials.cuh create mode 100644 cpp/include/raft/util/cuda_rt_essentials.hpp create mode 100644 cpp/src/distance/specializations/detail/00_write_template.py delete mode 100644 cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index bdaacb4a85..034dc059b0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -312,10 +312,6 @@ if(RAFT_COMPILE_LIBRARY) src/distance/specializations/detail/l1_double_double_double_int.cu src/distance/specializations/detail/l2_expanded_float_float_float_int.cu src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu src/distance/specializations/detail/l_inf_double_double_double_int.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 8049074c09..d92ccba8e3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -72,6 +72,11 @@ if(BUILD_BENCH) OPTIONAL LIB ) + ConfigureBench( + NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu + bench/distance/tune_pairwise/bench.cu bench/main.cpp + ) + ConfigureBench( NAME DISTANCE_BENCH diff --git a/cpp/bench/distance/tune_pairwise/bench.cu b/cpp/bench/distance/tune_pairwise/bench.cu new file mode 100644 index 0000000000..87159ab1b1 --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/bench.cu @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Tuning benchmarks. +// +// Goals: +// +// 1. Fast compile times to maintain iteration speed. +// 2. Create benchmarks that can inform the design of the kernels. +// +// Non-goals: +// +// 1. Measure every distance operation. Instead measures just one distance +// operation at the same time. +// 2. Be useful for finding performance regressions. This is handled by the +// normal benchmarks. +// +// So far, both goals are partly achieved. +// +// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not. +// When the internals of a pairwise distance kernel is changed, this file is not +// recompiled. +// +// RE 2, benchmarks with intent: this file contains a benchmark to check the +// maximal throughput of a kernel. Measuring other things, like performance on +// skinny or wide matrices is not yet implemented. + +#include "kernel.cuh" // launch_kernel +#include // std::min +#include // RAFT_BENCH_REGISTER +#include // pairwise_matrix_params +#include // rmm::device_uvector +#include // std::vector + +namespace raft::bench::distance::tune { + +// Max throughput benchmark. +// +// Goal: Measure the maximum distances/sec that can be computed. +// +// To achieve this, we make sure that: +// +// - Input data size is a multiple of the block tile size. +// +// - Perfect distribution of work between SMs, i.e. the number of block tiles is +// a large multiple (num_waves) of the number of blocks (#SMs * occupancy). +// +// - Multiple iterations over Kblk are executed (num_k_iters). +struct throughput_param { + int num_waves; + int occupancy; + int num_k_iters; +}; + +const std::vector throughput_params{ + // 32 waves, requested occupancy of 4, and 32 k iterations typically achieves + // maximum throughput. No need to pick higher values. + {32, 4, 32}, +}; + +struct throughput_bench : public fixture { + const throughput_param p; + + throughput_bench(const throughput_param& p_) : p(p_) {} + + void run_benchmark(::benchmark::State& state) override + { + // Get block size: + int block_m, block_n, block_k; + get_block_size(block_m, block_n, block_k); + + // Determine number of blocks that will be launched. This informs the size + // of the inputs as well as the grid size. + const int num_sms = raft::getMultiProcessorCount(); + const int max_occupancy = get_max_occupancy(); + const int occupancy = std::min(p.occupancy, max_occupancy); + const int num_blocks = occupancy * num_sms; + dim3 grid(num_blocks); + + // Create input sizes that are a multiple of the block tile size. + size_t m = block_m; + size_t n = block_n * p.num_waves * num_blocks; + size_t k = block_k * p.num_k_iters; + + // DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh + rmm::device_uvector x_vec(m * k, stream); + rmm::device_uvector y_vec(n * k, stream); + rmm::device_uvector x_norm_vec(m, stream); + rmm::device_uvector y_norm_vec(n, stream); + rmm::device_uvector out_vec(m * n, stream); + + auto x = x_vec.data(); + auto y = y_vec.data(); + auto x_norm = x_norm_vec.data(); + auto y_norm = y_norm_vec.data(); + auto out = out_vec.data(); + FinOpT fin_op{}; + + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = row_major ? k : m; + IdxT ldy = row_major ? k : n; + IdxT ld_out = row_major ? n : m; + + // Template parameters of pairwise_matrix_params are defined in kernel.cuh + pairwise_matrix_params kparams{ + IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; + + // Run benchmark + loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); }); + + // Report metrics. We don't report flop/s because we do not know for each + // distance operation how many flops it costs. For L2_unexp and l1, we can + // double this number to get the flop/s. For l2 expanded, core_ops/s should + // equal flop/s (modulo the sqrt and subtracting from the norm). + size_t num_core_ops = m * n * k; + size_t read_elts = n * k + m * k; + size_t write_elts = m * n; + + state.counters["m"] = benchmark::Counter(m); + state.counters["n"] = benchmark::Counter(n); + state.counters["k"] = benchmark::Counter(k); + state.counters["occupancy"] = benchmark::Counter(occupancy); + state.counters["# waves"] = benchmark::Counter(p.num_waves); + state.counters["# k iters"] = benchmark::Counter(p.num_k_iters); + + state.counters["core_ops/s"] = benchmark::Counter(num_core_ops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + } +}; + +RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params); + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu new file mode 100644 index 0000000000..3112e1ea9a --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.cuh" +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 +#include // raft::util::arch::SM_compute_arch + +namespace raft::bench::distance::tune { + +// Distance op +using OpT = raft::distance::detail::ops::lp_unexp_distance_op; +constexpr float metric_arg = 2.0; +OpT distance_op{metric_arg}; + +// Kernel policy +constexpr int vec_len = 1; +using Policy = typename raft::linalg::Policy4x4::Policy; + +// Architecture +namespace arch = raft::util::arch; +constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + +void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream) +{ + dim3 block(Policy::Nthreads); + int smem_size = OpT::shared_mem_size(); + + // Obtain function pointer to kernel + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + + kernel<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +void get_block_size(int& m, int& n, int& k) +{ + m = Policy::Mblk; + n = Policy::Nblk; + k = Policy::Kblk; +} + +void* get_kernel_ptr() +{ + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + return reinterpret_cast(kernel); +} + +int get_max_occupancy() +{ + void* kernel_ptr = get_kernel_ptr(); + int max_occupancy; + int smem_size = OpT::shared_mem_size(); + + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, kernel_ptr, Policy::Nthreads, smem_size)); + + return max_occupancy; +} + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cuh b/cpp/bench/distance/tune_pairwise/kernel.cuh new file mode 100644 index 0000000000..5da54a343c --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cuh @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // lp_unexp_distance_op +#include // pairwise_matrix_params + +namespace raft::bench::distance::tune { + +// Launch one specific kernel with the following template parameters +constexpr bool row_major = true; +using DataT = float; +using AccT = float; +using OutT = DataT; +using IdxT = int; + +using FinOpT = raft::identity_op; + +using pairwise_matrix_params = + raft::distance::detail::pairwise_matrix_params; + +// Launches kernel +void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t); + +// Describes the block size that is decided by the policy +void get_block_size(int& m, int& n, int& k); + +int get_max_occupancy(); + +} // namespace raft::bench::distance::tune diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 8d3321eb77..192d160d45 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -20,7 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include -#include +#include // raft::shfl_xor #endif namespace raft { /** diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index f469250b45..7493c4e558 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -16,25 +16,18 @@ #pragma once -#include -#include - -#include -#include -#include -#include -#include - #include - +#include #include #include - +#include +#include #include #include -#include -#include -#include +#include +#include +#include +#include namespace raft { namespace distance { @@ -140,14 +133,14 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - AccT* norm_col_vec = workspace; - AccT* norm_row_vec = workspace; - AccT* sq_norm_col_vec = workspace; - AccT* sq_norm_row_vec = workspace; + AccT* x_norm = workspace; + AccT* y_norm = workspace; + AccT* sq_x_norm = workspace; + AccT* sq_y_norm = workspace; if (x != y) { - norm_row_vec += m; + y_norm += m; - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -158,7 +151,7 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - raft::linalg::reduce(norm_row_vec, + raft::linalg::reduce(y_norm, y, k, n, @@ -170,12 +163,12 @@ void distance_impl(raft::resources const& handle, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += (m + n); - sq_norm_row_vec = sq_norm_col_vec + m; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); - raft::linalg::rowNorm(sq_norm_row_vec, y, k, n, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += (m + n); + sq_y_norm = sq_x_norm + m; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); } else { - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -186,15 +179,15 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += m; - sq_norm_row_vec = sq_norm_col_vec; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += m; + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } using OpT = ops::correlation_distance_op; - OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); pairwise_matrix_dispatch( - corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); + corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -223,22 +216,22 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } ops::cosine_distance_op distance_op{}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -389,10 +382,6 @@ void distance_impl(raft::resources const& handle, return (!x_zero) * raft::exp(input); }; - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; - if (x != y) { raft::linalg::unaryOp( (DataT*)y, y, n * k, unaryOp_lambda, stream); @@ -401,8 +390,12 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - pairwise_matrix_dispatch( - kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + // This op takes some shortcuts when x equals y. So its behavior changes based + // on this. + ops::kl_divergence_op distance_op{is_row_major, x == y}; + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); if (x != y) { // Now reverse previous log (x) back to x using (e ^ log(x)) @@ -464,22 +457,22 @@ void distance_impl_l2_expanded( // NOTE: different name "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } ops::l2_exp_distance_op distance_op{perform_sqrt}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -543,13 +536,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -571,13 +564,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 930294ce31..eaf37b7e9c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // raft::abs +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +43,7 @@ struct canberra_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 289b69070a..4fc4bb8297 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -61,7 +61,7 @@ struct correlation_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index 7c37c27b4e..0883136c9f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -26,7 +26,7 @@ struct cosine_cutlass_op { __device__ cosine_cutlass_op() noexcept {} __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); } __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; @@ -53,7 +53,7 @@ struct cosine_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -76,7 +76,10 @@ struct cosine_distance_op { } } - cosine_cutlass_op get_cutlass_op() { return cosine_cutlass_op(); } + constexpr cosine_cutlass_op get_cutlass_op() const + { + return cosine_cutlass_op(); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh index d3eb90467b..7a4fe0ce83 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // std::false_type +#include // std::declval namespace raft::distance::detail::ops { @@ -34,7 +35,8 @@ struct has_cutlass_op : std::false_type { // Specialization recognizes types that do support CUTLASS template -struct has_cutlass_op> : std::true_type { +struct has_cutlass_op().get_cutlass_op())>> + : std::true_type { }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 1cfdcfdc73..475b8892e9 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -45,7 +45,7 @@ struct hamming_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index c4aecc7a6f..0489b45854 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +42,7 @@ struct hellinger_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 41eeb9dd83..e46c63734c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::log +#include // DI namespace raft::distance::detail::ops { @@ -44,7 +45,7 @@ struct jensen_shannon_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index d046b62c30..d083c5ddcc 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::log +#include // DI namespace raft::distance::detail::ops { @@ -49,7 +50,7 @@ struct kl_divergence_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 8ec4000827..7e86fd3603 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -41,7 +41,7 @@ struct l1_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 2a7af53813..95577fd311 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -54,7 +54,7 @@ struct l2_exp_distance_op { using AccT = AccType; using IdxT = IdxType; - bool sqrt; + const bool sqrt; l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} @@ -67,7 +67,7 @@ struct l2_exp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -102,7 +102,10 @@ struct l2_exp_distance_op { } } - l2_exp_cutlass_op get_cutlass_op() { return l2_exp_cutlass_op(sqrt); } + constexpr l2_exp_cutlass_op get_cutlass_op() const + { + return l2_exp_cutlass_op(sqrt); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index f0ea591eaf..62c212ee8f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -46,7 +46,7 @@ struct l2_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index fb21fb1a21..88853a3083 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +42,7 @@ struct l_inf_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 71dfd51a6e..290f4af1b4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::pow, raft::abs +#include // DI namespace raft::distance::detail::ops { @@ -45,7 +46,7 @@ struct lp_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index ea09e4d1db..63dbf350d1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -47,7 +47,7 @@ struct russel_rao_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 6998f3cad4..4320068361 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,8 +42,8 @@ struct template_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template - constexpr size_t shared_mem_size() + template + static constexpr size_t shared_mem_size() { return Policy::SmemSize + TODO; } @@ -59,6 +59,10 @@ struct template_distance_op { { TODO; } + + // If exist, returns a cutlass op that performs the same operation. + // See cosine and l2_exp distance ops for an example. + constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..be6fed9f10 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,23 +16,20 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; @@ -124,11 +121,10 @@ DI void updateReducedVal( template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, const DataT* x, @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, int* mutex, ReduceOpT redOp, KVPReduceOpT pairRedOp, - CoreLambda core_op, + OpT distance_op, FinalLambda fin_op) { extern __shared__ char smem[]; @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; - } - } - } - // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances + row_major, + write_out> obj(x, y, m, @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ldd, xn, yn, - nullptr, + nullptr, // Output pointer smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min, constexpr auto maxVal = std::numeric_limits::max(); typedef KeyValuePair KVPair; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min, } constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 0293f10c29..c6b09be31e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -14,14 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include +#include // raft::linalg::Contractions_NT +#include // ceildiv +#include // RAFT_CUDA_TRY -#include +#include // size_t namespace raft { namespace distance { @@ -29,16 +26,12 @@ namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam useNorms whether norms are needed * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) + * @tparam OpT A distance operation, e.g., cosine_distance_op. * @tparam EpilogueLambda applies an elementwise function to compute final values. Its signature is: template void epilogue_lambda @@ -56,19 +49,17 @@ namespace detail { * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine * @param[output] pD output matrix * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param core_op the core accumulation operation lambda + * @param distance_op the distance operation, e.g. cosine_distance_op * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ -template > struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + private: typedef Policy P; const DataT* xn; @@ -83,7 +77,7 @@ struct PairwiseDistances : public BaseClass { const DataT* const yBase; OutT* dOutput; char* smem; - CoreLambda core_op; + OpT distance_op; EpilogueLambda epilog_op; FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; @@ -109,7 +103,7 @@ struct PairwiseDistances : public BaseClass { const DataT* _yn, OutT* _dOutput, char* _smem, - CoreLambda _core_op, + OpT _distance_op, EpilogueLambda _epilog_op, FinalLambda _fin_op, rowEpilogueLambda _rowEpilog_op) @@ -119,7 +113,7 @@ struct PairwiseDistances : public BaseClass { yBase(_y), dOutput(_dOutput), smem(_smem), - core_op(_core_op), + distance_op(_distance_op), epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), @@ -159,15 +153,25 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); // Epilog: - if (useNorms) { + if (distance_op.use_norms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { store_output(tile_idx_m, tile_idx_n); } @@ -201,24 +205,41 @@ struct PairwiseDistances : public BaseClass { } } - DI void accumulate() + DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], + DataT (®_y)[P::AccColsPerTh][P::Veclen]) { #pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); + for (int v = 0; v < P::Veclen; ++v) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); } } } } + DI void accumulate() + { + // We have a separate ldsXY and accumulate_reg_tile outside the loop body, + // so that these separated calls can be interspersed with preceding and + // following instructions, thereby hiding latency. + this->ldsXY(0); + + // If expensive inner loop, do not unroll loop. + constexpr int num_iterations = P::Kblk / P::Veclen - 1; + constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; +#pragma unroll unroll_count + for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { + accumulate_reg_tile(this->regx, this->regy); + this->ldsXY(ki); + } + + // Accumulate last loaded tile. + accumulate_reg_tile(this->regx, this->regy); + } + DI void load_norms(IdxT tile_idx_m, IdxT tile_idx_n, DataT (®xn)[P::AccRowsPerTh], @@ -274,7 +295,11 @@ struct PairwiseDistances : public BaseClass { template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { - const auto numSMs = raft::getMultiProcessorCount(); + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int numSMs; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); + int numBlocksPerSm = 0; dim3 grid; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index c5fdd28117..efcd5d9389 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -64,21 +64,20 @@ template -typename std::enable_if::value>::type cutlassDistanceKernel( - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) +std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 8524ce6fdf..e04b56ee8a 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,63 +15,74 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/distance/specializations/detail/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and the specializations in src/distance/distance/specializations/detail/. namespace raft::distance::detail { +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance specializations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } + typename SM_compat_t> +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); +template +void pairwise_matrix_instantiation_point(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) +{ // On CUDA 12: // - always execute normal kernel // // On CUDA 11 and below: // - execute CUTLASS-based kernel on SM_80 and above // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); if constexpr (is_ctk_12 || cutlass_op_unavailable) { // Always execute legacy kernels on CUDA 12 - auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else { - auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); // Get pointer to SM60 kernel to determine the runtime architecture of the // current system. Other methods to determine the architecture (that do not @@ -79,7 +90,7 @@ void pairwise_matrix_dispatch(OpT distance_op, // https://github.com/NVIDIA/cub/issues/545 auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -92,4 +103,35 @@ void pairwise_matrix_dispatch(OpT distance_op, } } +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + pairwise_matrix_instantiation_point(distance_op, params, stream); +} + }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh index c1e4c08af4..f2b0e59822 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include "kernel_sm60.cuh" -#include -#include - +#include // std::min +#include // size_t +#include // RAFT_EXPECTS +#include // pairwise_matrix_params +#include // std::integral_constant namespace raft::distance::detail { /** @@ -99,15 +100,15 @@ auto dispatch_layout(bool row_major, int vec_len, F&& f) { if (row_major) { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::true_type(), vec_len_constant<4>()); + case 2: return f(std::true_type(), vec_len_constant<2>()); + default: return f(std::true_type(), vec_len_constant<1>()); } } else { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::false_type(), vec_len_constant<4>()); + case 2: return f(std::false_type(), vec_len_constant<2>()); + default: return f(std::false_type(), vec_len_constant<1>()); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh index 6e284007ea..2080fbe9cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -15,10 +15,10 @@ */ #pragma once -#include -#include -#include -#include +#include // std::min +#include // dispatch_layout +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 namespace raft::distance::detail { @@ -35,7 +35,11 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 { int vec_len = determine_vec_len(params); - return dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and returns + // the corresponding kernel wrapper. The wrapper contains the launch + // parameters of the kernel: a pointer to the kernel function, grid size, + // block size, and shared memory size. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -46,15 +50,19 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 // Prevent double, vec_len=4 combination (this is not supported) constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = typename std::conditional::type; auto wrapper = make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); return wrapper; - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + return dispatch_layout(params.is_row_major, vec_len, f); } template // std::min -#include -#include +#include // std::min +#include // cutlassDistanceKernel +#include // dispatch_layout namespace raft::distance::detail { @@ -34,7 +34,9 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, { int vec_len = determine_vec_len(params); - dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and runs the + // corresponding cutlass launch code. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -56,7 +58,11 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, params.fin_op, distance_op, stream); - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + dispatch_layout(params.is_row_major, vec_len, f); } }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 6e3ab7b26b..2d0a98862e 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -15,11 +15,11 @@ */ #pragma once -#include -#include -#include -#include -#include +#include // assert +#include // raft::void_op +#include // PairwiseDistances +#include // pairwise_matrix_params +#include // raft::util::arch::SM_compute_arch namespace raft::distance::detail { @@ -36,43 +36,27 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { + if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { assert(false); return; } extern __shared__ char smem[]; - using AccT = typename OpT::AccT; - - // Wrap operator back into lambdas. This is temporary and should be removed. - // See: https://github.com/rapidsai/raft/issues/1323 - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - distance_op.core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); - }; - + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances -void pairwise_matrix(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - dim3 blk(Policy::Nthreads); - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - size_t smem_size = distance_op.template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - kernel<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); -} - // The type of a pointer to the pairwise matrix kernel. The following template // arguments are type-erased: // @@ -181,9 +140,9 @@ pairwise_matrix_sm60_wrapper make_pairwise_matri SM_compat_t sm_compat_range) { dim3 block(Policy::Nthreads); - // Use .template to disambiguate (See: + // Use ::template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = distance_op.template shared_mem_size(); + int smem_size = OpT::template shared_mem_size(); // Obtain function pointer to kernel auto kernel = pairwise_matrix_kernel; diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..63ae6580b4 --- /dev/null +++ b/cpp/include/raft/distance/specializations/detail/00_write_template.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# This template manages all files in this directory, apart from +# inner_product.cuh and kernels.cuh. + + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +start_template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail { + +""" + +extern_template = """ +extern template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); +""" + +end_template = """} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + + + + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + # cosine uses CUTLASS for SM80+ + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + # L2 expanded uses CUTLASS for SM80+ + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +for op_instance in op_instances: + path = fill_in("path_prefix.cuh", op_instance) + with open(path, "w") as f: + f.write(start_template) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "raft::identity_op", + } + + text = fill_in(extern_template, instance) + + f.write(text) + + f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index badce715a5..276c85e5f6 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -16,37 +16,25 @@ #pragma once -#include #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + float, + float, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + double, + double, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index 013a0d43a3..f019f678df 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + float, + float, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + double, + double, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index c88bd1b0f6..dcde4ec286 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::cosine_distance_op, + int, + double, + double, + raft::identity_op>(ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index 3c5cad3315..1d6964fbce 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + float, + float, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + double, + double, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index bf214c046f..f96a06f919 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + float, + float, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + double, + double, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 145834fb70..0b58646582 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + float, + float, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + double, + double, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index f0928916cd..5c164e0fd4 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + raft::identity_op>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index 23261a2571..870627d909 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -18,35 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + raft::identity_op>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index f953018b7d..ee3207bcce 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l2_exp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh deleted file mode 100644 index 9f5f6a3706..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh deleted file mode 100644 index 94531ddc33..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index 224b21fce8..1fbf57632b 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh index 9a46d7b488..388d3bf439 100644 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ b/cpp/include/raft/distance/specializations/detail/l_inf.cuh @@ -18,35 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l_inf_distance_op, + int, + double, + double, + raft::identity_op>(ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index e05ef02c42..d8e86ce6f2 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index afc87997c0..4803fb8ab0 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + float, + float, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + double, + double, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 8daa398b49..a34f696e9e 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 4e18a210d4..4a571c1447 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -22,6 +22,8 @@ #include "processing.cuh" #include #include +#include +#include #include #include @@ -183,13 +185,11 @@ DI void updateSortedWarpQ( } } -template Pair; @@ -222,295 +222,279 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x using namespace raft::neighbors::detail::faiss_select; typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( - IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - Pair* shDumpKV = nullptr; - if (useNorms) { - shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } } - } - __threadfence(); - __syncthreads(); + __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); // Perform merging of otherKV with topk's across warp. #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); } } + cta_processed++; } - cta_processed++; - } #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } } } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; + __threadfence(); + __syncthreads(); - // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); } - } + }; - Pair* shDumpKV = nullptr; - if (useNorms) { - constexpr size_t shmemSize = - Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - shDumpKV = (Pair*)(&smem[shmemSize]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } } - } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs += numValsWarpTopK[i]; } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } } } } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } } } } - } - } else { + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); - } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } } } - } - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; - raft::distance::detail::PairwiseDistances + write_out> obj(x, y, m, @@ -521,9 +505,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x ldd, _xn, _yn, - nullptr, + nullptr, // output ptr, can be null as write_out == false. smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -562,38 +546,32 @@ void fusedL2UnexpKnnImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { @@ -604,8 +582,10 @@ void fusedL2UnexpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); if (grid.x > 1) { @@ -628,9 +608,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, (int*)workspace, out_dists, @@ -753,36 +732,33 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(workspace != nullptr, "workspace is null"); dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { @@ -793,9 +769,8 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + - ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + - (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t* mutexes = nullptr; @@ -835,9 +810,8 @@ void fusedL2ExpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, mutexes, out_dists, diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 8c48b87269..dc35b10063 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -15,25 +15,27 @@ */ #pragma once -namespace raft::arch { +#include // RAFT_CUDA_TRY -/* raft::arch provides the following facilities: +namespace raft::util::arch { + +/* raft::util::arch provides the following facilities: * - * - raft::arch::SM_XX : hardcoded compile-time constants for various compute - * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * - raft::util::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::util::arch::SM_min and raft::util::arch::SM_future * represent architectures that are always smaller and larger (respectively) * than any architecture that can be encountered in practice. * - * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * - raft::util::arch::SM_compute_arch : a compile-time value for the *current* * compute architecture that a kernel is compiled with. It can only be used * inside kernels with a template argument. * - * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time* * which version of a kernel will launch (i.e., it will return the compute * architecture of the version of the kernel that will be launched by the * driver). * - * - raft::arch::SM_range : a compile-time value to represent an open interval + * - raft::util::arch::SM_range : a compile-time value to represent an open interval * of compute architectures. This can be used to check if the current * compile-time architecture is in a specified compatibility range. */ @@ -46,9 +48,6 @@ struct SM_generic { public: __host__ __device__ constexpr int value() const { return n; } }; - -// A dummy kernel that is used to determine the runtime architecture. -__global__ inline void dummy_runtime_kernel() {} } // namespace detail // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) @@ -119,7 +118,7 @@ struct SM_runtime { inline SM_runtime kernel_runtime_arch(void* kernel) { cudaFuncAttributes attributes; - cudaFuncGetAttributes(&attributes, kernel); + RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel)); return SM_runtime(10 * attributes.ptxVersion); } @@ -143,4 +142,4 @@ struct SM_range { } }; -} // namespace raft::arch +} // namespace raft::util::arch diff --git a/cpp/include/raft/util/cuda_dev_essentials.cuh b/cpp/include/raft/util/cuda_dev_essentials.cuh new file mode 100644 index 0000000000..5080dc33ee --- /dev/null +++ b/cpp/include/raft/util/cuda_dev_essentials.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions for use in __device__ code. The +// scope is necessarily limited to ensure that compilation times are minimized. +// Please make sure not to include large / expensive files from here. + +namespace raft { + +/** helper macro for device inlined functions */ +#define DI inline __device__ +#define HDI inline __host__ __device__ +#define HD __host__ __device__ + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType ceildiv(IntType a, IntType b) +{ + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignTo(IntType a, IntType b) +{ + return ceildiv(a, b) * b; +} + +/** + * @brief Provide an alignment function ie. (a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignDown(IntType a, IntType b) +{ + return (a / b) * b; +} + +/** + * @brief Check if the input is a power of 2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI bool isPo2(IntType num) +{ + return (num && !(num & (num - 1))); +} + +/** + * @brief Give logarithm of the number to base-2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) +{ + return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); +} + +/** number of threads per warp */ +static const int WarpSize = 32; + +/** get the laneId of the current thread */ +DI int laneId() +{ + int id; + asm("mov.s32 %0, %%laneid;" : "=r"(id)); + return id; +} + +} // namespace raft diff --git a/cpp/include/raft/util/cuda_rt_essentials.hpp b/cpp/include/raft/util/cuda_rt_essentials.hpp new file mode 100644 index 0000000000..e5f3af4e61 --- /dev/null +++ b/cpp/include/raft/util/cuda_rt_essentials.hpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions that wrap the CUDA runtime API. +// The scope is necessarily limited to ensure that compilation times are +// minimized. Please make sure not to include large / expensive files from here. + +#include +#include + +namespace raft { + +/** + * @brief Exception thrown when a CUDA error is encountered. + */ +struct cuda_error : public raft::exception { + explicit cuda_error(char const* const message) : raft::exception(message) {} + explicit cuda_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for CUDA runtime API functions. + * + * Invokes a CUDA runtime API function call, if the call does not return + * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an + * exception detailing the CUDA error that occurred + * + */ +#define RAFT_CUDA_TRY(call) \ + do { \ + cudaError_t const status = call; \ + if (status != cudaSuccess) { \ + cudaGetLastError(); \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "CUDA error encountered at: ", \ + "call='%s', Reason=%s:%s", \ + #call, \ + cudaGetErrorName(status), \ + cudaGetErrorString(status)); \ + throw raft::cuda_error(msg); \ + } \ + } while (0) diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 5be9dc999a..687a6b4651 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -23,113 +23,10 @@ #include #include #include - -#ifndef ENABLE_MEMCPY_ASYNC -// enable memcpy_async interface by default for newer GPUs -#if __CUDA_ARCH__ >= 800 -#define ENABLE_MEMCPY_ASYNC 1 -#endif -#else // ENABLE_MEMCPY_ASYNC -// disable memcpy_async for all older GPUs -#if __CUDA_ARCH__ < 800 -#define ENABLE_MEMCPY_ASYNC 0 -#endif -#endif // ENABLE_MEMCPY_ASYNC +#include namespace raft { -/** helper macro for device inlined functions */ -#define DI inline __device__ -#define HDI inline __host__ __device__ -#define HD __host__ __device__ - -/** - * @brief Provide a ceiling division operation ie. ceil(a / b) - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType ceildiv(IntType a, IntType b) -{ - return (a + b - 1) / b; -} - -/** - * @brief Provide an alignment function ie. ceil(a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignTo(IntType a, IntType b) -{ - return ceildiv(a, b) * b; -} - -/** - * @brief Provide an alignment function ie. (a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignDown(IntType a, IntType b) -{ - return (a / b) * b; -} - -/** - * @brief Check if the input is a power of 2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI bool isPo2(IntType num) -{ - return (num && !(num & (num - 1))); -} - -/** - * @brief Give logarithm of the number to base-2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) -{ - return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); -} - -/** Device function to apply the input lambda across threads in the grid */ -template -DI void forEach(int num, L lambda) -{ - int idx = (blockDim.x * blockIdx.x) + threadIdx.x; - const int numThreads = blockDim.x * gridDim.x; -#pragma unroll - for (int itr = 0; itr < ItemsPerThread; ++itr, idx += numThreads) { - if (idx < num) lambda(idx, itr); - } -} - -/** number of threads per warp */ -static const int WarpSize = 32; - -/** get the laneId of the current thread */ -DI int laneId() -{ - int id; - asm("mov.s32 %0, %%laneid;" : "=r"(id)); - return id; -} - -/** - * @brief Swap two values - * @tparam T the datatype of the values - * @param a first input - * @param b second input - */ -template -HDI void swapVals(T& a, T& b) -{ - T tmp = a; - a = b; - b = tmp; -} - /** Device function to have atomic add support for older archs */ template DI void myAtomicAdd(Type* address, Type val) diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 0feb188ad8..0a7ca23028 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include #include @@ -40,42 +41,7 @@ #include #include #include - -namespace raft { - -/** - * @brief Exception thrown when a CUDA error is encountered. - */ -struct cuda_error : public raft::exception { - explicit cuda_error(char const* const message) : raft::exception(message) {} - explicit cuda_error(std::string const& message) : raft::exception(message) {} -}; - -} // namespace raft - -/** - * @brief Error checking macro for CUDA runtime API functions. - * - * Invokes a CUDA runtime API function call, if the call does not return - * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an - * exception detailing the CUDA error that occurred - * - */ -#define RAFT_CUDA_TRY(call) \ - do { \ - cudaError_t const status = call; \ - if (status != cudaSuccess) { \ - cudaGetLastError(); \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "CUDA error encountered at: ", \ - "call='%s', Reason=%s:%s", \ - #call, \ - cudaGetErrorName(status), \ - cudaGetErrorString(status)); \ - throw raft::cuda_error(msg); \ - } \ - } while (0) +#include // FIXME: Remove after consumers rename #ifndef CUDA_TRY diff --git a/cpp/include/raft/util/device_loads_stores.cuh b/cpp/include/raft/util/device_loads_stores.cuh index 2b87c44d60..c9bda26b81 100644 --- a/cpp/include/raft/util/device_loads_stores.cuh +++ b/cpp/include/raft/util/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ #pragma once -#include +#include // uintX_t +#include // DI namespace raft { diff --git a/cpp/src/distance/specializations/detail/00_write_template.py b/cpp/src/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..3f2f853569 --- /dev/null +++ b/cpp/src/distance/specializations/detail/00_write_template.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +INCLUDE_SM_HEADERS + +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + archs = [60], + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +def fill_include_sm_headers(op_instance): + include_headers ="\n".join([ + f"#include " + for arch in op_instance["archs"] + ]) + + return { + "path_prefix": op_instance["path_prefix"], + "OpT": op_instance["OpT"], + "INCLUDE_SM_HEADERS": include_headers + } + +for op_instance in op_instances: + op_instance = fill_include_sm_headers(op_instance) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "decltype(raft::identity_op())", + } + + text = fill_in(template, instance) + + path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) + with open(path, "w") as f: + f.write(text) diff --git a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu index 4e9e608792..037d218178 100644 --- a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu @@ -14,24 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu index 6dfc385e55..0ed8ea7bb0 100644 --- a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu index 2df77a4b5d..0c11f0621e 100644 --- a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu index 76ed00afa6..396e158554 100644 --- a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu index 3e0bcb92ed..e9afb6f563 100644 --- a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu @@ -14,26 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu index 23131ce2c7..1033c491d6 100644 --- a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu @@ -14,26 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index b618fd024c..195115914d 100644 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index 18e7aad9e9..a74c6c404e 100644 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index 08ab20cfe5..bac1dd7bd0 100644 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index 79eed075fb..77c113b1a9 100644 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index ed84ee6dc4..188e52c152 100644 --- a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index a241af767c..b0afbf7bb2 100644 --- a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu index c4c944d123..f06ae85414 100644 --- a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu index aa1db5a837..00d5a5ee5b 100644 --- a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu index 391a1c2aa4..5c235316da 100644 --- a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu index 7b45e52ca1..fb293ca83d 100644 --- a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu index 8c5f746fa2..2c02f0224f 100644 --- a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -14,24 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu index c266125f98..85e25a25ca 100644 --- a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu deleted file mode 100644 index 399b120527..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu deleted file mode 100644 index 66de212b8e..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu deleted file mode 100644 index 562d93b2de..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu deleted file mode 100644 index 386bbafc5f..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index 7733c3af48..5b4d995d14 100644 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index 4ea18d31de..a63c3f0bb8 100644 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu index 74414f8fd6..831167523f 100644 --- a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu index e418fc455f..02e667cbe3 100644 --- a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index 402cb51b7e..ebd71065ec 100644 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index 7efe2b3349..b94a81fdce 100644 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu index b1e6f5e1f4..6f952fcc37 100644 --- a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu index 1e12bcd705..3223ce33a7 100644 --- a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 0e084f2ad8..438e212fbd 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -16,16 +16,24 @@ #include "../test_utils.cuh" #include -#include -#include -#include -#include -#include -#include +#include // common::nvtx::range + +#include // make_device_matrix_view +#include // raft::device_resources +#include // raft::sqrt +#include // raft::distance::DistanceType +#include +#include // rmm::device_uvector + +// When the distance library is precompiled, include only the raft_runtime +// headers. This way, a small change in one of the kernel internals does not +// trigger a rebuild of the test files (it of course still triggers a rebuild of +// the raft specializations) #if defined RAFT_COMPILED -#include +#include +#else +#include #endif -#include namespace raft { namespace distance { @@ -409,6 +417,25 @@ template return os; } +// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is +// implemented. +// +// Context: +// https://github.com/rapidsai/raft/issues/1338 +template +constexpr bool layout_to_row_major(); + +template <> +constexpr bool layout_to_row_major() +{ + return true; +} +template <> +constexpr bool layout_to_row_major() +{ + return false; +} + template void distanceLauncher(raft::device_resources const& handle, DataType* x, @@ -422,12 +449,23 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { +#if defined RAFT_COMPILED + // TODO: Implement and use mdspan-based + // raft::runtime::distance::pairwise_distance here. + // + // Context: + // https://github.com/rapidsai/raft/issues/1338 + bool row_major = layout_to_row_major(); + raft::runtime::distance::pairwise_distance( + handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); +#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); +#endif } template @@ -523,9 +561,25 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + void pairwise_distance(raft::device_resources const& handle, + float* x, + float* y, + float* dists, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + constexpr bool row_major = true; + constexpr float metric_arg = 0.0f; +#if defined RAFT_COMPILED + raft::runtime::distance::pairwise_distance( + handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); +#else raft::distance::distance( - handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); - + handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); +#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 4a74d7f16a..383ad39319 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -182,22 +182,20 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - MinAndDistanceReduceOp redOp; - fusedL2NN, int>( - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true, - stream); + + const bool init_out_buffer = true; + fusedL2NNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + Sqrt, + init_out_buffer, + stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } }; From 7c73f23ddf5f81ad2cd057c650b7e2e8947c5265 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 23 Mar 2023 03:29:03 -0700 Subject: [PATCH 04/13] Add nccl to dependencies.yaml (#1361) raft-dask requires nccl to build, add to the dependencies.yaml so that when creating a clean raft conda environment - we can build all of raft out of the box Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/raft/pull/1361 --- conda/environments/all_cuda-118_arch-x86_64.yaml | 1 + dependencies.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 39f1fef4d5..7972a8824d 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -33,6 +33,7 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- nccl>=2.9.9 - ninja - numpydoc - pydata-sphinx-theme diff --git a/dependencies.yaml b/dependencies.yaml index 9fbf26bcd1..c06ce4a20f 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -53,6 +53,7 @@ dependencies: packages: - c-compiler - cxx-compiler + - nccl>=2.9.9 specific: - output_types: conda matrices: From 419f0c28cd18654064bfd3a6ed2f638e239f46b3 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Thu, 23 Mar 2023 09:28:16 -0400 Subject: [PATCH 05/13] Generate pyproject dependencies with dfg (#1364) This PR uses dependencies.yaml to generate the dependency lists in pyproject.toml Authors: - Vyas Ramasubramani (https://github.com/vyasr) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1364 --- .pre-commit-config.yaml | 2 +- .../all_cuda-118_arch-x86_64.yaml | 9 +- conda/recipes/pylibraft/meta.yaml | 1 + dependencies.yaml | 107 +++++++++++++++--- python/pylibraft/pyproject.toml | 24 ++-- python/raft-dask/pyproject.toml | 24 ++-- 6 files changed, 127 insertions(+), 40 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7606914589..630b8788f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -101,7 +101,7 @@ repos: args: ["--toml", "pyproject.toml"] exclude: (?x)^(^CHANGELOG.md$) - repo: https://github.com/rapidsai/dependency-file-generator - rev: v1.4.0 + rev: v1.5.1 hooks: - id: rapids-dependency-file-generator args: ["--clean"] diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 7972a8824d..1afebc98e6 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -18,13 +18,14 @@ dependencies: - cupy - cxx-compiler - cython>=0.29,<0.30 -- dask-cuda=23.04 +- dask-cuda==23.4.* - dask>=2023.1.1 - distributed>=2023.1.1 - doxygen>=1.8.20 - gcc_linux-64=11.* - graphviz - ipython +- joblib>=0.11 - libcublas-dev=11.11.3.6 - libcublas=11.11.3.6 - libcurand-dev=10.3.0.86 @@ -35,12 +36,14 @@ dependencies: - libcusparse=11.7.5.86 - nccl>=2.9.9 - ninja +- numba>=0.49 +- numpy>=1.21 - numpydoc - pydata-sphinx-theme - pytest - pytest-cov - recommonmark -- rmm=23.04 +- rmm==23.4.* - scikit-build>=0.13.1 - scikit-learn - scipy @@ -48,6 +51,6 @@ dependencies: - sphinx-markdown-tables - sysroot_linux-64==2.17 - ucx-proc=*=gpu -- ucx-py=0.31.* +- ucx-py==0.31.* - ucx>=1.13.0 name: all_cuda-118_arch-x86_64 diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index a528064348..7730801801 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -36,6 +36,7 @@ requirements: - cython >=0.29,<0.30 - libraft {{ version }} - libraft-headers {{ version }} + - numpy >=1.21 - python x.x - rmm ={{ minor_version }} - scikit-build >=0.13.1 diff --git a/dependencies.yaml b/dependencies.yaml index c06ce4a20f..dd361a0cdf 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -7,11 +7,14 @@ files: arch: [x86_64] includes: - build + - build_pylibraft - cudatoolkit - develop - docs - - run - - test_python + - run_raft_dask + - run_pylibraft + - test_python_common + - test_pylibraft test_cpp: output: none includes: @@ -21,7 +24,8 @@ files: includes: - cudatoolkit - py_version - - test_python + - test_python_common + - test_pylibraft checks: output: none includes: @@ -33,6 +37,54 @@ files: - cudatoolkit - docs - py_version + py_build_pylibraft: + output: pyproject + pyproject_dir: python/pylibraft + extras: + table: build-system + includes: + - build + - build_pylibraft + - build_wheels + py_run_pylibraft: + output: pyproject + pyproject_dir: python/pylibraft + extras: + table: project + includes: + - run_pylibraft + py_test_pylibraft: + output: pyproject + pyproject_dir: python/pylibraft + extras: + table: project.optional-dependencies + key: test + includes: + - test_python_common + - test_pylibraft + py_build_raft_dask: + output: pyproject + pyproject_dir: python/raft-dask + extras: + table: build-system + includes: + - build + - build_wheels + py_run_raft_dask: + output: pyproject + pyproject_dir: python/raft-dask + extras: + table: project + includes: + - run_raft_dask + py_test_raft_dask: + output: pyproject + pyproject_dir: python/raft-dask + extras: + table: project.optional-dependencies + key: test + includes: + - test_python_common channels: - rapidsai - rapidsai-nightly @@ -42,10 +94,9 @@ channels: dependencies: build: common: - - output_types: [conda, requirements] + - output_types: [conda, requirements, pyproject] packages: - cmake>=3.23.1,!=3.25.0 - - cuda-python >=11.7.1,<12.0 - cython>=0.29,<0.30 - ninja - scikit-build>=0.13.1 @@ -67,6 +118,12 @@ dependencies: packages: - gcc_linux-aarch64=11.* - sysroot_linux-aarch64==2.17 + build_pylibraft: + common: + - output_types: [conda, requirements, pyproject] + packages: + - &cuda_python cuda-python >=11.7.1,<12.0 + - &rmm rmm==23.4.* checks: common: - output_types: [conda, requirements] @@ -151,6 +208,12 @@ dependencies: - recommonmark - sphinx-copybutton - sphinx-markdown-tables + build_wheels: + common: + - output_types: pyproject + packages: + - wheel + - setuptools py_version: specific: - output_types: conda @@ -170,23 +233,41 @@ dependencies: - matrix: packages: - python>=3.8,<3.11 - run: + run_pylibraft: common: - - output_types: [conda] + - output_types: [conda, pyproject] + packages: + - &numpy numpy>=1.21 + - *cuda_python + - *rmm + run_raft_dask: + common: + - output_types: [conda, pyproject] packages: - dask>=2023.1.1 + - dask-cuda==23.4.* - distributed>=2023.1.1 + - joblib>=0.11 + - numba>=0.49 + - *numpy + - ucx-py==0.31.* + - output_types: conda + packages: - ucx>=1.13.0 - - ucx-py=0.31.* - ucx-proc=*=gpu - - rmm=23.04 - - dask-cuda=23.04 - test_python: + - output_types: pyproject + packages: + - pylibraft==23.4.* + test_python_common: common: - - output_types: [conda, requirements] + - output_types: [conda, requirements, pyproject] packages: - - cupy - pytest - pytest-cov + test_pylibraft: + common: + - output_types: [conda, requirements, pyproject] + packages: + - cupy - scikit-learn - scipy diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 7d92fd0763..fed15bbab0 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -15,15 +15,15 @@ [build-system] requires = [ - "wheel", - "setuptools", - "cython>=0.29,<0.30", - "cuda-python>=11.7.1,<12.0", - "scikit-build>=0.13.1", "cmake>=3.23.1,!=3.25.0", + "cuda-python >=11.7.1,<12.0", + "cython>=0.29,<0.30", "ninja", "rmm==23.4.*", -] + "scikit-build>=0.13.1", + "setuptools", + "wheel", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. build-backend = "setuptools.build_meta" [project] @@ -37,10 +37,10 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.8" dependencies = [ - "numpy", - "cuda-python>=11.7.1,<12.0", + "cuda-python >=11.7.1,<12.0", + "numpy>=1.21", "rmm==23.4.*", -] +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", @@ -50,10 +50,12 @@ classifiers = [ [project.optional-dependencies] test = [ + "cupy", "pytest", - "scipy", + "pytest-cov", "scikit-learn", -] + "scipy", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] Homepage = "https://github.com/rapidsai/raft" diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 2fe6522f57..fe490ea117 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -15,13 +15,13 @@ [build-system] requires = [ - "wheel", - "setuptools", - "cython>=0.29,<0.30", - "scikit-build>=0.13.1", "cmake>=3.23.1,!=3.25.0", + "cython>=0.29,<0.30", "ninja", -] + "scikit-build>=0.13.1", + "setuptools", + "wheel", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project] name = "raft-dask" @@ -34,15 +34,15 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.8" dependencies = [ - "numpy", - "numba>=0.49", - "joblib>=0.11", "dask-cuda==23.4.*", "dask>=2023.1.1", - "ucx-py==0.31.*", "distributed>=2023.1.1", + "joblib>=0.11", + "numba>=0.49", + "numpy>=1.21", "pylibraft==23.4.*", -] + "ucx-py==0.31.*", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", @@ -53,8 +53,8 @@ classifiers = [ [project.optional-dependencies] test = [ "pytest", - "dask[distributed,dataframe]", -] + "pytest-cov", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] Homepage = "https://github.com/rapidsai/raft" From 31847afbaa55ead3ee99d44fcbe0c41ff8e1f726 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 23 Mar 2023 18:48:01 -0400 Subject: [PATCH 06/13] Python API for brute-force KNN (#1292) Closes #1289 Authors: - Corey J. Nolet (https://github.com/cjnolet) - Ben Frederickson (https://github.com/benfred) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1292 --- cpp/CMakeLists.txt | 1 + .../raft_runtime/neighbors/brute_force.hpp | 19 +- .../brute_force_knn_int64_t_float.cu | 46 ++--- python/pylibraft/pylibraft/common/mdspan.pyx | 1 - .../pylibraft/neighbors/CMakeLists.txt | 2 +- .../pylibraft/pylibraft/neighbors/__init__.py | 5 +- .../pylibraft/neighbors/brute_force.pyx | 179 ++++++++++++++++++ .../pylibraft/pylibraft/neighbors/common.pyx | 12 +- .../pylibraft/neighbors/cpp/__init__.pxd | 0 .../pylibraft/neighbors/cpp/__init__.py | 14 ++ .../pylibraft/neighbors/cpp/brute_force.pxd | 55 ++++++ .../pylibraft/test/test_brue_force.py | 99 ++++++++++ .../pylibraft/pylibraft/test/test_doctests.py | 3 +- 13 files changed, 395 insertions(+), 41 deletions(-) create mode 100644 python/pylibraft/pylibraft/neighbors/brute_force.pyx create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/__init__.py create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd create mode 100644 python/pylibraft/pylibraft/test/test_brue_force.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 034dc059b0..c1704552ec 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -304,6 +304,7 @@ if(RAFT_COMPILE_LIBRARY) # These are somehow missing a kernel definition which is causing a compile error. # src/distance/specializations/detail/kernels/rbf_kernel_double.cu # src/distance/specializations/detail/kernels/rbf_kernel_float.cu + src/neighbors/brute_force_knn_int64_t_float.cu src/distance/specializations/detail/kernels/tanh_kernel_double.cu src/distance/specializations/detail/kernels/tanh_kernel_float.cu src/distance/specializations/detail/kl_divergence_float_float_float_int.cu diff --git a/cpp/include/raft_runtime/neighbors/brute_force.hpp b/cpp/include/raft_runtime/neighbors/brute_force.hpp index 19904f4f78..12da6ff101 100644 --- a/cpp/include/raft_runtime/neighbors/brute_force.hpp +++ b/cpp/include/raft_runtime/neighbors/brute_force.hpp @@ -21,18 +21,17 @@ namespace raft::runtime::neighbors::brute_force { -#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ - void knn(raft::device_resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - int k, \ - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ - std::optional metric_arg = std::make_optional(2.0f), \ +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ + std::optional metric_arg = std::make_optional(2.0f), \ std::optional global_id_offset = std::nullopt); -RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major); +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); #undef RAFT_INST_BFKNN diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu index b0411a59ce..585084fc97 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu @@ -14,8 +14,6 @@ * limitations under the License. */ -#pragma once - #include #include #include @@ -24,30 +22,34 @@ #include +#include + namespace raft::runtime::neighbors::brute_force { -#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ - void knn(raft::device_resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ - std::optional metric_arg = std::make_optional(2.0f), \ - std::optional global_id_offset = std::nullopt) \ - { \ - raft::neighbors::brute_force::knn(handle, \ - index, \ - search, \ - indices, \ - distances, \ - static_cast(indices.extent(1)), \ - metric, \ - metric_arg, \ - global_id_offset); \ +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset) \ + { \ + std::vector> vec; \ + vec.push_back(index); \ + raft::neighbors::brute_force::knn(handle, \ + vec, \ + search, \ + indices, \ + distances, \ + static_cast(distances.extent(1)), \ + metric, \ + metric_arg, \ + global_id_offset); \ } -RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major); +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); #undef RAFT_INST_BFKNN diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index c7b42ecab7..f35a94bb9c 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -159,7 +159,6 @@ cdef device_matrix_view[float, int64_t, row_major] \ return make_device_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) - cdef device_matrix_view[uint8_t, int64_t, row_major] \ get_dmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 98f0d7f67a..7b9c1591c1 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index f7510ba2db..a50b6f21a7 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from pylibraft.neighbors import brute_force + from .refine import refine -__all__ = ["common", "refine"] +__all__ = ["common", "refine", "brute_force"] diff --git a/python/pylibraft/pylibraft/neighbors/brute_force.pyx b/python/pylibraft/pylibraft/neighbors/brute_force.pyx new file mode 100644 index 0000000000..dbd888756d --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/brute_force.pyx @@ -0,0 +1,179 @@ +# +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libcpp cimport bool, nullptr +from libcpp.vector cimport vector + +from pylibraft.distance.distance_type cimport DistanceType + +from pylibraft.common import ( + DeviceResources, + auto_convert_output, + cai_wrapper, + device_ndarray, +) + +from libc.stdint cimport int64_t, uintptr_t + +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64 + +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.input_validation import is_c_contiguous +from pylibraft.common.interruptible import cuda_interruptible + +from pylibraft.distance.distance_type cimport DistanceType + +# TODO: Centralize this + +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.neighbors.cpp.brute_force cimport knn as c_knn + + +def _get_array_params(array_interface, check_dtype=None): + dtype = np.dtype(array_interface["typestr"]) + if check_dtype is None and dtype != check_dtype: + raise TypeError("dtype %s not supported" % dtype) + shape = array_interface["shape"] + if len(shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(shape)) + data = array_interface["data"][0] + return (shape, dtype, data) + + +@auto_sync_handle +@auto_convert_output +def knn(dataset, queries, k=None, indices=None, distances=None, + metric="sqeuclidean", metric_arg=2.0, + global_id_offset=0, handle=None): + """ + Perform a brute-force nearest neighbors search. + + Parameters + ---------- + dataset : array interface compliant matrix, row-major layout, + shape (n_samples, dim). Supported dtype [float] + queries : array interface compliant matrix, row-major layout, + shape (n_queries, dim) Supported dtype [float] + k : int + Number of neighbors to search (k <= 2048). Optional if indices or + distances arrays are given (in which case their second dimension + is k). + indices : Optional array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + Supported dtype uint64 + distances : Optional array interface compliant matrix shape + (n_queries, k), dtype float. If supplied, neighbor + indices will be written here in-place. (default None) + + {handle_docstring} + + Returns + ------- + indices: array interface compliant object containing resulting indices + shape (n_queries, k) + + distances: array interface compliant object containing resulting distances + shape (n_queries, k) + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors.brute_force import knn + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 40 + >>> distances, neighbors = knn(dataset, queries, k) + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + """ + + if handle is None: + handle = DeviceResources() + + dataset_cai = cai_wrapper(dataset) + queries_cai = cai_wrapper(queries) + + if k is None: + if indices is not None: + k = cai_wrapper(indices).shape[1] + elif distances is not None: + k = cai_wrapper(distances).shape[1] + else: + raise ValueError("Argument k must be specified if both indices " + "and distances arg is None") + + n_queries = cai_wrapper(queries).shape[0] + + if indices is None: + indices = device_ndarray.empty((n_queries, k), dtype='int64') + + if distances is None: + distances = device_ndarray.empty((n_queries, k), dtype='float32') + + cdef DistanceType c_metric = DISTANCE_TYPES[metric] + + distances_cai = cai_wrapper(distances) + indices_cai = cai_wrapper(indices) + + cdef optional[float] c_metric_arg = metric_arg + cdef optional[int64_t] c_global_offset = global_id_offset + + cdef device_resources* handle_ = \ + handle.getHandle() + + if dataset_cai.dtype == np.float32: + with cuda_interruptible(): + c_knn(deref(handle_), + get_dmv_float(dataset_cai, check_shape=True), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), + c_metric, + c_metric_arg, + c_global_offset) + else: + raise TypeError("dtype %s not supported" % dataset_cai.dtype) + + return (distances, indices) diff --git a/python/pylibraft/pylibraft/neighbors/common.pyx b/python/pylibraft/pylibraft/neighbors/common.pyx index a8380b589b..24c1abcf18 100644 --- a/python/pylibraft/pylibraft/neighbors/common.pyx +++ b/python/pylibraft/pylibraft/neighbors/common.pyx @@ -22,13 +22,15 @@ import warnings from pylibraft.distance.distance_type cimport DistanceType +SUPPORTED_DISTANCES = { + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, + "inner_product": DistanceType.InnerProduct, + +} + def _get_metric(metric): - SUPPORTED_DISTANCES = { - "sqeuclidean": DistanceType.L2Expanded, - "euclidean": DistanceType.L2SqrtExpanded, - "inner_product": DistanceType.InnerProduct - } if metric not in SUPPORTED_DISTANCES: if metric == "l2_expanded": warnings.warn("Using l2_expanded as a metric name is deprecated," diff --git a/python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd b/python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/cpp/__init__.py b/python/pylibraft/pylibraft/neighbors/cpp/__init__.py new file mode 100644 index 0000000000..a7e7b75096 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd new file mode 100644 index 0000000000..de5e0af267 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd @@ -0,0 +1,55 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +import pylibraft.common.handle + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string +from libcpp.vector cimport vector + +from rmm._lib.memory_resource cimport device_memory_resource + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType + + +cdef extern from "raft_runtime/neighbors/brute_force.hpp" \ + namespace "raft::runtime::neighbors::brute_force" nogil: + + cdef void knn(const device_resources & handle, + device_matrix_view[float, int64_t, row_major] index, + device_matrix_view[float, int64_t, row_major] search, + device_matrix_view[int64_t, int64_t, row_major] indices, + device_matrix_view[float, int64_t, row_major] distances, + DistanceType metric, + optional[float] metric_arg, + optional[int64_t] global_id_offset) except + diff --git a/python/pylibraft/pylibraft/test/test_brue_force.py b/python/pylibraft/pylibraft/test/test_brue_force.py new file mode 100644 index 0000000000..f349be892d --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_brue_force.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, Stream, device_ndarray +from pylibraft.neighbors.brute_force import knn + + +@pytest.mark.parametrize("n_index_rows", [32, 100]) +@pytest.mark.parametrize("n_query_rows", [32, 100]) +@pytest.mark.parametrize("n_cols", [40, 100]) +@pytest.mark.parametrize("k", [1, 5, 32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cityblock", + "chebyshev", + "canberra", + "correlation", + "russellrao", + "cosine", + "sqeuclidean", + # "inner_product", + ], +) +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("order", ["F", "C"]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_knn( + n_index_rows, n_query_rows, n_cols, k, inplace, metric, order, dtype +): + index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype) + queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype) + + # RussellRao expects boolean arrays + if metric == "russellrao": + index[index < 0.5] = 0.0 + index[index >= 0.5] = 1.0 + queries[queries < 0.5] = 0.0 + queries[queries >= 0.5] = 1.0 + + indices = np.zeros((n_query_rows, k), dtype="int64") + distances = np.zeros((n_query_rows, k), dtype=dtype) + + index_device = device_ndarray(index) + + queries_device = device_ndarray(queries) + indices_device = device_ndarray(indices) + distances_device = device_ndarray(distances) + + s2 = Stream() + handle = DeviceResources(stream=s2) + ret_distances, ret_indices = knn( + index_device, + queries_device, + k, + indices=indices_device, + distances=distances_device, + metric=metric, + handle=handle, + ) + handle.sync() + + pw_dists = cdist(queries, index, metric=metric) + + distances_device = ret_distances if not inplace else distances_device + + actual_distances = distances_device.copy_to_host() + + actual_distances[actual_distances <= 1e-5] = 0.0 + argsort = np.argsort(pw_dists, axis=1) + + for i in range(pw_dists.shape[0]): + expected_indices = argsort[i] + gpu_dists = actual_distances[i] + + if metric == "correlation" or metric == "cosine": + gpu_dists = gpu_dists[::-1] + + cpu_ordered = pw_dists[i, expected_indices] + np.testing.assert_allclose( + cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 + ) diff --git a/python/pylibraft/pylibraft/test/test_doctests.py b/python/pylibraft/pylibraft/test/test_doctests.py index 3276ca115f..34be6c55f5 100644 --- a/python/pylibraft/pylibraft/test/test_doctests.py +++ b/python/pylibraft/pylibraft/test/test_doctests.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -96,6 +96,7 @@ def _find_doctests_in_obj(obj, finder=None, criteria=None): DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.distance)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.brute_force)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.random)) From 0df5cee684cb032e81308e355852b437265a6ffd Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 23 Mar 2023 23:55:31 +0100 Subject: [PATCH 07/13] Relax UCX pin to allow 1.14 (#1366) UCX 1.14.0 was recently released and conda-forge package was updated in https://github.com/conda-forge/ucx-split-feedstock/pull/111 with several packaging improvements. Relax the pin to allow installing UCX v1.14.x as well. Authors: - Peter Andreas Entschev (https://github.com/pentschev) - AJ Schmidt (https://github.com/ajschmidt8) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/raft/pull/1366 --- conda/recipes/raft-dask/conda_build_config.yaml | 2 +- conda/recipes/raft-dask/meta.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index ef22522116..778b187870 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -11,7 +11,7 @@ sysroot_version: - "2.17" ucx_version: - - "1.13.0" + - ">=1.13.0,<1.15.0" ucx_py_version: - "0.31.*" diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index b387f0f47c..59a67fe148 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -54,7 +54,7 @@ requirements: - pylibraft {{ version }} - python x.x - rmm ={{ minor_version }} - - ucx >={{ ucx_version }} + - ucx {{ ucx_version }} - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} From 1b18d1fd5e143a547c94dec1fbf793bdb4685400 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 23 Mar 2023 20:28:57 -0400 Subject: [PATCH 08/13] Adding architecture diagram to README.md (#1370) I noticed some folks on Twitter asking about an architecture diagram for RAFT. I think it's a good idea to provide this in the README. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/1370 --- README.md | 10 +++++++--- img/arch.png | Bin 0 -> 52500 bytes 2 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 img/arch.png diff --git a/README.md b/README.md index 8519ebcae1..8c6e817bf9 100755 --- a/README.md +++ b/README.md @@ -35,12 +35,16 @@ While not exhaustive, the following general categories help summarize the accele | **Tools & Utilities** | common utilities for developing CUDA applications, multi-node multi-gpu infrastructure | -All of RAFT's C++ APIs can be accessed header-only and optional pre-compiled shared libraries can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers. +RAFT is a C++ header-only template library with an optional shared library that +1) can speed up compile times for common template types, and +2) provides host-accessible "runtime" APIs, which don't require a CUDA compiler to use -In addition to the C++ library, RAFT also provides 2 Python libraries: -- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible "runtime" APIs. +In addition being a C++ library, RAFT also provides 2 Python libraries: +- `pylibraft` - lightweight Python wrappers around RAFT's host-accessible "runtime" APIs. - `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask. +![RAFT is a C++ header-only template library with optional shared library and lightweight Python wrappers](img/arch.png) + ## Getting started ### RAPIDS Memory Manager (RMM) diff --git a/img/arch.png b/img/arch.png new file mode 100644 index 0000000000000000000000000000000000000000..ea9cad9204d16a1ff4587a08e53e8530c5ba0e33 GIT binary patch literal 52500 zcmcG#WmsEH)HO`7;O_43Qrsz0v`{E6#i6*n1Ziwwc z-c?%h0}{0PA(?-Hz9)8jtK+8WXzAu@>S6(7<=|* zhh$tVOx7^6Pj<@P3Ije_HYz zgm~HiX>=U_xBC6+A4PQ_fqm7KE?;$ZZsoUD!vdaLH0IakC;K_G`c^{5qdf)0u`s#^ zllxi-|1;$)za0{F5H&DB}c#@Dtd z`1RWJL)6wP9QJEV=+(Ui<&+iTz4`YR`tnr?K=*&8Kg_*Gj)k2--AGa%h&Dx&X?q9w z@Nq%HT=LmxAOkJtXZAA^7qI*#Al<$%5zd7&e38|3>I0YZ-zB(;HK5_vs=z-hnu1^z zXlN)Auhlz%4hyiUH=11;{!ulQpnOdIxJ#w%GNRT(@L!I`4V!_3C;N^=nj3E_45k&p zJZD3|=VfGcVs$rsqOE1#e?LYXq_H3?D~=|d{hO?LWFcx=6kEbvDy?U!ARpF;L+&NK zO9%1t7>sS?bx2XP0<(7FeEj3z^7rPzeETwR8r&D79*ERDCcTPieHG&whxo&gBthGg zpj8}<)#^uE3lY!4{PWYa3j`)fHi~@xrUG4l<<~Ed6i5`oZS*oajzQSUgqlm+Unu>m z9==0}zL&z@$c!d}J9IK|UoA{X!IQ+nCT0f`VRxv23a$<2q*M5at>I3^txx}UpvM-a zNm$Bq(YrJWDS#P^%d&0kaTAfvrEa7LMqpovt;^v9x9=zYqm}}-nUyn9_Ufg4Wdo-# zs4sEjznR)kyOXtEb2lHx`gGevh8>1=VCY%z2}m7E25$1P9&ujpe^;adjXiU|mC9_C z-O+vDA!DE+#Xw}K(hLASuFFnMC)Y-FJ!qNYkWv!mjm!>CQ5h)xv?m5LWY}KpMGw&X zOPnUgV#hrmrA7Mu&k z-Z8>myy-?SDRpmM4-->+TRRN^MwWrGC8veGb`%fqhH#4F$lCDoH`OwFf>uY}6Pb9MxN;kxB@Qa|c)&;Kv}Tf)8(HydFvEy0$9v|;|vUO(C4AS{Ddc-24JAmp7%^<=j1^v!Y!!C8kF zpjbP`EHuKk?o8CK3mdHw4ePU%9wVj&tNzPqeVMOe(znngBbj==kOH^p&_Xd~DyRpb z@YTJ^h`$U?^08zKZY7?mv_wNy@a%Lf$(dLSwEW61qURf}MS2RlIKIJj@H#Z}2V|6& zo!QRZTbkb+e28KCH`=?iNqHaa1nxO^hE1S=x#}(`M6SO<+BsFM%wHUj15}U^-$J4h zs{<&gn!(l8959=evyFQ%@|%&ze+0;H_rN)>2dhb-^S0K|!O4K1pbh`^R};>Ey+$Yz zYMMtG^~2tP?!cpOc00x{zMjP1DrV{fRR@|L1>7l7p*=#brEyehH-Q?#_oML5!&1s8 z`D|y(tJ=PG<$r@0H~0Vt4S{Illjq>n@`3jf;zhoF<>c^0QKoCc0|*BVG5+fY#MlzG z=pAG6R))rBUur5nxsMFtv|4ayW&Vw@36Cj=j)td?K{@JngRNF>lwt0Mjqe6AiA?>9 z>*^?Ei_moFzV=JIyVf#qAG(akcIoF+$r{Ud|5CtRTi(Bs4C_un8vQYAW#@@5q>szL zWT_Oz5hd|iGRUZ!u8Wh_qX%Y9GCLs;M#ZX@Xja#~CYymqO=$GeO^PFC`mu_QxJr!Z zf076!C~l_|QJAjHuULD3#xA2CD6_|2pbw((eSWfr5s zvC2}%t|F){Up>$xC-kH-ys;5F z0ZRcUvmz+4jX6%J-yOnddZ5?HEB}(GKopsTkmi%$@>RlU*Y&W+-H6RrV;8gDH+WA^ zbLE-@nKAyuOEcB402&(P-E|C+IT_1LEzuTe7}-(ywfuOOO>|Ii+{o4lthX|{^{Z$O zCVcNoIkOOFsA<}b=9DjeW_U)??j-wROXh4v1_U7t!YFFOKd=lcgn6_o#g>%{Lg$_{ zM$!V+?S~cTpOyHXZ39k2AV|0UUUI_!9D_LJ70i$=%_YZ~{|fU*$E-L(ljJ%Ax|zLL(#GycBEC* z>$-nIOO&Jf`Xw5U7M~}%htX3y7f2w6XsDlaGvEdJ&#SkkNrIs+N#ek?h&s>q*LR31 zENNEVZ3M_ggE+kXZllNkW(!IAI*iSYas3a*oMu12mhaZ7C8k{{TrkRl&4?sv z{Vg9~&AND>{Tg*l-W1^o4a)-83C%ek=ri;E>uXww0uHbqJ>&Dif$;v;D=?$^XH8jB z+F~E#vSXb2;mX%B67@Igk|w6OL*8KY9C}Jl#+kk+u4L`!_<=px=M;ujINJ%;z*{TY0`S)c zbVKW83FM>L*WcDdRGOE1H^b1iF1bSvekwxHehZyv9NE`N&fR_j9O(k&1K}3urRt%u z1H~5#Q7?EdZ=r2G=`og+BMNVc$J6eJ>$L?PGz3(t61)o(`N?=pA2}R95*E%6YiPa# zK~s>5Gf!HJvsW@Qdo37*2&g+uC0@SWIT>fU$Hz}?%&nyG!fK*qiz_;n{Dn%`r zH!G-RiDhOPi?-y4;bIdfpp}S?0^3G=+PuY|cb+{Z!`ao{0TKk=lMklj?-1f%9Yap$ z)w6J*L2hV&gDGxZ0s><-h&3SFu5?ledRD=~g{ik<6=$^#JjCFvAOJH%;~J!nbD}z4 z(t%x7xFiugwgaIzp7(N*8JrPIV*XKf6rtnx;oXHfPxeLPI5kPVkVBbH9oQyf^#)Xw zti5vcq=r-Sl! zR_WfvywWhMF5#k9T$TpR7X|@W)##Zf(*b?VEH+X@Qsha1J7(7*=_U~O(}m>R_WL7W zjPHJw)$-GFq4i}>k9hF_1WyCur)IqEa+q8)mIi6UIi~VQAzDXHokhadpVcHhxKm26 zIpOQu*Xe*VV6+DT=LtFwj5Ii&JrE-%n(vj88ihd(g-i3sEGbZRbkM-We*O}pmMj)U z9lyI3-{c6;V}C4SeRGN{D=INHzY$c5z-2=MMmUwSj#K4zh2V!%SR4154{!w6q9OTo zEyHNeb+&3vv?3Z&X746&6IT;Try;ZB6P}28mL+*CY*)ti1nrp?`Aw|Bj0DyC9l9gr z#@clF%W__#w;8wf4{0H3Mh067Hef!=2!=oNP8blI8Ma6yJh4jAym(i(IP&N|VSq3Tks6j1*bS@v%gxHQF&I*?s%}w{Rgd<= zp&pi+UCU%7?1-(%7_&mm0-@Hx!l|bxzHMR2YF0~~TvvnzHL_lE(B-cX68mVpHmuN{ zb{6T$S8`Doyq@b4)7?)_*Uxo!;X^z{LTDoqQ-xHX?R7=BRCsXDhL!d^Oq!YLH!lVL ze)W=FI8LUN(~<3|7cLax-VNQ`hn=GEb<3SKKGsh|dguaWp(|hBi$1@*^=9UKi#Ws} zb3W~)!dy!ka(T>^1J9!(*bcz>T(`^~`t~Q4<5w-sA`db1HQd<%grg>*v1=}vd7go5 z-M5|RQcoOAK19{pI$?DY%WDW9))k-*Y_7~4rXp)jR`&>F)JCXhw)!|0oIfrTi@fva z?BE?!2XTO3VaD@TxJhZqq^=}B8B-WxLDw)u@+W5XNyM*`GA)w1?zlN!DG*roR)A47 z14WCoEBNO|dJ7Ry0#<2jI?kMZd%AtxV(E+PNtme6XuUOvVko zcC=X~wOt7=1Y`!?=XRgP)B_GX-2*P(=yz4lp;wwQAGN_5rJ`yj=@^61umlxcdw!pG ztD%_qa(_sIODs#!=LnKfZ^yDdEPYd~1h!!m6<7i%kXBpjV2*c~Fi9-pO&bX*-qpTw*=!1bv4hGikjVh+IG~*T(=AoYwLHDKr=O{$2XGtkR6T z)ra?pDX8I~OixA0cfJAy*RZ_TZI6&3ZvCQdZHqXM!$ct!e3r_Z&G>8VipUST0vDYMF{`r;VxqQ7S4Lw4$UWePGDG)wgLRT^3+C}3s(P$8oF3GPM1K6!JqIzS;CaLHk&|f zEDu386MTMiXp|y^Zw<;_K4w%Y@?AUVu%b7YKbG*suOuZwOq&uXz-?{G*~hBQ3afW1 z@>@EQabD+E)JJ%GSthxvzk|KpAQfZi%Hk-tEjtcFwEFB&v|`Dg_Ca){Sy$6kZsxiR zThIYG>+&(C6eGC`#xhwijx?(dw&WE-dEUHkU7KX{RqCEG>kS2I3JTGkf)34O>`Dqf z!yBm)RUcb&CQ%L05wQ1_bUk(Ut-YuMz1J_xc~V>m4vvFh98&6uxN^#h1D6f2wVhb( zIm|KUDC4v$fRtCVG4%)4AolW7YDCIblQix1Xn4fn93eVxML+PAd0vh81b!f|ztL}8 zAz9_KKv332u(>IUTiJ06@RSu}5Xkc`wHJkSg(P9yk5|Hp6UX!DNVHDu)GQCf(u643 zcy#8Jc>5JvXFXymU+Dg+@LdZ_RFiy~o^WHa{bMQEOzQGczpG%#MZ~MR8>PLi+7V57 zFK9&c&{VAyPYQSGXO(3&D-1@A3gv}E!;Slnn*SD(fe%(ub{l8s>9^K2eHoq$t(;XJdjuIKJVck@v2Up9SldT)j)!$; zNC&OSp-9&%CXbsyCO~}dNDnn<)@lk2$5pzfeY%ub}iYbkjojwQTMqX<`OE`_=2Sdy?&nut+TB#^>(=~ z+SlcY`iHF?48UzMv||Y(>CSANoAR)!n9ADb7q-t`C`=tS6B=4VDibN0=uhof#T=jlta3I!x{+gy`pK?G6JKk@yW_EU`arQ>0Kx|9@ zNyf=!d|=56qVLVd`R<>=tpKTx$CR4;jgS^$x_{dKD zI6$WkAv$cXs;HdwXTU3kANXqx4<(>~wu4#c(Q#46#Aex}isQ!X%7y}M2?vFrD}imH zE;_$$%36RJDIUimuH@RrC?{kf@mC`ir2nNnY+`yaz$d)2yaa8iUUEKa?n_qqv3dup zo*Z`?V}W27TSbOmQWMEV+Qu#R;hjlShPVK5mpwc@<2~tp_I~y?muHw(uea=q`jsf& zcu=mI^cF|aItDtvhkYJq&dsrQu!|Xht$8_3jIFuRohD=1zD~l1I`1tzFM)JDum}`R z`*ppc3g+S>cZxKZzvnl~q3Ik-i$cq*g$6GF1T?sCk%KHC4=S}Iao#PJVCYN@GAcDI zTs0+#vki`Qz5Ye3=UB^W`chpm5HI6OFPJU2w|S7Uhav{&HwFm?0ORcBi$dz9ZWkL2 zhO|udq;MOuH05f14OR-8Qk*;)sA1#Aw()qc@-OxP znaxC3t(B{*084Ru zglZu05ZB*A`sDXKUO0VWwB_X~!%{BOdDU7^-y>ReQh9|{I?P#f!5;HX~ zp_sd`0xQHxTYPPX!S^xor1)$H#i3~tIR!W_l%U0YF7I}Hs(oR!HQOR1ICk*?0}Q1G zNe?$CR_TYDvt0f;>-=sFsH_<%#&;SHXV=w8NU#Fb++FT3gk94^y>^eUM{)2i;6%bN zjcEHbfmOl~54YQf6oyJYtcr~4lk>!u&Kc2m7%rG*4Jcg!ijjhM`hkKTRrl8&rc1`% zBlI^_!sht90sgGp_95QG$C4R@X*ZHPVjd<;s8C(x93`7#X^HC6R>7veJh3>#gT;iJ zZ3zuw@>l}(?x)Eb&)HRd2kK45SHlR_!+^#O>IX4XXj}9YT^7fV zHOj(z&8lI^yi$o!t-SDnTtsO~ZEc2VPAGYX=5@#+ZJ7VtNeQ3tyDY*^U$ZeWxQWC- z^(+hE_-mT_l>%+cSQP>2M!OlHG>FBIB=QoR^fO1yC-0AL(IVpf_mBkB7xo;11Y~Z_4wuBpK; z(4<`tXskX7FZFlwi zGefoMfB~$ogP_F3%t<1(9U^?3&h-xtMT`OVsM|%j`4u?ypZ)b-_5U&&&Ffagg_);J zSkeIftE2ORW=|;S0PJ!y5VOZ}`enRz=fzULOt_AqF5@9r)=@3|l6fU?C&525i;;hhe zC4Q%GFkNEm?T7@Wpz#;PW>u*;WM&G5VHiiDqc*~WeiM}Y&@`0!4l}XO>j=K{lJmDW zudH6yT<0Ih`j8QJkfeSbKH|q}5`wv;CiQXU*V_)`R{f2nG85%+Ma{CsMxWNX5hKdT znKlfjmr!?%m^aoU{A=X@G#g^^SILIMUEDXk_WYz@5cLhH!MoL$Tbg`=A{VLgby@}@ z!fc_6(tl0C0OkL%pZ2=>U=yU;b3k$c5v+T8_IdCdbO5G~)*so`q}}gsJ@z%A06|a+ z1J@avYnN>*Y83<(7i_J42%1Nnk0M_0>u3Hy0wbBu8Jn_AP?iL#S=eBFpujdqc&Vtu zFsw3?hh!Y#5^46az4ZdhT0OkS5m##f5bueIAPLZdv*7UyuUln#W81iE1W5JAdxHh4 z!U&x(97gyK4fwkqA&osf!8!;+&m`cZTI(8->Ji0IWH{RhS=K^XVh zg=T0=D)FTvChu4ucjeCx%ACN3;3kafJ()v{qIJL?M(}$y`thFPhMd0#Rr^Yn=p#xN zS1HJKtQ&PGjpeZKenfjlye|26uG~MpBp=IR1&~+_mOt#GtP`}N{VoPe7JR{DEw~)$ zoh(*daD%LK^#)8~=sER6$_O~49bKLg)dxTr&ff912L272TdU=sez#tHiXLDe^tflb z&!SU{-b$aVgg@eqMJAHtEIK_Co^2!cVKn%yNVXmK%pCN1>%aR;|GR$|;Idv)m0-pp z=%y$^z7MDI#FR_9s14E$hAQONA{cVfS7p?$GUp5fKY1hHaMTNN*(35Dk##pS%s9eV z*w}mSSj%mX1Vhp8X8V1ub5+XR_0DL>H}?4aieTAICCtP$QW}1d=yPiC4kFy1JR^!e zqMc+UGfw5Pd4B`QJI27qjzqf-B8pQC!w;01mwZ8MkiMO*RcI#Qx{H5g#>Z}|C05F| zMuvyTtmdF56N-LbJp!_3*SnnE9Grl@uieXI+^^rJU`3;;HywW(Ul{)h1rU^ru0xof zRfyk6yIS~BSiGgQ5<)HJ_st)G%&_=GcB!ui4ZS3JG zH>(sV;<%r>kGRrR3$EE2`PJySw<;dJBBhmwsz8ig?E#^}-}W8eZ9W+b3)YMd?jPiJ ziDWrk{FYfA84ACAiSG*(w?%7L;gtWPOX`~vVE<*f97%T_2MG9v=!dcDy`Xnu6BAtjV!N3u7lACc;-x<1Y^W4N_Y}y>7TiTiAIJ3)<D zM|c}^HyvHjHLt_*bms+<`7tI6)p5k6%>x4#Gl711t0d-AM_A=}U~G=WTPeFsdzo_aK#ncrgZ+kx)+%aw%G^sCHNeGEr2Z!nC_K z6%dFg2KUw?ekZuEKgUjd#9(SZ;0%w93Do>|h84KY3EE}9Ii%pgWeR<(pWPX7{48u; z9yAnDs`o(%v!_+x^upbQB4o)Ptc%`sS;i2FDa{RDM_AUrc`#6&vu(7Y-Sz8@pq)BQ4Rl%Y$^ zA>c}0H6+N{`wQ|s{Gcv@$y_l-f&U-14OeC|gS~CPbT|DixkR#WNLZHL)i8_eyy2Od zb_}kc>{%Gg&n_d9#0+jS6X$oEZD^G54H9*YFEwk0I%^MCA0c9@!a&D!NWWx%A22T> z)u~Le%lJmU5&B;{jm67&kYxI9xL+m-`G@{yIfn2Qw>HJe24ALtMY@LsKBI?Z^CIVq zT}q+UvujoHhf%Io&iqQxn4b1#nBtdHVqC$? zPC>o^mKlTVeZmLTLSSRgn%bei&CKKdCZulNmr)?KdAvZ=5aIVn_`#-@h$dqQx4Ulq zQ(z&Kj!daUf|&t_}n`gzAahx<9}5a#t~?zARwO_M=TPVpY&?BOW4Lat~E5 z;;4$9b$#DOOBO!`;P>pgv8mjWJG^Zn=;Zg5FWn3OwyLgyQrFNO3)NOfD|2DiiX*xU zO>Iu^-JQT5ExX^AQE2=pp35d?KI5Bn8sR_WH?-TwsJ_!jqVUJ3gAH=_$BYs{NIA<) zG+x==)%kFSl-)9nmhzq1_^3nLSXk2$2W4R8q&fI#n@dx| zYU*yBIb9HOgfF#S5Rh?BKdHNux0{Kl$|rqK_=p$FbSyAZ`Z7IOmyi5+IzNy1>m|Nw zfOi9G1?E)8--ZxcxObi2J`$;Os~k7$hnua$40|u9ruV%G9ih~Qj{6KB{pi2E-^mlE z@q))Ha(4N2iX}JTPxbY(LlstfdkvrHHm#jTa^o9afTs&Q=K&-sb>z;ha|il0Za%Q> zZ{suGQRP@tryt@_zgGWMt>tvsn+ ziRbU}S$O`5N=TaZmo#{Ki+qtIEfXr+_kBaPvQp9qB(^!ZrTkXj!<05dsLcqu=}~_o zss)snKyg_87gppc;uhNniAjmdsYU=v8nr7w-k+j6w`VH8lZcy|IqUS|{?f^}Hm=mX zEPoT)aHek1E55it>xr2A7VK{4bLwN`E>_ghjVMmB<~f> z$X`^|NWa=JCY?R25K_HN7iWesfyOR@jF}`x=e!~pd_f^79m`p>l@mS4D0no^=b1s7wF3T`w4C&b#F)mJP zVR70nPAlb>Yop&JFQXQGz0O$+5VOe#JIkEeTgN^Vl%l58q!4QkSQx5VK!uEy`ZB0Gylj#W+TjfTv1Sqryim%a z_6eOty2`YY?~MqB>ce!}O8>t;oXTr5M8>Lvrt`grznH4Kp2%q(W@#6gt zJDy{`PCn};-T_GN9-3-;3sc*@q)Ap3?kr$8DKy2yH$LqE)v-nc%G`?GFDz-@PTekA zBn~&Kr(lGVMG#e8i3`-#aIpF|kP@3qBh>t7YeP^n6x}G?ogT5FRL<0;eNCoolj>xI zP8iB3KFi0adID3x;WLcFcPW^9FZ@XY&-|M@zRD;9smh$ZO`lQDjCcwuRd!+sY35Ri z%;y>&@z!v97ZYWf#FQ_55v`sl|7;K*`EXsr`W+ksM5lY;Sq#V=Q8g;NdwE5oH_GzUd)d63 zr#_XF+3tcOLi>3e?h?uD=sbrfRQ(;;TH{vlg1Rl7T=J#PDuC8K*q@dJd{dtOQ2iED z>kYl%E9(rkO47O_ z4ibu&0LD5*HJlGv;Ong4-thHfG`lU*hEJi$J&#saVY$?RXHEoJn>L12Dy?}9Trric z!-Gu8piP z4)1aAig?9|Km8*0;!M3%*;@TcG-9r|RxR#PI39D2?wVqgg*tT~>0je9(zVZG@@p`9 z=j%O#Sa0aQwmv!IxBaz@ZqDaLDtC`BUsqL&C78$@@cYaCNi}Q@$sRKkT_xO{A7jkz z7L5xpBv>_nuPwg3p>g*}=@e&-I91%Li%5SQ<1l)Vvi(6TyV2q4dQ8}z-bkB~WsotZ zR`G6cG!m<(-p#cqYa#cSZTjQIdtlBYb+r;GuR*X#KY7URMI?;3qYf4jhKlQ^Kx1gS zIS3#@54ao-)_!;9{SGrBAtq_g?NjJeEdAW|+&;J$br2qT>#oK@+_xxd+)qGq>?py= zc#!Y0#eG_A0~O%%OXG~e|Ic_{g8^BfZ}v&h-#XmiMHU z|JZLRYowj412glcnv4C0 zjt4JCIUH`wzrIjTFNiT_WEmxOWy+&7JSz$)z0gatbUq-~*%^4PQggISjz5kDh%zg- z`m!Ofxt|S3q~BVGHGY1F!$?u>unQRL_}wCa-`@We&YF(@AM-*US%1Ey_b;0*=q*l`o z0ZQj7A@1JN{wHN6N*`aCu^n-T1u`l#W0n_hcj1>hs*IP`aVaWt=|Xfp38Xo+jTLPn z?@Q8~Cc_=GjRxg(6~sMxVm8Mopa6dbW7l_>Kcu;__IYNi#q)yu>G+$~x~AWb$Ir{g zn-*Wh;QEiS5(^eiW8e)V&Xv9mL3T*v^*atHmutxNjMC%4cVG+oyao37E`t3Q}hT* zWu>XlvUZsz+!&D%*{AsXlwzlGn~&9cvU$9Eb1u5Ky}&x?6gF5|q^IANi(&gRIJm!x zG1eO2{WH*!9C9LHKq(q{Z@kjlHr#wt?K6E)4TtB?ANv0I*KSz*%;cb9 z7+~-XXo<7nqEE(Q&5tWAu|qY)730v~70SK_=HI3KHs;!2*(aOZLp{}IG@=wFd`Kle zU*i?u+KXf*C-dcaw!gny9%^x)ou6CRC<(XmkVMxlfm#bir`nj;UcX>7_Ijp+c106aJNu^>08m$nR0}eEF)) z3SW-W1O2Djm%Adk_5yY|= z#OVtTtCD-AH@t&ud!a3l+kC=86jr?z&SirZ4rkitUA^)6}{G#5#_(camLhCn5n9MX9Otn*}G1BC1*QY zb*vwZwo4Rx>M>o@rgU|wra@kzw)CgBenCOy-2xkZk{Ci>2$kxzMlb&^I~>aJWf?5xvo>9WY(RrsFcSVd>cP8Nu^_x z9JnvU(T?OL1w=moir5@oecZ4UFR-9@^~w0zQvoUe4nAZL{#W2!+v?A38>_P~%P=Em zhOgt1^jcFYLfWA?E>l@VZWPe^a;Y*N1PM6%F4cuHfMC>OGPK`obg4hd)Zjq{t#Wr)hWb{*2qZQm^OBatS9JCl$wnT$EDb3?|;meXtf>LQMI+NLKXQt!j4 z$VD`4RKU(7DLAm-_G6vR$gHa@uZD=sN$x^zoG%z35I}`M5CLd9(8}rll*-*Ss7d0C+s@;Q?mCK7Yxkc! z({#usr=ZQR`|%8W`w zTh?n1YF+ns^$OQmG*{pENql39);2m}HO8)PuKnrY zHL5sTSs#pv?;{T$FbAJ9mc7O#lBL~k;_Zf&ckM5xANzY!Eu4;oBq+M=VNMxc??ryM zHZ4op&7Af@(eP{TD;iq!%lML-`%+z-ReiKsM{#72K*;9`ZBC>Y06c82Q~BB=!29;t&BLr+*2_NFQ%I+NN$0_kXK2ffFLF;jZlaYS+7MZA@b1@(9~ zdNmZjX!k+u;LjQdLw5k8fo5Rxu`fLspsm9d?iabM>A%~(MiPs|%9crCHr_qveTCnI z+8oH*Hz*=R5>4Ks%#I^8ZLl)@QG>UBrnQ<;z}JTz)2y5u_!W6E^lC31!lC~ zc`KA%CAezE$K`XC!r-v0#-1Kf@0DYDdm;#muhNb0Z3k{fZ-=z=rHp-CIeK|8|J9$D zVvpRsp|;I$p+{SeLW#*^<{rP~NXPwj!=(~<8toem#9{Mvxt|B1Sp6Ki{}eyL2$y%{ z;rfgF*<*}{QsCJaCYL7cpu-(3mKv&iuvK3uPjF6d>#L&HbjEQ=4bc~?trfaEelQge z2*l*&Q2S+LXgpdF<>9c!ohWp!fRfeN*I#?+dwccK(CnVrNoRyE;34j$3pI@K)<+5N zKlu(*Z>^p}b%4@@9M@J`dX_>#7X!Mhyg)H4-3xAY>};in8(;8v)Mdrm916O(`cyG( z0mNrW3TA{u!_14xTex5@^1%s7%(;b@VZsJMAH9(-2QFoXEZt+*W&n8KQlnAs~nDZNEwL zsXbU_ziGuS^#UWx@DoVsRnm8&z^HWGvc>Rb+}80N#b5ir%<0zfRyLSf&n`z-Y3CAZ zJy?f6Eux);6ar>(8aCW#u{8EY`?_|u2{j|*n-vJ&5Mx8obJ?Mayjwx&d4|$2kpp8Y zDIx!M*SP?VY7RDPa2ioxCka4cCfjIFNmXwx;z}QO8xNQg(A2fBlf@a{%_mqtTAt(H z5w}@wuKHMG7T>ND*1#x;7Av;7sjS7bdPprSC=$#|Zp4P%wuU1MeE={4-g?ChJ?4fjT63!i zz1T54`4s+?4qCL{euc48Pw23S+CUk>`GWYP9IhC=vav9F>)AC}gH;-I4}cX|vnmZ) zF4t&mBX9R3ixfxyBa+?AVI=4+wo;4C6I|-MQDCogjatx+E>Ulb2&jc_3`skyNVMN{ z+CP}=5bHFLkZfEI>G|q3!bKWIueR?&h20YVW6?XOCpFVKN||yvl20*4ETM;xuA$MJ z^O2y#wcsS6>EuCU;jgSdtL)LxAA~s;L6{$}^-F2?7UQ}N0Xl{N&GNc?LD%A^=K7v% zh~D9zeMt57oY_a_?_!t~A{GU^9)gwRE?E(rZRib0Ag%r!E>zo-AUKf+qxD}p$0AgL zyen7nb$UxKOui)5zhLzi>E~Y7#Gd4ZzxtV{4w5~v9x}Xn=%R4ph^skKPE5D!#_lLd zv|h7Y?e*41Ou^y}B4c<^gHwL2BWEJ|!6meQmB>5Evf|q+{3$_X;HsLdW0k*CW$OV- zW*rcWT(ncnwD}}-9Pnj4IL=*>nj ziH%aZE^*5d<^&7Z$;s~lOsgE?NZjfueNZRGK=;YzYw*owfm8|R%EvILx9pdGEbw&a zC!{@sX6(=%hiV8jL(R)vsNvXq|Agr~8Ag>;KohlrZwA@JwTXVHvn9Y1CpdJZlmF@_ z9Tl2k=DXAN5{K-vl*~@qe;*kqc2~_mWSz>U>(dGhH-OSU9+%GnAE;LdgI?2DCJ&y-Oyd44_vdD_B8F^iudL>R}BHTO!x@e zZpTB93ncBMkC(GTKV}}iq$gko8+SG7^cQ0osr5EAT9skTp*u%fdD=&i1F2hJLaaLV zt^5nD8?lYzsseiS4vm%V)qNRf$^qlE;QPR%RcF5s*BPfTA9+#-gA{RLe<;fi>YyQl z9KPg>{6;1DIEQ$(WPTlySH_VrXE(J3>rrPu97m$(i%_q3rpK*aj_y*Y*b-@-Bj?Os zQy2a_jfehMBGV`xDsopSSLz@gBra7F=i6W?9bM?_YuDc`W&TD=hSryo= zS>2QEQtQSWm{DNNr_Y-cp0I(4c$AG}MG+m3JiR_NghcD4md~yhg!3<)l&PGZSFN|I1gS~ncGH7)u>Dxf5tV@ayEPjn_ZidP2 z4Edf;cb>~l2PGNlUi@hh}9z1R0m^81yVY0*8mYKfT;9$Kldded#I} z;5Y;qYmuiv>f&W(93V{nx#B;jQ;W;KSy#5wg|Ohv*St4q%LZ+A@S%%F8L9z)Ikj!j zZ!mCP29kY!T8}PVGEZ51Yi0_vU6s)BywvahatoG#vEkTi;3MF7WiBX_NCT?-@;K7imY7yZYucPej4C`{ zZKcq-e7ISC>XZ5gpW*42j@!LZCKm(K3q$VxWeJnG$k!qv#fCmTc{)bh;O`jaw~+}Fr(fw@Y#EU+!-*B*DzV$vofpU6^=$$)minE#vT} z9%xIB;7gdfV72{}Ze4sV_MHDzjWHFtxpO%W{|m3@qv&??^8V5Dys0Z;LZ(9OslZwC z&JN96_y4U0uq+Pr_c^fJd^EV072(~zYhn`r*7O3Q`$0bW;kme7j73Hk^XKExgKlKcH#6wkubc-a!B1AhSifg_*JBHSsR|2io&~PnUBt!D?rm({ju!ShLu;=sNMuU!7d*QTDix zI}dWVX%QtSLTwe%eby|VvhWdf&pNUinhidssw#>p)kk}D zx8QS+N%*iXfV{1FPPb@nZw$ z2ufNk&%%fiB3i9O+NN7mPF_cp^p?C?it<>V^XK6Nzh)PlQ#Y@amSgy=!0}GxK&Y#? zoSnv6o9|yFZUxleJ_l(l>XqI;#@h~9^!k#*+IhR zk0QzA5M0mqgddM>&H6;-$hz#ZH*{&ylV#QICEA`xB9`x@Eg6vs0k=)ntuf=LyqmWI zwC+Du_*N!VrwVXa&dttdiEsZVBdWT&;duIrT*k>=&cB?15Ady?-&mE}Gwx6}Qx!*T z*HMTbj{4#C-TZ&dy>(DsP17$Lf(1`PaM%P14#9mx&;Y?9xNC5C55X-s8z;ECyE_DT zx8M%J&*BMr-mgx5U)}r1sdH~vQAO>&*Q}nN?w+2RU-wXXXr~%9Lbf%J=xzAmv68I5 z^At6-3wH7f*^HxiIV7y=-g|?a<_vG~?`t7p_$|jTB$GXdCH6^;ya>DT?)=EQ5*l1{ zQ@y_h(+z;olqti<&5}Dbui+$7WNx>KNXEpjI~EDx@;nUIN4woJqLSajvOB#El)C(2 z2p?D8;JU5ZdTfkkv!VWBH3jM&I=)YZV{2K(fkwU4;2uC9&B zyJ_wiruo*b^-J}~F0Zeh&8u}BPPz-h8rIERk-lAgRDHeQPMbY`!*6PYb4(7hXZEA1 z!UrrA4u6N(NG*F?J)(2Y7$8*EznPfcBBFN)m^R!FmL(Pm>#J zzL=tKj2u><^hpDMU4tj~n;FXp+77EKJtBj4d06D$Z08yZ(AKc<%}r9X6~)R5YPqQv z;lb;2;2>V*ghY(`tO&w%^P(QIIV;fZtlCf}#r+^=S0?BUCAsTdmWjuqZqmfC&YRc(M;U9Hr_yO?V{-es!cQ3Z|9xIoZYW# zOW(kQCCmEsH1ORo&Yp=~Yovz$Ou2`Ny?Z(O-Y1aDw|N1f8@ZgVRS(PcHLM~Sb?gEC zUia&_kC7H%3k*Fq+lCIQxnJYR!zUK7T9zHsSJ4CioPn;iA1}(CgC_tCE{yN9%-?ZN0bZAy1ntX%=&gnxxUDPV-e$vlTRVmO9Sxk2u*E~0K^E+s z8T8%>q`Dy*X4{x0v-M!ffe>NTL~qZbCG6gj<1?3EfQPmK=lECo_lWFxN-4B?%)^hl zBbN*3(60k&Xba{w2ZYNDTxoQIA(=?C$V$t|Z%UFUkTEVLzNdB~Pu*kV9(X?}A()ZK zVDSsmees3WC_{~{@PCH9|J}q`^ znWO==wf2X^?KT7&s8!B@8}8q#984w5H=;uq?c{K%%}lrI5}ISele2WFZtVlvYaG)! zv}?D{a%jM!N#^GX{gv74afpEPm!u-$prSA^9t6>T7-j-lrQOw3ckcHY9@ z;TOvoh^Hd`Upz~>wAt8+RI8x)8kfLH|6uUC`OqQ8$Eu*UMKeadBfyGdPmnA)hc2RN zPLjV2O7_mRsl8`00;zydSl|ueJ?HB=Fo=;${C5BDg{v*Y_K74-^^HhyK-=;AO)YTg z%GI6|Q?n-T8UG!67kl&Ko5wuym~iLmL(i&a$?wU`>4g|7!Z{b?HR(qpg2LY)+(-zU zMp}iiT1Qqz@BfQM_fO z2=kh?gEQ`~BUoz*U<zG{f9(E5CQD+A0j1%{ZCSP z#Pffv)*y&NMLhl7|62xtDtRJ)Spe$&L_|-_0Q`i1|No}>jGR()Pye1PV&d0~gm}9O z0LyCAfIt3XYMN1Aj_&F2|3XJsHy|{WjF?y_Se2aE`o5Irg$4Mr(<_0V6^FrgBtXyT z#MRZ+v+vz+k2Q~nxDFFlE|Lyi{hUs>&|i7+KkDR_Ca=gC8gAt-egwxl6hkA1 z1IgIlPcZMEnh64AJNS>#up;iwDMLd;Ol^QXac~|z&?Bv>I8dx#-W=cD8NvuI910vR ze%uKFoRjsF?0RY@!M1~kmush2zW}P7KfFwFeosvN+`E+?bMVb8$WP{P690@tSLQ+T zM=*3oXlN*tGjMQgKRPm!M)a7(F7h;N!6t6d83!xI3w5z}Qy+J5xC8uqdZeGmZcSlH zL*x!)iVw1u0K3=!vODEK zq3Y2uB&c^r@kcGgPsITKil@TEy}EII02!x$eBvij@%rL_d)z70LTkMGBcNu%+)qkq zt?|gnf8OAvp9&H|i~wXD_N1B|7PiA3z+i60{QaQVZWzd&0N&)F@@N^aIEr(DHhX^k z?b22>FcKQYpKn9}KL72HSqVt^(F%l9{%r@D87Tqa1@zw-1V?cWJkVxs_Fpxi9V{Uu zX5{&|+x!l38OlJtm!W?&3m9u_jqN#VN=Df*U?q_*@x&U=rL(W_juntAaqyRs)4;$l zLH>y@PZqh}@S|-O=w0%^T%9RGi#9$o8vFDS(8Yg^Gf|O~i_4{PlIC&frT!Jb89`B2 z{G;@x{;TwNP^0AGfk6c_LG7Rg1>sIs54xS!lCZk#Z~%rGeSxN=?f~Wd?aG6iQqY0E z%~3hz(QwD~Xh4C1ZCdn+RC$FTfCem)7@PW*)N8eVo&&1TQ}g*)^TDAY;jwVg+UGFS zUxmDc`1SVw{)t;OV-f9$-~VbYpc? z)Awse_sl`{bch|2oj2S18(% z2w9N$OX?gjLyL|Lap_3520e?t*fP@);=I{bmXCH&ziwxt*U<~Ly~Xh1%~732?wvHZ z-~>R&9|;Q+R>WG{h(dd=2ywo9H`l2+uF%6W#4ZR)ORTf!yLiQgUc?J0e3<@3n`Lt$D_?#y8u01jx+0g4fHyB&zqYEaz z!Lro2&((V9Ym{5wOEJ>>03_+nKM#f@`8+4BEx{g$lRU5A<&A>+a_Q89Ep>tc83+#Y zW5t?)->(6?DeToZO1{NjjeBi6J-W$fhTvjS&-;aptW|838L2dKphJ+~Br+1sSlw`3 zHIVhfiDWdmce!ur8(!R><`+n4!eqYu$O{DY5ULUBoY{4Debk4A0MrIUY@# z(g^2J(dt^5fd^(muMtV6afQ?7)N5&YUt1;HF(+Lu5W^ZAt^r&yRgl8o+Q|F| z$MK*Mab(oKl(_D}OUacd3I^=m(;}St%hi<(# zPA~MEf>6=-S(a5njqkC_%;i8EelFiQc-SgKX}K*Ne(t~@N#B{nvSj;a=rdkgGyUGl z<&ow)4^IZChpzA_@ohkHv6v0kpnX2l)(zXZH!!Y?tqz_xdih7YpaA6}MIMxQ+xvEQ z*7E5SnbC;!*yM1YADi`IyU;j{3H4{(t?zR<$l=CpWt#y!yEl>-&PF?PCQo689zHNr zz4onJGjGVyeNHc_ZOr{s({r7x%Mgrc!`3L*F$oINHPF-j^oiUl&u97zQ&!MsIP(V! zvOlj6Kjw#u5Mg(okl5ikRjGNc{QF#;yMh8tZ!&8Ie{p#s=|d+2zv*oK$b}$Gx0Z5l*+fpeiYHFG9@N z`%v~AbAs8e$bwV$N*`8u!g1p)=0ALJcmVSQPB2k)@XX{g657OJl|F7kYJFP zL>D95FwJ1wx`SM8G0Y7V`~fB!!)eN9Ctwi!hV$I^A6XHaobwi`^v+3yPWKO~->CW)$uC_Vihp%w(TA{9gW zTb!(E`BA#cmu_vHt(;heBqb7Zrc{)%D*gxTzPsq~5>&-c;jMPvg9-22g-4yo&vQVX zGYA)1u~_`n#)Hp%RTJe2uxD2E7iFfi{xLOw5h>iq7~~KA{~#>>pMv^-rcGk>IEJ9i zD_A{oR(K6UO8v2m{hB+HQY$8$>ZtEFLfmA}68Bu~SzHo?b>`6dh^yv9)(3*o8qT;V zmQ=nqnwhG;`=_5<~St_O2$;S#D{ z-msk+NfayLgS)vW&zcbJxog7~rc9Ty@{vmL?O!6*eAM#;QE3jXc>|vJ{z0DxqLbE# z?DA1>ZcSi-1y8A&1`j`9xm*{kP~h4Yd{M?tg&tFS@iCJom_g16=EmR50|rO7_ez1% zfvnc?zE`aeE$p*EhKIqZIptcb4`yDUwNu9d?k*gPsLH`^LE{i&GGg-eF?i1i$8jMl z;S=TWAJM;>kT@kr7Z-zXdffBy=VsJ5+U6QfXBibXu_APKPH_&Q?0Fo}xORp2q#x&` zt|;FUAVhE|J%w}@(9L|1NDQfXgC|N(HQiVo#kN$vcJWT9-iav;v)<9#h!p|IenaSB z=86?6FCfj{uxiZQdW|NL!ZOA~+!#8vZf3JakZBf1%VfiJ(jy}Fu05t*ZEeP^b8b}p zQ}dRgUKKAf0KQ}9CFmc($5huPX!jLqM$VmPEEi2q9wJT%Y2xUuF&NMZr$y>#rjEgJ zqKQwzwm#pPI&oY+cC}|!75!8ij=8ov=$5%>FE2d-(>sNEyxa_0p30y4&dTvSEq!jL zf-=F5>k|DG7SD(>VEmRYWg_99Nu;Oa2S|hw%IS-l<;X9jhU+~`oQ+~GNm`acs?B2$ zb%EJcSy0}m4X^5*?wlm_pz5+XiF$BGWwfllyPY(4Hn)6JRUkoD=h31tmD|?b#SOYt z*AG{5k9q$)$FS-IvFSRJ=y#%$_DyymO>Q`9Bnwrajb&qdd8e*A7Q=T$ZWvS57HPL1 z!Y*zIq8g<;A;GtLulvmS?Zh~ON^ z5D&fl3D>u10k_)w+Zq(Z$4jJr%szdk1`3PX$4w(LUM-v(+9W@nHgyQ?6S(1|k zMK=7`ew~Q2I4%;9TzO=l@zypcTbH1$fgvLkX1pI91s!odQHEQVwmoaXAH%;l~z zteuaqnNLeesGn}@?%05IBTNU9v^&@HvoEA!S)hIK1=vvsNCyV=F3tpXj(*N=cByY= zWKsSGCtjmEd!#g8uz;i2gJ?T&Tjy5Sm!*q3rRNSy_mJ=0ntdHekf(+Z_u>LA|XT8!~h z7R6gtbRns*R$*KGf&>6Rm3}0ds&e@iE}*@$nwgJo_NpJH8Lz!q;o*So z-`FMv!Upx2;Vd?MFZcz>6}O`cP`7SW)ZbS6UI#X2_oVdvm&DJV*`6*&@^Y;8@?XYy z2n#q~&7hl5)5(U-7a&6fYV#g4+ScG3Y1n;H2&vHX&6ai@6ElQ_6dJv)l@{nP11z1~ zwl=6XhS?D>jSP324ay3M8kn@jRD@RoLuepY5%B_&Ds=JP^Zf(z&B`rxTh2q%24ow) zgWPcmQh_5vb|H~nl1!yo%A(1jdq}lfJ*3iME7DbmkdWDtzH?Z&h$Vani*9Q%R{eX_ za9?q;j*JQwgPA1NZWg_R0k>9{9*ip;m1eDMJW)?pq^eiH5K}B{e3T zd-s7u;Pk%gzz<}+n1kKL(Jg!8Tvn3Prqc^Yx; zjqkZQ&cW^@UGs0K8O$~cs0Z}JEd00H;yTqYEyx@ssTe+V$+BU5YIQ02#)4+9ZLGoY zX;Uv8(9#Hqt==N!vQ??Ro8C0C|g1DWk^2*VHHrSbr4!gq+fk9OtQw}YCnMp=W z_twDPmulk%hI6{gru!W7=7NW%FC&xv1}_nJ^SGDuwJz%=t<5FhIl1!~HB)97!}!KY z3|yN$#v*rPFIj7!bw=tSyxBPiS*@N-=)UjYfUm!V`hJ#)@WW}g5(2Z2&+N?m99{{--(1m)5VbvX!?8al!9StVQXK+M!iz~TP53?s&5}b)XRfp zEn2NQ_GRVF2^&JEH4g#j@*qF#U8*KG6V8wMtVfX_w@PpQ*vibSlN1DqtKQ|8;y}gc z-kQ-R+`n{nG^a@WYpGGS}KDl>Spu^TjskrRKy9UgZI3-o|hj z!Xe$u5&BE4SgDnh=d$+-9gVy-;qa`qoA-n&S3@RVepb(R#HP=b7@|Qf(*yUTw&of< z0Rr&qxZ$n~( zW|Y7z?qu=&hj!Vp}pHPYzMm>mV7;owmNj!|_^k z2HjL>U1hMMsoKjh-}!deB_V06U(719r%hEce8d5cwmE^y1;HhMPkxFsrvmHvuo(?? zCTB_L3U6aNpFqBzYPvj~{Uv8-2Ar?%*~sF0Mpj|H8|sB=|JId?mNU;%iBIy#7kY%L zLT%C4%%9DdujDE}$IsMPC*lrdsK*jIwy@0B+~jkRF6&$d<-1^Mfu4_Tux@touW7^E zWq(CS``oNKcnW`$I}a>i9cp;kqk6sTjlsJb6xq>^jF=dSoQK+kend%Jhdtdo1%?(g zJ6gjK^Ydx6y=-tI^tP?cRXxg{Z3?Qo3Yxs-Px!3u(>RU-u6n;1?~T18F!&qwNH{I6~x@+8T6)h zzEy6DcRDsG(43vVxx4Y`x(d%Edgy~MRplI#);C+uf?Arvm zv})}!J{ep%Vw+8OF4FsTS8&=Nd-=6EN#R~wP1AgsU~2)@`!yVV@FIbRT3J~Ln(ErUbLb2Z1E`)an+G~{Oo7H`u_=HFXS_E?3kObkAD zTTJ1%Wp6ILw_WlUmi5@u>0lX~e*NM*{dw(HMGf>NlR((z$0HDu79>SL*oLE?cZVY( zF>=A90v=Hdn2y?}=~>^H=^wBtPZz0oMX0i#{tD9|BJUL#a}g?VNn34*4}YvfuU7-) zDUdbfDi&*R(6k(;We!RdHu{@c#FlJHW25cxWo;`&0QhUvpo8fT!eQ{2asAoPUYeoa zCRw+`Sr-`If2y6T^pUV&FC|lEfRgyAZupJ=*Yl7Zx8d;U;pfIIMpn|9~F-JJ?~5 zNF)bC-Z{j`%kYB%Ov}ahCnsCr zn7+BM`AWSwWH@+mboidJ!euWNSpn0X0i6>}b9HTid9KDfgI{ByQ?6`lM(nbPz3w8` zmwvQ6VKyVo=}YGj`*5PH@E;=iv%rckXV_7LLNO&o`4m<$6~P95`~9 zOGRm$TkJ)?;_S4bS_5bw&8n~6L#K0W7wB)gBCBemDy#L;U2L%|K$btJtzQ@OTDrFH zxWW9qYrKJH*W`JA!;q|BoF85T($o9q7KYM(rcYDGdg|4faz1x)%BN0MMaqQZk>a-z z67mhd;1|`VDw+i{@ATtc1dtZsWa%3TJ9_l?H}3^~p=`tUoo@K|8z%sCbdu|s^{^YM z!cp@1OA}B*@<#`0hM)B4HkL}k@P`~l?N(GjkMZ$@Q*G)N4)S!Wni!7w;=XzX^|A+i zfxZ(r;)S8L?@Or-KiMkCP_GZ3FD}?5Gr@EyoaR4o-Yj(}vBBYZhrP!#Hp<$u_62Di zgAuY80F)+Jqc=Dm+%W0!uPm>mh%~cny&SyCSf~qLFQ3+H9#D3*QRcz)_S=a?dTv*I zT~@T&I#*uGY`VTkSj5AUA5FO+c7Rf~Qr$X+&-oT{2vMYdYtO|S`O^7&IKwxWJuO+& zTGFX#%R?riy9=>TH6wk+heK_)LUch9tjYH%k4{)b1UO%K4R`oTF}26M0i*w|kYR!` zbxCveZZL0j<+-O{e7j`B2l3H{mhT~6ztzXb*`N)w$drcF%qtjK=xQ|DYu)-JBrBf% zl;6rYHf39NkePbJ&ePwSn>J-;4SNx(rN1e1UUgENQ&w!a&_1)rj!Zh;Q^Ot1-z`iX zb8vc}Ybesa0-OTZgua%AaAe?QI*as?zC&d%3+wYOpYKW?$XKX?yc1n6T~~>zY|^0>-`{#r5JkWg|F9bl4Gz>w$t-UK=?Nll zfbb%Sl7yzRWY;JNEJbw3>Ni@IG0o;})vxL8ANNVfzUd{ywK;7=*3x|nB*`CtswpS9lH`eH*3&tF{n*k+%0Z2WwVs)`YiaEi)XrI zRpX0GUPeC`ZNVr+QAX1~cl0k!NZMyFZ4~Z#STMe1oiwT&+7X0yoc~m1g3s^E#f|>J zMlbkUYiDcKkC(fIa8YZSQDy402#IYNiyKjx&sfr4B^SKLwj+*9l9vsH+vEFE!Q&2e zB7jdI#cx7PUQ*|yKAw&_SZ&2CFM)CS9=9qh3W*f5t^te1i%p*d5)`+J4Y^?ys6|Ow zs~xROC;E5B3v~sz>`poaI$nHW%BuM6YU3WJA!0S-dK$`2dE zZ!WBUCjq;D{2{~QeYEwCwnjWSqOJ5TT4Y4MzO(g<#9 zzr=a{0Z;%}9T3$umj=Eesac{~p8=IGLUiPMYlu>ib(unveyXf>pmbzlSSvbHAs{Ho z7WHlw;ax*81~rFA6k4?2t*PKKz%el(1ay4c^VwRyDtkTeP90>e7;eAg!W!wi1{X8P zFs9PzWp+_s{F%96cP01&I9gEysV?Cq&5E+HNUA{f#|G8*8<6N5a2nZI*KA)GyrM`w zNGsjIe%3p+9rnq9fGic}_Z`1#-ZyDg%#o~pXkFE`T1adtH5YzgyObmTT&IxXci=`Q|NP{ci$A!^tJggj`M1-QiIRe zPL7O@3&-W73AwvPgDURwDR%TrXW;B$<+I z;i=ZRxKlB+TUJ6u>u92<*D&^6)4pdfB47q4DYtlkLh4CO68kzLAh*xMV`;ih{+>Wr z#{E^0!6kc>0LC5-wcb*vP(dA~U#+68N7h$1hhl>A_;q~ZquSk!{VdC@;YHrr{m>UQG|%T z?>6-d9-JtV20`&ViQ|=^^dw4B^_ZdYwR*>j~ z@b}y(Q|v4t%5^_}4V+$Vg_ye5el=o?I{r0fQ(}vy%zFU5UiE14h?#>a-ik0MwZQaU zr43stGi~$?xH&egf2Pij_cZ3R`7VW!R_n%%!|Pek7RLlNtU#C!(8=2Kl@R9ep0*8) zg2t7zvzV)THfKI|daJz+4OKX=rXg`QgB>NLIPRecWI#}H|y3hXKw zz!%$)N5g-bDYXMf_nIi(@SgO<5wN2j2-HvC&AC=3DG@B5)9@8X)44%XyTPb!ftBOgjMfP=<@pRjF&TP}RSw zHp;@4A+mN-P#id;Mm^Bk3UbKIE&2)((i7j*F_`Cdsn}A0W#tK*^q4 zFe#tXd99s;(ho|YFF%;!!^x=u$Cpa}q1>XM8Y6eoq485V!q@vZ=~;G2LsB?%jBu_6 z+y^S;BX|ktr^4Nb^c_6c=Z5MCG364ph$dn7CE#<4^3gcdsP@*Li6i7@kgE7!0FC>j zUmhQYb-qau>k~&BAy>n{?GW1K*!8i`irZ#2(Buw?A-REXq>~sMg(R@QE3lbKfW&3U z8<2}5+#aBE8R4e68vfkXR%3{yXCl}Y`uUp;o<*R>pp9`xY<6R^#m;V8XqRHV3U}Op zQTOrqTaSJPP6Q};s+V&+SYxuEmhFz3;FrY-FR^lqRM{m(QOVeu-Xelfy`8c^GEVRc^!b`63tq`(Jt*R50k+<5h>))T-FJjMc#^nU@d zNrzm@8so|l6jy5A{kL1M7Uy0{P{~n&|WO`xvBo4ws zQ>q`So#dx)%$>p{_549}b}rt?)P9X;HDgm6=Y6w&_I6^`ynss^@y}j`&^gIa8#y4Gx}l;oh7xNm#=exhF$*YM5FLAb+ht=O zgiT7)clr_GF@yYbROWSL$;R~r^@nOzO9B)9Y$r*+*-mGwH*p-a2`t4NN6GzUe+1_b zAhpO!U{n$CVhm(1=fqjhQNn_^Y3?*&eQZ{(nr+a~xRv=OhAn?Nf+bKS^YGu9x}?V$ zf=YA{OW(_hLHV49%G|Ho8{cZP{?yW6G0QK(R~rna8urYuU7}q7C~D(qy-f$=OdjHN zaf{=6Ld)Rb0Fgbq+&rVg1}({o5~b?Rc4cH)%{Pmo0KfEirc~SLp!126l-X8i&^v?j z^-s#;1f%b%!fQ`YOW$>(l_mS_Oa;!k(S#?QR9VW+OJ5NE83RUFKwA7_SKSjd!QSz_ z=4oXXTOb=&4z>*4=Xv|8CDkZp3`=zG_rIv9hPRI@dARGRuRfA@Y*=pbW`$`BJSsj= zIhP;6{*r;bLxfJpps;T86Afut2m8oKpjw1J^Y#y~r{$(4PrfqWl>W?hIRcY0FibK? z-xcySTu6byzL8@Vuqi}vRGX6vF0Yav2zF9>c1ve`&uY1!cHE`9*0KSgWf;L=R-j7q zhi$Ca7$Y$V-L~8K5hx!_z#(_*0;0#Q@d8Z-Dzn2nH|2;F$f#bsm|$8Ax!~wb2i`#m zu%dacR(H@`?lE~8eKoc!|FDmoNB|qt{wF*q%gtVmo17~FA}-W3BiDq4EY`QI(d(ZE ze_Gyu(|poZcK|7h9!g0Rj{R)7)ZI4;()tDU){y+=L9h{6C_aEis&%~lVKGD0%S|eo z-*ni{EwWCr<YCMk*U7yiG~HoxDOin>5joGzG#0g|0iwx;ix6iYfwSq^U^%C2covRHGd3(5 zDICm({T2ep@#k_x;FgWYaZI$=wEPn5ce}+mi@5E)LK(VO|~>iJO|8^NVp z;Bpto0DP}Mq^0Nb-qLkF|`FM;@M=VlCueVe^E zWIvrW*OrihMY7Nk3N=HLS}gJD;v~1DC2wW+$sH>oCGS`RA7otLS?6^*jw`j%;EBS0 z0ZaR;UxhyjLg3MF85>;g>d!|fO9@}=yw-T1W|aCFEO~K%0ebusFI9-3RnQ@av>v84 zc1kR~B9}1RD%*z^UB5mP7q6MVa^@)1UBLfiShcAXz+O`Cjm9gf#`3sIoO~tL{HQ@c zxWlqFGb@>eRZjRqJkGI#iS-PKQ5n{;X>=Xo-~MIdiZJp`lxLHM5u_*q^+w}@dvzC* zwc`>y*$(O6Zbya`xY*ycH0J&peL>*>z=CULsitp#=~Pv&mtJ_G&?Fjwb}>;PM$V`{ zXi;T3nQ`zSb%~7{Y{@@Y0}^B5KH`EZO&1e!<*di>I0ZuooY|(Zw63uER>zx;Z*OqBYE1Ui+1n+2fIzYrKG4hmCFm{E`)}na)Z!(e0;Je8H6AdD#tcIe*B=v zMt%rutPcx1X*g|q=>6serGpbp?2i@<*8@Z7N%q^*$y0vsheRH07Dn-2mQwjt;kE{z zS-q<0n^-+VTJ4Upb*G?&Y+5w)OT6)KS^U9zfczP8Lqy4WD#p*8#Z z_Xj|Z1@xyUs64VUL3U{puSe5Zm8yB|#_HYv;}^((GZYxV*x`GP^C0H#UI0%!Gz2sOvWF`W+?faFOfQ_ik%Vuk(hYuFT5Yewoa9=u>QM9 zTDX7DdD;_791g6>7tRl)otXT^?*qF0`)?ACxeBfukU`+~H}`Y3hxL4+1 zn7N>!>)$*bwvZx{P-dV4|N8ac^bcU(<9TBEQ79>Y`I8Wk15y~wk^f^@tH8<6<4<1I zZX1yTf~@@e26Gj}P(i^Ww7+>9V>>G706Mb><-d6rc>MUksP(@cX=0X_R~1~J{mT}> zsZcovc7VSc;os>~P_P3)8UC54iHZNijgisg{Sf~O%v?p?gcwL;sr)NJz$*VE0F=VV zyESy5lNQTcI=7%71V$2&HpfV z|MbKS_4|*Hl8rHUPl4sDaK5}@sEDP)`>8U*zd~dbxl&&k8H-Cu783VUz(qAIZ7LXd z=@++k#Vi+w$r*pR`YtD@xUaCBh6eW4K|LQkhYNZwW z`{t^zWmWgp7;I+fHpJqw2gZGdYg`=THSlG_&@JCTsHm_a@ByEKn|$pE->lYIH9jHb z!}(Tso@cP)z1ye{fJ95UmTb{|ltUF3vx4gSg`XT>`;zHHh0o-}>02ih*1ij@*QbLM zq$tOrG1u)lylz2xS&-CHirCAaYL|#Yc9;=|Wv`uuqJJb#qnWkwP{r2D7!0e2*jM3U zyFx|G314}+u8vNIzgeyApQzgIx_Ckm3&Mru)BM;^?a~k#tUaj4g&u^_|)sNfN`{_*&k1|4momw8;oCxtBZ2-)EIUY`=;(+LIu)e20iyQ~Lwe^)64b-~*DyPR4k2ZSQ z)&02bSCwjOODVd~ik!@UJGD+`Z3a!si{5pSvLLAj2JracLl|DBaFT3v98kn29hoCbD`Y2n+UzgU2N!JjXYw25B9;7P#XrEV95c%NVN@6cX3 zxRIJOXfxZgfBe9?Bq`R6&3NUou!x0q$?^&_Sxi0$CthsEV?fxZdOwcdPlBAP1Gz{= zGI?4+qIzuL+C3>Y4gqi5tbUMBPFKM1{yZc09{ZVq)NnYYC&yODcU(q8L!|N2T{ks4 zE?i;Ye%YEc$6+rst2J$P;Yn@1dsvXZ51>~c^;-y`+H*mx$XI^B8!b|3}Ih+%DeBZOtC@=)WX#vX6K8~ zC}|QjT`k^?+<(Z2_Cp$IP1cz#DtDov>RaAGyWJer2^+NJr**3H#sEfYH(zbx8f7jJZl z5AGMZF;YJK?A&ri!<&-Ieyc?k_8@>txnjVc&|hA~zf=0%_JIr(mLgozFDal(ktQ&J z?Ue_1I6myG;^E195}I|Py8f&EWDXw4Fxo>y^k&oEtq|z_Em72AA;QLzL&GY<$6MMC zskt`vNbXk2ax;d9lR0F2J8X5m+gWZ*T8ee6{PWSASbKM}uN0YDF2oQMfvc3}kmaSz zz2gE@Ha%RU^2u;k820LfA4x&V)x`o@BUOH87Th8iCZ|35xj$SrF%!%PsivPxrq5W6yza9CMV4Y*mLlty9wTSVg&ocoF( zxw($$YNp$bYtsH!S%I$pwY9}dd9lUaSKJpVUap5%+v^h!OZ!dd-J(u1#1FU5sIl=3 zuQ(0f@m?K}NjZ-5tuG`f*fxGI)rwAqrg>_hmCTK*KdtzXv}K*@Sk zB-3W*jX;^hf4+%1$!Y{gF`)B05ru=#7=eAm-UHVYd%DfqYDJluNo`$KS{6U+p7Lnt zY(I`%P9>4SPD-P(0EElnAGw^ezc?O*P!bC$Gj4l8#zP!Ayc!}CGpKV5sK&F^mxdc< zF|JjIYC?Rc!q;DXK`oe zoKVOL1`@-h19MgB#6eSe4bg+@sPT14ZIi?Iwt6CO_wo8)@Fcfy5}Z0++__q$VoNsY z;|jOf4e@zjHWZDi_TI1)`jr!=B|gn8YBp0d=g#6r(o!8Yz=+s&j zuV>w>Qy|N#P>25cO&Zx&IngUvhsB*%7bk9TKX?gP?C>SVtNa>%fRj_gqT6y{TEz9v ztU+NKLqD=)#Xgl!pH;i=)Ej+zV$3|459h5_+8PRam%AU&3Mn{l7(D;-^8_Uzc{*@T*Jr<#Pe~g=lh}*AppC z`BciSRWhn^TsyL6q`H>RMVU_gYr)D{aB>CbS<6=PM%sZ3P9_g`P0ecG(BOzK zMo&2-gkO)OxN|7>Y)%zKNh)>G?DD@BR(^4EF@<0MH405bTa}Ia;?H^!cqto2dQ=b5 zt61|y%Fe}Mz-MxYmt5d&iE$`-pkwIsoKl`+q2r;qdToLn_7CmqaQE5iCMg5o?wPnL z?H3Tn%=!IJW-r zqNVKT(ttj5Mflxx)nYxw`3+8t1lG8t=8c`XJp~gY{q9+E3rLCJV3t(4gl5s0jL21| zlQx-7SH=ZT(xKSRS`g9KZyETTN_0MgMZT_8cjp*+^NzwvEDM5ZDSr&0Cw6{9N}K7K z8;YFs3-;SPk{2lSG9HWK^9%W+r$yHjYSq3=u~8SU4N$|fN9qSf#X$WHgya#9``m## zw#ck@ly5>;aZQuw_jFU^mM3!f;-}xp-N!^B;i7V%p@G>&-Ch+fRcQ)vMRQizyua*dng6XoL;(iY;5{<1AigZss1xy108&|*^> z%@I+-Vk6+UQEXA~e(Fy8)vc-Fiarr$oPWv{rBtVb`PAp;Ktt9--EjfdnBYV1nCx0> zhW@32PN=wU8pDx0Y7s=VLswT_%FG*uqc(US=F45qLhegZ0xe=HcY>vqWIkKV?pGgt zqEbGQ^4!X%#E0o0518cVSYKxH)(1FQzL)A!<>K99H!`2@KzcA_c!#B^fJAqt;!sa!)Nuxy>S;2C-b(?wCMyR>UDPd2z=}Y3U~3WW*&}T} zTUMiGIcLR(E)8}`jsCUr>NnT^wA7pt$dfh1OW}v_$K^xw)U2jAA9p9Jw@+PL6^I2$ zUf54rpV)AnZ`DC&?S!U`+VEpO=t;jV988h@tT`p|!rtd@o7x*vlBVFhg!v0nWj`t` z5tLQ*WRz;}**^cL2*Z@7CUi!tqM+)N1_;7o{5LBmAhi5%*I>Xy|1Ok(j{bMc!he<) z{|^hY|F_LwE&M`Nv)KG)iYX!f7W*0b#V9iCToPN2A~c05<)^G{^qmkmynmJ?Xa9-S z&5acS-$0Ww;2IE>=B3VBVQKoHGwGpoIgcOT8v)ORGjVQe<}xEw609DgI1%J-C2vq# zG|S(Y*hqt8;X5qe5=-$RZ&>?oF2O%uqvw70yX7x^ts;~~yH*FwJ~TJxtRvNJy;7Tz zjolwDG1$X|Y!qIO4N{j>7J6mHs2J~@P?nly=H0gJXHV+CHrs9}xs>A< zmLPG?z;O|UYo4`tYfleSF)=c3gxYG) zM-0`e*E57B3ZGF>7uqFIRGIxaiwntW*&Yy^X?xFfS=M2E3|h}P@xJ1|xsc_tefJV< zj4N}%?cVIhI!wj=wDzhsfJ^;fy}eabTwT{ShzHl;794`R6q4Xt5C{<5-Q6JscMa}N zaCaxTLvRXrDcq&&dEY$W_xI?&>oIy?)ZICC&OY1LTyxEd5T9TgLOSR%$*f+NnU-YR z8^QCAJIbYbHOnZC=#;BT=Vi&Hx;mC5U3#3Z}-{& zU*m0s)#DJu=b%}Y9G0@8uXKXhr0y>?YTrW0oIBAHbD-J~fmjXnes@v*JDeZQ&T*%2 zn=naJKcyh05y|m?Gb?UMuO1|@rn$UD9FxRpw&d4Um-UI6osm!JE5gAg0qU)V=Gfi`w}k7)WK>4F8s-d`n(4y_%e$bJCQ&a!87`w`+EZAPjnX z!m5=Rn4tURi*pQIF|Sx0&(Bw8j~V^IV9(g5`3M5i*<>rPBL;plFKN(&@IHIg3?*sx zZ=JR*OAGgm8MqaQO8*v|I?^C~`EE>WdqkTgFSFao>grU@9=l0Zj5VJzY2`QZ0i4p; zmK|1bt@+A55xD)8fjbHf4)PMX`ey3W-nxrs$L~4dqBJLPIP0W2#1vk1fmyIdpXmGS zFuKQt=TA>wr3>WUyCQ7z*onwRd)J>g zLP~AzHsH54A2h$Q{_(}$S}$yiiuLgeJe;0B>@YHAIgpiG&BfwRx4-%UJ7tzraMA{0 zTJA@nubtK7W*)CXMLQX0dih-mmu}d@ELV6(Wbwm_WPV{g z@6=Ru`P1`mu~@~DujbnS^w)%pR~6Qsx`$v|0-4_|x!sNe*ag$G$25RDsU%r1V3Wiq zsMpzVi8bkv`jTFUk8x3hNA!$YimF#_&u2Yfdkxj=4sJW zkCB>E*`)_Xse*jr<&DnTz469I=YQP3}7;HTIZ1-$7bRAoaor7bk+BaA9YO$T~eJNN0HK5T!8Rv3}( z4=P}{_{Pre3{N-{*ziFW#~}fh?#MrFR>Q9x{JK z78xxGrg!}TXqH45v7>b$keXWI)%7ZHsaKj>sfPGwuJqmNq`PlS7g8%|03pUqyQEUM zEypDS^*=8Wz^8R(i{G+(SQE-4QRVuj?HJ?u8Q$)3EosZC9i0(|bNP3;wU;fUb)@A- z)qGXgvSWXXXSk5%a$C{!ulQyMU~36S9|WaHnZ^n_qpbk%@@1-qWJXgt;P}*~r=hj7 zqr-?eNjcF+jsfy+PAl}wKP8Owd;hUihqO3G_6>=8t#k+;)adf2XT-dg#6XlLD~~u; z4{=F|CeE1Sor9s34Ax+s-!ENl3a%|eZ)?e{Z+ypms=kvf>-`nt@3smGIqzYOyZ>Tp zzfyXr+*gv@1LV{v%led3C3WJ=vN9s;#UB^OHwqEKPiqxQ{=?!(%34qhk0?W9`;9lm zZPWY`PU+RRquhPIQaCwp8+e{R5ZN3ZRLV@Pg;V;*l|1tLp5QJWICHOru4bKQB)(s3 zmv_yyk^IR~TMt=39g#FkHiIU!pm~=mr7l^opa9#sGBM8$Pi1P#3k7Pt0@y7t>W$~% zpM2&Kz0l^tzr)N6G)v%Oz;#M)i^|1QWZQPb=RDIH4@8HS zRC1?>dL5WTmvRp`9E_|sL}xR&j?nLoWr;%2B*YD+!?Zl;8%oe(4n%+nR?4HTGW}v5 zVnYEs9}!8+0WGtClL6`iS4{u{5JVHQbOu#pa&8oqmfl`&uelzjfA2=56jM;*?ant{ zb#?uT&tdmf;ES6n=F8{OV%M4KpfQ5Z`Iy`mRK!D9fzM)NSutUt#m$vj?FI8aHc0SQ3EiSE z<15+7w!Rpt0t;P*j?_3@py!J7){&g!hTVWUlzGpeqkRW%0tvR#_MC=IzSYKe|UiF?1_l zfX8*`?ihg;%|vTF-?1#2P*xwxO7fSPs^fn@H31W}av2W3zM#77MI{rtO34^?@lGrr zx(NW=-Bs`DD2^bV%)XMT|-j%WPxv50S~D z`gA9r%t}~Mi%SHlwRhsDJ#h*1w);@0#{zw78l6eyBt1ZmSPENb31(bSEzHy&&x;f@|+)|eeb$YLQCR88>7xeDH2cFU0kS7hcce8h;ZtDFq z4|G*ZyTLgQsbir!(t=v_NBalaIVq^Lv6Gd3LgE}YZ48ZG;hE`!Cb94B8(z)=(5uYV zRTyZYj&51``shcFLmEv2a8P+_w{B-U^+%BfhypyM-r%I2mo5*0y- z3j)w67t^yBKjomO8j~-#(LZDaE?$GINOjoho-lA{P{k!(RWK}#1*ZMG4@G5xI(0zz za0_xL(N7;^ipGjtD4VXr?K+UnWKWzmG3;`$wH2TEO6zS32hiY(?k-w71g3{Qm};Dn zpe7Ry?O;}!*YF+i|NB%}*$7hK=hUP@Akj2-oVQQHa$71$SJ7(h2#`fTLDm>*>Xt#Z zI8$HxgNG&Ubna##TDqk2jBL>=s9{ghJ41GS7 z;{V$4!^05$EvYuv94tKW-JVja2wUjdQj%6@Mmz5JJQaGRQF7gf#y3;1e&@eYwZlhJC0`}jRnf9E`?O_V`jpWl4eebt+ovhw zUT&_dSNrNQ=u>joax1Q8%Q)M6r2|P^v>5mSQT661mBB)ZNoF;+iWDr}g2L%^WZz_; zr=pDlGWeqHOuW9(dJK+JuceQYSpA~Yx2#d|$V+@Nkz3MU8m(`QhAE1>)kk+Ss@R|} z9I!BSPx-l$hRopUM_8rm9ubT&zH{*|nbRop=Z63sDIJPJ@+$ooUumGmc*aGkcu-2t zZ%L0`BHh=;{*%M2>OciL%h3JT5dq>Sx#ZdhrR3L#;<}wDk1@W66o^{l;PR)}L#C8; zg-d%z1AkiZ)i>!LKm>!pnu8Rl3)y4sLLF+dL9!R7Chkff>gSbJ z?E0o%=0f%OzEFWF7uP;1_%K9+_Gnzp>q-q+RiiqjXg=p6eG?V!S?eB7;5pK&{31l@ zhw{Y$?*|$)B)s_Y$DUM9kQn>hxyWT0Q_ZEen?$vp8Ul`K-Or&ak1UZv8Dj2*qldri zXG_=(_a`E6nxHlGqvJazuvK=Wz^~|Qqk2WX<&(UY+{8;pov;cK&C{=>b$THx z3Dnzox|Pq^K~hk*bkVbge89WTvcKWzfbO*{x??{ zm;Qf{EdM7d3%TO|H31A+TK;cvhO-S!LAOOge<8NB0#P#YgdNdvf-*a!@4vyh>^boP zLSLSG9kmJB$!Jb<`srtbE9sgdP71-eO|db!r*}Wvg65sjiMmP(Un8T^wzcRVjVGh@ zY<&+LeJ`0Af7>ycs;h)Ny)gA_;Kv)w{Kh8wEUi!fiePSIlk<=PE45ubm)-cq_d`g$ z3Xq^B!>_|lI9CTISxuc{Wql&A{AR{y;0vQhl3?3ar%l_^>psbk;x6%#sqV9y;7H4j?@l+v@pb&ni>X`kEl*4 zxt}_e50#}GZp|=`AZ)Y2FtLGMXk*0kP)Bf>P-tJJ5afuNG+hc94MTBNdN>paf7Kl7 znRO*vhCw1Vq_8yp=rD~UlbRIvyCp1M#l_wmZPI2%d`E}>*>F$r$Ew%j?U7*P!k_7r zGjHHRjjCW1bMnvu_cP!4$a>?r*)cm{+@?Ev$(xTWQ{t*I78^NQtZdP>dEO}}o6zY% zr#~`nG9yeb6;i#or-Ys!QOj^B7Hb?mL2Xe2Plbt>EVH_v^aV7#4Q~8$e(P*`#=27q z#>0Ma6-bX+J#SY0+ar~6(2!)|ytc2S-$yz(#r@sN%{V=nx_!J_^T^c4@ZZniXKhP? zU>pS(J|b(#`O0KvaOwFHfcNZ^G0JDj5HY~3bpFaX*X(n$7}p~@bG?fy()gqOmvCeX zJ9VreXr$+|Vy;Qg=vwaC8%S%ge^+fYx|=KGl#rR$D(ZSVUi3@>sCyQO#Ivl(f3=3K z_=L&KC)_0{cvNp^1iMj?+e1tSy-~1aFz5hsDm&c4Lmp|3e$Rnrsxh(6SrVfn@xhMcK*d4QX^;zPQf^ZfGA&g0D*BT^4P$cUkAzdU&M zckVgP5~Y0P=|S-jG@;n5OLsp77dE$AN4L@?OAV&vvL`3IIU~iiI-cBgkbdCGOk8}O z-$0_HIp~q^ea`sm8@n_bN=&CpcWNI|Dnd%~0eEx@ z@rY=Umz{v=Ib?baVaoH72>O3wIwz-JZ2|DV`M<&!Tpo-SivObhD=0gZJoDXqH%9i` zXJxjzuRBVhWs`hA0gD72HW<}Ll97pg9?oOa#RA<{t$V zpX7uA!i_hc-s2$D`JYQ(Yv42s@1SHin?gdiU}h7b^JI^P@rifdQB(DCeUuJvn*;Qo z10nT|l5J^1t9JNt23(W|a5;q-E9Ku5M}zAPbi4r(ry&j$VE0e*JH~;y!+tdAYNzV- z(`2VD^OF3BUjSEJRHU(HsviMAwc3N)`@)*sMcEd@L`S?(dY~qjO6^)|eixGBrqkQN zuyOXpFTaFRiop!CtD&pe7J|^C!_@^T^gju38umTe;{kT1PQC{f0KEAAB@V34RSJJX zmZ2FytzRnTcks~17G#QJgKEl5o3xAn)7r}!N?w~DKO!(`(kt0T?&9N*ISwWEmny*dC2$DxKWYrhhoyV#lwc5mow9#_m=lMr?Uwq?Du?7-T~O(}SGyg3O%EvYj*5TymkVIqc(?pIRLGQ;IAviR;TMW) zloQ0i?jE)i4QP5FfA`M}>61x|Q+T!^5H#U0`8YQI1_qEE%!0wOxGL+ zl4)&|cEA!)BHjS^EnRuE)wV6p9XHG!p-d4&yRCENrX6q`<~;)E%Uk$VF-Q^@FCD~L z`vo}yc56xEF6(K`3E#4I>elY-1Inp+ViXV-z`wr zF2C!eQE7j?h_SnpENy{hoFa>qegMwAIy=l2KS&zU_hiamrm%xx?o*L?Tq^%;`&oq$ z`y!1IkoWXptQ1!@-1R-Zq3{mV|8hNz@o=LGcrHrd5>5>8;&EwJOZLGjG{cYQp!VNs zhJ#2cx)e4|XahscB>2VP`a+=o2O{tMS}E|Ub)0z5zDMTR1Ue39f!~hHp$zGbF$J_E z&=~A8IPcs#$;TeEYSCc5(wTB;8)37uq{@}8Lki}QH=5oK68~v-xe96}n=~Uy{oGy} z%PVq>S=y{+?g9<1&vom7+w^(M@SE$zXO-mrK|WUs8L^Cv$g#Z(;BK^hU?V>?gsPJ& z4%Uz?tjgQ70b%JN;Jp6I|8ETe8NxNWkIvBGdSM*N#D9gRy_d9a5UCeDz=@Sw=lx%FU@#WqM9~{3s3z8B$4cx= zz#i9%Z8^C}tMgxFUv9FHcX>07<>;T>cHMO0Cv5VcQ!-onnBamG=?1_*$s~nctVcx? zTY`S6kXGKcC#XTeAffl(eXZ(NdbY@Exfx-B|Gsh9j|y*ZSoQHlyNyAgZ1@Z9ZCCoR zXZrS~(Tg(0v(AhtGxfJw_T{Eq?WaRL!Q3=R684lfed{lCm9nGt2um7G0hBl6yEkzY zx}3Qt-jgfPEgs1|t75-WY;P*q99F-%;71+da4ov^jKc6xZk4lD@Vu@4Yx#TZL%}ys za#<6ShIqdYJJ;fA*~$b66b_BR_=U=+!GO~8D@>1>^;#tQQCrDyYX3d7f)+!`0m&-- zZ(}xpZ&%VYT0t1|6@{9$I*xP?Z3G$U&Q{{14&Xa9-f?Uf$l~r;$N3nezZhE+m|~lR zHAM@fKoOWCS_uxl_2=+LT7McK7UX{OsokrhVinW|sL8^Ddiz$gKX3_@fOige0-ODw64K^iK>=`%|WRSWHqY0sM3_RfXu0*h4ZJb@%=*!BKyxMdu$$|4-L4j#epF_O$h7fPWaLsWXJ;m;W z%6wxpJeo|rER_GorDm>v6;S^POKE-ihQH-?k2DjO< zQ6XB6nx2!BpB(#kQD9t1H-Vd+xB$oZ&vU?6-OePpXjQL9G_T0PGmk)avF$KK)bI0_ z(Z0GpbS4qQfP{H5EixZ3a4;LonMEh*9h;E?kOW9a&_{@e7-cqcZ@)m*t~z^lH{y`s zm|~RLf3Px$Z45;;e_3WTX+O7cx3H<0a5I%4z6vw0cand5mHEE?4fUx@7Duqv<0Vtz z9C7J4z{eMentm;n1x(v|$grc|mOe!Q-EX)6brkps4(B0EfwFbEKm3TCl9(0^YM_%5lPu5k$?Th6{t$7)&afCQ6uwB3eY-1m#8nR;Uo($gS|k*+t)z+mgUY%gS>SGAV@eR161lfI2 zdv7ZaZx7MX(FReQfsjMRr`Py@j+2hlh}dTBJtDx+-wp?aJl0-uux0@7I*vd#&lBuD z;2XtuNQ(P@Khmvlb!f+{ocmF>np5#p)sEd8`fnY-2Mgv8;7k}~O7*L&_c&E=7s>~h zkI_v?^_@L_pzPo>Pr|ojMPhqs$=PqcA$05(t47p=f(nMK7&3K25@(eQrKcBe6gX~@ z#aKIn(4VPN2+}1trXe}>br32vmn!q_QfAfcK?g!UP4b^QGK%wl&gVg`u3K-7OHM39 z-}twOali%GYGomHdgB4x@EYaBW&Nz4~?eZZ_M796K?A{)>Y z4B5X^?=D5NzMn;)?0gj#=`!Pp%C0aT_kst3d4%dE^w!&G2iEF>c9E#=jTgLME4J~K02eSv#~KEv>!18cqafhTW0-aQflXk9S? z_C8oZ#3hEZtVmjpoZM~2C}i~XfXn}ccgpB!#Woq%w$NQlV-$FlAt?^rOC_~-yBjW9 z$R;G|pMT`WCZBfd+PEa1W=uFfDeuvma_r~!33^UQDi|>&-e{HeU+z4XU0-+SF~{&b@7*Y?4OfS|i&-%1{DZS$JS(<14c}v=FP9E`+ zSxXvIFZh;S4Ahaq@jAzZ*_T|~15EmPho@Ox?{)S_%!qDvP-8sj_6nvi*atJ+ZjaL{ zi3dg?-49~Y(6*jXM!3X(b$bYuErMWqSc$g{z33?YV&sTRH|1))sn z*G4;}6@N$?s&>0N*{-fyL=1b^6ce27v1ru4(KRZ9b&QBg9b^a;%XgyCH?cmAmWn8Q zpM(fb?Q`-ddo9cKPt|z8RN7rUPA{~*d9#T~F4Mo0B9dY&y8G;{hjxZ+IoXeG0HoqJ zWukd%#RwBSbC2&wAy--9Ne#7nUSfsC`mmYP?2NKbmMybNTF0KedrLxDzf8YWml7`& z!H%FA+P)@QJq?GSQO~pBSUVlO?TJ5O>t;#H+O8qxK`tCgBIW4*AdSvSET0lWi|L_7 zFPYK2Ok|(Ypt62&E>v60o^A8+V9g1E=`2l}@9v8@vz1+1E1!(kT*F=Ei%VV_p}~6e zP=bDi_L*)Qm)tNV+}Yz*yLByhj0FV;pJl^6*RWt_WQdG{iz-g;`$12y|BQqe6VJe< zgm|`AM*4d#-py$V;4cxaQQd``r`t70MkbQ5QpRS_)5ZyEHV4@c7amO!l&^F;UO4l1}amOj}8t zU-#{UpC$8=v=`1FtR7>JXbJE+l2*oPJ>}cvF*v!mp#0u8Vo1)*^i2(|-=shdR`;=D zQOaIPCY+torUOwDTJ4fFYb&-E0ILXLX_?R-5a6&h0IB?KQ95e3U_%HOPX zgI)2pU&&4voJPBa=3m%8GV!R7nQB}+?nSQl4{Uen)tCJS{;BD~;cF(kRhplg z3|Iw{6yI(fJT#@~M!PBbSz_F%JGK*RMlKuPhVT<*r#yQ)Zo!0ALef%+9LzM|~?(X7DVcQAr$Hog}DJguuc;kPZ_?Fei zQ~UgOSK=)CK{ekEcCCEi43nEl0Q^(?WZaI|=17>mnF(rfz6(s0wk}((zvbCf7log= z?R>6X_YYlP6H?-?IZY)MXb;+PlBXl5A1J);UViYRpS0zzbHhrBjEm3~v5!smluo`$ zf4=aXi1bObTlWyz^oQmmA2zfZ25JvbD;K?{yvs{qBO-pdUF~9D(NP$FTBoS8d(Z?2 z3lH3Nz9JfljoowG>|p@rT`Q%G*w}ryh@ZXZe8k~Wb1LFg=ya*x0+_-T1$y3mO()&s z!^EUE7Vg)dHv^E*Ql6mI7C_*ap! z*vmZlT#@@=Cca!WJ-mN~vyq2)Lv616_L=Y43`n}a(5`j5O{zNm!W}@<5|x|@ zX}|sgT`Ztoe0(AteY;N`MZBnbMQtQW$G%-j4O&?%JX^3qk*?#Tf1L$t7T-+;8!&~r zRZ$A<`HD^eK6Hy;4L1#~{FDgmJFGyuAu##@YYup&gddj~3r&?cnZhE_q2sgGnjlB- zg*!HqC*%xb7ZdGP9S3!p#J5!$`8}SeZXS0iU66mp$eg_I#>|P^8{m+Yh48zPDbK$B zJ~?0rd?$8wNV&&IG%xkuMJC`YU&<;cI8SjrA*7>WpQ_|jiP}T~ToI8^P z^gpj!sSC9IlUYFX`n})_t!(3Z-!beB;gS^W^=+BaLNw~-5p%)#812MDxXxvv0jl0G zjtWl68+C$po)A>N9Yvr*H6S3QH`@{2<%3^Dc& zIRO+&shn5tC(%&hZmZ)lx|u_Yyr6ON=qUXGSPwydira9{shM;3aY;Vq#>h8MJE4%3 zjk#Pi_fliJf!qW_$P~oTep+m^V}6>1s>JXp-T;)Z)g<;#zT0v*7wGU;kC=B{=rhGH>HTdo4j72Jbr7Wm~qfhBxa5aod@{E4*fyFbS#B#xEU-nu&((# zH&{-O@Fjo_>swd+$r~Up`*l6c*zcSj%dLEi947&)M5y`4>IjZFsvPBER+OOQAYvL& zZKD~g{-@G4$S}Ms)MV_G9UmnTR$sUG?K$L+cjk@Y+mX0$s|95l<3t>r3zbN9H2jDD zv_6bZvQhuLv%TUf;7&0y?fDiv4u!DQjW)%VPb3Dy#Mw^5V^(|)7Ik#`*z~U^LA_t3 z=?~|DIL|^<;Ky-0roI3uH|_>A1UAg$vyS7BIrpF7Swso&-w;Inu34F~S_6^9s+?_? zkc9cprHK4P0a^Nn$y>(x;790u#m*tG>^p683(-qr;{nCJA(qLfD=E9@#Qa>jZ_ne_ zXV-lyIKS?hmxF<(tBWB7r!1QRUCeZM9eQ=^{68ELbhm~X?l#x&dEncUFO@G4qO}R+ zwuBTpw;OM-t}i(I7QAw4ven&e_vEFSjBLyD-HYi2VM<=CY5e$<6hb$}K;Bh%=ZN2( z;NO!~B1oR!wBb!160${_3aMbCvs1ZseN}TMPWVBB_M4ev)}L;8SS0=*i8HJLb*X?TCk^vb)$0 z4@wb6?+8;7b3j{6K7Itt-DZnWO%__4+Z2y8$F|XYY-*yU{%pt`D954_HhDdZ(RiD< zZfQQFrxsjJ%OE~V;+vY9U;8lJl-xsiHO#X3YxzH0&Ui||7!+4Fae0>mAOBU*{$#hL z3gRA1m%%$pp|Se`(i5TC&L=Hz`+6+Y1esQ&nTA8IOYW>2fU_$WgGziQAyE71TuOe>o&%dFoFsZ z=XSgR*vpn~#mze;*|k!?Zz5ayxnG=WS^0rGMziB;cP`F?j4QnkhJeyp5>{H?d}R7HtYOrYa-IK-ZnE!j%o(T+h@R*jvRa#&VF|k(OC7(Z ze$@y6^SKM7ii#%h?X$An4N4;4_Z&NRs#dY@dvBO^hGQo$#!iz}{uD9sc*H2&Rvx2g zK7w8MR6(kL^nFtZ^8tQfSAGz!-#UWqiCx`YH|KM%H64q5MyZZm*4=CT3#0qh6cQU~ z*u3uU-q+hydv48ZGh%_h(iTpINClwFNI?zRUbFogH!n-gg&ZB--w@E7KI`{mq)b!T z^GXWJ$eKe3Wk04Z9~aEOq96-fb!AC@jP{F+z=|Z~xvlM(oU_ikSAu?L=~nCarqZsm z4X1*dA97tRgo6XV$D{aKeb+q^G0UwLZD;yMh)9wuY5lZsA(E{H6q;Oam$0VS{d{)T zK++0LG1HxtYRO={a*R}77OH>tlA7|a-hq(o8LxIjm+{?t>cnLDr330f7hL$z+dAmw zos=rbi^Q+)Bv|3z0mfQe4GjfuDF^d|v?vYFd=i_{$C?h&kQ!c4s--Pyx~t^$?Jwsd z3A3~JE_*+8nHO7(D%B+oizzKb4y5yXeN~_`!R{kf8t3e ze^wHyN`3yhW~Mm7Kk2_;H+khFF5byKVWZ$zqe-S721@?Jmzw@F#_tW`dH;3NxYo$d zx+Q?IjO&*9N{M{3LBe?H1Htv};>fbGUJg7HnczYjG(clv-#~xsO)BO{Y#PH{7(2LL zP;lXTd(=RS>~}<#q4$+fdp%v)+H|F|%{$@lb+pRJXAo>mw5;WdVVGpUUKp>hkm0qw zgE|V5Yd@z2NN4G3Okk8t5XdWzVbAB=s5mDjg~(Rxppiqh3e6z#o|S%gHD$Iu*P}Ys zYB_WDYF0C=xh+Wqq%$hkB+`4Wa_oLl)l}>22evZP)}VpdsSMmUi=m&D70|@5jsaE< zM8h4)bNRp)XN(UiAx>t~hkAXj%wd)}Xk!vjwoNgQHg1=TW$^7gI&vy(pWp!K>I|dd z>qZ);B4Pz4$;(2zl#8ixI+A;-lg?}RR;FZnaxOox*?9Kv9JG44w6+Ulo(Cd@pt%vmO!#;4TUlI zqjLQ}*_}?WPE<&j|HOUS-pLX3*vl@85vkMglbMz z?8q-T((ucDgDfR_#;yZAHtdqEHJg3xBS>&8nkX%Gvv3*4)NSuaw)PtH%;+5m?_m*m znhhHS-Z4A#;awI2^&^8i10w1LFKgw7$o6S{0{#C8K;P*PiYFlvxONMq2f6tU{c!54 zHY$KAmV)$|sowU-qqWJn^cNDwqRNep+^9M3irm(5-sY?I7Jrwj*YG0x#Y(?4x-Xh)j%$UZE>srlHB9>@Bz1SRRqiT213 z2dhN_IHM74jupmlZh2i&csXZvQHjlR4#sPf88KGp(I2;Y_E~;eJyp6VAAOVWnWJ3I zr;APEfH@vIK}jaO6I(w`CE+K7N=~@Q-wq$sT8$(}%CP2F_d9dcS;=)qf+^s9&*I&7 z_y}>F8ZyBez_+IOL7W~p4u|qm9`K&F^ruoEf{Jg?#xm#n`*$mDzBVBYgNXo#cj_0t z^WzyO7c6%tUCIJ~hUFEp8XI{-TbpHW3 zm9sKHt;tw_UWmd863wG&VSFP$*b^!chC3Tcj^IK4eI^^7B_*>6DGj(u_tqa4vwF>X>_xy1{stFtte2Hu40&BrDt z^`lv&-pE*-YJ2BH%*m-kd!v=#q@688?@!f{#K+x}<3cCNXtmnczVJ~@;aiZK7srcL zsXLwgFfZGw)ULy3EZ%Hv3x{@~e^GmIbw$IshJnY5#ZT zIDSS#rHa^B2@m>g!ZkGOwe7-#B*Du%6p#JVJcs?FS8(?`@#Xm~A9t?4@VL4{N;~)*)(_o-ah2<9*Hw-dTQKJ{R9#%XgMhW*3{2d=k{z z+(5`_iWP;0qhV%y$d{+kX^oZn4`V)lOO>aXGlCTVF!)-21$o!x%@*Ltt*B2QMmi?# z_W!{>I?Lfi{gx4_$mY*5p7I?Fow`IK;gcp(3lQMkq|_y1s}}TI%y-_}+l$d0oo8x( zySUPW^K*Afb0WW(w(y)FRyunuk0F!LD1KpKzbS%CR;ub07KgvHr0&k>S$*Q7tt_d~ zY?W@{Gb5xRK=rHd*UwWN5qUqPhT5I=H0XGKWrT!}5}da;#MLZ8$0hhjyrXd-;eBRo z$?vT{@eXCSL4#9F38aMJ48!{_5Lr$L4xqP148U~F3CqF^%7h6Xvw%;NzB;XtTQNRyA<5e{PR05j!*X6}d zG}`6h<>cy!87Y@q|^td63PuLu0`HWLG|sNAvAIgTJoDrKBKrPUe+|!{E9JR zPw_HhV2Ma(ZYV3ksMilx$b-#7jPv%2yH#y}c0i*G@k+_QBk8(-&EL^1bALI7QS1KhQ3f5?6@0O$G&pN+Tow8Iro}U@Wzbc~X5e+#_&t_bSE3;n$Yg zc0lOlRE3HAL~8SeZNSq|Zg-+4!9{@wGv`m&JH**X2~rAF+Q{_V8%ZYq=E0TRvgp0r z9Ry#U+&*U*iuLXGPTud#`59iOyV}p~dWJmv*4k^E&C8T8+xENL)qxgQ3)lQshVtQ` z%S}fx10^q7@1{KG;6|-$E5WB^HJm1r{7VjWOM>;`5TUkeqJm7)3ysWX5;Em?oG_^B zUs@`#aRN@-i!~PO4k{XAz1;+uD>_im+J2RYyw}3xzY5M{bkxgOkXvt%@8)J#OW%XR z;lWpnY50XDYfzM)2utJmpHP#n-QoU|2yh$hn{-`VkX0r_lPp}6wV zR;|m2Zql|j9luPyc6Vx_>hAyiV!OTdHgEi_U|R~C%UMBZE$!S5zS3$*B;2>3luxugXV)*0=o+fj>G2 zJvRGz!?H7C`CZrv{dt!_T0gKVj0QiG>ko1kb^%mwddhPwt1!~9e8VdN9 zT-9D5V0vNkiX85&?%U8D?R>L>Dw{Ef4fP9tPpCcya)b*yz`Iq<+7lwg8MU2DK%-vz z-Xr-nKN-FHd|AqgW}JH@_*nrh&9F1rl(ZdgHOgav#2I2|y0YJ~#zq_UWaOX0ZgF4+ zqA3P93)6@xq>Km;KeE@U`hNr!8o?$T+DN zp3iX=&Puq2*S#9Lv0iXbd!J07G~0NEGX>C=!QhJVu;j40Uc+yMT;FP=w0dSeTB`a2$Pc zCwY(;#Nm^mq@Av-L!V$`xNKUe&NpEOgKhYDb87fUTkiG>nluO=2f`{-)RjPCcPJw=(%U>yv567s2d`}s8EQx zI-;{1NU}Fctg%+Z(pbP}pSInyE6cr^5&&63(4iKSv6cxmy7oC6ztHeCw8e~FLHj3U zQvz2?Q#2Q}NoA3(Z2BQ4n1^S6FA`8vWvCWYO6A zvXa+zZZ0?IXyv4FUHSP233+)=B(h(*EjFY-0al0xgWRVwXXA_}BlLi{N~p{=SLM8S z=BvKstf#keKBFOm_W9`vQokQXH(9F;;g5@H`TdKl(=B z`c&&ZuZaHU<=kJ^zFU-`lmea;lxYw=jCX&u#~MP!Ie2_#nbxkRCCbS z$dK7qDAc>RY0=|1p^3-6=JWCRJ`CFzAtw&g-^&Q1q`H~`#2#gh{^eT934Y_Lub3AO zZ0_GT;$(Szo#QBHqm^a%f5!ie3t${aHz0pncB}E(_i2AJng0I70@WxQZs!*dzu;ky z%E?b)y*El9DU+=rc?Yy?V%?7t3PYAzNsYe}6v@gEQwx3v^t>Rg$n`vyJyL|47fr%V zo@sr+-*?3v&3GjdNfn3~vobYU{9x>P0dH0EXiuYH@S!4hODnr*=#-)XAYU>H$LI|A z@!Q*(Y}#O=vOrZvi!Da!rVbaWdjiBAB@E~jzGC}L+%`BNRHHBY1-Jc021VnFoQAb7 z)EaiIo7U3pH7@GOXIO~y@d)LDi7edjlEN)8`D$?m_qyeAaw}KAN98@AS==ET+sZK? zNlC%wiO`07V8~wT*JmG_@)C%=ALzN5-hR9{bCdQbThFZk`Xi&ej~2~dv*QTZJ^Y00 z0}-urZS?Y*u<6dvEk0c_oo(4)q7ltju7oXw2)X&H)^@smk43O!!(QPl#Vn>HC2T#7 zXzfqKRV6RKz_;_f9z9CR zK~u5oYb*=Nm>m<@wkQ3YKGGOa&$dYBF&>ugB zh@=c>!JW4hi`+&EHh^_3OWdopQp#QbvoPQnd4aHSiQ;h4#=JUxOJx9a`d*E&ujB$` zx9*VU`SciEUjHAWei#ng1_ru*NLq;u+rg=q+^zds|7riv{%CFWbZIC*FUS3lGfSc? zpCS;@{D&?OZf|PpW@w&HSjb;p_10NWRGZ)K)I?Tz0YskJzp)PU6~E0mMApyFb2Mvm ztH)5qB~Yv`fI=!-_ShYv`S7ptKy{ z;_`WmN;tl{=a;^v2w?o`7$> zXg#pbhOxF|rUP@zMJvBbyuo2tH~pOke}(8mA-8;B@ls%TKo7UgZ}cEqvTzl(!ML(k z!F>hKoKQi*%&AXqZ-5GS^s>cncf6Gm9cpK#MsE%0pNQC4oorfi|JBNaxZKk%e_v=L z0sj<^W@M2+zd{DRKFg91PO1D=1cXF5GSYvg{~%=&)c<>>IQG91dXP;R>3{X|;sXDD zOo%Cm=>M=j#i=R&S9c*k_OEW<|JU2k8|$SQHgL`e-zt;+6@r1tVM@+#QpJ0&{pZBc s*+2N7W5Pc_>-w7iYeW41=Oe$83(5`>&9UpELP4601W>$ERR72S1tr@~ literal 0 HcmV?d00001 From 8f1fa073257ccf26628098bca5985238ff495a11 Mon Sep 17 00:00:00 2001 From: Yong Wang Date: Sat, 25 Mar 2023 00:21:07 +0800 Subject: [PATCH 09/13] Improve the performance of radix top-k (#1175) The main changes are: - Add a one-block version. It uses single thread block for one row of a batch and is used when `len` is relatively small (<= 16384) - Avoid writing candidates to buffers when the number of candidates is larger than buffer length. - Add a parameter to control whether to use a fused filter in the last pass or use a standalone filter kernel. The later case is preferable when the leading bits of inputs are almost same. - Early stopping: when the target bucket contains `k` values, we can stop the computation earlier - Many implementation details are polished, like the initialization of `counter`, calculation of kernel launch parameters, and the scan step - Tests and benchmarks are updated to include the new implementations. New benchmarks are added to demonstrate the advantage of adaptive version. Authors: - Yong Wang (https://github.com/yong-wang) - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1175 --- cpp/bench/matrix/select_k.cu | 128 +- cpp/include/raft/matrix/detail/select_k.cuh | 2 +- .../raft/matrix/detail/select_radix.cuh | 1149 ++++++++++++----- cpp/include/raft/spatial/knn/knn.cuh | 4 +- .../raft_internal/matrix/select_k.cuh | 43 +- cpp/test/matrix/select_k.cu | 12 +- 6 files changed, 975 insertions(+), 363 deletions(-) diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index d4873e2640..870119db52 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -35,6 +35,10 @@ #include #include +#include +#include +#include + namespace raft::matrix { using namespace raft::bench; // NOLINT @@ -50,7 +54,23 @@ struct selection : public fixture { { raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + + KeyT min_value = -1.0; + KeyT max_value = 1.0; + if (p.use_same_leading_bits) { + if constexpr (std::is_same_v) { + uint32_t min_bits = 0x3F800000; // 1.0 + uint32_t max_bits = 0x3F8000FF; // 1.00003 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } else if constexpr (std::is_same_v) { + uint64_t min_bits = 0x3FF0000000000000; // 1.0 + uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } + } + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); } void run_benchmark(::benchmark::State& state) override // NOLINT @@ -60,6 +80,7 @@ struct selection : public fixture { try { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } state.SetLabel(label_stream.str()); loop_on_state(state, [this, &handle]() { select::select_k_impl(handle, @@ -85,21 +106,55 @@ struct selection : public fixture { }; const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, + {20000, 500, 1, true}, + {20000, 500, 2, true}, + {20000, 500, 4, true}, + {20000, 500, 8, true}, + {20000, 500, 16, true}, + {20000, 500, 32, true}, + {20000, 500, 64, true}, + {20000, 500, 128, true}, + {20000, 500, 256, true}, + + {1000, 10000, 1, true}, + {1000, 10000, 2, true}, + {1000, 10000, 4, true}, + {1000, 10000, 8, true}, + {1000, 10000, 16, true}, + {1000, 10000, 32, true}, + {1000, 10000, 64, true}, + {1000, 10000, 128, true}, + {1000, 10000, 256, true}, + + {100, 100000, 1, true}, + {100, 100000, 2, true}, + {100, 100000, 4, true}, + {100, 100000, 8, true}, + {100, 100000, 16, true}, + {100, 100000, 32, true}, + {100, 100000, 64, true}, + {100, 100000, 128, true}, + {100, 100000, 256, true}, + + {10, 1000000, 1, true}, + {10, 1000000, 2, true}, + {10, 1000000, 4, true}, + {10, 1000000, 8, true}, + {10, 1000000, 16, true}, + {10, 1000000, 32, true}, + {10, 1000000, 64, true}, + {10, 1000000, 128, true}, + {10, 1000000, 256, true}, + + {10, 1000000, 1, true, false, true}, + {10, 1000000, 2, true, false, true}, + {10, 1000000, 4, true, false, true}, + {10, 1000000, 8, true, false, true}, + {10, 1000000, 16, true, false, true}, + {10, 1000000, 32, true, false, true}, + {10, 1000000, 64, true, false, true}, + {10, 1000000, 128, true, false, true}, + {10, 1000000, 256, true, false, true}, }; #define SELECTION_REGISTER(KeyT, IdxT, A) \ @@ -109,24 +164,27 @@ const std::vector kInputs{ RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT - -SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT - -SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index ac1ba3dfa3..20c2fb119d 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -84,7 +84,7 @@ void select_k(const T* in_val, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); } } diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 643a63d9db..7ac40ac0eb 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -16,11 +16,11 @@ #pragma once -#include - -#include #include #include +#include +#include +#include #include #include #include @@ -35,8 +35,8 @@ #include namespace raft::matrix::detail::select::radix { +namespace impl { -constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template @@ -51,13 +51,6 @@ _RAFT_HOST_DEVICE constexpr int calc_num_passes() return ceildiv(sizeof(T) * 8, BitsPerPass); } -// Minimum reasonable block size for the given radix size. -template -_RAFT_HOST_DEVICE constexpr int calc_min_block_size() -{ - return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); -} - /** * Bit 0 is the least significant (rightmost); * this implementation processes input from the most to the least significant bit. @@ -82,23 +75,43 @@ _RAFT_DEVICE constexpr unsigned calc_mask(int pass) } /** - * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well * as of integers. */ template -_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); - if (greater) { bits = ~bits; } + if (!select_min) { bits = ~bits; } return bits; } +template +_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) +{ + if (!select_min) { bits = ~bits; } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + template -_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) +{ + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) { - static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int - return (twiddle_in(x, greater) >> start_bit) & mask; + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) + // and `out_idx_buf`(IdxT). + // The ratio between these cases determines whether to skip writing and hence the buffer size. + constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + return len / ratio; } /** @@ -111,17 +124,18 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template -_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process( + size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) { - const IdxT stride = blockDim.x * gridDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { - for (IdxT i = tid; i < len; i += stride) { + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { @@ -134,8 +148,8 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); // The main loop: process all aligned data - for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; - i += stride * wide_t::Ratio) { + for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += num_threads * wide_t::Ratio) { wide.load(in, i); #pragma unroll for (int j = 0; j < wide_t::Ratio; ++j) { @@ -145,30 +159,55 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) static_assert(WarpSize >= wide_t::Ratio); // Processes the skipped elements on the left - if (tid < skip_cnt_left) { f(in[tid], tid); } + if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } // Processes the skipped elements on the right const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); - const IdxT remain_i = len - skip_cnt_right + tid; + const IdxT remain_i = len - skip_cnt_right + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } template -struct Counter { +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to least significant. In + // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at + // the end of the pass. IdxT k; IdxT len; + + // `previous_len` is the length of input in previous pass. Note that `previous_len` rather + // than `len` is used for the filtering step because filtering is indeed for previous pass (see + // comments before `radix_kernel`). IdxT previous_len; - int bucket; - IdxT filter_cnt; - unsigned int finished_block_cnt; - IdxT out_cnt; - IdxT out_back_cnt; + // We determine the bits of the k_th value inside the mask processed by the pass. The + // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a + // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful + // (discarded). The bits that are not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the position in the + // `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This counter is used to + // determine if the current block is the last running block. If so, this block will execute scan() + // and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements less (if + // select_min==true) than the k-th value are written from front to back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements equal to the k-th + // value are written from back to front. We need to keep count of them separately because the + // number of elements that <= the k-th value might exceed k. + alignas(128) IdxT out_back_cnt; }; /** - * Fused filtering of the current phase and building histogram for the next phase - * (see steps 4-1 in `radix_kernel` description). + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). */ template _RAFT_DEVICE void filter_and_histogram(const T* in_buf, @@ -177,12 +216,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, - IdxT len, + IdxT previous_len, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass, - int k) + bool early_stop) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -198,19 +237,20 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. - auto f = [greater, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); }; - vectorized_process(in_buf, len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } else { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - IdxT& filter_cnt = counter->filter_cnt; - IdxT& out_cnt = counter->out_cnt; - const IdxT counter_len = counter->len; + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; const int previous_start_bit = calc_start_bit(pass - 1); - const unsigned previous_mask = calc_mask(pass - 1); // See the remark above on the distributed execution of `f` using vectorized_process. auto f = [in_idx_buf, @@ -218,38 +258,50 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx_buf, out, out_idx, - greater, - k, + select_min, start_bit, mask, previous_start_bit, - previous_mask, - want_bucket, - &filter_cnt, - &out_cnt, - counter_len](T value, IdxT i) { - int prev_bucket = - calc_bucket(value, previous_start_bit, previous_mask, greater); - if (prev_bucket == want_bucket) { - IdxT pos = atomicAdd(&filter_cnt, IdxT(1)); - out_buf[pos] = value; - if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); - - if (counter_len == 1) { - out[k - 1] = value; - out_idx[k - 1] = in_idx_buf ? in_idx_buf[i] : i; + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); } - } else if (prev_bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to + // `out` too. So we won't write the same value to `out` multiple times in different passes. + // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at + // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; - - vectorized_process(in_buf, previous_len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } + if (early_stop) { return; } __syncthreads(); // merge histograms produced by individual blocks @@ -259,69 +311,184 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, } /** - * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding - * `current` to each entry of the result. + * Replace histogram with its own prefix sum * (step 2 in `radix_kernel` description) */ template -_RAFT_DEVICE void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; - IdxT thread_data = 0; - int index = start + threadIdx.x; - if (index < num_buckets) { thread_data = histogram[index]; } + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; - BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); - __syncthreads(); - if (index < num_buckets) { histogram[index] = thread_data + current; } - __syncthreads(); // This sync is necessary, as the content of histogram needs - // to be read after + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } + } } /** * Calculate in which bucket the k-th value will fall - * (steps 2-3 in `radix_kernel` description) + * (steps 3 in `radix_kernel` description) */ -template -_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +template +_RAFT_DEVICE void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, + const int pass) { constexpr int num_buckets = calc_num_buckets(); - int index = threadIdx.x; - IdxT last_prefix_sum = 0; - int num_pass = 1; - if constexpr (num_buckets >= BlockSize) { - static_assert(num_buckets % BlockSize == 0); - num_pass = num_buckets / BlockSize; + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } } +} - for (int i = 0; i < num_pass && (last_prefix_sum < k); i++) { - // Turn the i-th chunk of the histogram into its prefix sum. - scan(histogram, i * BlockSize, num_buckets, last_prefix_sum); - if (index < num_buckets) { - // Number of values in the previous `index-1` buckets (see the `scan` op above) - IdxT prev = (index == 0) ? 0 : histogram[index - 1]; - // Number of values in `index` buckets - IdxT cur = histogram[index]; - - // one and only one thread will satisfy this condition, so only write once - if (prev < k && cur >= k) { - counter->k = k - prev; // how many values still are there to find - counter->previous_len = counter->len; - counter->len = cur - prev; // number of values in `index` bucket - counter->bucket = index; +// For one-block version, last_filter() could be called when pass < num_passes - 1. +// So `pass` could not be constexpr +template +_RAFT_DEVICE void last_filter(const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + const bool select_min, + const int pass) +{ + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if + // `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } - index += BlockSize; - // this will break the loop when the counter is set (cur >= k), because last_prefix_sum >= cur - last_prefix_sum = histogram[(i + 1) * BlockSize - 1]; } } +template +__global__ void last_filter_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, + const bool select_min) +{ + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, + select_min, + kth_value_bits, + needed_num_of_kth, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -350,35 +517,79 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, cons * * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. + * + * During the filtering step, we won't write candidates (elements in bucket j) to `out_buf` if the + * number of candidates is larger than the length of `out_buf` (this could happen when the leading + * bits of input values are almost the same). And then in the next pass, inputs are read from `in` + * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and + * their indices. */ -template -__global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counters, - IdxT* histograms, - const IdxT len, - const int k, - const bool greater, - const int pass) +template +__global__ void radix_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT k, + const bool select_min, + const int pass) { - __shared__ bool isLastBlockDone; + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is correct. + // This value is meaningless for pass 0, but it's fine because pass 0 won't be the + // last pass in this implementation so pass 0 won't hit the "if (pass == + // num_passes - 1)" branch. + // Maybe it's better to reload counter->previous_len and use it rather than + // current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { return; } - constexpr int num_buckets = calc_num_buckets(); - constexpr int num_passes = calc_num_passes(); - const int batch_id = blockIdx.y; - in_buf += batch_id * len; - out_buf += batch_id * len; + // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle + // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is + // handled in other way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else { + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + } out += batch_id * k; out_idx += batch_id * k; - if (in_idx_buf) { in_idx_buf += batch_id * len; } - if (out_idx_buf) { out_idx_buf += batch_id * len; } - auto counter = counters + batch_id; - auto histogram = histograms + batch_id * num_buckets; + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; filter_and_histogram(in_buf, in_idx_buf, @@ -386,126 +597,468 @@ __global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, out_idx_buf, out, out_idx, - len, + previous_len, counter, histogram, - greater, + select_min, pass, - k); + early_stop); __threadfence(); + bool isLastBlock = false; if (threadIdx.x == 0) { unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); - isLastBlockDone = (finished == (gridDim.x - 1)); + isLastBlock = (finished == (gridDim.x - 1)); } - // Synchronize to make sure that each thread reads the correct value of - // isLastBlockDone. - __syncthreads(); - if (isLastBlockDone) { - if (counter->len == 1 && threadIdx.x == 0) { - counter->previous_len = 0; - counter->len = 0; - } - // init counter, other members of counter is initialized with 0 by - // cudaMemset() - if (pass == 0 && threadIdx.x == 0) { - counter->k = k; - counter->len = len; - counter->out_back_cnt = 0; + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; } + + scan(histogram); __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + constexpr int num_passes = calc_num_passes(); + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter(out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : len, + k, + counter, + select_min, + pass); + } + } + } +} + +template +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) +{ + int active_blocks; + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); + + constexpr int items_per_thread = 32; + constexpr int num_waves = 10; + int chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + return std::min(chunk_size, batch_size); +} - IdxT ori_k = counter->k; +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) +{ + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0)); + active_blocks *= sm_cnt; - if (counter->len > 0) { - choose_bucket(counter, histogram, ori_k); + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop early, + // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; } - __syncthreads(); - if (pass == num_passes - 1) { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - int start_bit = calc_start_bit(pass); - unsigned mask = calc_mask(pass); - - // radix topk - IdxT& out_cnt = counter->out_cnt; - for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = out_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, greater); - if (bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } else if (bucket == want_bucket) { - IdxT needed_num_of_kth = counter->k; - IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1)); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } - } + if (num_blocks == max_num_blocks) { break; } + } + return best_num_blocks; +} + +template +_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +template +void radix_topk(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + unsigned grid_dim, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + auto kernel = radix_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + if (max_chunk_size != static_cast(batch_size)) { + grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); + } + const IdxT buf_len = calc_buf_len(len); + + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector> counters(max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + + const T* chunk_in = in + offset * len; + const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(grid_dim, chunk_size); + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(chunk_in, + chunk_in_idx, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data(), + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + if (fused_last_filter && pass == num_passes - 1) { + kernel = radix_kernel; } - __syncthreads(); - } else { - // reset for next pass - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram[i] = 0; + + kernel<<>>(chunk_in, + chunk_in_idx, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + counters.data(), + histograms.data(), + len, + k, + select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + if (!fused_last_filter) { + last_filter_kernel<<>>(chunk_in, + chunk_in_idx, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + len, + k, + counters.data(), + select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } +} + +// The following a few functions are for the one-block version, which uses single thread block for +// each row of a batch. +template +_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT* p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { *p_filter_cnt = 0; } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; + + if (pass == 0) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - if (threadIdx.x == 0) { counter->filter_cnt = 0; } } } } -/** - * Calculate the minimal batch size, such that GPU is still fully occupied. - */ template -inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) +__global__ void radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, + const IdxT len, + const IdxT k, + T* out, + IdxT* out_idx, + const bool select_min, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2) { - int dev_id, sm_count, occupancy, max_grid_dim_y; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&max_grid_dim_y, cudaDevAttrMaxGridDimY, dev_id)); - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, radix_kernel, BlockSize, 0)); - - // number of block we'd use if the batch size is enough to occupy the gpu in any case - size_t blocks_per_row = ceildiv(len, BlockSize * ITEM_PER_THREAD); - - // fully occupy GPU - size_t opt_batch_size = ceildiv(sm_count * occupancy, blocks_per_row); - // round it up to the closest pow-of-two for better data alignment - opt_batch_size = isPo2(opt_batch_size) ? opt_batch_size : (1 << (log2(opt_batch_size) + 1)); - // Take a max possible pow-of-two grid_dim_y - max_grid_dim_y = isPo2(max_grid_dim_y) ? max_grid_dim_y : (1 << log2(max_grid_dim_y)); - // If the optimal batch size is very small compared to the requested batch size, we know - // the extra required memory is not significant and we can increase the batch size for - // better occupancy when the grid size is not multiple of the SM count. - // Also don't split the batch size when there is not much work overall. - const size_t safe_enlarge_factor = 9; - const size_t min_grid_size = 1024; - while ((opt_batch_size << safe_enlarge_factor) < req_batch_size || - blocks_per_row * opt_batch_size < min_grid_size) { - opt_batch_size <<= 1; + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; } + __syncthreads(); + + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + in += batch_id * len; + if (in_idx) { in_idx += batch_id * len; } + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers( + in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + IdxT current_len = counter.len; + IdxT current_k = counter.k; + + filter_and_histogram_for_one_block(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + &counter, + histogram, + select_min, + pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } + __syncthreads(); - // Do not exceed the max grid size. - opt_batch_size = std::min(opt_batch_size, size_t(max_grid_dim_y)); - // Don't do more work than needed - opt_batch_size = std::min(opt_batch_size, req_batch_size); - // Let more blocks share one row if the required batch size is too small. - while (opt_batch_size * blocks_per_row < size_t(sm_count * occupancy) && - // Ensure we still can read data somewhat efficiently - len * sizeof(T) > 2 * VECTORIZED_READ_SIZE * BlockSize * blocks_per_row) { - blocks_per_row <<= 1; + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? in_idx : out_idx_buf, + out, + out_idx, + current_len, + k, + &counter, + select_min, + pass); + break; + } } +} - return dim3(blocks_per_row, opt_batch_size); +// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following +// one-block version uses single thread block for one row of a batch, so intermediate data, like +// counters and global histograms, can be kept in shared memory and cheap sync operations can be +// used. It's used when len is relatively small or when the number of blocks per row calculated by +// `calc_grid_dim()` is 1. +template +void radix_topk_one_block(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + static_assert(calc_num_passes() > 1); + + auto kernel = radix_topk_one_block_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + + auto pool_guard = + raft::get_pool_memory_resource(mr, + max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + + 256 * 4 // might need extra memory for alignment + ); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + kernel<<>>(in + offset * len, + in_idx ? (in_idx + offset * len) : nullptr, + len, + k, + out + offset * k, + out_idx + offset * k, + select_min, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data()); + } } +} // namespace impl + /** * Select k smallest or largest key/values from each row in the input data. * @@ -546,6 +1099,12 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * the payload selected together with `out`. * @param select_min * whether to select k smallest (true) or largest (false) keys. + * @param fused_last_filter + * when it's true, the last filter is fused into the kernel in the last pass and only one thread + * block will do the filtering; when false, a standalone filter kernel with multiple thread + * blocks is called. The later case is preferable when leading bits of input data are almost the + * same. That is, when the value range of input data is narrow. In such case, there could be a + * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). @@ -553,109 +1112,65 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) template void select_k(const T* in, const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, + int batch_size, + IdxT len, + IdxT k, T* out, IdxT* out_idx, bool select_min, + bool fused_last_filter, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { - // reduce the block size if the input length is too small. - if constexpr (BlockSize > calc_min_block_size()) { - if (BlockSize * ITEM_PER_THREAD > len) { - return select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + if (k == len) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + if (in_idx) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + } else { + auto out_idx_view = + raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); + raft::device_resources handle(stream); + raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); } + return; } - // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); - - dim3 blocks = get_optimal_grid_size(batch_size, len); - size_t max_chunk_size = blocks.y; - - size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); - size_t mem_req = req_aux + req_buf; - size_t mem_free, mem_total; - RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); - std::optional managed_memory; - rmm::mr::device_memory_resource* mr_buf = nullptr; - if (mem_req > mem_free) { - // if there's not enough memory for buffers on the device, resort to the managed memory. - mem_req = req_aux; - managed_memory.emplace(); - mr_buf = &managed_memory.value(); - } - - auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); + // TODO: use device_resources::get_device_properties() instead; should change it when we refactor + // resource management + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } - if (mr_buf == nullptr) { mr_buf = mr; } - - rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); - for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { - blocks.y = std::min(max_chunk_size, batch_size - offset); + constexpr int items_per_thread = 32; - RAFT_CUDA_TRY( - cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; - - constexpr int num_passes = calc_num_passes(); - - for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in + offset * len; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in + offset * len; - in_idx_buf = in_idx ? in_idx + offset * len : nullptr; - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } else if (pass % 2 == 0) { - in_buf = buf1.data(); - in_idx_buf = idx_buf1.data(); - out_buf = buf2.data(); - out_idx_buf = idx_buf2.data(); - } else { - in_buf = buf2.data(); - in_idx_buf = idx_buf2.data(); - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } - - radix_kernel - <<>>(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out + offset * k, - out_idx + offset * k, - counters.data(), - histograms.data(), - len, - k, - !select_min, - pass); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (len <= BlockSize * items_per_thread) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + unsigned grid_dim = + impl::calc_grid_dim(batch_size, len, sm_cnt); + if (grid_dim == 1) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 692d262043..a7bbfd9500 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -153,12 +153,12 @@ template case SelectKAlgo::RADIX_8_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::RADIX_11_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::WARP_SORT: diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index ede6382c33..188122c9b4 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -33,7 +33,8 @@ struct params { size_t len; int k; bool select_min; - bool use_index_input = true; + bool use_index_input = true; + bool use_same_leading_bits = false; }; inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& @@ -42,7 +43,8 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << ", len: " << ss.len; os << ", k: " << ss.k; os << (ss.select_min ? ", asc" : ", dsc"); - os << (ss.use_index_input ? "}" : ", no-input-index}"); + os << (ss.use_index_input ? "" : ", no-input-index"); + os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}"); return os; } @@ -50,6 +52,7 @@ enum class Algo { kPublicApi, kRadix8bits, kRadix11bits, + kRadix11bitsExtraPass, kWarpAuto, kWarpImmediate, kWarpFiltered, @@ -63,6 +66,7 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& case Algo::kPublicApi: return os << "kPublicApi"; case Algo::kRadix8bits: return os << "kRadix8bits"; case Algo::kRadix11bits: return os << "kRadix11bits"; + case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; case Algo::kWarpAuto: return os << "kWarpAuto"; case Algo::kWarpImmediate: return os << "kWarpImmediate"; case Algo::kWarpFiltered: return os << "kWarpFiltered"; @@ -103,11 +107,38 @@ void select_k_impl(const device_resources& handle, } } case Algo::kRadix8bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); case Algo::kRadix11bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); + case Algo::kRadix11bitsExtraPass: + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + false, // fused_last_filter + stream); case Algo::kWarpAuto: return detail::select::warpsort::select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 392464eb27..2a40d70abc 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -332,6 +332,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Values(select::Algo::kPublicApi, select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); @@ -426,6 +427,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -440,6 +442,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -451,7 +454,11 @@ TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleInt, - testing::Combine(inputs_random_largesize, testing::Values(select::Algo::kWarpAuto))); + testing::Combine(inputs_random_largesize, + testing::Values(select::Algo::kWarpAuto, + select::Algo::kRadix8bits, + select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = SelectK::params_random>; @@ -459,6 +466,7 @@ TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits))); + testing::Values(select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); } // namespace raft::matrix From 9389108b7f7f100c883092a46f09738f23ab8151 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 24 Mar 2023 13:42:03 -0400 Subject: [PATCH 10/13] RAFT skeleton project template (#1312) This is a copy and modification of a user's project but I think this is going to be generally useful to users as the same types of challenges are going to come up again. In this case, the user wasn't able to build/link because they weren't using `rapids-cmake` to propagate important configuration settings. I think having a skeleton project available that we build in CI and keep up to date will help new users build more applications on RAFT. TODO: - [x] Make building the template optional - [x] Verify this can build in CMake and reuse already built/installed bits - [x] Add to docs / readme and reference in README.md - [x] Add a little example of invoking an API (maybe `pairwise_distances`?) to `main()` Authors: - Corey J. Nolet (https://github.com/cjnolet) - Ben Frederickson (https://github.com/benfred) Approvers: - Micka (https://github.com/lowener) - Dante Gama Dessavre (https://github.com/dantegd) - Divye Gala (https://github.com/divyegala) - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/raft/pull/1312 --- .github/workflows/build.yaml | 1 + README.md | 61 ++++-------------- build.sh | 15 ++++- .../recipes/libraft/build_libraft_template.sh | 5 ++ conda/recipes/libraft/meta.yaml | 36 +++++++++++ cpp/template/CMakeLists.txt | 38 +++++++++++ cpp/template/README.md | 18 ++++++ cpp/template/build.sh | 41 ++++++++++++ .../cmake/thirdparty/fetch_rapids.cmake | 21 ++++++ cpp/template/cmake/thirdparty/get_raft.cmake | 62 ++++++++++++++++++ cpp/template/src/test_distance.cu | 42 ++++++++++++ docs/source/build.md | 64 ++++++++----------- python/pylibraft/setup.cfg | 38 +++++++++++ setup.cfg | 55 ++++++++++++++++ 14 files changed, 409 insertions(+), 88 deletions(-) create mode 100644 conda/recipes/libraft/build_libraft_template.sh create mode 100644 cpp/template/CMakeLists.txt create mode 100644 cpp/template/README.md create mode 100755 cpp/template/build.sh create mode 100644 cpp/template/cmake/thirdparty/fetch_rapids.cmake create mode 100644 cpp/template/cmake/thirdparty/get_raft.cmake create mode 100644 cpp/template/src/test_distance.cu create mode 100644 python/pylibraft/setup.cfg create mode 100644 setup.cfg diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 41b6a639d8..32aab5656b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -52,6 +52,7 @@ jobs: branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} + skip_upload_pkgs: libraft-template docs-build: if: github.ref_type == 'branch' && github.event_name == 'push' needs: python-build diff --git a/README.md b/README.md index 8c6e817bf9..81973ff82e 100755 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ #

 RAFT: Reusable Accelerated Functions and Tools
- ![Navigating the canyons of accelerated possibilities](img/raft.png) ## Resources @@ -85,9 +84,9 @@ raft::device_resources handle; int n_samples = 5000; int n_features = 50; -auto input = raft::make_device_matrix(handle, n_samples, n_features); -auto labels = raft::make_device_vector(handle, n_samples); -auto output = raft::make_device_matrix(handle, n_samples, n_samples); +auto input = raft::make_device_matrix(handle, n_samples, n_features); +auto labels = raft::make_device_vector(handle, n_samples); +auto output = raft::make_device_matrix(handle, n_samples, n_samples); raft::random::make_blobs(handle, input.view(), labels.view()); @@ -222,52 +221,15 @@ pip install raft-dask-cu11 --extra-index-url=https://pypi.ngc.nvidia.com ### CMake & CPM -RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library, which makes it simple to include in downstream cmake projects. RAPIDS CMake provides a convenience layer around CPM. - -After [installing](https://github.com/rapidsai/rapids-cmake#installation) rapids-cmake in your project, you can begin using RAFT by placing the code snippet below in a file named `get_raft.cmake` and including it in your cmake build with `include(get_raft.cmake)`. This will make available several targets to add to configure the link libraries for your artifacts. - -```cmake - -set(RAFT_VERSION "22.12") -set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") - -function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARIES) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - - #----------------------------------------------------- - # Invoke CPM find_package() - #----------------------------------------------------- - - rapids_cpm_find(raft ${PKG_VERSION} - GLOBAL_TARGETS raft::raft - BUILD_EXPORT_SET projname-exports - INSTALL_EXPORT_SET projname-exports - CPM_ARGS - GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git - GIT_TAG ${PKG_PINNED_TAG} - SOURCE_SUBDIR cpp - OPTIONS - "BUILD_TESTS OFF" - "BUILD_BENCH OFF" - "RAFT_COMPILE_LIBRARIES ${PKG_COMPILE_LIBRARIES}" - ) - -endfunction() - -# Change pinned tag here to test a commit in CI -# To use a different RAFT locally, set the CMake variable -# CPM_raft_SOURCE=/path/to/local/raft -find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} - COMPILE_LIBRARIES NO -) -``` +RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library, which makes it easy to include in downstream cmake projects. RAPIDS-CMake provides a convenience layer around CPM. Please refer to [these instructions](https://github.com/rapidsai/rapids-cmake#installation) to install and use rapids-cmake in your project. + +#### Example Template Project + +You can find an [example RAFT](cpp/template/README.md) project template in the `cpp/template` directory, which demonstrates how to build a new application with RAFT or incorporate RAFT into an existing cmake project. + +#### CMake Targets -Several CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. +Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. | Component | Target | Description | Base Dependencies | |-------------|---------------------|-----------------------------------------------------------|---------------------------------------| @@ -321,6 +283,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - `src`: Compiled APIs and template specializations for the shared libraries + - `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT. - `test`: Googletests source code - `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs) - `python`: Source code for Python libraries. diff --git a/build.sh b/build.sh index b5a72f4205..9468d2cab0 100755 --- a/build.sh +++ b/build.sh @@ -18,7 +18,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-lib --allgpuarch --no-nvtx --show_depr_warn -h" +VALIDARGS="clean libraft pylibraft raft-dask docs tests bench template clean --uninstall -v -g -n --compile-lib --allgpuarch --no-nvtx --show_depr_warn -h" HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench=] where is: clean - remove all existing build artifacts and configuration (start over) @@ -29,6 +29,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool= is: -v - verbose build mode @@ -354,13 +355,12 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ -DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ - -DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \ + -DRAFT_COMPILE_LIBRARY=${COMPILE_LIBRARY} \ -DRAFT_NVTX=${NVTX} \ -DDISABLE_DEPRECATION_WARNINGS=${DISABLE_DEPRECATION_WARNINGS} \ -DBUILD_TESTS=${BUILD_TESTS} \ -DBUILD_BENCH=${BUILD_BENCH} \ -DCMAKE_MESSAGE_LOG_LEVEL=${CMAKE_LOG_LEVEL} \ - -DRAFT_COMPILE_LIBRARY=${COMPILE_LIBRARY} \ ${CACHE_ARGS} \ ${EXTRA_CMAKE_ARGS} @@ -410,3 +410,12 @@ if hasArg docs; then cd ${SPHINX_BUILD_DIR} sphinx-build -b html source _html fi + +################################################################################ +# Initiate build for example RAFT application template (if needed) + +if hasArg template; then + pushd cpp/template + ./build.sh + popd +fi diff --git a/conda/recipes/libraft/build_libraft_template.sh b/conda/recipes/libraft/build_libraft_template.sh new file mode 100644 index 0000000000..9759402884 --- /dev/null +++ b/conda/recipes/libraft/build_libraft_template.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +# Just building template so we verify it uses libraft.so and fail if it doesn't build +./build.sh template \ No newline at end of file diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index 2a724672ab..f911166a9a 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -150,3 +150,39 @@ outputs: home: https://rapids.ai/ license: Apache-2.0 summary: libraft tests + - name: libraft-template + version: {{ version }} + script: build_libraft_template.sh + build: + script_env: *script_env + number: {{ GIT_DESCRIBE_NUMBER }} + string: cuda{{ cuda_major }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} + ignore_run_exports_from: + - {{ compiler('cuda') }} + requirements: + build: + - {{ compiler('c') }} + - {{ compiler('cuda') }} {{ cuda_version }} + - {{ compiler('cxx') }} + - cmake {{ cmake_version }} + - ninja + - sysroot_{{ target_platform }} {{ sysroot_version }} + host: + - {{ pin_subpackage('libraft', exact=True) }} + - {{ pin_subpackage('libraft-headers', exact=True) }} + - cuda-profiler-api {{ cuda_profiler_api_host_version }} + - libcublas {{ libcublas_host_version }} + - libcublas-dev {{ libcublas_host_version }} + - libcurand {{ libcurand_host_version }} + - libcurand-dev {{ libcurand_host_version }} + - libcusolver {{ libcusolver_host_version }} + - libcusolver-dev {{ libcusolver_host_version }} + - libcusparse {{ libcusparse_host_version }} + - libcusparse-dev {{ libcusparse_host_version }} + run: + - {{ pin_subpackage('libraft', exact=True) }} + - {{ pin_subpackage('libraft-headers', exact=True) }} + about: + home: https://rapids.ai/ + license: Apache-2.0 + summary: libraft template diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt new file mode 100644 index 0000000000..501a5c9503 --- /dev/null +++ b/cpp/template/CMakeLists.txt @@ -0,0 +1,38 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) + +# ------------- configure rapids-cmake --------------# + +include(cmake/thirdparty/fetch_rapids.cmake) +include(rapids-cmake) +include(rapids-cpm) +include(rapids-cuda) +include(rapids-export) +include(rapids-find) + +# ------------- configure project --------------# + +rapids_cuda_init_architectures(test_raft) + +project(test_raft LANGUAGES CXX CUDA) + +# ------------- configure raft -----------------# + +rapids_cpm_init() +include(cmake/thirdparty/get_raft.cmake) + +# -------------- compile tasks ----------------- # +add_executable(TEST_RAFT src/test_distance.cu) +target_link_libraries(TEST_RAFT PRIVATE raft::raft raft::compiled) diff --git a/cpp/template/README.md b/cpp/template/README.md new file mode 100644 index 0000000000..348dff270a --- /dev/null +++ b/cpp/template/README.md @@ -0,0 +1,18 @@ +# Example RAFT Project Template + +This template project provides a drop-in sample to either start building a new application with, or using RAFT in an existing CMake project. + +First, please refer to our [installation docs](https://docs.rapids.ai/api/raft/stable/build.html#cuda-gpu-requirements) for the minimum requirements to use RAFT. + +Once the minimum requirements are satisfied, this example template application can be built with the provided `build.sh` script. This is a bash script that calls the appropriate CMake commands, so you can look into it to see the typical CMake based build workflow. + +This directory (`RAFT_SOURCE/cpp/template`) can be copied directly in order to build a new application with RAFT. + +RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. + +Make sure to link against the appropriate Cmake targets. Use `raft::raft`to add make the headers available and `raft::compiled` when utilizing the shared library. + +```cmake +target_link_libraries(your_app_target PRIVATE raft::raft raft::compiled) +``` + diff --git a/cpp/template/build.sh b/cpp/template/build.sh new file mode 100755 index 0000000000..3ac00fc9af --- /dev/null +++ b/cpp/template/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Copyright (c) 2023, NVIDIA CORPORATION. + +# raft empty project template build script + +# Abort script on first error +set -e + +PARALLEL_LEVEL=${PARALLEL_LEVEL:=`nproc`} + +BUILD_TYPE=Release +BUILD_DIR=build/ + +RAFT_REPO_REL="" +EXTRA_CMAKE_ARGS="" +set -e + + +if [[ ${RAFT_REPO_REL} != "" ]]; then + RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`" + EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}" +fi + +if [ "$1" == "clean" ]; then + rm -rf build + exit 0 +fi + +mkdir -p $BUILD_DIR +cd $BUILD_DIR + +cmake \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DRAFT_NVTX=OFF \ + -DCMAKE_CUDA_ARCHITECTURES="NATIVE" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + ${EXTRA_CMAKE_ARGS} \ + ../ + +cmake --build . -j${PARALLEL_LEVEL} diff --git a/cpp/template/cmake/thirdparty/fetch_rapids.cmake b/cpp/template/cmake/thirdparty/fetch_rapids.cmake new file mode 100644 index 0000000000..40ba83be9e --- /dev/null +++ b/cpp/template/cmake/thirdparty/fetch_rapids.cmake @@ -0,0 +1,21 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +# Use this variable to update RAPIDS and RAFT versions +set(RAPIDS_VERSION "23.04") + +if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) +endif() +include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) diff --git a/cpp/template/cmake/thirdparty/get_raft.cmake b/cpp/template/cmake/thirdparty/get_raft.cmake new file mode 100644 index 0000000000..5463942adf --- /dev/null +++ b/cpp/template/cmake/thirdparty/get_raft.cmake @@ -0,0 +1,62 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake +set(RAFT_VERSION "${RAPIDS_VERSION}") +set(RAFT_FORK "rapidsai") +set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + +function(find_and_configure_raft) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + set(RAFT_COMPONENTS "") + if(PKG_COMPILE_LIBRARY) + string(APPEND RAFT_COMPONENTS " compiled") + endif() + + if(PKG_ENABLE_MNMG_DEPENDENCIES) + string(APPEND RAFT_COMPONENTS " distributed") + endif() + + #----------------------------------------------------- + # Invoke CPM find_package() + #----------------------------------------------------- + rapids_cpm_find(raft ${PKG_VERSION} + GLOBAL_TARGETS raft::raft + BUILD_EXPORT_SET raft-template-exports + INSTALL_EXPORT_SET raft-template-exports + COMPONENTS ${RAFT_COMPONENTS} + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + OPTIONS + "BUILD_TESTS OFF" + "BUILD_BENCH OFF" + "RAFT_NVTX ${ENABLE_NVTX}" + "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" + ) +endfunction() + +# Change pinned tag here to test a commit in CI +# To use a different RAFT locally, set the CMake variable +# CPM_raft_SOURCE=/path/to/local/raft +find_and_configure_raft(VERSION ${RAFT_VERSION}.00 + FORK ${RAFT_FORK} + PINNED_TAG ${RAFT_PINNED_TAG} + COMPILE_LIBRARY ON + ENABLE_MNMG_DEPENDENCIES OFF + ENABLE_NVTX OFF +) diff --git a/cpp/template/src/test_distance.cu b/cpp/template/src/test_distance.cu new file mode 100644 index 0000000000..b86dde70e5 --- /dev/null +++ b/cpp/template/src/test_distance.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#ifdef RAFT_COMPILED +#include +#endif + +int main() +{ + raft::device_resources handle; + + int n_samples = 5000; + int n_features = 50; + + auto input = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + auto output = raft::make_device_matrix(handle, n_samples, n_samples); + + raft::random::make_blobs(handle, input.view(), labels.view()); + + auto metric = raft::distance::DistanceType::L2SqrtExpanded; + raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); +} diff --git a/docs/source/build.md b/docs/source/build.md index bbb454736a..e8e6ac8a14 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -75,7 +75,7 @@ Once installed, `libraft` headers (and dependencies which were downloaded and in ``` -### C++ Shared Libraries (optional) +### C++ Shared Library (optional) A shared library can be built for speeding up compile times. The shared library also contains a runtime API that allows you to invoke RAFT APIs directly from C++ source files (without `nvcc`). The shared library can also significantly improve re-compile times both while developing RAFT and using its APIs to develop applications. Pass the `--compile-lib` flag to `build.sh` to build the library: ```bash @@ -109,7 +109,7 @@ Compile the tests using the `tests` target in `build.sh`. Test compile times can be improved significantly by using the optional shared libraries. If installed, they will be used automatically when building the tests but `--compile-libs` can be used to add additional compilation units and compile them with the tests. ```bash -./build.sh libraft tests --compile-libs +./build.sh libraft tests --compile-lib ``` The tests are broken apart by algorithm category, so you will find several binaries in `cpp/build/` named `*_TEST`. @@ -151,19 +151,17 @@ make -j install RAFT's cmake has the following configurable flags available:. -| Flag | Possible Values | Default Value | Behavior | -| --- | --- | --- | --- | -| BUILD_TESTS | ON, OFF | ON | Compile Googletests | -| BUILD_BENCH | ON, OFF | OFF | Compile benchmarks | -| raft_FIND_COMPONENTS | nn distance | | Configures the optional components as a space-separated list | -| RAFT_COMPILE_LIBRARIES | ON, OFF | ON if either BUILD_TESTS or BUILD_BENCH is ON; otherwise OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | -| RAFT_COMPILE_NN_LIBRARY | ON, OFF | OFF | Compiles the `libraft-nn` shared library | -| RAFT_COMPILE_DIST_LIBRARY | ON, OFF | OFF | Compiles the `libraft-distance` shared library | -| DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | -| RAFT_NVTX | ON, OFF | OFF | Enable NVTX Markers | -| CUDA_ENABLE_KERNELINFO | ON, OFF | OFF | Enables `kernelinfo` in nvcc. This is useful for `compute-sanitizer` | -| CUDA_ENABLE_LINEINFO | ON, OFF | OFF | Enable the -lineinfo option for nvcc | -| CUDA_STATIC_RUNTIME | ON, OFF | OFF | Statically link the CUDA runtime | +| Flag | Possible Values | Default Value | Behavior | +|---------------------------|----------------------| --- | --- | +| BUILD_TESTS | ON, OFF | ON | Compile Googletests | +| BUILD_BENCH | ON, OFF | OFF | Compile benchmarks | +| raft_FIND_COMPONENTS | compiled distributed | | Configures the optional components as a space-separated list | +| RAFT_COMPILE_LIBRARY | ON, OFF | ON if either BUILD_TESTS or BUILD_BENCH is ON; otherwise OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | +| DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | +| RAFT_NVTX | ON, OFF | OFF | Enable NVTX Markers | +| CUDA_ENABLE_KERNELINFO | ON, OFF | OFF | Enables `kernelinfo` in nvcc. This is useful for `compute-sanitizer` | +| CUDA_ENABLE_LINEINFO | ON, OFF | OFF | Enable the -lineinfo option for nvcc | +| CUDA_STATIC_RUNTIME | ON, OFF | OFF | Statically link the CUDA runtime | Currently, shared libraries are provided for the `libraft-nn` and `libraft-distance` components. @@ -248,9 +246,9 @@ PROPERTIES CXX_STANDARD 17 ``` -### C++ header-only integration +### C++ header-only integration (without cmake) -When the needed [build dependencies](#build-dependencies) are already satisfied, RAFT can be trivially integrated into downstream projects by cloning the repository and adding `cpp/include` from RAFT to the include path: +While not a highly suggested method for building against RAFT, when all of the needed [build dependencies](#build-dependencies) are already satisfied, RAFT can be integrated into downstream projects by cloning the repository and adding `cpp/include` from RAFT to the include path: ```cmake set(RAFT_GIT_DIR ${CMAKE_CURRENT_BINARY_DIR}/raft CACHE STRING "Path to RAFT repo") ExternalProject_Add(raft @@ -262,8 +260,12 @@ ExternalProject_Add(raft INSTALL_COMMAND "") set(RAFT_INCLUDE_DIR ${RAFT_GIT_DIR}/raft/cpp/include CACHE STRING "RAFT include variable") ``` +### C++ header-only integration (with cmake) -If RAFT has already been installed, such as by using the `build.sh` script, use `find_package(raft)` and the `raft::raft` target. + +When using cmake, you can install RAFT headers into your environment with `./build.sh libraft`. + +If the RAFT headers have already been installed into your environment with cmake or through conda, such as by using the `build.sh` script, use `find_package(raft)` and the `raft::raft` target. ### Using C++ pre-compiled shared libraries @@ -271,17 +273,19 @@ Use `find_package(raft COMPONENTS compiled distributed)` to enable the shared li The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. -The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead: +The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead. RAFT's cmake creates a variable `RAFT_COMPILED` which can be used to ignore the pre-compiled template specializations only when the shared library has been enabled through cmake (such as by specifying the `compiled` component in `find_package`): ```c++ +#ifdef RAFT_COMPILED #include #include +#endif ``` ### Building RAFT C++ from source in cmake RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library so it can be more easily included into downstream projects. RAPIDS cmake provides a convenience layer around the [CMake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake). -The following example is similar to invoking `find_package(raft)` but uses `rapids_cpm_find`, which provides a richer and more flexible configuration landscape by using CPM to fetch any dependencies not already available to the build. The `raft::raft` link target will be made available and it's recommended that it be used as a `PRIVATE` link dependency in downstream projects. The `COMPILE_LIBRARIES` option enables the building the shared libraries. +The following example is similar to invoking `find_package(raft)` but uses `rapids_cpm_find`, which provides a richer and more flexible configuration landscape by using CPM to fetch any dependencies not already available to the build. The `raft::raft` link target will be made available and it's recommended that it be used as a `PRIVATE` link dependency in downstream projects. The `COMPILE_LIBRARY` option enables the building the shared libraries. The following `cmake` snippet enables a flexible configuration of RAFT: @@ -292,19 +296,10 @@ set(RAFT_FORK "rapidsai") set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY CLONE_ON_PIN) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - - #----------------------------------------------------- - # Clone RAFT locally if PINNED_TAG has been changed - #----------------------------------------------------- - if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${RAFT_VERSION}") - message("Pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") - set(CPM_DOWNLOAD_raft ON) - set(CMAKE_IGNORE_PATH "${CMAKE_INSTALL_PREFIX}/include/raft;${CMAKE_IGNORE_PATH}) - endif() - + #----------------------------------------------------- # Invoke CPM find_package() #----------------------------------------------------- @@ -332,15 +327,12 @@ endfunction() find_and_configure_raft(VERSION ${RAFT_VERSION}.00 FORK ${RAFT_FORK} PINNED_TAG ${RAFT_PINNED_TAG} - - # When PINNED_TAG above doesn't match cuml, - # force local raft clone in build directory - # even if it's already installed. - CLONE_ON_PIN ON COMPILE_LIBRARY NO ) ``` +You can find a fully-functioning [example template project](../../cpp/template/README.md) in the `cpp/template` directory, which provides everything you need to build a new application with RAFT or incorporate RAFT Into your existing libraries. + ## Uninstall Once built and installed, RAFT can be safely uninstalled using `build.sh` by specifying any or all of the installed components. Please note that since `pylibraft` depends on `libraft`, uninstalling `pylibraft` will also uninstall `libraft`: diff --git a/python/pylibraft/setup.cfg b/python/pylibraft/setup.cfg new file mode 100644 index 0000000000..7d1a0c9065 --- /dev/null +++ b/python/pylibraft/setup.cfg @@ -0,0 +1,38 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +[isort] +line_length=79 +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +combine_as_imports=True +order_by_type=True +known_dask= + dask + distributed + dask_cuda +known_rapids= + nvtext + cudf + cuml + cugraph + dask_cudf + rmm +known_first_party= + raft + pylibraft +default_section=THIRDPARTY +sections=FUTURE,STDLIB,THIRDPARTY,DASK,RAPIDS,FIRSTPARTY,LOCALFOLDER +skip= + thirdparty + .eggs + .git + .hg + .mypy_cache + .tox + .venv + _build + buck-out + build + dist + __init__.py diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..e64641d05b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,55 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +[flake8] +filename = *.py, *.pyx, *.pxd, *.pxi +exclude = __init__.py, *.egg, build, docs, .git +force-check = True +ignore = + # line break before binary operator + W503, + # whitespace before : + E203 +per-file-ignores = + # Rules ignored only in Cython: + # E211: whitespace before '(' (used in multi-line imports) + # E225: Missing whitespace around operators (breaks cython casting syntax like ) + # E226: Missing whitespace around arithmetic operators (breaks cython pointer syntax like int*) + # E227: Missing whitespace around bitwise or shift operator (Can also break casting syntax) + # E275: Missing whitespace after keyword (Doesn't work with Cython except?) + # E402: invalid syntax (works for Python, not Cython) + # E999: invalid syntax (works for Python, not Cython) + # W504: line break after binary operator (breaks lines that end with a pointer) + *.pyx: E211, E225, E226, E227, E275, E402, E999, W504 + *.pxd: E211, E225, E226, E227, E275, E402, E999, W504 + *.pxi: E211, E225, E226, E227, E275, E402, E999, W504 + +[pydocstyle] +# Due to https://github.com/PyCQA/pydocstyle/issues/363, we must exclude rather +# than include using match-dir. Note that as discussed in +# https://stackoverflow.com/questions/65478393/how-to-filter-directories-using-the-match-dir-flag-for-pydocstyle, +# unlike the match option above this match-dir will have no effect when +# pydocstyle is invoked from pre-commit. Therefore this exclusion list must +# also be maintained in the pre-commit config file. +match-dir = ^(?!(ci|cpp|conda|docs|java|notebooks)).*$ +# Allow missing docstrings for docutils +ignore-decorators = .*(docutils|doc_apply|copy_docstring).* +select = + D201, D204, D206, D207, D208, D209, D210, D211, D214, D215, D300, D301, D302, D403, D405, D406, D407, D408, D409, D410, D411, D412, D414, D418 + # Would like to enable the following rules in the future: + # D200, D202, D205, D400 + +[mypy] +ignore_missing_imports = True +# If we don't specify this, then mypy will check excluded files if +# they are imported by a checked file. +follow_imports = skip + +[codespell] +# note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - +# this is only to allow you to run codespell interactively +skip = ./.git,./.github,./cpp/build,.*egg-info.*,./.mypy_cache,.*_skbuild +# ignore short words, and typename parameters like OffsetT +ignore-regex = \b(.{1,4}|[A-Z]\w*T)\b +ignore-words-list = inout,unparseable,numer +builtin = clear +quiet-level = 3 From f4c7f1f94e135c511de67459ffac3d9d6e9f9794 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 24 Mar 2023 17:40:46 -0400 Subject: [PATCH 11/13] Adding some functions back in that seem to be a copy/paste error (#1373) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1373 --- cpp/include/raft/util/cuda_dev_essentials.cuh | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cpp/include/raft/util/cuda_dev_essentials.cuh b/cpp/include/raft/util/cuda_dev_essentials.cuh index 5080dc33ee..bb9ebbba59 100644 --- a/cpp/include/raft/util/cuda_dev_essentials.cuh +++ b/cpp/include/raft/util/cuda_dev_essentials.cuh @@ -88,4 +88,30 @@ DI int laneId() return id; } +/** Device function to apply the input lambda across threads in the grid */ +template +DI void forEach(int num, L lambda) +{ + int idx = (blockDim.x * blockIdx.x) + threadIdx.x; + const int numThreads = blockDim.x * gridDim.x; +#pragma unroll + for (int itr = 0; itr < ItemsPerThread; ++itr, idx += numThreads) { + if (idx < num) lambda(idx, itr); + } +} + +/** + * @brief Swap two values + * @tparam T the datatype of the values + * @param a first input + * @param b second input + */ +template +HDI void swapVals(T& a, T& b) +{ + T tmp = a; + a = b; + b = tmp; +} + } // namespace raft From 76c828dd4da4dc922626ba2a440a46dea6ab03b9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Sat, 25 Mar 2023 20:09:35 +0100 Subject: [PATCH 12/13] Add extern template for ivfflat_interleaved_scan (#1360) This should cut compilation time for refine_d_int64_t_float.cu.o et al from ~900 seconds to 29 seconds. The refine specialization contain >100 instances of the ivfflat_interleaved_scan kernel, even though these should be seperately compiled by the ivfflat_search specializations. The call to ivf_flat_interleaved_scan is [here](https://github.com/rapidsai/raft/blob/56ac43ad93a319a61073dce1b3b937f6f13ade63/cpp/include/raft/neighbors/detail/refine.cuh#L121). Depends on (so please merge after) PR #1307. Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1360 --- .../raft/neighbors/detail/ivf_flat_search.cuh | 8 ++++ cpp/include/raft/neighbors/detail/refine.cuh | 8 ++++ .../neighbors/specializations/ivf_flat.cuh | 25 ++++++++++++- .../ivfflat_search_float_int64_t.cu | 37 ++++++++++++++++--- .../ivfflat_search_int8_t_int64_t.cu | 28 +++++++++++--- .../ivfflat_search_uint8_t_int64_t.cu | 28 +++++++++++--- 6 files changed, 115 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index f657070df4..e6533eaf51 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -1065,6 +1065,14 @@ void ivfflat_interleaved_scan(const index& index, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // related function calls after editing this function definition. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. + const int capacity = bound_by_power_of_two(k); select_interleaved_scan_kernel::run(capacity, index.veclen(), diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index f244d5875c..aedfc42698 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -117,6 +117,14 @@ void refine_device(raft::device_resources const& handle, n_queries, n_candidates); + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // and adjust the extern template definition and the instantiation when the + // below function call is edited. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 013c7359e5..161f3462c9 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -20,6 +20,13 @@ namespace raft::neighbors::ivf_flat { +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided here. Please check related function +// calls after editing template definition below. Search for +// `greppable-id-specializations-ivf-flat-search` to find them. #define RAFT_INST(T, IdxT) \ extern template auto build(raft::device_resources const& handle, \ const index_params& params, \ @@ -44,7 +51,23 @@ namespace raft::neighbors::ivf_flat { const raft::neighbors::ivf_flat::index&, \ raft::device_matrix_view, \ raft::device_matrix_view, \ - raft::device_matrix_view); + raft::device_matrix_view); \ + \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); RAFT_INST(float, int64_t); RAFT_INST(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu index 6de65546c8..dce7083139 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu @@ -18,12 +18,37 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided. To make sure +// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate +// it below. Please check related function calls after editing template +// definition below. Search for `greppable-id-specializations-ivf-flat-search` +// to find them. +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu index 8eda240ccd..b03d878bae 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu index 8ff6533628..2d42bae0d1 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(uint8_t, int64_t); From 0d3bd3da5a2eb77a5ca2f7f9b9ed367030811c06 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 28 Mar 2023 01:00:57 -0700 Subject: [PATCH 13/13] add a distance epilogue function to the bfknn call (#1371) Add the ability for a user to specify an epilogue function to run after the distance in the brute_force::knn call. This lets us remove faiss from cuml, by updating the hdbscan reachability code (https://github.com/rapidsai/cuml/pull/5293) Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1371 --- cpp/include/raft/neighbors/brute_force.cuh | 32 ++++++------ .../raft/neighbors/detail/knn_brute_force.cuh | 51 ++++++++++++++----- .../neighbors/specializations/brute_force.cuh | 3 +- .../raft/spatial/knn/detail/ball_cover.cuh | 1 - .../brute_force_knn_int64_t_float.cu | 11 +--- .../brute_force_knn_impl_long_float_int.cu | 3 +- .../brute_force_knn_impl_long_float_uint.cu | 3 +- .../brute_force_knn_impl_uint_float_int.cu | 3 +- .../brute_force_knn_impl_uint_float_uint.cu | 3 +- cpp/test/neighbors/ball_cover.cu | 1 - cpp/test/neighbors/knn.cu | 2 +- 11 files changed, 69 insertions(+), 44 deletions(-) diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4891cc5f8d..dac1a29c7f 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -122,9 +122,8 @@ inline void knn_merge_parts( * * raft::raft::device_resources handle; * ... - * int k = 10; * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, k, metric); + * brute_force::knn(handle, index, search, indices, distances, metric); * @endcode * * @param[in] handle: the cuml handle to use @@ -132,28 +131,31 @@ inline void knn_merge_parts( * @param[in] search: matrix (size n*d) to be used for searching the index * @param[out] indices: matrix (size n*k) to store output knn indices * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] k: the number of nearest neighbors to return * @param[in] metric: distance metric to use. Euclidean (L2) is used by default * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. * @param[in] global_id_offset: optional starting global id mapping for the local partition * (assumes the index contains contiguous ids in the global id space) + * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This + function takes a triple of the (value, rowid, colid) for each + element in the pairwise distances and returns a transformed value + back. */ template + typename search_layout, + typename epilogue_op = raft::identity_op> void knn(raft::device_resources const& handle, std::vector> index, raft::device_matrix_view search, raft::device_matrix_view indices, raft::device_matrix_view distances, - value_int k, distance::DistanceType metric = distance::DistanceType::L2Unexpanded, std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt) + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) { RAFT_EXPECTS(index[0].extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); @@ -161,15 +163,14 @@ void knn(raft::device_resources const& handle, RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), "Number of rows in output indices and distances matrices must equal number of rows " "in search matrix."); - RAFT_EXPECTS( - indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), + "Number of columns in output indices and distances matrices must the same"); bool rowMajorIndex = std::is_same_v; bool rowMajorQuery = std::is_same_v; std::vector inputs; - std::vector sizes; + std::vector sizes; for (std::size_t i = 0; i < index.size(); ++i) { inputs.push_back(const_cast(index[i].data_handle())); sizes.push_back(index[i].extent(0)); @@ -183,18 +184,19 @@ void knn(raft::device_resources const& handle, raft::neighbors::detail::brute_force_knn_impl(handle, inputs, sizes, - static_cast(index[0].extent(1)), + index[0].extent(1), // TODO: This is unfortunate. Need to fix. const_cast(search.data_handle()), - static_cast(search.extent(0)), + search.extent(0), indices.data_handle(), distances.data_handle(), - k, + indices.extent(1), rowMajorIndex, rowMajorQuery, trans_arg, metric, - metric_arg.value_or(2.0f)); + metric_arg.value_or(2.0f), + distance_epilogue); } /** diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 875fc3b37c..a776ce2586 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -47,7 +47,9 @@ using namespace raft::spatial::knn; * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::device_resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 0.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -152,25 +155,41 @@ void tiled_brute_force_knn(const raft::device_resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data() + i; - auto col_norms = index_norms.data() + j; + auto row_norms = search_norms.data(); + auto col_norms = index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( handle, raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType i) { - IndexType row = i / current_centroid_size, col = i % current_centroid_size; + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); - auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx]; // due to numerical instability (especially around self-distance) // the distances here could be slightly negative, which will // cause NaN values in the subsequent sqrt. Clamp to 0 val = val * (val >= 0.0001); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } + val = distance_epilogue(val, row, col); return val; }); + } else { + // if we're not l2 distance, and we have a distance epilogue - run it now + if constexpr (!std::is_same_v) { + auto distances_ptr = temp_distances.data(); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(temp_distances.data(), + current_query_size * current_centroid_size), + [=] __device__(size_t idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + return distance_epilogue(distances_ptr[idx], row, col); + }); + } } select_k(temp_distances.data(), @@ -250,7 +269,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm */ -template +template void brute_force_knn_impl( raft::device_resources const& handle, std::vector& input, @@ -265,7 +287,8 @@ void brute_force_knn_impl( bool rowMajorQuery = true, std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) + float metricArg = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { auto userStream = handle.get_stream(); @@ -355,6 +378,7 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + std::is_same_v && (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || @@ -424,7 +448,10 @@ void brute_force_knn_impl( out_d_ptr, out_i_ptr, tiled_metric, - metricArg); + metricArg, + 0, + 0, + distance_epilogue); break; } } diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index d418d40185..1337beb68a 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -36,7 +36,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); RAFT_INST(long, float, unsigned int); RAFT_INST(uint32_t, float, int); diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 99d688e232..c8fc6eefda 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -185,7 +185,6 @@ void k_closest_landmarks(raft::device_resources const& handle, make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)), make_device_matrix_view(R_knn_inds, n_query_pts, k), make_device_matrix_view(R_knn_dists, n_query_pts, k), - k, index.get_metric()); } diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu index 585084fc97..88545b3607 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu @@ -38,15 +38,8 @@ namespace raft::runtime::neighbors::brute_force { { \ std::vector> vec; \ vec.push_back(index); \ - raft::neighbors::brute_force::knn(handle, \ - vec, \ - search, \ - indices, \ - distances, \ - static_cast(distances.extent(1)), \ - metric, \ - metric_arg, \ - global_id_offset); \ + raft::neighbors::brute_force::knn( \ + handle, vec, search, indices, distances, metric, metric_arg, global_id_offset); \ } RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu index 07810aa576..04aa42c9f1 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu index 0cb873b40a..a8b9d4299a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu index f8a69b896f..c97e6e936a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu index 3c23d1f3e0..87451c385a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 9b51d585de..46ef3a9150 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -121,7 +121,6 @@ void compute_bfknn(const raft::device_resources& handle, make_device_matrix_view(X2, n_query_rows, d), make_device_matrix_view(inds, n_query_rows, k), make_device_matrix_view(dists, n_query_rows, k), - k, metric); } diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index 4bb977432c..bcd4b9cb0b 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -96,7 +96,7 @@ class KNNTest : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_.data(), rows_, k_); auto metric = raft::distance::DistanceType::L2Unexpanded; - knn(handle, index, search, indices, distances, k_, metric, std::make_optional(0)); + knn(handle, index, search, indices, distances, metric, std::make_optional(0)); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data());