Skip to content

Commit

Permalink
Broke some internal tests
Browse files Browse the repository at this point in the history
Reverts e3bcd73

PiperOrigin-RevId: 657786386
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Jul 31, 2024
1 parent 2cea900 commit 9bb1871
Show file tree
Hide file tree
Showing 10 changed files with 1 addition and 228 deletions.
1 change: 1 addition & 0 deletions xla/translate/hlo_to_mhlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ cc_library(
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
55 changes: 0 additions & 55 deletions xla/translate/hlo_to_mhlo/hlo_function_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ limitations under the License.
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
Expand Down Expand Up @@ -2501,60 +2500,6 @@ absl::Status HloFunctionImporter::ConvertShapeToMlirLayout(
return Internal("Couldn't convert layout.");
}

// std::string FrontendAttributesToString(
// const FrontendAttributes& frontend_attributes) {
// std::vector<std::pair<std::string, std::string>> sorted_attributes(
// frontend_attributes.map().begin(), frontend_attributes.map().end());
// absl::c_sort(sorted_attributes);
// const auto formatter = [](std::string* out,
// const std::pair<std::string, std::string>& item)
// {
// if (LexesAsJsonDict(item.second)) {
// absl::StrAppend(out, item.first, "=", item.second);
// } else {
// absl::StrAppend(out, item.first, "=\"", item.second, "\"");
// }
// };
// return absl::StrFormat("{%s}",
// absl::StrJoin(sorted_attributes, ",", formatter));
// }

mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias,
mlir::Builder* builder) {
llvm::SmallVector<mlir::Attribute> element_attrs;
alias.ForEachAlias([&](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
std::string kindToString;
switch (alias.kind) {
case HloInputOutputAliasConfig::AliasKind::kMayAlias:
kindToString = "may_alias";
break;
case HloInputOutputAliasConfig::AliasKind::kMustAlias:
kindToString = "must_alias";
break;
default:
kindToString = "undefined_alias";
}
mlir::NamedAttribute alias_named_attributes[3] = {
builder->getNamedAttr(
"parameter_index",
builder->getDenseI64ArrayAttr(ArrayRef<int64_t>(
alias.parameter_index.begin(), alias.parameter_index.end()))),
builder->getNamedAttr("parameter_number", builder->getI64IntegerAttr(
alias.parameter_number)),
builder->getNamedAttr("kind", builder->getStringAttr(kindToString))};

mlir::NamedAttribute named_attributes[2] = {
builder->getNamedAttr("output_index",
builder->getDenseI64ArrayAttr(ArrayRef<int64_t>(
output_index.begin(), output_index.end()))),
builder->getNamedAttr(
"alias", builder->getDictionaryAttr(alias_named_attributes))};
element_attrs.push_back(builder->getDictionaryAttr(named_attributes));
});
return builder->getArrayAttr(element_attrs);
}

mlir::Attribute ConvertSharding(const HloSharding& sharding,
mlir::Builder* builder) {
return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true));
Expand Down
7 changes: 0 additions & 7 deletions xla/translate/hlo_to_mhlo/hlo_function_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/hlo.pb.h"
Expand Down Expand Up @@ -298,12 +297,6 @@ class HloFunctionImporter {
bool flatten_computation_args_result_;
};

// Returns a StringAttr that carries a prettyprinted representation of the
// given HLO C++ input_output_alias_config.
// Always succeeds and returns a non-empty attribute.
mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias,
mlir::Builder* builder);

// Returns a StringAttr that carries a prettyprinted representation of the
// given HLO C++ sharding.
// Always succeeds and returns a non-empty attribute.
Expand Down
4 changes: 0 additions & 4 deletions xla/translate/hlo_to_mhlo/hlo_module_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) {
ConvertSharding(hlo_module.spmd_output_sharding(), &builder_));
}

module->setAttr("mhlo.input_output_alias",
ConvertInputOutputAlias(
hlo_module.input_output_alias_config(), &builder_));

if (hlo_module.has_spmd_parameters_shardings()) {
llvm::SmallVector<mlir::Attribute> parameter_shardings;
parameter_shardings.reserve(hlo_module.spmd_parameters_shardings().size());
Expand Down
13 changes: 0 additions & 13 deletions xla/translate/hlo_to_mhlo/tests/module_attributes.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,6 @@
# FLATTEN-CHECK-LABEL: module @main attributes {
hlo_module {
name: "main"
input_output_alias {
entries {
output_shape_index: 0
parameter_number: 0
kind: MAY_ALIAS
}
entries {
output_shape_index: 1
parameter_number: 1
kind: MAY_ALIAS
}
}
entry_computation_name: "main.5"
computations {
name: "main.5"
Expand Down Expand Up @@ -229,7 +217,6 @@ hlo_module {
value: "attr_value"
}
}
# CHECK-SAME: mhlo.input_output_alias = [{alias = {kind = "may_alias", parameter_index = array<i64>, parameter_number = 0 : i64}, output_index = array<i64: 0>}, {alias = {kind = "may_alias", parameter_index = array<i64>, parameter_number = 1 : i64}, output_index = array<i64: 1>}]
# CHECK-SAME: mhlo.is_dynamic = true
is_dynamic: true
# CHECK-SAME: mhlo.use_auto_spmd_partitioning = true
Expand Down
1 change: 0 additions & 1 deletion xla/translate/mhlo_to_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/mlir_hlo",
"//xla/service:hlo_parser",
"//xla/service:hlo_proto_cc",
Expand Down
95 changes: 0 additions & 95 deletions xla/translate/mhlo_to_hlo/attribute_exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,99 +185,4 @@ std::optional<xla::OpSharding> ConvertSharding(llvm::StringRef sharding) {
return std::nullopt;
}

std::optional<xla::HloInputOutputAliasProto> ConvertInputOutputAlias(
llvm::ArrayRef<mlir::Attribute> aliasing) {
if (aliasing.empty()) return std::nullopt;

xla::HloInputOutputAliasProto input_output_alias_proto;
for (auto attr : aliasing) {
auto entry_attr = mlir::cast<mlir::DictionaryAttr>(attr);
auto alias_attr = mlir::cast<mlir::DictionaryAttr>(entry_attr.get("alias"));
mlir::ArrayRef<int64_t> output_index =
mlir::cast<mlir::DenseI64ArrayAttr>(entry_attr.get("output_index"))
.asArrayRef();
mlir::ArrayRef<int64_t> parameter_index =
mlir::cast<mlir::DenseI64ArrayAttr>(alias_attr.get("parameter_index"))
.asArrayRef();
HloInputOutputAliasProto::AliasEntryProto entry;
entry.mutable_output_shape_index()->Add(output_index.begin(),
output_index.end());
entry.set_parameter_number(
mlir::cast<mlir::IntegerAttr>(alias_attr.get("parameter_number"))
.getInt());
entry.mutable_parameter_shape_index()->Add(parameter_index.begin(),
parameter_index.end());
mlir::StringRef kind =
mlir::cast<mlir::StringAttr>(alias_attr.get("kind")).getValue();
if (kind == "may_alias")
entry.set_kind(xla::Kind::MAY_ALIAS);
else if (kind == "must_alias")
entry.set_kind(xla::Kind::MUST_ALIAS);
else
entry.set_kind(xla::Kind::UNDEFINED_ALIAS);
input_output_alias_proto.add_entries()->Swap(&entry);
}
return input_output_alias_proto;
}

DotDimensionNumbers ConvertDotDimensionNumbers(
mlir::mhlo::DotDimensionNumbersAttr input) {
DotDimensionNumbers output;

for (auto v : input.getLhsBatchingDimensions()) {
output.add_lhs_batch_dimensions(v);
}

for (auto v : input.getRhsBatchingDimensions()) {
output.add_rhs_batch_dimensions(v);
}

for (auto v : input.getLhsContractingDimensions()) {
output.add_lhs_contracting_dimensions(v);
}

for (auto v : input.getRhsContractingDimensions()) {
output.add_rhs_contracting_dimensions(v);
}

return output;
}

DotDimensionNumbers ConvertDotDimensionNumbers(
absl::Span<const int64_t> lhs_batch, absl::Span<const int64_t> lhs_contract,
absl::Span<const int64_t> rhs_batch,
absl::Span<const int64_t> rhs_contract) {
DotDimensionNumbers output;
for (auto v : lhs_batch) {
output.add_lhs_batch_dimensions(v);
}

for (auto v : rhs_batch) {
output.add_rhs_batch_dimensions(v);
}

for (auto v : lhs_contract) {
output.add_lhs_contracting_dimensions(v);
}

for (auto v : rhs_contract) {
output.add_rhs_contracting_dimensions(v);
}

return output;
}

absl::StatusOr<std::vector<int64_t>> ConvertMlirArrayAttrToInt64Array(
const mlir::ArrayAttr& array) {
int rank = array.size();
std::vector<int64_t> converted_array(rank);
for (int i = 0; i < rank; i++) {
mlir::IntegerAttr attr = mlir::dyn_cast<mlir::IntegerAttr>(array[i]);
if (!attr) {
return Internal("Type Error: Expected layout integer attribute");
}
converted_array[i] = attr.getInt();
}
return converted_array;
}
} // namespace xla
4 changes: 0 additions & 4 deletions xla/translate/mhlo_to_hlo/attribute_exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.

#include "absl/status/statusor.h"
#include "mlir/IR/Attributes.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape_util.h"
Expand Down Expand Up @@ -60,8 +59,5 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr);
// Will fail if both attempts at parsing failed.
std::optional<xla::OpSharding> ConvertSharding(mlir::StringRef sharding);

std::optional<xla::HloInputOutputAliasProto> ConvertInputOutputAlias(
llvm::ArrayRef<mlir::Attribute> aliasing);

} // namespace xla
#endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_
7 changes: 0 additions & 7 deletions xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3736,13 +3736,6 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module,
*hlo_module.mutable_spmd_output_sharding() =
*xla::ConvertSharding(spmd_output_sharding.getValue());
}
if (auto input_output_alias =
module->getAttrOfType<mlir::ArrayAttr>("mhlo.input_output_alias")) {
if (std::optional<xla::HloInputOutputAliasProto> input_output_alias_proto =
xla::ConvertInputOutputAlias(input_output_alias.getValue())) {
*hlo_module.mutable_input_output_alias() = *input_output_alias_proto;
}
}
if (auto spmd_parameters_sharding = module->getAttrOfType<mlir::ArrayAttr>(
"mhlo.spmd_parameters_shardings")) {
for (const auto& sharding : spmd_parameters_sharding.getValue()) {
Expand Down
42 changes: 0 additions & 42 deletions xla/translate/mhlo_to_hlo/tests/module_attributes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,45 +100,3 @@ module @ModuleWithFrontendAttributes attributes {
func.return %arg0 : tensor<1xf32>
}
}



// -----

module attributes {
// CHECK: input_output_alias {
// CHECK-NEXT: entries {
// CHECK-NEXT: output_shape_index: 0
// CHECK-NEXT: kind: MAY_ALIAS
// CHECK-NEXT: }
// CHECK-NEXT: entries {
// CHECK-NEXT: output_shape_index: 1
// CHECK-NEXT: parameter_number: 1
// CHECK-NEXT: kind: MAY_ALIAS
// CHECK-NEXT: }
// CHECK-NEXT: }
mhlo.input_output_alias = [
{
alias =
{
kind = "may_alias",
parameter_index = array<i64>,
parameter_number = 0 : i64
},
output_index = array<i64: 0>
},
{
alias =
{
kind = "may_alias",
parameter_index = array<i64>,
parameter_number = 1 : i64
},
output_index = array<i64: 1>
}
]
} {
func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32> ) -> (tensor<1xf32>, tensor<1xf32>) {
func.return %arg0, %arg1: tensor<1xf32>, tensor<1xf32>
}
}

0 comments on commit 9bb1871

Please sign in to comment.