diff --git a/cpp/include/raft/cache/cache_util.cuh b/cpp/include/raft/cache/cache_util.cuh new file mode 100644 index 0000000000..ce8ef9a095 --- /dev/null +++ b/cpp/include/raft/cache/cache_util.cuh @@ -0,0 +1,351 @@ +/* + * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { +namespace cache { + +/** + * @brief Collect vectors of data from the cache into a contiguous memory buffer. + * + * We assume contiguous memory layout for the output buffer, i.e. we get + * column vectors into a column major out buffer, or row vectors into a row + * major output buffer. + * + * On exit, the output array is filled the following way: + * out[i + n_vec*k] = cache[i + n_vec * cache_idx[k]]), where i=0..n_vec-1, and + * k = 0..n-1 where cache_idx[k] >= 0 + * + * We ignore vectors where cache_idx[k] < 0. + * + * @param [in] cache stores the cached data, size [n_vec x n_cached_vectors] + * @param [in] n_vec number of elements in a cached vector + * @param [in] cache_idx cache indices, size [n] + * @param [in] n the number of elements that need to be collected + * @param [out] out vectors collected from the cache, size [n_vec * n] + */ +template +__global__ void get_vecs(const math_t *cache, int n_vec, const int *cache_idx, + int n, math_t *out) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int row = tid % n_vec; // row idx + if (tid < n_vec * n) { + size_t out_col = tid / n_vec; // col idx + size_t cache_col = cache_idx[out_col]; + if (cache_idx[out_col] >= 0) { + if (row + out_col * n_vec < (size_t)n_vec * n) { + out[tid] = cache[row + cache_col * n_vec]; + } + } + } +} + +/** + * @brief Store vectors of data into the cache. + * + * Elements within a vector should be contiguous in memory (i.e. column vectors + * for column major data storage, or row vectors of row major data). + * + * If tile_idx==nullptr then the operation is the opposite of get_vecs, + * i.e. we store + * cache[i + cache_idx[k]*n_vec] = tile[i + k*n_vec], for i=0..n_vec-1, k=0..n-1 + * + * If tile_idx != nullptr, then we permute the vectors from tile according + * to tile_idx. This allows to store vectors from a buffer where the individual + * vectors are not stored contiguously (but the elements of each vector shall + * be contiguous): + * cache[i + cache_idx[k]*n_vec] = tile[i + tile_idx[k]*n_vec], + * for i=0..n_vec-1, k=0..n-1 + * + * @param [in] tile stores the data to be cashed cached, size [n_vec x n_tile] + * @param [in] n_tile number of vectors in the input tile + * @param [in] n_vec number of elements in a cached vector + * @param [in] tile_idx indices of vectors that we want to store + * @param [in] n number of vectos that we want to store (n <= n_tile) + * @param [in] cache_idx cache indices, size [n], negative values are ignored + * @param [inout] cache updated cache + * @param [in] n_cache_vecs + */ +template +__global__ void store_vecs(const math_t *tile, int n_tile, int n_vec, + const int *tile_idx, int n, const int *cache_idx, + math_t *cache, int n_cache_vecs) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int row = tid % n_vec; // row idx + if (tid < n_vec * n) { + int tile_col = tid / n_vec; // col idx + int data_col = tile_idx ? tile_idx[tile_col] : tile_col; + int cache_col = cache_idx[tile_col]; + + // We ignore negative values. The rest of the checks should be fulfilled + // if the cache is used properly + if (cache_col >= 0 && cache_col < n_cache_vecs && data_col < n_tile) { + cache[row + (size_t)cache_col * n_vec] = + tile[row + (size_t)data_col * n_vec]; + } + } +} + +/** + * @brief Map a key to a cache set. + * + * @param key key to be hashed + * @param n_cache_set number of cache sets + * @return index of the cache set [0..n_cache_set) + */ +int DI hash(int key, int n_cache_sets) { return key % n_cache_sets; } + +/** + * @brief Binary search to find the first element in the array which is greater + * equal than a given value. + * @param [in] array sorted array of n numbers + * @param [in] n length of the array + * @param [in] val the value to search for + * @return the index of the first element in the array for which + * array[idx] >= value. If there is no such value, then return n. + */ +int DI arg_first_ge(const int *array, int n, int val) { + int start = 0; + int end = n - 1; + if (array[0] == val) return 0; + if (array[end] < val) return n; + while (start + 1 < end) { + int q = (start + end + 1) / 2; + //invariants: + // start < end + // start < q <=end + // array[start] < val && array[end] <=val + // at every iteration d = end-start is decreasing + // when d==0, then array[end] will be the first element >= val. + if (array[q] >= val) { + end = q; + } else { + start = q; + } + } + return end; +} +/** + * @brief Find the k-th occurrence of value in a sorted array. + * + * Assume that array is [0, 1, 1, 1, 2, 2, 4, 4, 4, 4, 6, 7] + * then find_nth_occurrence(cset, 12, 4, 2) == 7, because cset_array[7] stores + * the second element with value = 4. + * If there are less than k values in the array, then return -1 + * + * @param [in] array sorted array of numbers, size [n] + * @param [in] n number of elements in the array + * @param [in] val the value we are searching for + * @param [in] k + * @return the idx of the k-th occurance of val in array, or -1 if + * the value is not found. + */ +int DI find_nth_occurrence(const int *array, int n, int val, int k) { + int q = arg_first_ge(array, n, val); + if (q + k < n && array[q + k] == val) { + q += k; + } else { + q = -1; + } + return q; +} + +/** + * @brief Rank the entries in a cache set according to the time stamp, return + * the indices that would sort the time stamp in ascending order. + * + * Assume we have a single cache set with time stamps as: + * key (threadIdx.x): 0 1 2 3 + * val (time stamp): 8 6 7 5 + * + * The corresponding sorted key-value pairs: + * key: 3 1 2 0 + * val: 5 6 7 8 + * rank: 0th 1st 2nd 3rd + * + * On return, the rank is assigned for each thread: + * threadIdx.x: 0 1 2 3 + * rank: 3 1 2 0 + * + * For multiple cache sets, launch one block per cache set. + * + * @tparam nthreads number of threads per block (nthreads <= associativity) + * @tparam associativity number of items in a cache set + * + * @param [in] cache_time time stamp of caching the data, + size [associativity * n_cache_sets] + * @param [in] n_cache_sets number of cache sets + * @param [out] rank within the cache set size [nthreads * items_per_thread] + * Each block should give a different pointer for rank. + */ +template +DI void rank_set_entries(const int *cache_time, int n_cache_sets, int *rank) { + const int items_per_thread = raft::ceildiv(associativity, nthreads); + typedef cub::BlockRadixSort + BlockRadixSort; + __shared__ typename BlockRadixSort::TempStorage temp_storage; + + int key[items_per_thread]; + int val[items_per_thread]; + + int block_offset = blockIdx.x * associativity; + + for (int j = 0; j < items_per_thread; j++) { + int k = threadIdx.x + j * nthreads; + int t = (k < associativity) ? cache_time[block_offset + k] : 32768; + key[j] = t; + val[j] = k; + } + + BlockRadixSort(temp_storage).Sort(key, val); + + for (int j = 0; j < items_per_thread; j++) { + if (val[j] < associativity) { + rank[val[j]] = threadIdx.x * items_per_thread + j; + } + } + __syncthreads(); +} + +/** + * @brief Assign cache location to a set of keys using LRU replacement policy. + * + * The keys and the corresponding cache_set arrays shall be sorted according + * to cache_set in ascending order. One block should be launched for every cache + * set. + * + * Each cache set is sorted according to time_stamp, and values from keys + * are filled in starting at the oldest time stamp. Entries that were accessed + * at the current time are not reassigned. + * + * @tparam nthreads number of threads per block + * @tparam associativity number of keys in a cache set + * + * @param [in] keys that we want to cache size [n] + * @param [in] n number of keys + * @param [in] cache_set assigned to keys, size [n] + * @param [inout] cached_keys keys of already cached vectors, + * size [n_cache_sets*associativity], on exit it will be updated with the + * cached elements from keys. + * @param [in] n_cache_sets number of cache sets + * @param [inout] cache_time will be updated to "time" for those elements that + * could be assigned to a cache location, size [n_cache_sets*associativity] + * @param [in] time time stamp + * @param [out] cache_idx the cache idx assigned to the input, or -1 if it could + * not be cached, size [n] + */ +template +__global__ void assign_cache_idx(const int *keys, int n, const int *cache_set, + int *cached_keys, int n_cache_sets, + int *cache_time, int time, int *cache_idx) { + int block_offset = blockIdx.x * associativity; + + const int items_per_thread = raft::ceildiv(associativity, nthreads); + + // the size of rank limits how large associativity can be used in practice + __shared__ int rank[items_per_thread * nthreads]; + rank_set_entries(cache_time, n_cache_sets, rank); + + // Each thread will fill items_per_thread items in the cache. + // It uses a place, only if it was not updated at the current time step + // (cache_time != time). + // We rank the places according to the time stamp, least recently used + // elements come to the front. + // We fill the least recently used elements with the working set. + // there might be elements which cannot be assigned to cache loc. + // these elements are assigned -1. + + for (int j = 0; j < items_per_thread; j++) { + int i = threadIdx.x + j * nthreads; + int t_idx = block_offset + i; + bool mask = (i < associativity); + // whether this slot is available for writing + mask = mask && (cache_time[t_idx] != time); + + // rank[i] tells which element to store by this thread + // we look up where is the corresponding key stored in the input array + if (mask) { + int k = find_nth_occurrence(cache_set, n, blockIdx.x, rank[i]); + if (k > -1) { + int key_val = keys[k]; + cached_keys[t_idx] = key_val; + cache_idx[k] = t_idx; + cache_time[t_idx] = time; + } + } + } +} + +/* Unnamed namespace is used to avoid multiple definition error for the + following non-template function */ +namespace { +/** + * @brief Get the cache indices for keys stored in the cache. + * + * For every key, we look up the corresponding cache position. + * If keys[k] is stored in the cache, then is_cached[k] is set to true, and + * cache_idx[k] stores the corresponding cache idx. + * + * If keys[k] is not stored in the cache, then we assign a cache set to it. + * This cache set is stored in cache_idx[k], and is_cached[k] is set to false. + * In this case AssignCacheIdx should be called, to get an assigned position + * within the cache set. + * + * Cache_time is assigned to the time input argument for all elements in idx. + * + * @param [in] keys array of keys that we want to look up in the cache, size [n] + * @param [in] n number of keys to look up + * @param [inout] cached_keys keys stored in the cache, size [n_cache_sets * associativity] + * @param [in] n_cache_sets number of cache sets + * @param [in] associativity number of keys in cache set + * @param [inout] cache_time time stamp when the indices were cached, size [n_cache_sets * associativity] + * @param [out] cache_idx cache indices of the working set elements, size [n] + * @param [out] is_cached whether the element is cached size[n] + * @param [in] time iteration counter (used for time stamping) + */ +__global__ void get_cache_idx(int *keys, int n, int *cached_keys, + int n_cache_sets, int associativity, + int *cache_time, int *cache_idx, bool *is_cached, + int time) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < n) { + int widx = keys[tid]; + int sidx = hash(widx, n_cache_sets); + int cidx = sidx * associativity; + int i = 0; + bool found = false; + // search for empty spot and the least recently used spot + while (i < associativity && !found) { + found = (cache_time[cidx + i] > 0 && cached_keys[cidx + i] == widx); + i++; + } + is_cached[tid] = found; + if (found) { + cidx = cidx + i - 1; + cache_time[cidx] = time; //update time stamp + cache_idx[tid] = cidx; //exact cache idx + } else { + cache_idx[tid] = sidx; // assign cache set + } + } +} +}; // end unnamed namespace +}; // namespace cache +}; // namespace raft diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index f479ef8e34..5f5755e24e 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -52,7 +53,12 @@ void copyRows(const m_t *in, idx_t n_rows, idx_t n_cols, m_t *out, const idx_array_t *indices, idx_t n_rows_indices, cudaStream_t stream, bool rowMajor = false) { if (rowMajor) { - ASSERT(false, "matrix.h: row major is not supported yet!"); + const idx_t TPB = 256; + cache:: + get_vecs<<>>( + in, n_cols, indices, n_rows_indices, out); + CUDA_CHECK(cudaPeekAtLastError()); + return; } idx_t size = n_rows_indices * n_cols; diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 14c5cd1b9a..28222c0697 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -121,8 +121,14 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { copyRows(input.data(), n_rows, n_cols, output.data(), indices.data(), n_selected, stream, false); - EXPECT_TRUE(raft::devArrMatchHost( - output_exp, output.data(), n_selected * n_cols, raft::Compare())); + EXPECT_TRUE(raft::devArrMatchHost(output_exp_colmajor, output.data(), + n_selected * n_cols, + raft::Compare())); + copyRows(input.data(), n_rows, n_cols, output.data(), indices.data(), + n_selected, stream, true); + EXPECT_TRUE(raft::devArrMatchHost(output_exp_rowmajor, output.data(), + n_selected * n_cols, + raft::Compare())); } protected: @@ -131,8 +137,10 @@ class MatrixCopyRowsTest : public ::testing::Test { int n_selected = 5; idx_array_t indices_host[5] = {0, 3, 4, 7, 9}; - math_t output_exp[15] = {0, 3, 4, 7, 9, 10, 13, 14, - 17, 19, 20, 23, 24, 27, 29}; + math_t output_exp_colmajor[15] = {0, 3, 4, 7, 9, 10, 13, 14, + 17, 19, 20, 23, 24, 27, 29}; + math_t output_exp_rowmajor[15] = {0, 1, 2, 9, 10, 11, 12, 13, + 14, 21, 22, 23, 27, 28, 29}; raft::handle_t handle; cudaStream_t stream; std::shared_ptr allocator;