Skip to content

Commit

Permalink
[lang] Fix scalarization for PrintStmt (#6945)
Browse files Browse the repository at this point in the history
Issue: #6927

### Brief Summary

Co-authored-by: Zhao Liang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 23, 2022
1 parent 6268cfb commit bbe38aa
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,60 @@ class Scalarize : public BasicStmtVisitor {
Stmt *print_stmt = std::get<Stmt *>(content);
if (print_stmt->is<MatrixInitStmt>()) {
auto matrix_init_stmt = print_stmt->cast<MatrixInitStmt>();
for (size_t j = 0; j < matrix_init_stmt->values.size(); j++) {
new_contents.push_back(matrix_init_stmt->values[j]);
auto tensor_shape =
print_stmt->ret_type->as<TensorType>()->get_shape();

bool is_matrix = tensor_shape.size() == 2;
int m = tensor_shape[0];

new_contents.push_back("[");
if (is_matrix) {
int n = tensor_shape[1];
for (size_t i = 0; i < m; i++) {
new_contents.push_back("[");
for (size_t j = 0; j < n; j++) {
size_t index = i * n + j;
new_contents.push_back(matrix_init_stmt->values[index]);
if (j != n - 1)
new_contents.push_back(", ");
}
new_contents.push_back("]");

if (i != m - 1)
new_contents.push_back(", ");
}
} else {
for (size_t i = 0; i < m; i++) {
new_contents.push_back(matrix_init_stmt->values[i]);
if (i != m - 1)
new_contents.push_back(", ");
}
}
new_contents.push_back("]");
} else {
new_contents.push_back(print_stmt);
}
}
}
modifier_.insert_before(stmt, Stmt::make<PrintStmt>(new_contents));

// Merge string contents
std::vector<std::variant<Stmt *, std::string>> merged_contents;
std::string merged_string = "";
for (const auto &content : new_contents) {
if (auto string_content = std::get_if<std::string>(&content)) {
merged_string += *string_content;
} else {
if (!merged_string.empty()) {
merged_contents.push_back(merged_string);
merged_string = "";
}
merged_contents.push_back(content);
}
}
if (!merged_string.empty())
merged_contents.push_back(merged_string);

modifier_.insert_before(stmt, Stmt::make<PrintStmt>(merged_contents));
modifier_.erase(stmt);
}

Expand Down

0 comments on commit bbe38aa

Please sign in to comment.