Skip to content

Commit

Permalink
Cast number to float when shape function takes Scalar arg (#1978)
Browse files Browse the repository at this point in the history
To keep things simple in shape functions, `Scalar` inputs are
considered `float`s. This means that when inserting the shape
functions into the IR, we must cast any `!torch.number`s into `float`s
so that the operand type matches the expected type in the shape
function. This commit adds the cast from `Scalar` to `float`.
  • Loading branch information
ramiro050 authored Mar 28, 2023
1 parent 72bb902 commit d803ab4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

// To keep things simple in shape functions, `Scalar` inputs are considered
// `float`s. This is safe since output shape of torch ops never depends on the
// dtype of input scalars. However, this also means we sometimes have to
// manually turn `Scalar`s into `float`s when inserting the shape functions
// into the IR.
if (operandType.isa<Torch::NumberType>() &&
desiredType.isa<Torch::FloatType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
}

// If the operand type is statically !torch.optional, then we need to do
// different things for the None and non-None cases.
// For the None case, we just need to derefine it to the desired type.
Expand Down
11 changes: 11 additions & 0 deletions test/Dialect/Torch/reify-shape-calculations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,14 @@ func.func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.v
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
return %1 : !torch.vtensor
}

// -----

// CHECK-LABEL: func.func @adjust_shape_function_arg$number(
// CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar {{.*}} : !torch.number -> !torch.float
// CHECK: %[[VAL_9:.*]] = func.call @__torch_mlir_shape_fn.aten.arange(%[[FLOAT]], {{.*}}) : (!torch.float, {{.*}}
func.func @adjust_shape_function_arg$number(%arg0: !torch.number) -> !torch.vtensor {
%none = torch.constant.none
%1 = torch.aten.arange %arg0, %none, %none, %none, %none : !torch.number, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %1 : !torch.vtensor
}

0 comments on commit d803ab4

Please sign in to comment.