diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index eb3a97ce053c3..4d7d9d8a2d07b 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/softmax.h" #include "paddle/phi/infermeta/spmd_rules/split.h" #include "paddle/phi/infermeta/spmd_rules/transpose.h" +#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" /** * Design Notes: @@ -71,7 +72,7 @@ PD_REGISTER_SPMD_RULE( // default data parallel rule PD_REGISTER_SPMD_RULE( - unsqueeze, + default_data_parallel, PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd), PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse)); PD_REGISTER_SPMD_RULE( @@ -85,6 +86,12 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd), PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse)); +// unsqueeze rule +PD_REGISTER_SPMD_RULE( + unsqueeze, + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd), + PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse)); + // elementwise unary rule PD_REGISTER_SPMD_RULE( assign, diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc new file mode 100644 index 0000000000000..6af4210f92d80 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights resized. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" +#include +#include + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +std::vector MakeUnsqueezeDimTrans( + const std::vector& x_shape, + std::vector* out_shape, + const std::vector& axis) { + int64_t n = static_cast(x_shape.size() + axis.size()); + std::vector ret; + ret.resize(n); + out_shape->resize(n); + fill(ret.begin(), ret.end(), new Singleton()); + fill(out_shape->begin(), out_shape->end(), 1); + + for (int64_t i = 0, j = 0; i < n; i++) { + auto it = find(axis.begin(), axis.end(), i); + + if (it == axis.end()) { + if (x_shape[j] != 1) { + ret[i] = new InputDim(j); + (*out_shape)[i] = x_shape[j]; + } + + j++; + } + } + + return ret; +} + +std::vector MakeUnsqueezeDimTransReverse( + const std::vector& out_shape, + const std::vector& axis, + const int& x_ndim, + const int& out_ndim) { + std::vector ret; + ret.resize(x_ndim); + fill(ret.begin(), ret.end(), new Singleton()); + + for (int64_t i = 0, j = 0; i < out_ndim; i++) { + auto it = find(axis.begin(), axis.end(), i); + + if (it == axis.end()) { + if (out_shape[i] != 1) { + ret[j] = new InputDim(i); + } + + j++; + } + } + + return ret; +} + +SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, + const std::vector& axis) { + // Step0: Verify input args based on unsqueeze logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build the transformation from + // the original shape to the target shape + + std::vector out_shape; + std::vector axis_copy(axis); + + for (int64_t i = 0; i < static_cast(axis_copy.size()); i++) { + if (axis_copy[i] < 0) { + axis_copy[i] += x_ndim + 1; + } + } + + std::vector trans = + MakeUnsqueezeDimTrans(x_shape, &out_shape, axis_copy); + + // Step2: Infer the dims mapping of input (if reshard is + // needed) and output from the dimension transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(x, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr out_dist_attr(x_dist_attr_src); + out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "UnsqueezeInferSpmd: X shape: [" << str_join(x_shape) + << "] Out shape: [" << str_join(out_shape) << "]"; + VLOG(4) << "Transformation from input to output:"; + for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) + << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) + << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& axis) { + // Step0: Verify input args based on unsqueeze logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + + // Step1: Build the transformation from the output shape + // to original shape. This function infers the dims mapping + // from output to input, we first get the transformation + // from output to input so that we can infer the dims mapping + // with the map from output axes to input axes. + + std::vector axis_copy(axis); + + for (int64_t i = 0; i < static_cast(axis_copy.size()); i++) { + if (axis_copy[i] < 0) { + axis_copy[i] += x_ndim + 1; + } + } + + std::vector trans = + MakeUnsqueezeDimTransReverse(out_shape, axis_copy, x_ndim, out_ndim); + + // Step2: Infer the dims mapping of input with + // output's dims_mapping and the transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(out, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr out_dist_attr_dst(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "UnsqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "] X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Transformation from output to input:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] " + << "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr}, {out_dist_attr_dst}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.h b/paddle/phi/infermeta/spmd_rules/unsqueeze.h new file mode 100644 index 0000000000000..a2f3490409b83 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, + const std::vector& axis); + +SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + const std::vector& axis); +} // namespace distributed +} // namespace phi diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index 5c8f78b6c6544..97c2d1dc0205e 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) py_test_modules(test_slice_rule MODULES test_slice_rule) py_test_modules(test_flatten_rule MODULES test_flatten_rule) + py_test_modules(test_unsqueeze_rule MODULES test_unsqueeze_rule) py_test_modules(test_concat_rule MODULES test_concat_rule) # End of unittests WITH single card WITHOUT timeout diff --git a/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py b/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py index 8d69da185246e..f8ceb1b88bf96 100644 --- a/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py +++ b/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py @@ -26,7 +26,7 @@ class TestDefaultDataParallelSPMDRule(unittest.TestCase): def setUp(self): # After replaced all spmd rules by phi impl, we can recover the # api name to `get_spmd_rule` - self.rule = core.get_phi_spmd_rule("unsqueeze") + self.rule = core.get_phi_spmd_rule("default_data_parallel") x_shape = [10, 10, 32, 48] y_shape = [32, 48] diff --git a/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py new file mode 100644 index 0000000000000..afb851279ca36 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestUnsqueezeSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("unsqueeze") + + x_shape = [8, 16] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + + def test_unsqueeze_infer_forward(self): + # shape: [8, 16] --> [1, 8, 16] + # dims_mapping: [0, 1] --> [0, 1] [-1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16] --> [8, 16, 1] + # dims_mapping: [0, 1] --> [0, 1] [0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] + # dims_mapping: [0, 1] --> [0, 1] [0, -1, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] + # dims_mapping: [0, 1] --> [0, 1] [-1, -1, -1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 0, 1] + ) + + # shape: [8, 16] --> [1, 8, 16] + # dims_mapping: [1, 0] --> [1, 0] [-1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [8, 16] --> [8, 16, 1] + # dims_mapping: [1, 0] --> [1, 0] [1, 0, -1] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] + # dims_mapping: [1, 0] --> [1, 0] [1, -1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] + # dims_mapping: [1, 0] --> [1, 0] [-1, -1, -1, 1, 0] + self.x_dist_tensor_spec.set_dims_mapping([1, 0]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1, 0] + ) + + # shape: [1, 8, 16] --> [1, 1, 8, 16] + # dims_mapping: [0, 1, -1] --> [-1, 1, -1] [-1, -1, 1, -1] + self.x_dist_tensor_spec.shape = [1, 8, 16] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, self.attrs['axis'] + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + + def test_unsqueeze_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16], output_tensor_dist_attr + ) + + # shape: [8, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1] --> [0, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16] --> [8, 16, 1] (input --> output) + # dims_mapping: [0, 1, -1] --> [0, 1], [0, 1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16, 1] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] (input --> output) + # dims_mapping: [0, -1, -1, 1] --> [0, 1], [0, -1, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, -1, -1, 0, 1] --> [0, 1], [-1, -1, -1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 1, 1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 0, 1]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 0, 1] + ) + + # shape: [8, 16] --> [1, 8, 16] (input --> output) + # dims_mapping: [-1, 1, 0] --> [1, 0], [-1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + # shape: [8, 16] --> [8, 16, 1] (input --> output) + # dims_mapping: [1, 0, -1] --> [1, 0], [1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16, 1] + self.output_dist_tensor_spec.set_dims_mapping([1, 0, -1]) + self.attrs['axis'] = [-1] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]) + + # shape: [8, 16] --> [8, 1, 1, 16] (input --> output) + # dims_mapping: [1, -1, -1, 0] --> [1, 0], [1, -1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 1, 1, 16] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) + self.attrs['axis'] = [1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + + # shape: [8, 16] --> [1, 1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, -1, -1, 1, 0] --> [1, 0], [-1, -1, -1, 1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 1, 1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 1, 0]) + self.attrs['axis'] = [0, 1, 2] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1, 0] + ) + + # shape: [1, 8, 16] --> [1, 1, 8, 16] (input --> output) + # dims_mapping: [-1, 0, 1, -1] --> [-1, 1, -1], [-1, -1, 1, -1] (output --> input, output) + self.x_dist_tensor_spec.shape = [1, 8, 16] + self.output_dist_tensor_spec.shape = [1, 1, 8, 16] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['axis'] = [0] + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1] + ) + + +if __name__ == "__main__": + unittest.main()