Skip to content

Commit

Permalink
Rewrite dtype functions for softmax ops in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Mar 29, 2023
1 parent d3a49fd commit fd8d8af
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 60 deletions.
46 changes: 46 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9821,6 +9821,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %6 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %6 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log_softmax.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
71 changes: 11 additions & 60 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,6 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis<
template <typename OpTy>
void visitAtenCatLikeOp(OpTy op, ArrayRef<const ValueState *> operands);

template <typename OpTy>
void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy>
void visitAten_SoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands);

void visitNumToTensorOp(PrimNumToTensorScalarOp op);
void visitBinaryScalarOp(Operation *op,
ArrayRef<const ValueState *> operands);
Expand Down Expand Up @@ -619,19 +614,17 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
AtenTanhBackwardOp, AtenHardtanhBackwardOp,
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
AtenThresholdOp, AtenSquareOp, AtenUniformOp, AtenBernoulliOp,
AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp,
AtenBernoulliPOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp,
AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenNegOp, AtenFloorOp, AtenDropoutOp, AtenTanhBackwardOp,
AtenHardtanhBackwardOp, AtenAddIntOp, AtenAbsOp, AtenThresholdOp,
AtenSquareOp, AtenUniformOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
AtenBernoulliTensorOp, ValsemVariantAtenBernoulliFloatOp,
AtenBernoulliTensorOp, AtenBernoulliPOp, AtenFillScalarOp,
AtenHardsigmoidOp, AtenCloneOp, AtenHardswishOp, AtenSiluOp,
AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp, AtenAvgPool2dOp,
AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp, AtenSqueezeOp,
AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp,
AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp,
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
AtenScatterReduceTwoOp, AtenSliceScatterOp, AtenGatherOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
Expand Down Expand Up @@ -905,21 +898,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
return;
}
if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
visitAten_SoftmaxLikeOp(_softmaxOp, operands);
return;
} else if (auto _logSoftmaxOp = dyn_cast<Aten_LogSoftmaxOp>(op)) {
visitAten_SoftmaxLikeOp(_logSoftmaxOp, operands);
return;
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
return;
}

if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
visitNumToTensorOp(numToTensorOp);
return;
Expand Down Expand Up @@ -1262,33 +1240,6 @@ void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
incorporateKnowledge(op.getResult(), knowledge);
}

// Common template for softmax like ops, eg., log_softmax.
template <typename OpTy>
void TypeAnalysis::visitAtenSoftmaxLikeOp(
OpTy op, ArrayRef<const ValueState *> operands) {
auto input = operands[0]->getValue();
auto dtype = op.getDtype();
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
incorporateKnowledge(op.getResult(), knowledge);
}

// Common template for softmax like ops, eg., log_softmax.(underscore variant)
template <typename OpTy>
void TypeAnalysis::visitAten_SoftmaxLikeOp(
OpTy op, ArrayRef<const ValueState *> operands) {
auto input = operands[0]->getValue();
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
bool halfToFloat;
if (matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) {
knowledge.dtype =
halfToFloat ? Float32Type::get(op->getContext()) : input.dtype;
}
incorporateKnowledge(op.getResult(), knowledge);
}

void TypeAnalysis::visitAtenScalarImplicitOp(
AtenScalarImplicitOp op, ArrayRef<const ValueState *> operands) {
auto knowledge =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,24 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
assert is_float_dtype(input_dtype)
return input_dtype, input_dtype, input_dtype

def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int:
return self_rank_dtype[1]

def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
return self_rank_dtype[1] if dtype is None else int(dtype)

def aten〇_log_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int:
return self_rank_dtype[1]

def aten〇log_softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
return self_rank_dtype[1] if dtype is None else int(dtype)

def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int:
return grad_output_rank_dtype[1]

def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int:
return grad_output_rank_dtype[1]

# ==============================================================================
# Main
# ==============================================================================
Expand Down

0 comments on commit fd8d8af

Please sign in to comment.