Skip to content

Commit

Permalink
Using 64-bit array lengths to increase scale of pca & tsvd (#3983)
Browse files Browse the repository at this point in the history
Addresses #2459 (likely not all of it)

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3983
  • Loading branch information
cjnolet authored Jul 23, 2021
1 parent 33a01b0 commit 40af8af
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions cpp/src/pca/pca.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void truncCompExpVars(const raft::handle_t& handle,
const paramsTSVDTemplate<enum_solver> prms,
cudaStream_t stream)
{
int len = prms.n_cols * prms.n_cols;
size_t len = prms.n_cols * prms.n_cols;
auto allocator = handle.get_device_allocator();
device_buffer<math_t> components_all(allocator, stream, len);
device_buffer<math_t> explained_var_all(allocator, stream, prms.n_cols);
Expand Down Expand Up @@ -103,7 +103,7 @@ void pcaFit(const raft::handle_t& handle,

raft::stats::mean(mu, input, prms.n_cols, prms.n_rows, true, false, stream);

int len = prms.n_cols * prms.n_cols;
size_t len = prms.n_cols * prms.n_cols;
device_buffer<math_t> cov(handle.get_device_allocator(), stream, len);

Stats::cov(handle, cov.data(), input, mu, prms.n_cols, prms.n_rows, true, false, true, stream);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/tsvd/tsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ void tsvdFit(const raft::handle_t& handle,
int n_components = prms.n_components;
if (prms.n_components > prms.n_cols) n_components = prms.n_cols;

int len = prms.n_cols * prms.n_cols;
size_t len = prms.n_cols * prms.n_cols;
device_buffer<math_t> input_cross_mult(allocator, stream, len);

math_t alpha = math_t(1);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src_prims/linalg/lstsq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ void lstsq(const raft::handle_t& handle,

ASSERT(n_rows > 1, "lstsq: number of rows cannot be less than two");

int U_len = n_rows * n_cols;
int V_len = n_cols * n_cols;
size_t U_len = n_rows * n_cols;
size_t V_len = n_cols * n_cols;

rmm::device_uvector<math_t> S(n_cols, stream);
rmm::device_uvector<math_t> V(V_len, stream);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ void class_probs(std::vector<float*>& out,
cudaStream_t stream = raft::select_stream(user_stream, int_streams, n_int_streams, i);

int n_unique_labels = n_unique[i];
int cur_size = n_query_rows * n_unique_labels;
size_t cur_size = n_query_rows * n_unique_labels;

CUDA_CHECK(cudaMemsetAsync(out[i], 0, cur_size * sizeof(float), stream));

Expand Down

0 comments on commit 40af8af

Please sign in to comment.