Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Support LU sparse solver on CUDA backend #6967

Merged
merged 13 commits into from
Dec 26, 2022
6 changes: 5 additions & 1 deletion python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def _get_ndarray_addr(self):

def print_triplets(self):
"""Print the triplets stored in the builder"""
self.ptr.print_triplets()
taichi_arch = get_runtime().prog.config().arch
if taichi_arch == _ti_core.Arch.x64 or taichi_arch == _ti_core.Arch.arm64:
self.ptr.print_triplets_eigen()
elif taichi_arch == _ti_core.Arch.cuda:
self.ptr.print_triplets_cuda()

def build(self, dtype=f32, _format='CSR'):
"""Create a sparse matrix using the triplets"""
Expand Down
23 changes: 2 additions & 21 deletions python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,32 +100,13 @@ def solve(self, b): # pylint: disable=R1710
return self.solver.solve(b)
if isinstance(b, Ndarray):
x = ScalarNdarray(b.dtype, [self.matrix.m])
self.solve_rf(self.matrix, b, x)
self.solver.solve_rf(get_runtime().prog, self.matrix.matrix, b.arr,
x.arr)
return x
raise TaichiRuntimeError(
f"The parameter type: {type(b)} is not supported in linear solvers for now."
)

def solve_cu(self, sparse_matrix, b):
if isinstance(sparse_matrix, SparseMatrix) and isinstance(b, Ndarray):
x = ScalarNdarray(b.dtype, [sparse_matrix.m])
self.solver.solve_cu(get_runtime().prog, sparse_matrix.matrix,
b.arr, x.arr)
return x
raise TaichiRuntimeError(
f"The parameter type: {type(sparse_matrix)}, {type(b)} and {type(x)} is not supported in linear solvers for now."
)

def solve_rf(self, sparse_matrix, b, x):
if isinstance(sparse_matrix, SparseMatrix) and isinstance(
b, Ndarray) and isinstance(x, Ndarray):
self.solver.solve_rf(get_runtime().prog, sparse_matrix.matrix,
b.arr, x.arr)
else:
raise TaichiRuntimeError(
f"The parameter type: {type(sparse_matrix)}, {type(b)} and {type(x)} is not supported in linear solvers for now."
)

def info(self):
"""Check if the linear systems are solved successfully.

Expand Down
48 changes: 42 additions & 6 deletions taichi/program/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,53 @@ SparseMatrixBuilder::SparseMatrixBuilder(int rows,
prog_, dtype_, std::vector<int>{3 * (int)max_num_triplets_ + 1});
}

void SparseMatrixBuilder::print_triplets() {
num_triplets_ = ndarray_data_base_ptr_->read_int(std::vector<int>{0});
template <typename T, typename G>
void SparseMatrixBuilder::print_triplets_template() {
auto ptr = get_ndarray_data_ptr();
G *data = reinterpret_cast<G *>(ptr);
num_triplets_ = data[0];
fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_,
num_triplets_, max_num_triplets_);
data += 1;
for (int i = 0; i < num_triplets_; i++) {
auto idx = 3 * i + 1;
auto row = ndarray_data_base_ptr_->read_int(std::vector<int>{idx});
auto col = ndarray_data_base_ptr_->read_int(std::vector<int>{idx + 1});
auto val = ndarray_data_base_ptr_->read_float(std::vector<int>{idx + 2});
fmt::print("[{}, {}] = {}\n", data[i * 3], data[i * 3 + 1],
taichi_union_cast<T>(data[i * 3 + 2]));
}
}

void SparseMatrixBuilder::print_triplets_eigen() {
auto element_size = data_type_size(dtype_);
switch (element_size) {
case 4:
print_triplets_template<float32, int32>();
break;
case 8:
print_triplets_template<float64, int64>();
break;
default:
TI_ERROR("Unsupported sparse matrix data type!");
break;
}
}

void SparseMatrixBuilder::print_triplets_cuda() {
#ifdef TI_WITH_CUDA
CUDADriver::get_instance().memcpy_device_to_host(
&num_triplets_, (void *)get_ndarray_data_ptr(), sizeof(int));
fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_,
num_triplets_, max_num_triplets_);
auto len = 3 * num_triplets_ + 1;
std::vector<float32> trips(len);
strongoier marked this conversation as resolved.
Show resolved Hide resolved
CUDADriver::get_instance().memcpy_device_to_host(
(void *)trips.data(), (void *)get_ndarray_data_ptr(),
len * sizeof(float32));
for (auto i = 0; i < num_triplets_; i++) {
int row = taichi_union_cast<int>(trips[3 * i + 1]);
int col = taichi_union_cast<int>(trips[3 * i + 2]);
auto val = trips[i * 3 + 3];
fmt::print("[{}, {}] = {}\n", row, col, val);
}
#endif
}

intptr_t SparseMatrixBuilder::get_ndarray_data_ptr() const {
Expand Down
6 changes: 5 additions & 1 deletion taichi/program/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class SparseMatrixBuilder {
const std::string &storage_format,
Program *prog);

void print_triplets();
void print_triplets_eigen();
void print_triplets_cuda();

intptr_t get_ndarray_data_ptr() const;

Expand All @@ -36,6 +37,9 @@ class SparseMatrixBuilder {
template <typename T, typename G>
void build_template(std::unique_ptr<SparseMatrix> &);

template <typename T, typename G>
void print_triplets_template();

private:
uint64 num_triplets_{0};
std::unique_ptr<Ndarray> ndarray_data_base_ptr_{nullptr};
Expand Down
Loading