Skip to content

Commit

Permalink
Add dtype functions for aten.ge.Tensor and aten.le.Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ramiro050 committed May 5, 2023
1 parent f58f9c2 commit 899d8bc
Show file tree
Hide file tree
Showing 19 changed files with 4,139 additions and 2,560 deletions.
2 changes: 1 addition & 1 deletion build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
python -m e2e_testing.main --config=torchdynamo -v
}

function setup_venv() {
Expand Down
8 changes: 2 additions & 6 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
# ERROR: assert isinstance(e, FakeTensor)
"RsubInt0d_NumToTensor_Module_basic",

# Dtype function transition failures
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
}

STABLEHLO_PASS_SET = {
Expand Down Expand Up @@ -706,7 +701,8 @@
"FullLikeModuleInt2DStatic_basic",
"FullModuleInt3D_basic",
"FullModuleFloat2D_basic",
"RepeatModule_basic"
"RepeatModule_basic",
"ResNet18StaticModule_basic",
}

LTC_XFAIL_SET = {
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TorchToLinalg/TensorConstructors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ class ConvertAtenEmptyMemoryFormatOp
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) {
resultElementType = resultType.getElementType();
resultElementType = getDefaultDtypeForTorchScalar(
Torch::FloatType::get(op->getContext()));
} else {
int64_t dtypeInt;
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
Expand Down
2,794 changes: 1,981 additions & 813 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3716,8 +3716,14 @@ class DecomposeAtenRandnGeneratorOp
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resultType = op.getType();
auto resultType = op.getType().cast<BaseTensorType>();

if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}

Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
Value none = rewriter.create<ConstantNoneOp>(loc);
Value low = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)0.0));
Expand All @@ -3729,11 +3735,13 @@ class DecomposeAtenRandnGeneratorOp
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));

Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
loc, resultType, op.getSize(), /*dtype=*/dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
loc, resultType, op.getSize(), /*dtype=*/dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);

Expand Down
Loading

0 comments on commit 899d8bc

Please sign in to comment.