diff --git a/examples/oneflow2onnx/nodes/CPU/test_fill.py b/examples/oneflow2onnx/nodes/CPU/test_fill.py new file mode 100644 index 0000000..c4236a6 --- /dev/null +++ b/examples/oneflow2onnx/nodes/CPU/test_fill.py @@ -0,0 +1,50 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import tempfile +import oneflow as flow +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check + + +class FillModule(flow.nn.Module): + def __init__(self) -> None: + super(FillModule, self).__init__() + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return x.fill_(5) + + +m = FillModule() + + +class fillOpGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = m + + def build(self, x): + out = self.m(x) + return out + + +def test_fill(): + + fill_graph = fillOpGraph() + fill_graph._compile(flow.randn(1, 5)) + + convert_to_onnx_and_check(fill_graph, onnx_model_path="/tmp") + + +test_fill() diff --git a/examples/oneflow2onnx/nodes/GPU/test_fill.py b/examples/oneflow2onnx/nodes/GPU/test_fill.py new file mode 100644 index 0000000..616968d --- /dev/null +++ b/examples/oneflow2onnx/nodes/GPU/test_fill.py @@ -0,0 +1,50 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import tempfile +import oneflow as flow +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check + + +class FillModule(flow.nn.Module): + def __init__(self) -> None: + super(FillModule, self).__init__() + + def forward(self, x: flow.Tensor) -> flow.Tensor: + return x.fill_(5) + + +m = FillModule().to("cuda") + + +class fillOpGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = m + + def build(self, x): + out = self.m(x) + return out + + +def test_fill(): + + fill_graph = fillOpGraph() + fill_graph._compile(flow.randn(1, 5).to("cuda")) + + convert_to_onnx_and_check(fill_graph, onnx_model_path="/tmp", device="gpu") + + +test_fill() diff --git a/oneflow_onnx/oneflow2onnx/handlers/math.py b/oneflow_onnx/oneflow2onnx/handlers/math.py index 62fd438..62b48e4 100644 --- a/oneflow_onnx/oneflow2onnx/handlers/math.py +++ b/oneflow_onnx/oneflow2onnx/handlers/math.py @@ -739,3 +739,32 @@ def Version_13(cls, ctx, node, **kwargs): var_node = ctx.MakeNode( "ReduceMean", [sqr_sub], op_name_scope=node.name, name="var", dtypes=dtypes, attr={"axes": origin_dim, "keepdims": keepdim_mean}, outputs=[node.output_tensor_names[0]] ) + + +@flow_op("fill_", onnx_op="Constant") +class Fill: + @classmethod + def Version_1(cls, ctx, node, **kwargs): + is_floating_value = node.attrs["is_floating_value"] + output_name = node.output_tensor_names[0] + out_shape = ctx.get_shape(output_name) + + if is_floating_value: + values = np.full(shape=out_shape, fill_value=node.attrs["floating_value"], dtype=np.float32) + else: + values = np.full(shape=out_shape, fill_value=node.attrs["integral_value"], dtype=np.float32) + + ctx.RemoveNode(node.name) + ctx.MakeConst(output_name, values) + + @classmethod + def Version_9(cls, ctx, node, **kwargs): + cls.Version_1(ctx, node, **kwargs) + + @classmethod + def Version_11(cls, ctx, node, **kwargs): + cls.Version_1(ctx, node, **kwargs) + + @classmethod + def Version_13(cls, ctx, node, **kwargs): + cls.Version_1(ctx, node, **kwargs)