Skip to content

Commit

Permalink
refactor: optimize Expression::serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
TennyZhuang committed Oct 18, 2021
1 parent 88ca7f2 commit 7c996a5
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 130 deletions.
10 changes: 8 additions & 2 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@

TLANG_NAMESPACE_BEGIN

std::string Expr::serialize() const {
void Expr::serialize(std::stringstream &ss) const {
TI_ASSERT(expr);
return expr->serialize();
expr->serialize(ss);
}

std::string Expr::serialize() const {
std::stringstream ss;
serialize(ss);
return ss.str();
}

void Expr::set_tb(const std::string &tb) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "taichi/util/str.h"
#include "taichi/ir/type_utils.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -77,6 +78,7 @@ class Expr {
Expr operator[](const ExprGroup &indices) const;

std::string serialize() const;
void serialize(std::stringstream &ss) const;

void operator+=(const Expr &o);
void operator-=(const Expr &o);
Expand Down
14 changes: 9 additions & 5 deletions taichi/ir/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@ ExprGroup ExprGroup::loaded() const {
return indices_loaded;
}

std::string ExprGroup::serialize() const {
std::string ret;
void ExprGroup::serialize(std::stringstream &ss) const {
for (int i = 0; i < (int)exprs.size(); i++) {
ret += exprs[i].serialize();
exprs[i].serialize(ss);
if (i + 1 < (int)exprs.size()) {
ret += ", ";
ss << ", ";
}
}
return ret;
}

std::string ExprGroup::serialize() const {
std::stringstream ss;
serialize(ss);
return ss.str();
}

} // namespace lang
Expand Down
6 changes: 5 additions & 1 deletion taichi/ir/expression.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "taichi/util/str.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/expr.h"

Expand Down Expand Up @@ -36,7 +37,7 @@ class Expression {
stmt = nullptr;
}

virtual std::string serialize() = 0;
virtual void serialize(std::stringstream &ss) = 0;

virtual void flatten(FlattenContext *ctx) {
TI_NOT_IMPLEMENTED;
Expand Down Expand Up @@ -98,7 +99,10 @@ class ExprGroup {
return exprs[i];
}

void serialize(std::stringstream &ss) const;

std::string serialize() const;

ExprGroup loaded() const;
};

Expand Down
93 changes: 51 additions & 42 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,17 @@ void RandExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

std::string UnaryOpExpression::serialize() {
void UnaryOpExpression::serialize(std::stringstream &ss) {
ss << '(';
if (is_cast()) {
std::string reint = type == UnaryOpType::cast_value ? "" : "reinterpret_";
return fmt::format("({}{}<{}> {})", reint, unary_op_type_name(type),
data_type_name(cast_type), operand->serialize());
ss << (type == UnaryOpType::cast_value ? "" : "reinterpret_");
ss << unary_op_type_name(type);
ss << '<' << data_type_name(cast_type) << "> ";
} else {
return fmt::format("({} {})", unary_op_type_name(type),
operand->serialize());
ss << unary_op_type_name(type) << ' ';
}
operand->serialize(ss);
ss << ')';
}

bool UnaryOpExpression::is_cast() const {
Expand Down Expand Up @@ -185,16 +187,19 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) {
ctx->push_back(std::move(ptr));
}

std::string GlobalPtrExpression::serialize() {
std::string s = fmt::format(
"{}[", snode ? snode->get_node_type_name_hinted() : var.serialize());
void GlobalPtrExpression::serialize(std::stringstream &ss) {
if (snode) {
ss << snode->get_node_type_name_hinted();
} else {
var.serialize(ss);
}
ss << '[';
for (int i = 0; i < (int)indices.size(); i++) {
s += indices.exprs[i]->serialize();
indices.exprs[i]->serialize(ss);
if (i + 1 < (int)indices.size())
s += ", ";
ss << ", ";
}
s += "]";
return s;
ss << ']';
}

void GlobalPtrExpression::flatten(FlattenContext *ctx) {
Expand Down Expand Up @@ -299,19 +304,19 @@ void RangeAssumptionExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

std::string LoopUniqueExpression::serialize() {
std::string result = "loop_unique(" + input->serialize();
void LoopUniqueExpression::serialize(std::stringstream &ss) {
ss << "loop_unique(";
input.serialize(ss);
for (int i = 0; i < covers.size(); i++) {
if (i == 0)
result += ", covers=[";
result += covers[i]->get_node_type_name_hinted();
ss << ", covers=[";
ss << covers[i]->get_node_type_name_hinted();
if (i == (int)covers.size() - 1)
result += "]";
ss << ']';
else
result += ", ";
ss << ", ";
}
result += ")";
return result;
ss << ')';
}

void LoopUniqueExpression::flatten(FlattenContext *ctx) {
Expand Down Expand Up @@ -339,28 +344,29 @@ void IdExpression::flatten(FlattenContext *ctx) {
}
}

std::string AtomicOpExpression::serialize() {
void AtomicOpExpression::serialize(std::stringstream &ss) {
if (op_type == AtomicOpType::add) {
return fmt::format("atomic_add({}, {})", dest.serialize(), val.serialize());
ss << "atomic_add(";
} else if (op_type == AtomicOpType::sub) {
return fmt::format("atomic_sub({}, {})", dest.serialize(), val.serialize());
ss << "atomic_sub(";
} else if (op_type == AtomicOpType::min) {
return fmt::format("atomic_min({}, {})", dest.serialize(), val.serialize());
ss << "atomic_min(";
} else if (op_type == AtomicOpType::max) {
return fmt::format("atomic_max({}, {})", dest.serialize(), val.serialize());
ss << "atomic_max(";
} else if (op_type == AtomicOpType::bit_and) {
return fmt::format("atomic_bit_and({}, {})", dest.serialize(),
val.serialize());
ss << "atomic_bit_and(";
} else if (op_type == AtomicOpType::bit_or) {
return fmt::format("atomic_bit_or({}, {})", dest.serialize(),
val.serialize());
ss << "atomic_bit_or(";
} else if (op_type == AtomicOpType::bit_xor) {
return fmt::format("atomic_bit_xor({}, {})", dest.serialize(),
val.serialize());
ss << "atomic_bit_xor(";
} else {
// min/max not supported in the LLVM backend yet.
TI_NOT_IMPLEMENTED;
}
dest.serialize(ss);
ss << ", ";
val.serialize(ss);
ss << ")";
}

void AtomicOpExpression::flatten(FlattenContext *ctx) {
Expand Down Expand Up @@ -389,15 +395,17 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

std::string SNodeOpExpression::serialize() {
void SNodeOpExpression::serialize(std::stringstream &ss) {
ss << snode_op_type_name(op_type);
ss << '(';
ss << snode->get_node_type_name_hinted() << ", [";
indices.serialize(ss);
ss << "]";
if (value.expr) {
return fmt::format("{}({}, [{}], {})", snode_op_type_name(op_type),
snode->get_node_type_name_hinted(), indices.serialize(),
value.serialize());
} else {
return fmt::format("{}({}, [{}])", snode_op_type_name(op_type),
snode->get_node_type_name_hinted(), indices.serialize());
ss << ' ';
value.serialize(ss);
}
ss << ')';
}

void SNodeOpExpression::flatten(FlattenContext *ctx) {
Expand Down Expand Up @@ -469,9 +477,10 @@ void FuncCallExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

std::string FuncCallExpression::serialize() {
return fmt::format("func_call(\"{}\", {})", func->func_key.get_full_name(),
args.serialize());
void FuncCallExpression::serialize(std::stringstream &ss) {
ss << "func_call(\"" << func->func_key.get_full_name() << "\", ";
args.serialize(ss);
ss << ')';
}

Block *ASTBuilder::current_block() {
Expand Down
Loading

0 comments on commit 7c996a5

Please sign in to comment.