Skip to content

Commit

Permalink
build: update llvm tag to 2dde4ba (#1229)
Browse files Browse the repository at this point in the history
Summary of changes:
 - Tensor dialect now sets `emitAccessorPrefix` to prefixed, thus
   requring updates to methods that retrieve arguments
   [https://reviews.llvm.org/D131361]
 - Update MHLO to build with LLVM commit hash 2dde4ba
 - Replace `AbsOp` with `AbsFOp` [https://reviews.llvm.org/D131325]
 - Replace deprecated `getValue()` with `value()`
   [https://reviews.llvm.org/D131349]
 - Remove `AnalysisState::defaultInitialize()`
   [https://reviews.llvm.org/D131746]
 - Update MHLO MLIR tests to use the updated assembly format
 - Disabled two failing TOSA tests (Github Issue link:
   #1231)
  • Loading branch information
ashay authored Aug 16, 2022
1 parent 3b3cb99 commit 84d345c
Show file tree
Hide file tree
Showing 18 changed files with 75 additions and 85 deletions.
2 changes: 0 additions & 2 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
"ElementwiseBinaryModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
Expand Down Expand Up @@ -103,7 +102,6 @@
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"PermuteModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<TMTensorOp> {
for (OpOperand *opOperand : op.getInputOperands()) {
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
? tensorCastOp.source()
? tensorCastOp.getSource()
: opOperand->get());
}
// Init tensors may fold, in which case the resultType must also change.
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 3202 files
2 changes: 1 addition & 1 deletion externals/mlir-hlo
Submodule mlir-hlo updated 96 files
+124 −74 BUILD
+0 −1 CMakeLists.txt
+1 −1 README.md
+2 −2 WORKSPACE
+1 −1 build_tools/llvm_version.txt
+1 −0 include/mlir-hlo/Dialect/CMakeLists.txt
+5 −5 include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td
+14 −3 include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td
+2 −1 include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td
+18 −6 include/mlir-hlo/Dialect/gml_st/IR/gml_st_set_ops.td
+0 −3 include/mlir-hlo/Dialect/gml_st/transforms/passes.h
+0 −5 include/mlir-hlo/Dialect/gml_st/transforms/passes.td
+3 −13 include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h
+2 −1 include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
+0 −5 include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h
+178 −10 include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+98 −4 include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.h
+26 −1 include/mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h
+5 −5 include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
+5 −0 include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td
+5 −2 include/mlir-hlo/Dialect/mhlo/transforms/passes.h
+15 −0 include/mlir-hlo/Dialect/thlo/CMakeLists.txt
+22 −0 include/mlir-hlo/Dialect/thlo/IR/CMakeLists.txt
+34 −0 include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h
+50 −10 include/mlir-hlo/Dialect/thlo/IR/thlo_ops.td
+4 −4 lib/Analysis/shape_component_analysis.cc
+13 −11 lib/CAPI/Attributes.cc
+1 −0 lib/Dialect/CMakeLists.txt
+226 −610 lib/Dialect/gml_st/IR/gml_st_ops.cc
+0 −1 lib/Dialect/gml_st/transforms/CMakeLists.txt
+6 −8 lib/Dialect/gml_st/transforms/bufferize_tiled_loop.cc
+9 −1 lib/Dialect/gml_st/transforms/fusion.cc
+2 −2 lib/Dialect/lhlo/IR/lhlo_ops.cc
+2 −2 lib/Dialect/lhlo/transforms/lhlo_legalize_to_parallel_loops.cc
+2 −0 lib/Dialect/mhlo/IR/CMakeLists.txt
+7 −7 lib/Dialect/mhlo/IR/chlo_ops.cc
+63 −18 lib/Dialect/mhlo/IR/hlo_ops.cc
+5 −2 lib/Dialect/mhlo/transforms/CMakeLists.txt
+10 −27 lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc
+1 −1 lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc
+4 −4 lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc
+18 −16 lib/Dialect/mhlo/transforms/legalize_mhlo_to_thlo.cc
+65 −74 lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+2 −2 lib/Dialect/mhlo/transforms/legalize_to_standard.cc
+1 −1 lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc
+15 −0 lib/Dialect/thlo/CMakeLists.txt
+33 −0 lib/Dialect/thlo/IR/CMakeLists.txt
+620 −0 lib/Dialect/thlo/IR/thlo_ops.cc
+1 −0 lib/Transforms/CMakeLists.txt
+3 −3 lib/Transforms/buffer_packing.cc
+1 −1 lib/Transforms/copy_removal.cc
+2 −2 lib/Transforms/gml_st_pipeline.cc
+39 −26 tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir
+117 −0 tests/Dialect/gml_st/canonicalize.mlir
+37 −11 tests/Dialect/gml_st/compose_set_ops.mlir
+29 −11 tests/Dialect/gml_st/fusion.mlir
+9 −9 tests/Dialect/gml_st/invalid.mlir
+0 −35 tests/Dialect/gml_st/ops.mlir
+94 −101 tests/Dialect/gml_st/tiling_and_fusion.mlir
+12 −12 tests/Dialect/lhlo/lhlo-fuse-linalg.mlir
+7 −7 tests/Dialect/lhlo/ops.mlir
+6 −6 tests/Dialect/mhlo/canonicalize/canonicalize.mlir
+15 −15 tests/Dialect/mhlo/canonicalize/convolution.mlir
+4 −4 tests/Dialect/mhlo/canonicalize/reshape.mlir
+2 −2 tests/Dialect/mhlo/canonicalize/transpose.mlir
+2 −2 tests/Dialect/mhlo/expand_hlo_tuples.mlir
+4 −4 tests/Dialect/mhlo/group_reduction_dimensions.mlir
+1 −1 tests/Dialect/mhlo/hlo-legalize-gather-to-torch-index-select.mlir
+2 −2 tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir
+297 −651 tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+1 −1 tests/Dialect/mhlo/hlo-legalize-to-memref.mlir
+1 −1 tests/Dialect/mhlo/legalize-control-flow.mlir
+9 −9 tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir
+15 −15 tests/Dialect/mhlo/lower-general-dot.mlir
+1 −1 tests/Dialect/mhlo/materialize-broadcasts.mlir
+6 −6 tests/Dialect/mhlo/mhlo_canonicalize_reduction.mlir
+53 −0 tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir
+1 −1 tests/Dialect/mhlo/mhlo_reduce_pretty_print.mlir
+44 −0 tests/Dialect/mhlo/ops.mlir
+3 −3 tests/Dialect/mhlo/optimize-hlo.mlir
+7 −7 tests/Dialect/mhlo/restrict_max_rank.mlir
+2 −2 tests/Dialect/mhlo/sparse_gendot_lower.mlir
+2 −2 tests/Dialect/mhlo/sparse_lower.mlir
+2 −2 tests/Dialect/mhlo/sparse_transpose.mlir
+20 −20 tests/Dialect/mhlo/symbolic-shape-optimization.mlir
+1 −1 tests/Dialect/mhlo/unfuse_batch_norm.mlir
+49 −0 tests/Dialect/thlo/invalid.mlir
+63 −0 tests/Dialect/thlo/ops.mlir
+48 −0 tests/end2end/gml_st_broadcast.mlir
+40 −0 tests/end2end/gml_st_concat.mlir
+2 −2 tests/end2end/gml_st_transpose.mlir
+86 −180 tests/gml_st_pipeline.mlir
+1 −1 tests/legalize-trigonometric-to-approximation.mlir
+52 −52 tests/rank-specialization.mlir
+3 −3 tests/shape-component-analysis.mlir
+2 −1 tools/mlir-hlo-opt/mlir-hlo-opt.cc
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
rewriter
.create<tensor::CollapseShapeOp>(loc, intermediateResultType,
castedInput, inputAssociations)
.result();
.getResult();
}

if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) {
Expand All @@ -588,7 +588,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
expandedInput.has_value() ? expandedInput.value()
: castedInput,
outputAssociations)
.result();
.getResult();
}

Value result = collapsedInput.has_value() ? collapsedInput.value()
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
auto abs = b.create<math::AbsOp>(loc, self);
auto abs = b.create<math::AbsFOp>(loc, self);
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenAbsOp>(op))
return b.create<math::AbsOp>(loc, payloadArgs[0]);
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
if (isa<AtenSigmoidOp>(op)) {
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
Expand Down
3 changes: 0 additions & 3 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,9 +1063,6 @@ LogicalResult ConvertAtenOp<ValsemVariantAtenUniformOp>::matchAndRewrite(
op.getLoc(),
rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue));

auto outType = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
rewriter.replaceOpWithNewOp<mhlo::RngOp>(
op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM);
return success();
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2);
Value paddingValue =
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).getValue();
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec);
Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
if (!initValue) return llvm::None;
Value initIndex;
if (mlir::mhlo::kMhloDimSizeBits == 32) {
initIndex =
mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).getValue();
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else {
initIndex =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
}

DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
Expand Down
9 changes: 3 additions & 6 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,16 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {}
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
setSafe();
}

bool isUninitialized() const override {
// We are an optimistic analysis, so we are always default initialized to
// the optimistic "assumed safe" state.
return false;
}

ChangeResult defaultInitialize() override {
// We are an optimistic analysis, so the default state is always "safe".
return setSafe();
}

void print(raw_ostream &os) const override {
os << "InlineGlobalSlotsAnalysisState(" << (isSafe ? "safe" : "unsafe")
<< ")";
Expand Down
12 changes: 6 additions & 6 deletions test/Conversion/TorchToMhlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = mhlo.copy %[[VAL_1]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
Expand Down Expand Up @@ -47,7 +47,7 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> {
// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = "mhlo.reshape"(%[[T2]]) : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[T4]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> {
Expand Down Expand Up @@ -229,16 +229,16 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],
// CHECK: %true = torch.constant.bool true
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64>
// CHECK: %[[VAL_6:.*]] = "mhlo.dynamic_reshape"(%[[VAL_1]], %[[VAL_5]]) : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32>
// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32>
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>)
// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64>
// CHECK: %[[VAL_13:.*]] = "mhlo.dynamic_reshape"(%[[VAL_9]], %[[VAL_12]]) : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_15:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]], %[[VAL_14]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64>
// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_11]], %[[VAL_16]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32>
// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32>
// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32>
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TorchToMhlo/dropout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64
// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64
// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64>
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf64>) -> tensor<f64>
// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor<?x?xf32>) -> tensor<?x?xf64>
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor<?x?xf64>
// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64
Expand All @@ -33,7 +33,7 @@
// CHECK: shape.assuming_yield %[[T19]] : tensor<?x?xf32>
// CHECK: }
// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32>
// CHECK: %[[T21:.*]] = "mhlo.reshape"(%[[T20]]) : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T21:.*]] = mhlo.reshape %[[T20]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor<?x?xf32> -> tensor<2xindex>
// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor<?x?xf32>
Expand All @@ -44,4 +44,4 @@ func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %ar
%bool_true = torch.constant.bool true
%result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>
}
}
Loading

0 comments on commit 84d345c

Please sign in to comment.