From c9db851458278fc7a143e8376fef51d008b592bc Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 29 Dec 2023 08:34:41 +0000 Subject: [PATCH] [DimExpr] DimExpr support hash --- paddle/pir/dialect/shape/utils/dim_expr.cc | 52 +++++++++++++++++++ paddle/pir/dialect/shape/utils/dim_expr.h | 13 +++++ .../pir/shape_dialect/symbol_dim_expr_test.cc | 34 +++++++++--- 3 files changed, 93 insertions(+), 6 deletions(-) diff --git a/paddle/pir/dialect/shape/utils/dim_expr.cc b/paddle/pir/dialect/shape/utils/dim_expr.cc index 0d9b6ece23245c..61f7a582cb5a56 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/dialect/shape/utils/dim_expr.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/dialect/shape/utils/dim_expr.h" +#include "paddle/pir/core/utils.h" namespace symbol { @@ -184,4 +185,55 @@ std::ostream& operator<<(std::ostream& stream, const DimExpr& dim_expr) { return stream; } +namespace { + +std::size_t GetHashValueImpl(const std::int64_t& dim_expr) { return dim_expr; } + +std::size_t GetHashValueImpl(const std::string& dim_expr) { + return std::hash()(dim_expr); +} + +std::size_t GetHashValueImpl(const Negative& dim_expr) { + return -GetHashValue(dim_expr->data); +} + +std::size_t GetHashValueImpl(const Reciprocal& dim_expr) { + return pir::hash_combine(1, -GetHashValue(dim_expr->data)); +} + +std::size_t GetHashValueImpl(const List& exprs) { + std::size_t ret = 0; + for (const auto& expr : *exprs) { + ret = pir::hash_combine(ret, GetHashValue(expr)); + } + return ret; +} + +std::size_t GetHashValueImpl(const Add& dim_expr) { + return pir::hash_combine(1, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Mul& dim_expr) { + return pir::hash_combine(2, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Max& dim_expr) { + return pir::hash_combine(3, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Min& dim_expr) { + return pir::hash_combine(4, GetHashValueImpl(dim_expr.operands)); +} + +std::size_t GetHashValueImpl(const Broadcast& dim_expr) { + return pir::hash_combine(5, GetHashValueImpl(dim_expr.operands)); +} + +} // namespace + +std::size_t GetHashValue(const DimExpr& dim_expr) { + return std::visit([](const auto& impl) { return GetHashValueImpl(impl); }, + dim_expr.variant()); +} + } // namespace symbol diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 277a6febe66ed7..a65390200cd062 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -253,4 +253,17 @@ IR_API std::string ToString(const DimExpr& dim_expr); IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); +IR_API std::size_t GetHashValue(const DimExpr& dim_expr); + } // namespace symbol + +namespace std { + +template <> +struct hash { + std::size_t operator()(const symbol::DimExpr& dim_expr) const { + return symbol::GetHashValue(dim_expr); + } +}; + +} // namespace std diff --git a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc index 6157850e3842c3..3aebb367d1a272 100644 --- a/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc +++ b/test/cpp/pir/shape_dialect/symbol_dim_expr_test.cc @@ -22,7 +22,7 @@ namespace symbol::test { // Construct DimExpr by overloaded operator(+, - , *, /) -TEST(DimExpr, dim_expr_naive) { +TEST(DimExpr, DimExprNaive) { DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); DimExpr constant1 = DimExpr(1); @@ -30,7 +30,7 @@ TEST(DimExpr, dim_expr_naive) { } // Construct DimExpr by DimExprBuilder -TEST(DimExpr, dim_expr_builder) { +TEST(DimExpr, DimExprBuilder) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); @@ -40,7 +40,7 @@ TEST(DimExpr, dim_expr_builder) { } // Add constraints by DimExprBuilder -TEST(DimExpr, constraint) { +TEST(DimExpr, Constraint) { std::vector constraints{}; DimExprBuilder builder(&constraints); DimExpr sym0 = DimExpr("S0"); @@ -55,7 +55,7 @@ TEST(DimExpr, constraint) { extend_x = x.shape out = pd.reshape(y, extend_x) */ -TEST(DimExpr, data_shape_expr) { +TEST(DimExpr, DataShapeExpr) { // Show ideal ShapeOrDataDimExprs of each pir::Value std::vector x_shapes{DimExpr("S0"), DimExpr(2)}; std::vector y_shapes{DimExpr(1), DimExpr("S1"), DimExpr(2)}; @@ -80,7 +80,7 @@ TEST(Simplify, NumberArithmetic) { ASSERT_EQ((mul_div.Get()), 1); } -TEST(DimExpr, equal) { +TEST(DimExpr, Equal) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); @@ -111,7 +111,7 @@ TEST(DimExpr, equal) { builder.Broadcast(DimExpr("S0"), constant1)); } -TEST(DimExpr, print) { +TEST(DimExpr, Print) { DimExprBuilder builder{nullptr}; DimExpr sym0 = DimExpr("S0"); DimExpr sym1 = DimExpr("S1"); @@ -124,4 +124,26 @@ TEST(DimExpr, print) { ASSERT_EQ((ToString(builder.Broadcast(sym0, sym1))), "Broadcast(S0, S1)"); } +TEST(DimExpr, Hash) { + DimExprBuilder builder{nullptr}; + DimExpr sym0 = DimExpr("S0"); + DimExpr sym1 = DimExpr("S1"); + ASSERT_EQ((std::hash()(sym0 + sym1)), + (std::hash()(sym0 + sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym1 + sym0))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym0 - sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym0 * sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(sym0 / sym1))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(builder.Max(sym0, sym1)))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(builder.Min(sym0, sym1)))); + ASSERT_NE((std::hash()(sym0 + sym1)), + (std::hash()(builder.Broadcast(sym0, sym1)))); +} + } // namespace symbol::test