Skip to content

Commit

Permalink
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 484351696
  • Loading branch information
tlu7 authored and jax authors committed Oct 27, 2022
1 parent fc8f40c commit 66e75ed
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
8 changes: 4 additions & 4 deletions jaxlib/gpu/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ std::pair<size_t, py::bytes> BuildCsrMatvecDescriptor(
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
GPUSPARSE_SPMV_CSR_ALG, &buffer_size)));

JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
Expand Down Expand Up @@ -346,7 +346,7 @@ std::pair<size_t, py::bytes> BuildCsrMatmatDescriptor(
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_CSR_ALG, &buffer_size)));

JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
Expand Down Expand Up @@ -467,7 +467,7 @@ std::pair<size_t, py::bytes> BuildCooMatvecDescriptor(
SparseConst beta = ConstZero(y.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize(
handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type,
GPUSPARSE_MV_ALG_DEFAULT, &buffer_size)));
GPUSPARSE_SPMV_COO_ALG, &buffer_size)));

JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
Expand Down Expand Up @@ -537,7 +537,7 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
SparseConst beta = ConstZero(C.type);
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize(
handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a,
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)));
mat_b, &beta, mat_c, C.type, GPUSPARSE_SPMM_COO_ALG, &buffer_size)));

JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/gpu/sparse_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers,

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
d.y.type, GPUSPARSE_SPMV_CSR_ALG, buf)));

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
Expand Down Expand Up @@ -324,7 +324,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers,
/*ld=*/d.C.cols, Cbuf, d.C.type, GPUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_CSR_ALG, buf)));

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
Expand Down Expand Up @@ -463,7 +463,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers,

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
gpusparseSpMV(handle.get(), d.op, &alpha, mat_a, vec_x, &beta, vec_y,
d.y.type, GPUSPARSE_MV_ALG_DEFAULT, buf)));
d.y.type, GPUSPARSE_SPMV_COO_ALG, buf)));

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnVec(vec_x)));
Expand Down Expand Up @@ -529,7 +529,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers,
/*batchStride=*/d.C.batch_stride)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM(
handle.get(), d.op_A, /*opB=*/GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_ALG_DEFAULT, buf)));
mat_a, mat_b, &beta, mat_c, d.C.type, GPUSPARSE_SPMM_COO_ALG, buf)));

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroySpMat(mat_a)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusparseDestroyDnMat(mat_b)));
Expand Down
26 changes: 22 additions & 4 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,28 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_INDEX_64I CUSPARSE_INDEX_64I
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
#define GPUSPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO
#define GPUSPARSE_MV_ALG_DEFAULT CUSPARSE_MV_ALG_DEFAULT
// Use CUSPARSE_SPMV_COO_ALG2 and CUSPARSE_SPMV_CSR_ALG2 for SPMV and
// use CUSPARSE_SPMM_COO_ALG2 and CUSPARSE_SPMM_CSR_ALG3 for SPMM, which
// provide deterministic (bit-wise) results for each run.
// CUSPARSE_SPMV_COO_ALG2 is available since cuda version 11.2.1
// CUSPARSE_SPMV_CSR_ALG2 is available since cuda version 11.2.1
// CUSPARSE_SPMM_COO_ALG2 is available since cuda version 11.0.3
// CUSPARSE_SPMM_CSR_ALG3 is available since cuda version 11.2.1
#if JAX_GPU_HAVE_SPARSE
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_SPMV_COO_ALG2
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_SPMV_CSR_ALG2
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_COO_ALG2
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_CSR_ALG3
#else
#define GPUSPARSE_SPMV_COO_ALG CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_CSR_ALG CUSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMM_COO_ALG CUSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_SPMM_CSR_ALG CUSPARSE_SPMM_ALG_DEFAULT
#endif
#define GPUSPARSE_OPERATION_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT CUSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS CUSPARSE_STATUS_SUCCESS

#define gpuGetLastError cudaGetLastError
Expand Down Expand Up @@ -418,13 +434,15 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_INDEX_32I HIPSPARSE_INDEX_32I
#define GPUSPARSE_INDEX_64I HIPSPARSE_INDEX_64I
#define GPUSPARSE_DENSETOSPARSE_ALG_DEFAULT HIPSPARSE_DENSETOSPARSE_ALG_DEFAULT
#define GPUSPARSE_MV_ALG_DEFAULT HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_COO_ALG HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMV_CSR_ALG HIPSPARSE_MV_ALG_DEFAULT
#define GPUSPARSE_SPMM_COO_ALG HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_SPMM_CSR_ALG HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO
#define GPUSPARSE_OPERATION_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE
#define GPUSPARSE_OPERATION_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE
#define GPUSPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS

#define gpuGetLastError hipGetLastError
Expand Down
6 changes: 5 additions & 1 deletion tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,11 @@ def f_sparse(lhs_bcoo, lhs, rhs):
# TODO(tianjianlu): In some cases, this fails python_should_be_executing.
# self._CompileAndCheck(f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
if dtype is np.complex128:
atol = 1E-1
else:
atol = 1E-2
self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker, atol=atol, rtol=1E-6)
else:
lhs_bcoo, lhs, rhs = args_maker()
matmat_expected = f_dense(lhs_bcoo, lhs, rhs)
Expand Down

0 comments on commit 66e75ed

Please sign in to comment.