Skip to content

Commit

Permalink
[Lang] Sparse matrix element modify
Browse files Browse the repository at this point in the history
  • Loading branch information
FantasyVR committed Oct 8, 2021
1 parent 0bd85c3 commit 2468b24
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/taichi/lang/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __matmul__(self, other):
def __getitem__(self, indices):
return self.matrix.get_element(indices[0], indices[1])

def __setitem__(self, indices, value):
self.matrix.set_element(indices[0], indices[1], value)

def __str__(self):
return self.matrix.to_string()

Expand Down
4 changes: 4 additions & 0 deletions taichi/program/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,9 @@ float32 SparseMatrix::get_element(int row, int col) {
return matrix_.coeff(row, col);
}

void SparseMatrix::set_element(int row, int col, float32 value){
matrix_.coeffRef(row,col) = value;
}

} // namespace lang
} // namespace taichi
1 change: 1 addition & 0 deletions taichi/program/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SparseMatrix {
Eigen::SparseMatrix<float32> &get_matrix();
const Eigen::SparseMatrix<float32> &get_matrix() const;
float32 get_element(int row, int col);
void set_element(int row, int col, float32 value);

friend SparseMatrix operator+(const SparseMatrix &sm1,
const SparseMatrix &sm2);
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ void export_lang(py::module &m) {
.def("transpose", &SparseMatrix::transpose,
py::return_value_policy::reference_internal)
.def("get_element", &SparseMatrix::get_element)
.def("set_element", &SparseMatrix::set_element)
.def("num_rows", &SparseMatrix::num_rows)
.def("num_cols", &SparseMatrix::num_cols);

Expand Down
14 changes: 14 additions & 0 deletions tests/python/test_sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def fill(Abuilder: ti.sparse_matrix_builder()):
for i in range(n):
assert A[i, i] == i

@ti.test(arch=ti.cpu)
def test_sparse_matrix_element_modify():
n = 8
Abuilder = ti.SparseMatrixBuilder(n, n, max_num_triplets=100)

@ti.kernel
def fill(Abuilder: ti.sparse_matrix_builder()):
for i in range(n):
Abuilder[i, i] += i

fill(Abuilder)
A = Abuilder.build()
A[0,0] = 1024.0
assert A[0,0] == 1024.0

@ti.test(arch=ti.cpu)
def test_sparse_matrix_addition():
Expand Down

0 comments on commit 2468b24

Please sign in to comment.