Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch of statically-typed Exprs #5235

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
56 changes: 56 additions & 0 deletions src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ struct StringImm : public ExprNode<StringImm> {

} // namespace Internal

template<typename T>
struct ExprT;

/** A fragment of Halide syntax. It's implemented as reference-counted
* handle to a concrete expression node, but it's immutable, so you
* can treat it as a value type. */
Expand Down Expand Up @@ -320,8 +323,61 @@ struct Expr : public Internal::IRHandle {
Type type() const {
return get()->type;
}

/** Convert to a statically-typed ExprT<T>, doing a runtime check as needed */
template<typename T>
ExprT<T> typed() const;
};

/** An ExprT<T> is just an Expr that can only hold values
* of type T; attempting to construct/copy/move one using
* the wrong type will fail at compile time. Note
* that an ExprT<T> is always implicitly convertible to an Expr,
* but the reverse is not true; you can use either:
*
* - Expr::typed<T>(), which does a runtime check to verify that
* the source Expr has the expected type (with assert-fail if
* not the case)
*
* - cast<T>(Expr), which will apply the usual coercion rules to
* forcibly produce an Expr<T> of the given type (or assert-fail
* if a cast is impossible by Halide rules)
*/
template<typename T>
struct ExprT final : public Expr {

HALIDE_ALWAYS_INLINE
ExprT() = default;

HALIDE_ALWAYS_INLINE
explicit ExprT(T x)
: Expr(x) {
}

HALIDE_ALWAYS_INLINE
explicit ExprT(const Internal::BaseExprNode *n)
: Expr(n) {
check_type(n->type);
}

static void check_type(const Type &t) {
user_assert(t == type_of<T>())
<< "Cannot convert an Expr of type " << t << " to ExprT<" << type_of<T>() << ">.\n";
}
};

// Must add a specialization for ExprT<bool> since Expr has no bool ctor
template<>
inline ExprT<bool>::ExprT(bool x)
: Expr(Internal::UIntImm::make(UInt(1), x)) {
}

template<typename T>
inline ExprT<T> Expr::typed() const {
ExprT<T>::check_type(type());
return ExprT<T>(get());
}

/** This lets you use an Expr as a key in a map of the form
* map<Expr, Foo, ExprCompare> */
struct ExprCompare {
Expand Down
4 changes: 2 additions & 2 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ Expr rounding_mul_shift_right(Expr a, Expr b, int q);

/** Cast an expression to the halide type corresponding to the C++ type T. */
template<typename T>
inline Expr cast(Expr a) {
return cast(type_of<T>(), std::move(a));
inline ExprT<T> cast(Expr a) {
return cast(type_of<T>(), std::move(a)).template typed<T>();
}

/** Cast an expression to a new type. */
Expand Down
4 changes: 2 additions & 2 deletions src/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
namespace Halide {

Var::Var(const std::string &n)
: e(Internal::Variable::make(Int(32), n)) {
: e(Internal::Variable::make(Int(32), n).typed<int>()) {
}

Var::Var()
: e(Internal::Variable::make(Int(32), Internal::make_entity_name(this, "Halide:.*:Var", 'v'))) {
: e(Internal::Variable::make(Int(32), Internal::make_entity_name(this, "Halide:.*:Var", 'v')).typed<int>()) {
}

Var Var::implicit(int n) {
Expand Down
6 changes: 3 additions & 3 deletions src/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Var {
* construction of the Var to avoid making a fresh Expr every time
* the Var is used in a context in which is will be converted to
* one. */
Expr e;
ExprT<int> e;

public:
/** Construct a Var with the given name */
Expand Down Expand Up @@ -155,7 +155,7 @@ class Var {
//}

/** A Var can be treated as an Expr of type Int(32) */
operator const Expr &() const {
operator const ExprT<int> &() const {
return e;
}

Expand All @@ -178,7 +178,7 @@ struct ImplicitVar {
operator Var() const {
return to_var();
}
operator Expr() const {
operator ExprT<int>() const {
return to_var();
}
};
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ tests(GROUPS correctness
tuple_update_ops.cpp
tuple_vector_reduce.cpp
two_vector_args.cpp
typed_expr.cpp
undef.cpp
uninitialized_read.cpp
unique_func_image.cpp
Expand Down
85 changes: 85 additions & 0 deletions test/correctness/typed_expr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "Halide.h"
#include <iostream>

using namespace Halide;
using namespace Halide::Internal;

template<typename T>
void check_type(const Expr &e) {
if (e.type() != type_of<T>()) {
std::cerr << "constant of type " << type_of<T>() << " returned expr of type " << e.type() << "\n";
exit(-1);
}
}

template<typename T>
void test_expr(T value) {
std::cout << "Test " << type_of<T>() << " = " << (0 + value) << "\n";

{
ExprT<T> et(value);
check_type<T>(et);

// ExprT<> -> Expr is always OK
Expr e0 = et;
check_type<T>(e0);

Expr e1(et);
check_type<T>(e1);

Expr e2(std::move(et));
check_type<T>(e2);
}

{
ExprT<T> et(value);
check_type<T>(et);

// ExprT<int> et_nope = et; // won't compile, wrong types

// Cast the type to an int -- is generally ok, will
// coerce the values as appropriate
// (except for strings, which fill fail at runtime)
if (!std::is_same<T, const char *>::value) {
ExprT<int> et1 = cast<int>(et);
check_type<int>(et1);
}

// Obviously this won't even compile
// ExprT<int> et2 = et.typed<T>();
// check_type<int>(et2);

// Will fail at runtime if et isn't an int32
if (std::is_same<T, int>::value) {
ExprT<int> et3 = et.template typed<int>();
check_type<int>(et3);
}
}
}

template<typename T>
void test_expr_range() {
test_expr<T>((T)0);
test_expr<T>((T)1);
}

int main(int argc, char **argv) {
test_expr_range<bool>();
test_expr_range<uint8_t>();
test_expr_range<uint16_t>();
test_expr_range<uint32_t>();
test_expr_range<int8_t>();
test_expr_range<int16_t>();
test_expr_range<int32_t>();
test_expr_range<int64_t>();
test_expr_range<uint64_t>();
test_expr_range<float16_t>();
test_expr_range<bfloat16_t>();
test_expr_range<float>();
test_expr_range<double>();

test_expr<const char *>("foo");

std::cout << "Success!\n";
return 0;
}