Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir][refactor] Move all Expression subclasses to frontend_ir.h #919

Merged
merged 1 commit into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 340 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#pragma once

#include <string>
#include <vector>

#include "taichi/lang_util.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/expr.h"

TLANG_NAMESPACE_BEGIN

// Frontend Statements

class FrontendAllocaStmt : public Stmt {
public:
Identifier ident;
Expand Down Expand Up @@ -193,4 +198,339 @@ class FrontendWhileStmt : public Stmt {
DEFINE_ACCEPT
};

// Expressions

class ArgLoadExpression : public Expression {
public:
int arg_id;

ArgLoadExpression(int arg_id) : arg_id(arg_id) {
}

std::string serialize() override {
return fmt::format("arg[{}]", arg_id);
}

void flatten(FlattenContext *ctx) override {
auto ran = std::make_unique<ArgLoadStmt>(arg_id);
ctx->push_back(std::move(ran));
stmt = ctx->back_stmt();
}
};

class RandExpression : public Expression {
public:
DataType dt;

RandExpression(DataType dt) : dt(dt) {
}

std::string serialize() override {
return fmt::format("rand<{}>()", data_type_name(dt));
}

void flatten(FlattenContext *ctx) override {
auto ran = std::make_unique<RandStmt>(dt);
ctx->push_back(std::move(ran));
stmt = ctx->back_stmt();
}
};

class UnaryOpExpression : public Expression {
public:
UnaryOpType type;
Expr operand;
DataType cast_type;

UnaryOpExpression(UnaryOpType type, const Expr &operand)
: type(type), operand(smart_load(operand)) {
cast_type = DataType::unknown;
}

bool is_cast() const;

std::string serialize() override;

void flatten(FlattenContext *ctx) override;
};

class BinaryOpExpression : public Expression {
public:
BinaryOpType type;
Expr lhs, rhs;

BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
: type(type) {
this->lhs.set(smart_load(lhs));
this->rhs.set(smart_load(rhs));
}

std::string serialize() override {
return fmt::format("({} {} {})", lhs->serialize(),
binary_op_type_symbol(type), rhs->serialize());
}

void flatten(FlattenContext *ctx) override {
// if (stmt)
// return;
lhs->flatten(ctx);
rhs->flatten(ctx);
ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs->stmt, rhs->stmt));
ctx->stmts.back()->tb = tb;
stmt = ctx->back_stmt();
}
};

class TernaryOpExpression : public Expression {
public:
TernaryOpType type;
Expr op1, op2, op3;

TernaryOpExpression(TernaryOpType type,
const Expr &op1,
const Expr &op2,
const Expr &op3)
: type(type) {
this->op1.set(load_if_ptr(op1));
this->op2.set(load_if_ptr(op2));
this->op3.set(load_if_ptr(op3));
}

std::string serialize() override {
return fmt::format("{}({} {} {})", ternary_type_name(type),
op1->serialize(), op2->serialize(), op3->serialize());
}

void flatten(FlattenContext *ctx) override {
// if (stmt)
// return;
op1->flatten(ctx);
op2->flatten(ctx);
op3->flatten(ctx);
ctx->push_back(
std::make_unique<TernaryOpStmt>(type, op1->stmt, op2->stmt, op3->stmt));
stmt = ctx->back_stmt();
}
};

class ExternalTensorExpression : public Expression {
public:
DataType dt;
int dim;
int arg_id;

ExternalTensorExpression(const DataType &dt, int dim, int arg_id)
: dt(dt), dim(dim), arg_id(arg_id) {
set_attribute("dim", std::to_string(dim));
}

std::string serialize() override {
return fmt::format("{}d_ext_arr", dim);
}

void flatten(FlattenContext *ctx) override {
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, true);
ctx->push_back(std::move(ptr));
stmt = ctx->back_stmt();
}
};

class GlobalVariableExpression : public Expression {
public:
Identifier ident;
DataType dt;
SNode *snode;
bool has_ambient;
TypedConstant ambient_value;
bool is_primal;
Expr adjoint;

GlobalVariableExpression(DataType dt, const Identifier &ident)
: ident(ident), dt(dt) {
snode = nullptr;
has_ambient = false;
is_primal = true;
}

GlobalVariableExpression(SNode *snode) : snode(snode) {
dt = snode->dt;
has_ambient = false;
is_primal = true;
}

void set_snode(SNode *snode) {
this->snode = snode;
set_attribute("dim", std::to_string(snode->num_active_indices));
}

std::string serialize() override {
return "#" + ident.name();
}

void flatten(FlattenContext *ctx) override {
TI_ASSERT(snode->num_active_indices == 0);
auto ptr = Stmt::make<GlobalPtrStmt>(LaneAttribute<SNode *>(snode),
std::vector<Stmt *>());
ctx->push_back(std::move(ptr));
}
};

class GlobalPtrExpression : public Expression {
public:
Expr var;
ExprGroup indices;

GlobalPtrExpression(const Expr &var, const ExprGroup &indices)
: var(var), indices(indices) {
}

std::string serialize() override;

void flatten(FlattenContext *ctx) override;

bool is_lvalue() const override {
return true;
}
};

class EvalExpression : public Expression {
public:
Stmt *stmt_ptr;
int stmt_id;
EvalExpression(Stmt *stmt) : stmt_ptr(stmt), stmt_id(stmt_ptr->id) {
// cache stmt->id since it may be released later
}

std::string serialize() override {
return fmt::format("%{}", stmt_id);
}

void flatten(FlattenContext *ctx) override {
stmt = stmt_ptr;
}
};

class RangeAssumptionExpression : public Expression {
public:
Expr input, base;
int low, high;

RangeAssumptionExpression(const Expr &input,
const Expr &base,
int low,
int high)
: input(input), base(base), low(low), high(high) {
}

std::string serialize() override {
return fmt::format("assume_in_range({}{:+d} <= ({}) < {}{:+d})",
base.serialize(), low, input.serialize(),
base.serialize(), high);
}

void flatten(FlattenContext *ctx) override {
input->flatten(ctx);
base->flatten(ctx);
ctx->push_back(
Stmt::make<RangeAssumptionStmt>(input->stmt, base->stmt, low, high));
stmt = ctx->back_stmt();
}
};

class IdExpression : public Expression {
public:
Identifier id;
IdExpression(const std::string &name = "") : id(name) {
}
IdExpression(const Identifier &id) : id(id) {
}

std::string serialize() override {
return id.name();
}

void flatten(FlattenContext *ctx) override {
ctx->push_back(std::make_unique<LocalLoadStmt>(
LocalAddress(ctx->current_block->lookup_var(id), 0)));
stmt = ctx->back_stmt();
}

bool is_lvalue() const override {
return true;
}
};

// ti.atomic_*() is an expression with side effect.
class AtomicOpExpression : public Expression {
public:
AtomicOpType op_type;
Expr dest, val;

AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val)
: op_type(op_type), dest(dest), val(val) {
}

std::string serialize() override;

void flatten(FlattenContext *ctx) override;
};

class SNodeOpExpression : public Expression {
public:
SNode *snode;
SNodeOpType op_type;
ExprGroup indices;
Expr value;

SNodeOpExpression(SNode *snode, SNodeOpType op_type, const ExprGroup &indices)
: snode(snode), op_type(op_type), indices(indices) {
}

SNodeOpExpression(SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices,
const Expr &value)
: snode(snode), op_type(op_type), indices(indices), value(value) {
}

std::string serialize() override;

void flatten(FlattenContext *ctx) override;
};

class GlobalLoadExpression : public Expression {
public:
Expr ptr;
GlobalLoadExpression(const Expr &ptr) : ptr(ptr) {
}

std::string serialize() override {
return "gbl load " + ptr.serialize();
}

void flatten(FlattenContext *ctx) override {
ptr->flatten(ctx);
ctx->push_back(std::make_unique<GlobalLoadStmt>(ptr->stmt));
stmt = ctx->back_stmt();
}
};

class ConstExpression : public Expression {
public:
TypedConstant val;

template <typename T>
ConstExpression(const T &x) : val(x) {
}

std::string serialize() override {
return val.stringify();
}

void flatten(FlattenContext *ctx) override {
ctx->push_back(Stmt::make<ConstStmt>(val));
stmt = ctx->back_stmt();
}
};

TLANG_NAMESPACE_END
Loading