Skip to content

Commit

Permalink
Share StableHLO/MHLO pretty printers for ReduceOp and WhileOp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609037997
  • Loading branch information
GleasonK authored and TensorFlow MLIR Team committed Feb 21, 2024
1 parent 7a21ae5 commit f0fa7c1
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 779 deletions.
373 changes: 8 additions & 365 deletions mhlo/IR/hlo_ops.cc

Large diffs are not rendered by default.

407 changes: 407 additions & 0 deletions stablehlo/stablehlo/dialect/AssemblyFormat.cpp

Large diffs are not rendered by default.

37 changes: 32 additions & 5 deletions stablehlo/stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,25 @@ limitations under the License.
#ifndef STABLEHLO_DIALECT_ASSEMBLYFORMAT_H
#define STABLEHLO_DIALECT_ASSEMBLYFORMAT_H

#include <cstdint>
#include <functional>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/Base.h"

Expand Down Expand Up @@ -154,6 +160,15 @@ void printComplexOpType(OpAsmPrinter& p, Operation* op, ShapedType lhs,
ParseResult parseComplexOpType(OpAsmParser& parser, Type& lhs, Type& rhs,
Type& result);

// Print reduce with or without compact printing
void printReduceOp(OpAsmPrinter& p, Operation* op, ValueRange inputs,
ArrayRef<int64_t> dimensions, Region& body);

// Parse reduce with or without compact parsing
ParseResult parseReduceOp(
OpAsmParser& parser, OperationState& result,
std::function<Attribute(OpBuilder&, ArrayRef<int64_t>)> createDimensions);

// SelectOpType - only print the condition and result type when branch types
// match the result type.
//
Expand All @@ -170,15 +185,27 @@ void printSelectOpType(OpAsmPrinter& p, Operation* op, ShapedType pred,
ParseResult parseSelectOpType(OpAsmParser& parser, Type& pred, Type& onTrue,
Type& onFalse, Type& result);

// Print a `while` op.
//
// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict
// `cond` region
// `do` region
// assignment-list ::= assignment | assignment `,` assignment-list
// assignment ::= ssa-value `=` ssa-value
void printWhileOp(OpAsmPrinter& p, Operation* op, Region& cond, Region& body);

// Parse reduce with or without compact parsing
ParseResult parseWhileOp(OpAsmParser& parser, OperationState& result);

//===----------------------------------------------------------------------===//
// Attribute Printers and Parsers
//===----------------------------------------------------------------------===//

// SliceRanges - Used to print multi-dimensional ranges for slice.
void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
ArrayRef<int64_t> limitIndices,
ArrayRef<int64_t> strides);
llvm::ArrayRef<int64_t> startIndices,
llvm::ArrayRef<int64_t> limitIndices,
llvm::ArrayRef<int64_t> strides);

ParseResult parseSliceRanges(OpAsmParser& parser,
DenseI64ArrayAttr& startIndices,
Expand Down
Loading

0 comments on commit f0fa7c1

Please sign in to comment.