Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CHLO CAPI and PythonAPI for ragged dot #2737

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions stablehlo/integrations/c/ChloAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,127 @@ MlirStringRef chloComparisonTypeAttrGetValue(MlirAttribute attr) {
return wrap(mlir::chlo::stringifyComparisonType(
llvm::cast<mlir::chlo::ComparisonTypeAttr>(unwrap(attr)).getValue()));
}

//===----------------------------------------------------------------------===//
// RaggedDotDimensionNumbers
//===----------------------------------------------------------------------===//

MlirAttribute chloRaggedDotDimensionNumbersGet(
MlirContext ctx, intptr_t nLhsBatchingDimensions,
const int64_t *lhsBatchingDimensions, intptr_t nRhsBatchingDimensions,
const int64_t *rhsBatchingDimensions, intptr_t nLhsContractingDimensions,
const int64_t *lhsContractingDimensions, intptr_t nRhsContractingDimensions,
const int64_t *rhsContractingDimensions, intptr_t nLhsRaggedDimensions,
const int64_t *lhsRaggedDimensions, intptr_t nRhsGroupDimensions,
const int64_t *rhsGroupDimensions) {
return wrap(mlir::chlo::RaggedDotDimensionNumbersAttr::get(
unwrap(ctx),
llvm::ArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions),
llvm::ArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions),
llvm::ArrayRef(lhsContractingDimensions, nLhsContractingDimensions),
llvm::ArrayRef(rhsContractingDimensions, nRhsContractingDimensions),
llvm::ArrayRef(lhsRaggedDimensions, nLhsRaggedDimensions),
llvm::ArrayRef(rhsGroupDimensions, nRhsGroupDimensions)));
}

bool chloAttributeIsARaggedDotDimensionNumbers(MlirAttribute attr) {
return llvm::isa<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr));
}

intptr_t chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsSize(
MlirAttribute attr) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getLhsBatchingDimensions()
.size();
}

int64_t chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsElem(
MlirAttribute attr, intptr_t pos) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getLhsBatchingDimensions()[pos];
}

intptr_t chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsSize(
MlirAttribute attr) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getRhsBatchingDimensions()
.size();
}

int64_t chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsElem(
MlirAttribute attr, intptr_t pos) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getRhsBatchingDimensions()[pos];
}

intptr_t chloRaggedDotDimensionNumbersGetLhsContractingDimensionsSize(
MlirAttribute attr) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getLhsContractingDimensions()
.size();
}

int64_t chloRaggedDotDimensionNumbersGetLhsContractingDimensionsElem(
MlirAttribute attr, intptr_t pos) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getLhsContractingDimensions()[pos];
}

intptr_t chloRaggedDotDimensionNumbersGetRhsContractingDimensionsSize(
MlirAttribute attr) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getRhsContractingDimensions()
.size();
}

int64_t chloRaggedDotDimensionNumbersGetRhsContractingDimensionsElem(
MlirAttribute attr, intptr_t pos) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getRhsContractingDimensions()[pos];
}

intptr_t chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsSize(
MlirAttribute attr) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getLhsRaggedDimensions()
.size();
}

int64_t chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsElem(
MlirAttribute attr, intptr_t pos) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getLhsRaggedDimensions()[pos];
}

intptr_t chloRaggedDotDimensionNumbersGetRhsGroupDimensionsSize(
MlirAttribute attr) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getRhsGroupDimensions()
.size();
}

int64_t chloRaggedDotDimensionNumbersGetRhsGroupDimensionsElem(
MlirAttribute attr, intptr_t pos) {
return llvm::cast<mlir::chlo::RaggedDotDimensionNumbersAttr>(unwrap(attr))
.getRhsGroupDimensions()[pos];
}

//===----------------------------------------------------------------------===//
// PrecisionAttr
//===----------------------------------------------------------------------===//

MlirAttribute chloPrecisionAttrGet(MlirContext ctx, MlirStringRef value) {
std::optional<mlir::chlo::Precision> precision =
mlir::chlo::symbolizePrecision(unwrap(value));
if (!precision) llvm::report_fatal_error("Invalid value.");
return wrap(mlir::chlo::PrecisionAttr::get(unwrap(ctx), precision.value()));
}

bool chloAttributeIsAPrecisionAttr(MlirAttribute attr) {
return llvm::isa<mlir::chlo::PrecisionAttr>(unwrap(attr));
}

MlirStringRef chloPrecisionAttrGetValue(MlirAttribute attr) {
return wrap(mlir::chlo::stringifyPrecision(
llvm::cast<mlir::chlo::PrecisionAttr>(unwrap(attr)).getValue()));
}
64 changes: 64 additions & 0 deletions stablehlo/integrations/c/ChloAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,70 @@ MLIR_CAPI_EXPORTED bool chloAttributeIsAComparisonTypeAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
chloComparisonTypeAttrGetValue(MlirAttribute attr);

//===----------------------------------------------------------------------===//
// RaggedDotDimensionNumbers
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute chloRaggedDotDimensionNumbersGet(
MlirContext ctx, //
intptr_t nLhsBatchingDimensions, const int64_t *lhsBatchingDimensions, //
intptr_t nRhsBatchingDimensions, const int64_t *rhsBatchingDimensions, //
intptr_t nLhsContractingDimensions, //
const int64_t *lhsContractingDimensions, //
intptr_t nRhsContractingDimensions, //
const int64_t *rhsContractingDimensions, //
intptr_t nLhsRaggedDimensions, //
const int64_t *lhsRaggedDimensions, //
intptr_t nRhsGroupDimensions, //
const int64_t *rhsGroupDimensions);

MLIR_CAPI_EXPORTED bool chloAttributeIsARaggedDotDimensionNumbers(
MlirAttribute attr);

MLIR_CAPI_EXPORTED intptr_t
chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsSize(MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t
chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsElem(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED intptr_t
chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsSize(MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t
chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsElem(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED intptr_t
chloRaggedDotDimensionNumbersGetLhsContractingDimensionsSize(
MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t
chloRaggedDotDimensionNumbersGetLhsContractingDimensionsElem(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED intptr_t
chloRaggedDotDimensionNumbersGetRhsContractingDimensionsSize(
MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t
chloRaggedDotDimensionNumbersGetRhsContractingDimensionsElem(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED intptr_t
chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsSize(MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t
chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsElem(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED intptr_t
chloRaggedDotDimensionNumbersGetRhsGroupDimensionsSize(MlirAttribute attr);
MLIR_CAPI_EXPORTED int64_t
chloRaggedDotDimensionNumbersGetRhsGroupDimensionsElem(MlirAttribute attr,
intptr_t pos);

//===----------------------------------------------------------------------===//
// PrecisionAttr
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute chloPrecisionAttrGet(MlirContext ctx,
MlirStringRef value);

MLIR_CAPI_EXPORTED bool chloAttributeIsAPrecisionAttr(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirStringRef chloPrecisionAttrGetValue(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
98 changes: 98 additions & 0 deletions stablehlo/integrations/python/ChloModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ namespace nb = nanobind;

namespace {

// Returns a vector containing integers extracted from an attribute using the
// two provided callbacks.
std::vector<int64_t> attributePropertyVector(
MlirAttribute attr, llvm::function_ref<intptr_t(MlirAttribute)> sizeFn,
llvm::function_ref<int64_t(MlirAttribute, intptr_t)> getFn) {
std::vector<int64_t> result;
intptr_t size = sizeFn(attr);
result.reserve(size);
for (intptr_t i = 0; i < size; ++i) {
result.push_back(getFn(attr, i));
}
return result;
}

auto toPyString(MlirStringRef mlirStringRef) {
return nb::str(mlirStringRef.data, mlirStringRef.length);
}
Expand Down Expand Up @@ -79,4 +93,88 @@ NB_MODULE(_chlo, m) {
.def_property_readonly("value", [](MlirAttribute self) {
return toPyString(chloComparisonTypeAttrGetValue(self));
});

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "RaggedDotDimensionNumbers", chloAttributeIsARaggedDotDimensionNumbers)
.def_classmethod(
"get",
[](nb::object cls, const std::vector<int64_t> &lhsBatchingDims,
const std::vector<int64_t> &rhsBatchingDims,
const std::vector<int64_t> &lhsContractingDims,
const std::vector<int64_t> &rhsContractingDims,
const std::vector<int64_t> &lhsRaggedDims,
const std::vector<int64_t> &rhsGroupDims, MlirContext ctx) {
return cls(chloRaggedDotDimensionNumbersGet(
ctx, lhsBatchingDims.size(), lhsBatchingDims.data(),
rhsBatchingDims.size(), rhsBatchingDims.data(),
lhsContractingDims.size(), lhsContractingDims.data(),
rhsContractingDims.size(), rhsContractingDims.data(),
lhsRaggedDims.size(), lhsRaggedDims.data(), rhsGroupDims.size(),
rhsGroupDims.data()));
},
nb::arg("cls"), nb::arg("lhs_batching_dimensions"),
nb::arg("rhs_batching_dimensions"),
nb::arg("lhs_contracting_dimensions"),
nb::arg("rhs_contracting_dimensions"),
nb::arg("lhs_ragged_dimensions"), nb::arg("rhs_group_dimensions"),
nb::arg("context").none() = nb::none(),
"Creates a RaggedDotDimensionNumbers attribute with the given "
"dimension configuration.")
.def_property_readonly(
"lhs_batching_dimensions",
[](MlirAttribute self) {
return attributePropertyVector(
self, chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsSize,
chloRaggedDotDimensionNumbersGetLhsBatchingDimensionsElem);
})
.def_property_readonly(
"rhs_batching_dimensions",
[](MlirAttribute self) {
return attributePropertyVector(
self, chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsSize,
chloRaggedDotDimensionNumbersGetRhsBatchingDimensionsElem);
})
.def_property_readonly(
"lhs_contracting_dimensions",
[](MlirAttribute self) {
return attributePropertyVector(
self,
chloRaggedDotDimensionNumbersGetLhsContractingDimensionsSize,
chloRaggedDotDimensionNumbersGetLhsContractingDimensionsElem);
})
.def_property_readonly(
"rhs_contracting_dimensions",
[](MlirAttribute self) {
return attributePropertyVector(
self,
chloRaggedDotDimensionNumbersGetRhsContractingDimensionsSize,
chloRaggedDotDimensionNumbersGetRhsContractingDimensionsElem);
})
.def_property_readonly(
"lhs_ragged_dimensions",
[](MlirAttribute self) {
return attributePropertyVector(
self, chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsSize,
chloRaggedDotDimensionNumbersGetLhsRaggedDimensionsElem);
})
.def_property_readonly("rhs_group_dimensions", [](MlirAttribute self) {
return attributePropertyVector(
self, chloRaggedDotDimensionNumbersGetRhsGroupDimensionsSize,
chloRaggedDotDimensionNumbersGetRhsGroupDimensionsElem);
});

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "PrecisionAttr", chloAttributeIsAPrecisionAttr)
.def_classmethod(
"get",
[](nb::object cls, const std::string &value, MlirContext ctx) {
return cls(chloPrecisionAttrGet(
ctx, mlirStringRefCreate(value.c_str(), value.size())));
},
nb::arg("cls"), nb::arg("value"),
nb::arg("context").none() = nb::none(),
"Creates a Precision attribute with the given value.")
.def_property_readonly("value", [](MlirAttribute self) {
return toPyString(chloPrecisionAttrGetValue(self));
});
}
27 changes: 27 additions & 0 deletions stablehlo/integrations/python/tests/chlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,30 @@ def test_comparison_type_attr():
assert attr is not None
assert str(attr) == ("#chlo<comparison_type FLOAT>")
assert attr.value == "FLOAT"


@run
def test_ragged_dot_dimension_numbers():
attr = chlo.RaggedDotDimensionNumbers.get(
lhs_batching_dimensions=[0],
rhs_batching_dimensions=[1],
lhs_contracting_dimensions=[2],
rhs_contracting_dimensions=[2],
lhs_ragged_dimensions=[1],
rhs_group_dimensions=[0],
)
assert attr is not None
assert str(attr) == (
"#chlo.ragged_dot<lhs_batching_dimensions = [0], "
"rhs_batching_dimensions = [1], "
"lhs_contracting_dimensions = [2], "
"rhs_contracting_dimensions = [2], >"
"lhs_ragged_dimensions = [1], "
"rhs_group_dimensions = [0]>"
)
assert attr.lhs_batching_dimensions == [0]
assert attr.rhs_batching_dimensions == [1]
assert attr.lhs_contracting_dimensions == [2]
assert attr.rhs_contracting_dimensions == [2]
assert attr.lhs_ragged_dimensions == [1]
assert attr.rhs_group_dimensions == [0]
Loading