Skip to content

Commit

Permalink
[bug] Fix sparse matrix memory release error
Browse files Browse the repository at this point in the history
ghstack-source-id: 87b4cc2568d311fd203e31775ab667d519cf0b78
Pull Request resolved: #8149
  • Loading branch information
listerily authored and Taichi Gardener committed Jun 9, 2023
1 parent 9c044e4 commit 179cbf9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
6 changes: 5 additions & 1 deletion python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def __init__(
max_num_triplets,
dtype,
storage_format,
get_runtime().prog,
)
self.ptr.create_ndarray(get_runtime().prog)
else:
raise TaichiRuntimeError("SparseMatrix only supports CPU and CUDA for now.")

Expand Down Expand Up @@ -292,5 +292,9 @@ def build(self, dtype=f32, _format="CSR"):
return SparseMatrix(sm=sm, dtype=self.dtype)
raise TaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.")

def __del__(self):
if get_runtime() is not None and get_runtime().prog is not None:
self.ptr.delete_ndarray(get_runtime().prog)


__all__ = ["SparseMatrix", "SparseMatrixBuilder"]
18 changes: 11 additions & 7 deletions taichi/program/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,26 @@ SparseMatrixBuilder::SparseMatrixBuilder(int rows,
int cols,
int max_num_triplets,
DataType dtype,
const std::string &storage_format,
Program *prog)
const std::string &storage_format)
: rows_(rows),
cols_(cols),
max_num_triplets_(max_num_triplets),
dtype_(dtype),
storage_format_(storage_format),
prog_(prog) {
storage_format_(storage_format) {
auto element_size = data_type_size(dtype);
TI_ASSERT((element_size == 4 || element_size == 8));
}

SparseMatrixBuilder::~SparseMatrixBuilder() = default;

void SparseMatrixBuilder::create_ndarray(Program *prog) {
ndarray_data_base_ptr_ = prog->create_ndarray(
dtype_, std::vector<int>{3 * (int)max_num_triplets_ + 1});
ndarray_data_ptr_ = prog->get_ndarray_data_ptr_as_int(ndarray_data_base_ptr_);
}

SparseMatrixBuilder::~SparseMatrixBuilder() {
prog_->delete_ndarray(ndarray_data_base_ptr_);
void SparseMatrixBuilder::delete_ndarray(Program *prog) {
prog->delete_ndarray(ndarray_data_base_ptr_);
}

template <typename T, typename G>
Expand Down Expand Up @@ -153,7 +157,7 @@ void SparseMatrixBuilder::print_triplets_cuda() {
}

intptr_t SparseMatrixBuilder::get_ndarray_data_ptr() const {
return prog_->get_ndarray_data_ptr_as_int(ndarray_data_base_ptr_);
return ndarray_data_ptr_;
}

template <typename T, typename G>
Expand Down
9 changes: 6 additions & 3 deletions taichi/program/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ class SparseMatrixBuilder {
int cols,
int max_num_triplets,
DataType dtype,
const std::string &storage_format,
Program *prog);
const std::string &storage_format);

~SparseMatrixBuilder();
void print_triplets_eigen();
void print_triplets_cuda();

void create_ndarray(Program *prog);

void delete_ndarray(Program *prog);

intptr_t get_ndarray_data_ptr() const;

std::unique_ptr<SparseMatrix> build();
Expand All @@ -44,13 +47,13 @@ class SparseMatrixBuilder {
private:
uint64 num_triplets_{0};
Ndarray *ndarray_data_base_ptr_{nullptr};
intptr_t ndarray_data_ptr_{0};
int rows_{0};
int cols_{0};
uint64 max_num_triplets_{0};
bool built_{false};
DataType dtype_{PrimitiveType::f32};
std::string storage_format_{"col_major"};
Program *prog_{nullptr};
};

class SparseMatrix {
Expand Down
12 changes: 10 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,12 +1170,20 @@ void export_lang(py::module &m) {

// Sparse Matrix
py::class_<SparseMatrixBuilder>(m, "SparseMatrixBuilder")
.def(py::init<int, int, int, DataType, const std::string &, Program *>(),
.def(py::init<int, int, int, DataType, const std::string &>(),
py::arg("rows"), py::arg("cols"), py::arg("max_num_triplets"),
py::arg("dt") = PrimitiveType::f32,
py::arg("storage_format") = "col_major", py::arg("prog") = nullptr)
py::arg("storage_format") = "col_major")
.def("print_triplets_eigen", &SparseMatrixBuilder::print_triplets_eigen)
.def("print_triplets_cuda", &SparseMatrixBuilder::print_triplets_cuda)
.def("create_ndarray",
[&](SparseMatrixBuilder *builder, Program *prog) {
return builder->create_ndarray(prog);
})
.def("delete_ndarray",
[&](SparseMatrixBuilder *builder, Program *prog) {
return builder->delete_ndarray(prog);
})
.def("get_ndarray_data_ptr", &SparseMatrixBuilder::get_ndarray_data_ptr)
.def("build", &SparseMatrixBuilder::build)
.def("build_cuda", &SparseMatrixBuilder::build_cuda)
Expand Down

0 comments on commit 179cbf9

Please sign in to comment.