Skip to content

Commit

Permalink
Add aten.scatter_reduce op definition (#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
cetiniz authored Feb 7, 2023
1 parent 3ebe5a5 commit 2a4a61f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
55 changes: 55 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7638,6 +7638,61 @@ def Torch_AtenScatterAdd_Op : Torch_Op<"aten.scatter_add_", [
}];
}

def Torch_AtenScatterReduceTwoOp : Torch_Op<"aten.scatter_reduce.two", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce,
Torch_BoolType:$include_self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduceTwoOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenScatterReduceTwoOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenScatterReduce_TwoOp : Torch_Op<"aten.scatter_reduce_.two", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_reduce_.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce,
Torch_BoolType:$include_self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduce_TwoOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenScatterReduce_TwoOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::cpu : (Tensor) -> (Tensor)")
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)")
emit("aten::IntImplicit : (Tensor) -> (int)")
emit("aten::FloatImplicit : (Tensor) -> (float)")
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
Expand Down

0 comments on commit 2a4a61f

Please sign in to comment.