From 213c8f25e9bb0fee1c31694636efd1ab999f0eb7 Mon Sep 17 00:00:00 2001 From: Praveen Narayanan Date: Fri, 7 Mar 2025 17:26:44 +0000 Subject: [PATCH] Add CHLO CAPI and PythonAPI for ragged dot --- stablehlo/integrations/c/ChloAttributes.cpp | 124 +++++++++++++++++++ stablehlo/integrations/c/ChloAttributes.h | 64 ++++++++++ stablehlo/integrations/python/ChloModule.cpp | 98 +++++++++++++++ stablehlo/integrations/python/tests/chlo.py | 27 ++++ 4 files changed, 313 insertions(+) diff --git a/stablehlo/integrations/c/ChloAttributes.cpp b/stablehlo/integrations/c/ChloAttributes.cpp index 4db047b2db..1393ad883c 100644 --- a/stablehlo/integrations/c/ChloAttributes.cpp +++ b/stablehlo/integrations/c/ChloAttributes.cpp @@ -66,3 +66,127 @@ MlirStringRef chloComparisonTypeAttrGetValue(MlirAttribute attr) { return wrap(mlir::chlo::stringifyComparisonType( llvm::cast(unwrap(attr)).getValue())); } + +//===----------------------------------------------------------------------===// +// RaggedDotDimensionNumbers +//===----------------------------------------------------------------------===// + +MlirAttribute chloRaggedDotDimensionNumbersGet( + MlirContext ctx, intptr_t nLhsBatchingDimensions, + const int64_t *lhsBatchingDimensions, intptr_t nRhsBatchingDimensions, + const int64_t *rhsBatchingDimensions, intptr_t nLhsContractingDimensions, + const int64_t *lhsContractingDimensions, intptr_t nRhsContractingDimensions, + const int64_t *rhsContractingDimensions, intptr_t nLhsRaggedDimensions, + const int64_t *lhsRaggedDimensions, intptr_t nRhsGroupDimensions, + const int64_t *rhsGroupDimensions) { + return wrap(mlir::chlo::RaggedDotDimensionNumbersAttr::get( + unwrap(ctx), + llvm::ArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions), + llvm::ArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions), + llvm::ArrayRef(lhsContractingDimensions, nLhsContractingDimensions), + llvm::ArrayRef(rhsContractingDimensions, nRhsContractingDimensions), + llvm::ArrayRef(lhsRaggedDimensions, nLhsRaggedDimensions), + llvm::ArrayRef(rhsGroupDimensions, nRhsGroupDimensions))); +} + +bool chloAttributeIsARaggedDotDimensionNumbers(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +intptr_t chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsSize( + MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getLhsBatchingDimensions() + .size(); +} + +int64_t chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsElem( + MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)) + .getLhsBatchingDimensions()[pos]; +} + +intptr_t chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsSize( + MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getRhsBatchingDimensions() + .size(); +} + +int64_t chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsElem( + MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)) + .getRhsBatchingDimensions()[pos]; +} + +intptr_t chloRaggedDotDimensionNumbersGetLhsContractingDimensionsSize( + MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getLhsContractingDimensions() + .size(); +} + +int64_t chloRaggedDotDimensionNumbersGetLhsContractingDimensionsElem( + MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)) + .getLhsContractingDimensions()[pos]; +} + +intptr_t chloRaggedDotDimensionNumbersGetRhsContractingDimensionsSize( + MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getRhsContractingDimensions() + .size(); +} + +int64_t chloRaggedDotDimensionNumbersGetRhsContractingDimensionsElem( + MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)) + .getRhsContractingDimensions()[pos]; +} + +intptr_t chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsSize( + MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getLhsRaggedDimensions() + .size(); +} + +int64_t chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsElem( + MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)) + .getLhsRaggedDimensions()[pos]; +} + +intptr_t chloRaggedDotDimensionNumbersGetRhsGroupDimensionsSize( + MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getRhsGroupDimensions() + .size(); +} + +int64_t chloRaggedDotDimensionNumbersGetRhsGroupDimensionsElem( + MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)) + .getRhsGroupDimensions()[pos]; +} + +//===----------------------------------------------------------------------===// +// PrecisionAttr +//===----------------------------------------------------------------------===// + +MlirAttribute chloPrecisionAttrGet(MlirContext ctx, MlirStringRef value) { + std::optional precision = + mlir::chlo::symbolizePrecision(unwrap(value)); + if (!precision) llvm::report_fatal_error("Invalid value."); + return wrap(mlir::chlo::PrecisionAttr::get(unwrap(ctx), precision.value())); +} + +bool chloAttributeIsAPrecisionAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirStringRef chloPrecisionAttrGetValue(MlirAttribute attr) { + return wrap(mlir::chlo::stringifyPrecision( + llvm::cast(unwrap(attr)).getValue())); +} diff --git a/stablehlo/integrations/c/ChloAttributes.h b/stablehlo/integrations/c/ChloAttributes.h index 7155bfd53c..ffd12733a2 100644 --- a/stablehlo/integrations/c/ChloAttributes.h +++ b/stablehlo/integrations/c/ChloAttributes.h @@ -47,6 +47,70 @@ MLIR_CAPI_EXPORTED bool chloAttributeIsAComparisonTypeAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirStringRef chloComparisonTypeAttrGetValue(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// RaggedDotDimensionNumbers +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute chloRaggedDotDimensionNumbersGet( + MlirContext ctx, // + intptr_t nLhsBatchingDimensions, const int64_t *lhsBatchingDimensions, // + intptr_t nRhsBatchingDimensions, const int64_t *rhsBatchingDimensions, // + intptr_t nLhsContractingDimensions, // + const int64_t *lhsContractingDimensions, // + intptr_t nRhsContractingDimensions, // + const int64_t *rhsContractingDimensions, // + intptr_t nLhsRaggedDimensions, // + const int64_t *lhsRaggedDimensions, // + intptr_t nRhsGroupDimensions, // + const int64_t *rhsGroupDimensions); + +MLIR_CAPI_EXPORTED bool chloAttributeIsARaggedDotDimensionNumbers( + MlirAttribute attr); + +MLIR_CAPI_EXPORTED intptr_t +chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsSize(MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t +chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsElem(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED intptr_t +chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsSize(MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t +chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsElem(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED intptr_t +chloRaggedDotDimensionNumbersGetLhsContractingDimensionsSize( + MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t +chloRaggedDotDimensionNumbersGetLhsContractingDimensionsElem(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED intptr_t +chloRaggedDotDimensionNumbersGetRhsContractingDimensionsSize( + MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t +chloRaggedDotDimensionNumbersGetRhsContractingDimensionsElem(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED intptr_t +chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsSize(MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t +chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsElem(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED intptr_t +chloRaggedDotDimensionNumbersGetRhsGroupDimensionsSize(MlirAttribute attr); +MLIR_CAPI_EXPORTED int64_t +chloRaggedDotDimensionNumbersGetRhsGroupDimensionsElem(MlirAttribute attr, + intptr_t pos); + +//===----------------------------------------------------------------------===// +// PrecisionAttr +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute chloPrecisionAttrGet(MlirContext ctx, + MlirStringRef value); + +MLIR_CAPI_EXPORTED bool chloAttributeIsAPrecisionAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirStringRef chloPrecisionAttrGetValue(MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/stablehlo/integrations/python/ChloModule.cpp b/stablehlo/integrations/python/ChloModule.cpp index 68e6743b8b..8720286637 100644 --- a/stablehlo/integrations/python/ChloModule.cpp +++ b/stablehlo/integrations/python/ChloModule.cpp @@ -22,6 +22,20 @@ namespace nb = nanobind; namespace { +// Returns a vector containing integers extracted from an attribute using the +// two provided callbacks. +std::vector attributePropertyVector( + MlirAttribute attr, llvm::function_ref sizeFn, + llvm::function_ref getFn) { + std::vector result; + intptr_t size = sizeFn(attr); + result.reserve(size); + for (intptr_t i = 0; i < size; ++i) { + result.push_back(getFn(attr, i)); + } + return result; +} + auto toPyString(MlirStringRef mlirStringRef) { return nb::str(mlirStringRef.data, mlirStringRef.length); } @@ -79,4 +93,88 @@ NB_MODULE(_chlo, m) { .def_property_readonly("value", [](MlirAttribute self) { return toPyString(chloComparisonTypeAttrGetValue(self)); }); + + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + m, "RaggedDotDimensionNumbers", chloAttributeIsARaggedDotDimensionNumbers) + .def_classmethod( + "get", + [](nb::object cls, const std::vector &lhsBatchingDims, + const std::vector &rhsBatchingDims, + const std::vector &lhsContractingDims, + const std::vector &rhsContractingDims, + const std::vector &lhsRaggedDims, + const std::vector &rhsGroupDims, MlirContext ctx) { + return cls(chloRaggedDotDimensionNumbersGet( + ctx, lhsBatchingDims.size(), lhsBatchingDims.data(), + rhsBatchingDims.size(), rhsBatchingDims.data(), + lhsContractingDims.size(), lhsContractingDims.data(), + rhsContractingDims.size(), rhsContractingDims.data(), + lhsRaggedDims.size(), lhsRaggedDims.data(), rhsGroupDims.size(), + rhsGroupDims.data())); + }, + nb::arg("cls"), nb::arg("lhs_batching_dimensions"), + nb::arg("rhs_batching_dimensions"), + nb::arg("lhs_contracting_dimensions"), + nb::arg("rhs_contracting_dimensions"), + nb::arg("lhs_ragged_dimensions"), nb::arg("rhs_group_dimensions"), + nb::arg("context").none() = nb::none(), + "Creates a RaggedDotDimensionNumbers attribute with the given " + "dimension configuration.") + .def_property_readonly( + "lhs_batching_dimensions", + [](MlirAttribute self) { + return attributePropertyVector( + self, chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsSize, + chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsElem); + }) + .def_property_readonly( + "rhs_batching_dimensions", + [](MlirAttribute self) { + return attributePropertyVector( + self, chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsSize, + chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsElem); + }) + .def_property_readonly( + "lhs_contracting_dimensions", + [](MlirAttribute self) { + return attributePropertyVector( + self, + chloRaggedDotDimensionNumbersGetLhsContractingDimensionsSize, + chloRaggedDotDimensionNumbersGetLhsContractingDimensionsElem); + }) + .def_property_readonly( + "rhs_contracting_dimensions", + [](MlirAttribute self) { + return attributePropertyVector( + self, + chloRaggedDotDimensionNumbersGetRhsContractingDimensionsSize, + chloRaggedDotDimensionNumbersGetRhsContractingDimensionsElem); + }) + .def_property_readonly( + "lhs_ragged_dimensions", + [](MlirAttribute self) { + return attributePropertyVector( + self, chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsSize, + chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsElem); + }) + .def_property_readonly("rhs_group_dimensions", [](MlirAttribute self) { + return attributePropertyVector( + self, chloRaggedDotDimensionNumbersGetRhsGroupDimensionsSize, + chloRaggedDotDimensionNumbersGetRhsGroupDimensionsElem); + }); + + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + m, "PrecisionAttr", chloAttributeIsAPrecisionAttr) + .def_classmethod( + "get", + [](nb::object cls, const std::string &value, MlirContext ctx) { + return cls(chloPrecisionAttrGet( + ctx, mlirStringRefCreate(value.c_str(), value.size()))); + }, + nb::arg("cls"), nb::arg("value"), + nb::arg("context").none() = nb::none(), + "Creates a Precision attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(chloPrecisionAttrGetValue(self)); + }); } diff --git a/stablehlo/integrations/python/tests/chlo.py b/stablehlo/integrations/python/tests/chlo.py index 1d4cbd6cd2..de3ab5ee88 100644 --- a/stablehlo/integrations/python/tests/chlo.py +++ b/stablehlo/integrations/python/tests/chlo.py @@ -42,3 +42,30 @@ def test_comparison_type_attr(): assert attr is not None assert str(attr) == ("#chlo") assert attr.value == "FLOAT" + + +@run +def test_ragged_dot_dimension_numbers(): + attr = chlo.RaggedDotDimensionNumbers.get( + lhs_batching_dimensions=[0], + rhs_batching_dimensions=[1], + lhs_contracting_dimensions=[2], + rhs_contracting_dimensions=[2], + lhs_ragged_dimensions=[1], + rhs_group_dimensions=[0], + ) + assert attr is not None + assert str(attr) == ( + "#chlo.ragged_dot" + "lhs_ragged_dimensions = [1], " + "rhs_group_dimensions = [0]>" + ) + assert attr.lhs_batching_dimensions == [0] + assert attr.rhs_batching_dimensions == [1] + assert attr.lhs_contracting_dimensions == [2] + assert attr.rhs_contracting_dimensions == [2] + assert attr.lhs_ragged_dimensions == [1] + assert attr.rhs_group_dimensions == [0]