Skip to content

Commit

Permalink
Add PrintOp to the Interpreter dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Jan 10, 2025
1 parent 51028d9 commit d8275d2
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
11 changes: 10 additions & 1 deletion stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,21 @@ FailureOr<func::FuncOp> getMainFunction(ModuleOp module, StringRef mainName) {
class DefaultInterpreterFallback : public InterpreterFallback {
public:
DefaultInterpreterFallback(const InterpreterConfiguration &config)
: config(config){};
: config(config) {};

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
llvm::StringRef funcName = op.getParentOfType<func::FuncOp>().getSymName();

if (auto printOp = dyn_cast<stablehlo::interpreter::PrintOp>(op)) {
auto input =
stablehlo::InterpreterValue(scope.findTensor(printOp.getOperand()));
auto result = stablehlo::interpreter::evalPrintOp(printOp, input);
scope.add(printOp.getResult(), input);
return wrapFallbackStatus(llvm::Error::success(), funcName,
"interpreter.print");
}

if (auto probeOp = dyn_cast<stablehlo::interpreter::ProbeOp>(op)) {
auto input =
stablehlo::InterpreterValue(scope.findTensor(probeOp.getOperand()));
Expand Down
14 changes: 14 additions & 0 deletions stablehlo/reference/InterpreterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ SmallVector<InterpreterValue> evalRunParallelOp(
return results;
}

InterpreterValue evalPrintOp(PrintOp &op, InterpreterValue operand) {
std::string ssaValueStr;
llvm::raw_string_ostream stream(ssaValueStr);
stream << op;

// Get the SSA name and print it like: `%0 = `
llvm::outs() << ssaValueStr.substr(0, ssaValueStr.find("=") + 2);

// Prints the tensor value
operand.getTensor().print(llvm::outs());
llvm::outs() << "\n";
return operand;
}

// `serializedProbeFileId` should be a unique positive integer which can be used
// to unambiguously derive a serialized filename for a given `probeId`.
llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
Expand Down
12 changes: 9 additions & 3 deletions stablehlo/reference/InterpreterOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "stablehlo/reference/Value.h"

#define GET_OP_CLASSES
#include "stablehlo/reference/InterpreterOps.h.inc"

namespace mlir {
namespace stablehlo {
namespace interpreter {
Expand All @@ -39,6 +42,12 @@ SmallVector<InterpreterValue> evalRunParallelOp(
ArrayRef<InterpreterValue> inputs, std::queue<StringAttr> &infeed,
SmallVector<SmallVector<StringAttr>> programs, SymbolTable &symbolTable);

// Print the SSA name followed by its type and value like:
// >>> %0 = tensor<i1> {
// ... [true]
// ... }
InterpreterValue evalPrintOp(PrintOp &op, InterpreterValue operand);

llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
StringRef probeOutputDir,
int64_t serializedProbeFileId);
Expand All @@ -47,7 +56,4 @@ llvm::Error evalProbeOp(InterpreterValue input, StringRef probeId,
} // namespace stablehlo
} // namespace mlir

#define GET_OP_CLASSES
#include "stablehlo/reference/InterpreterOps.h.inc"

#endif // STABLEHLO_REFERENCE_INTERPRETEROPS_H
23 changes: 23 additions & 0 deletions stablehlo/reference/InterpreterOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ def Interpreter_RunParallelOp : Op<Interpreter_Dialect, "run_parallel", []> {
let hasVerifier = 1;
}

def Interpreter_PrintOp : Op<Interpreter_Dialect, "print",
[SameOperandsAndResultType]> {
let summary = "Print operation";
let arguments = (ins
HLO_Tensor:$operand
);
let results = (outs HLO_Tensor:$result);
let description = [{
Print the value to stdout.

This is useful to print intermediate states of the tensors while debugging.
This should only be used to debug small tensors since every instance of this
op and its contents are printed to stdout. To gather information in bulk for
larger tensors, prefer using ProbeOp.

Example:
```mlir
%result = interpreter.print %operand : tensor<i1>
```
}];
let assemblyFormat = "$operand attr-dict `:` type($result)";
}

def Interpreter_ProbeOp : Op<Interpreter_Dialect, "probe",
[SameOperandsAndResultType]> {
let arguments = (ins
Expand Down

0 comments on commit d8275d2

Please sign in to comment.