diff --git a/paddle/operators/matmul_op.cc b/paddle/operators/matmul_op.cc index 3336978c8d8d9..3fe8bc165e10b 100644 --- a/paddle/operators/matmul_op.cc +++ b/paddle/operators/matmul_op.cc @@ -37,118 +37,156 @@ class MatMulOp : public framework::OperatorWithKernel { bool transpose_x = context->Attrs().Get("transpose_X"); bool transpose_y = context->Attrs().Get("transpose_Y"); + int x_num_col_dims = context->Attrs().Get("x_num_col_dims"); + int y_num_col_dims = context->Attrs().Get("y_num_col_dims"); + PADDLE_ENFORCE_GE(dim_x.size(), 1, "Input tensor X must be at least 1-dimensional."); PADDLE_ENFORCE_GE(dim_y.size(), 1, "Input tensor Y must be at least 1-dimensional."); - std::vector out_dim; - int64_t batch_count = 1; - if (dim_x.size() > 3) { - PADDLE_ENFORCE_EQ( - dim_y.size(), dim_x.size(), - "The dimensions of X and Y must be the same, and both of " - "them should be %d-dimensional.", - dim_x.size()); - - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - for (int j = 0; j < dim_x.size() - 2; ++j) { - PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], - "The %d-th dimension of X and Y must be the same.", - j); - out_dim.push_back(dim_x[j]); - batch_count *= dim_x[j]; + std::vector dim_out; + if (x_num_col_dims == 0 && x_num_col_dims == 0) { + std::vector out_dim; + int64_t batch_count = 1; + if (dim_x.size() > 3) { + PADDLE_ENFORCE_EQ( + dim_y.size(), dim_x.size(), + "The dimensions of X and Y must be the same, and both of " + "them should be %d-dimensional.", + dim_x.size()); + + // The first rank-2 dimensions are accumulated on the batch_count, + // and the last two dimensions are used for matrix multiplication. + for (int j = 0; j < dim_x.size() - 2; ++j) { + PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j], + "The %d-th dimension of X and Y must be the same.", + j); + out_dim.push_back(dim_x[j]); + batch_count *= dim_x[j]; + } } - } - int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; - bool remove_initial_dim = false, remove_final_dim = false; + int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0; + bool remove_initial_dim = false, remove_final_dim = false; - switch (dim_x.size()) { - case 1: - if (transpose_x) { - M = dim_x[0]; - KX = 1; - } else { - M = 1; - KX = dim_x[0]; - remove_initial_dim = true; - } - break; - case 2: - M = transpose_x ? dim_x[1] : dim_x[0]; - KX = transpose_x ? dim_x[0] : dim_x[1]; - break; - case 3: - batchCountX = dim_x[0]; - M = transpose_x ? dim_x[2] : dim_x[1]; - KX = transpose_x ? dim_x[1] : dim_x[2]; - break; - default: - batchCountX = batch_count; - size_t mat_s = dim_x.size() - 2; - M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s]; - KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1]; - break; - } + switch (dim_x.size()) { + case 1: + if (transpose_x) { + M = dim_x[0]; + KX = 1; + } else { + M = 1; + KX = dim_x[0]; + remove_initial_dim = true; + } + break; + case 2: + M = transpose_x ? dim_x[1] : dim_x[0]; + KX = transpose_x ? dim_x[0] : dim_x[1]; + break; + case 3: + batchCountX = dim_x[0]; + M = transpose_x ? dim_x[2] : dim_x[1]; + KX = transpose_x ? dim_x[1] : dim_x[2]; + break; + default: + batchCountX = batch_count; + size_t mat_s = dim_x.size() - 2; + M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s]; + KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1]; + break; + } - switch (dim_y.size()) { - case 1: - if (transpose_y) { - N = dim_y[0]; - KY = 1; + switch (dim_y.size()) { + case 1: + if (transpose_y) { + N = dim_y[0]; + KY = 1; + } else { + N = 1; + KY = dim_y[0]; + remove_final_dim = true; + } + break; + case 2: + KY = transpose_y ? dim_y[1] : dim_y[0]; + N = transpose_y ? dim_y[0] : dim_y[1]; + break; + case 3: + batchCountY = dim_y[0]; + KY = transpose_y ? dim_y[2] : dim_y[1]; + N = transpose_y ? dim_y[1] : dim_y[2]; + break; + default: + batchCountY = batch_count; + size_t mat_s = dim_y.size() - 2; + KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s]; + N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1]; + } + + PADDLE_ENFORCE_EQ( + KX, KY, + "First matrix's width must be equal with second matrix's height."); + if (batchCountX && batchCountY) { + PADDLE_ENFORCE_EQ( + batchCountX, batchCountY, + "When Input(X) and Input(Y) are both three dimensional, they " + "must have the same batch dimension."); + } + int batchCount = std::max(batchCountX, batchCountY); + + if (batchCount) { + if (dim_x.size() > 3) { + dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end()); } else { - N = 1; - KY = dim_y[0]; - remove_final_dim = true; + dim_out.push_back(batchCount); } - break; - case 2: - KY = transpose_y ? dim_y[1] : dim_y[0]; - N = transpose_y ? dim_y[0] : dim_y[1]; - break; - case 3: - batchCountY = dim_y[0]; - KY = transpose_y ? dim_y[2] : dim_y[1]; - N = transpose_y ? dim_y[1] : dim_y[2]; - break; - default: - batchCountY = batch_count; - size_t mat_s = dim_y.size() - 2; - KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s]; - N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1]; - } + } + if (!remove_initial_dim) { + dim_out.push_back(M); + } + if (!remove_final_dim) { + dim_out.push_back(N); + } + if (dim_out.size() == 0) { + // We don't support 0-dimensional Tensors (scalars), so instead + // treat the output as a Tensor of shape (1, ) in this case. + dim_out.push_back(1); + } + } else { + if (x_num_col_dims == 0) { + x_num_col_dims = 1; + } + if (y_num_col_dims == 0) { + y_num_col_dims = 1; + } + PADDLE_ENFORCE_GT( + dim_x.size(), x_num_col_dims, + "The input tensor X's rank of MulOp should be larger than " + "x_num_col_dims."); + PADDLE_ENFORCE_GT( + dim_x.size(), y_num_col_dims, + "The input tensor Y's rank of MulOp should be larger than " + "y_num_col_dims."); + + auto x_mat_dims = framework::flatten_to_2d(dim_x, x_num_col_dims); + auto y_mat_dims = framework::flatten_to_2d(dim_y, y_num_col_dims); - PADDLE_ENFORCE_EQ( - KX, KY, - "First matrix's width must be equal with second matrix's height."); - if (batchCountX && batchCountY) { PADDLE_ENFORCE_EQ( - batchCountX, batchCountY, - "When Input(X) and Input(Y) are both three dimensional, they " - "must have the same batch dimension."); - } - int batchCount = std::max(batchCountX, batchCountY); + x_mat_dims[1], y_mat_dims[0], + "First matrix's width must be equal with second matrix's height."); - std::vector dim_out; - if (batchCount) { - if (dim_x.size() > 3) { - dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end()); - } else { - dim_out.push_back(batchCount); + dim_out.reserve( + static_cast(x_num_col_dims + dim_y.size() - y_num_col_dims)); + + for (int i = 0; i < x_num_col_dims; ++i) { + dim_out.push_back(dim_x[i]); + } + + for (int i = y_num_col_dims; i < dim_y.size(); ++i) { + dim_out.push_back(dim_y[i]); } - } - if (!remove_initial_dim) { - dim_out.push_back(M); - } - if (!remove_final_dim) { - dim_out.push_back(N); - } - if (dim_out.size() == 0) { - // We don't support 0-dimensional Tensors (scalars), so instead - // treat the output as a Tensor of shape (1, ) in this case. - dim_out.push_back(1); } context->SetOutputDim("Out", framework::make_ddim(dim_out)); context->ShareLoD("X", /*->*/ "Out"); @@ -162,6 +200,37 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of MatMul op"); AddInput("Y", "The second input of MatMul op"); AddOutput("Out", "The output of MatMul op"); + AddAttr( + "x_num_col_dims", + R"DOC((int, default 0), The matmul_op can take tensors with more than two + dimensions as its inputs. If the input $X$ is a tensor with more + than two dimensions, $X$ will be flattened into a two-dimensional + matrix first. The flattening rule is: the first `num_col_dims` + will be flattened to form the first dimension of the final matrix + (the height of the matrix), and the rest `rank(X) - num_col_dims` + dimensions are flattened to form the second dimension of the final + matrix (the width of the matrix). As a result, height of the + flattened matrix is equal to the product of $X$'s first + `x_num_col_dims` dimensions' sizes, and width of the flattened + matrix is equal to the product of $X$'s last `rank(x) - num_col_dims` + dimensions' size. For example, suppose $X$ is a 6-dimensional + tensor with the shape [2, 3, 4, 5, 6], and `x_num_col_dims` = 3. + Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = + [24, 30]. The default value 0 indicates the input is a 2-D Matrix. + )DOC") + .SetDefault(0) + .EqualGreaterThan(0); + AddAttr( + "y_num_col_dims", + R"DOC((int, default 0), The matmul_op can take tensors with more than + two, dimensions as its inputs. If the input $Y$ is a tensor with + more than two dimensions, $Y$ will be flattened into a + two-dimensional matrix first. The attribute `y_num_col_dims` + determines how $Y$ is flattened. + See comments of `x_num_col_dims` for more details. + )DOC") + .SetDefault(0) + .EqualGreaterThan(0); AddAttr("transpose_X", R"DOC(If true, use the transpose of `X`. )DOC") diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index fe6a97465f899..831697052cd9b 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -38,9 +38,38 @@ class MatMulKernel : public framework::OpKernel { bool transpose_x = context.Attr("transpose_X"); bool transpose_y = context.Attr("transpose_Y"); - math::MatMulFunctor()( - context.template device_context(), x, transpose_x, y, - transpose_y, T(1), out, T(0)); + int x_num_col_dims = context.template Attr("x_num_col_dims"); + int y_num_col_dims = context.template Attr("y_num_col_dims"); + + if (x_num_col_dims == 0 && y_num_col_dims == 0) { + math::MatMulFunctor()( + context.template device_context(), x, transpose_x, y, + transpose_y, T(1), out, T(0)); + } else { + if (x_num_col_dims == 0) { + x_num_col_dims = 1; + } + if (y_num_col_dims == 0) { + y_num_col_dims = 1; + } + const Tensor x_matrix = + x.dims().size() > 2 ? framework::ReshapeToMatrix(x, x_num_col_dims) + : x; + const Tensor y_matrix = + y.dims().size() > 2 ? framework::ReshapeToMatrix(y, y_num_col_dims) + : y; + + auto out_dim = out->dims(); + if (out_dim.size() != 2) { + out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + math::matmul( + context.template device_context(), x_matrix, false, + y_matrix, false, 1, out, 0); + if (out_dim.size() != 2) { + out->Resize(out_dim); + } + } } }; @@ -124,111 +153,160 @@ class MatMulGradKernel : public framework::OpKernel { bool transpose_x = context.Attr("transpose_X"); bool transpose_y = context.Attr("transpose_Y"); - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); + int x_num_col_dims = context.template Attr("x_num_col_dims"); + int y_num_col_dims = context.template Attr("y_num_col_dims"); + auto& dev_ctx = context.template device_context(); - // If X is a vector, reshape it to a matrix. - if (x_dims.size() == 1) { - x_dims.insert(x_dims.begin(), 1); - } + if (x_num_col_dims == 0 && y_num_col_dims == 0) { + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); - // If Y is a vector, reshape it to a matrix. - if (y_dims.size() == 1) { - y_dims.push_back(1); - } + // If X is a vector, reshape it to a matrix. + if (x_dims.size() == 1) { + x_dims.insert(x_dims.begin(), 1); + } - int batch_count = 0; - // The first rank-2 dimensions are accumulated on the batch_count, and the - // last two dimensions are used for matrix multiplication. - if (x_dims.size() > 3) { - batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, - std::multiplies()); - } - // Fix the dOut dimensions. - int M = 0, N = 0, batchCountX = 0, batchCountY = 0; - - switch (x_dims.size()) { - case 2: - M = transpose_x ? x_dims[1] : x_dims[0]; - break; - case 3: - batchCountX = x_dims[0]; - M = transpose_x ? x_dims[2] : x_dims[1]; - break; - default: - batchCountX = batch_count; - size_t mat_s = x_dims.size() - 2; - M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; - } + // If Y is a vector, reshape it to a matrix. + if (y_dims.size() == 1) { + y_dims.push_back(1); + } - switch (y_dims.size()) { - case 2: - N = transpose_y ? y_dims[0] : y_dims[1]; - break; - case 3: - batchCountY = y_dims[0]; - N = transpose_y ? y_dims[1] : y_dims[2]; - break; - default: - batchCountY = batch_count; - size_t mat_s = y_dims.size() - 2; - N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; - } - if (batchCountX && batchCountY) { - PADDLE_ENFORCE_EQ( - batchCountX, batchCountY, - "When Input(X) and Input(Y) are both three dimensional, they " - "must have the same batch dimension."); - } - int batchCount = std::max(batchCountX, batchCountY); - std::vector dout_dims = {M, N}; - if (batchCount) { + int batch_count = 0; + // The first rank-2 dimensions are accumulated on the batch_count, and the + // last two dimensions are used for matrix multiplication. if (x_dims.size() > 3) { - dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); - } else { - dout_dims.insert(dout_dims.begin(), batchCount); + batch_count = accumulate(x_dims.begin(), x_dims.end() - 2, 1, + std::multiplies()); } - } - Tensor X = Reshape(x, make_ddim(x_dims)); - Tensor Y = Reshape(y, make_ddim(y_dims)); - Tensor dOut = Reshape(dout, make_ddim(dout_dims)); + // Fix the dOut dimensions. + int M = 0, N = 0, batchCountX = 0, batchCountY = 0; - auto& dev_ctx = context.template device_context(); - if (dx) { - dx->mutable_data(context.GetPlace()); - const Tensor& dOut_for_dX = - (x_dims.size() == 2 && y_dims.size() == 3) - ? CombineBatchAndN(dev_ctx, dOut) - : dOut; - if (x_dims.size() == 2 && y_dims.size() == 3) { - Y = transpose_y ? CombineBatchAndM(Y) - : CombineBatchAndN(dev_ctx, Y); - } - if (transpose_x) { - math::MatMulFunctor()( - dev_ctx, Y, transpose_y, dOut_for_dX, transpose_x, T(1), dx, T(0)); - } else { - math::MatMulFunctor()( - dev_ctx, dOut_for_dX, transpose_x, Y, !transpose_y, T(1), dx, T(0)); + switch (x_dims.size()) { + case 2: + M = transpose_x ? x_dims[1] : x_dims[0]; + break; + case 3: + batchCountX = x_dims[0]; + M = transpose_x ? x_dims[2] : x_dims[1]; + break; + default: + batchCountX = batch_count; + size_t mat_s = x_dims.size() - 2; + M = transpose_x ? x_dims[mat_s + 1] : x_dims[mat_s]; } - } - if (dy) { - dy->mutable_data(context.GetPlace()); - const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3) - ? CombineBatchAndM(dOut) - : dOut; - if (y_dims.size() == 2 && x_dims.size() == 3) { - X = transpose_x ? CombineBatchAndN(dev_ctx, X) - : CombineBatchAndM(X); - dOut = CombineBatchAndM(dOut); - } - if (transpose_y) { - math::MatMulFunctor()( - dev_ctx, dOut_for_dY, transpose_y, X, transpose_x, T(1), dy, T(0)); - } else { - math::MatMulFunctor()( - dev_ctx, X, !transpose_x, dOut_for_dY, transpose_y, T(1), dy, T(0)); + switch (y_dims.size()) { + case 2: + N = transpose_y ? y_dims[0] : y_dims[1]; + break; + case 3: + batchCountY = y_dims[0]; + N = transpose_y ? y_dims[1] : y_dims[2]; + break; + default: + batchCountY = batch_count; + size_t mat_s = y_dims.size() - 2; + N = transpose_y ? y_dims[mat_s] : y_dims[mat_s + 1]; + } + if (batchCountX && batchCountY) { + PADDLE_ENFORCE_EQ( + batchCountX, batchCountY, + "When Input(X) and Input(Y) are both three dimensional, they " + "must have the same batch dimension."); + } + int batchCount = std::max(batchCountX, batchCountY); + std::vector dout_dims = {M, N}; + if (batchCount) { + if (x_dims.size() > 3) { + dout_dims.insert(dout_dims.begin(), x_dims.begin(), x_dims.end() - 2); + } else { + dout_dims.insert(dout_dims.begin(), batchCount); + } + } + Tensor X = Reshape(x, make_ddim(x_dims)); + Tensor Y = Reshape(y, make_ddim(y_dims)); + Tensor dOut = Reshape(dout, make_ddim(dout_dims)); + + auto& dev_ctx = context.template device_context(); + if (dx) { + dx->mutable_data(context.GetPlace()); + const Tensor& dOut_for_dX = + (x_dims.size() == 2 && y_dims.size() == 3) + ? CombineBatchAndN(dev_ctx, dOut) + : dOut; + if (x_dims.size() == 2 && y_dims.size() == 3) { + Y = transpose_y ? CombineBatchAndM(Y) + : CombineBatchAndN(dev_ctx, Y); + } + if (transpose_x) { + math::MatMulFunctor()(dev_ctx, Y, transpose_y, + dOut_for_dX, transpose_x, + T(1), dx, T(0)); + } else { + math::MatMulFunctor()(dev_ctx, dOut_for_dX, + transpose_x, Y, !transpose_y, + T(1), dx, T(0)); + } + } + + if (dy) { + dy->mutable_data(context.GetPlace()); + const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3) + ? CombineBatchAndM(dOut) + : dOut; + if (y_dims.size() == 2 && x_dims.size() == 3) { + X = transpose_x ? CombineBatchAndN(dev_ctx, X) + : CombineBatchAndM(X); + dOut = CombineBatchAndM(dOut); + } + if (transpose_y) { + math::MatMulFunctor()(dev_ctx, dOut_for_dY, + transpose_y, X, transpose_x, + T(1), dy, T(0)); + } else { + math::MatMulFunctor()(dev_ctx, X, !transpose_x, + dOut_for_dY, transpose_y, + T(1), dy, T(0)); + } + } + } else { + if (x_num_col_dims == 0) { + x_num_col_dims = 1; + } + if (y_num_col_dims == 0) { + y_num_col_dims = 1; + } + const Tensor x_matrix = + x.dims().size() > 2 ? framework::ReshapeToMatrix(x, x_num_col_dims) + : x; + const Tensor y_matrix = + y.dims().size() > 2 ? framework::ReshapeToMatrix(y, y_num_col_dims) + : y; + const Tensor* dout = context.Input(framework::GradVarName("Out")); + + Tensor dout_mat; + dout_mat.ShareDataWith(*dout); + dout_mat.Resize({framework::flatten_to_2d(x.dims(), x_num_col_dims)[0], + framework::flatten_to_2d(y.dims(), y_num_col_dims)[1]}); + + if (dx) { + dx->mutable_data(context.GetPlace()); + Tensor dx_matrix = dx->dims().size() > 2 + ? framework::ReshapeToMatrix(*dx, x_num_col_dims) + : *dx; + + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + math::matmul(dev_ctx, dout_mat, false, y_matrix, true, + 1, &dx_matrix, 0); + } + if (dy) { + dy->mutable_data(context.GetPlace()); + Tensor dy_matrix = dy->dims().size() > 2 + ? framework::ReshapeToMatrix(*dy, y_num_col_dims) + : *dy; + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + math::matmul(dev_ctx, x_matrix, true, dout_mat, false, + 1, &dy_matrix, 0); } } } diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index a79479f469a0c..38f1a4f44956d 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -166,14 +166,11 @@ def fc(input, w = helper.create_parameter( attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) - tmp = helper.create_tmp_variable(dtype) - helper.append_op( - type="mul", - inputs={"X": input_var, - "Y": w}, - outputs={"Out": tmp}, - attrs={"x_num_col_dims": num_flatten_dims, - "y_num_col_dims": 1}) + if (len(input_shape) > 2): + tmp = matmul(input_var, w, False, False, num_flatten_dims, 1) + else: + tmp = matmul(input_var, w, False, False) + mul_results.append(tmp) # sum @@ -2286,7 +2283,13 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): return out -def matmul(x, y, transpose_x=False, transpose_y=False, name=None): +def matmul(x, + y, + transpose_x=False, + transpose_y=False, + x_num_col_dims=0, + y_num_col_dims=0, + name=None): """ Applies matrix multiplication to two tensors. @@ -2320,6 +2323,26 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): y (Variable): The input variable which is a Tensor or LoDTensor. transpose_x (bool): Whether to transpose :math:`x` before multiplication. transpose_y (bool): Whether to transpose :math:`y` before multiplication. + x_num_col_dims (integer): The matmul layer can accept an input tensor with + more than two dimensions. If this happens, the + multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter + `x_num_col_dims` determines how the input tensor + is flattened: the first `x_num_col_dims` + (inclusive, index starts from 1) dimensions will + be flatten to form the first dimension of the + final matrix (height of the matrix), and the rest + `rank(X) - x_num_col_dims` dimensions are + flattened to form the second dimension of the + final matrix (width of the matrix). For example, + suppose `X` is a 6-dimensional tensor with a shape + [2, 3, 4, 5, 6], and `x_num_col_dims` = 3. Then, + the flattened matrix will have a shape + [2 x 3 x 4, 5 x 6] = [24, 30]. By default, + the `x_num_col_dims` is set as 0 to indicate + the input is a 2-D Matrix. + y_num_col_dims (integer): See comments of `x_num_col_dims` for + more details. name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -2387,8 +2410,12 @@ def __check_input(x, y): inputs={'X': x, 'Y': y}, outputs={'Out': out}, - attrs={'transpose_X': transpose_x, - 'transpose_Y': transpose_y}) + attrs={ + 'transpose_X': transpose_x, + 'transpose_Y': transpose_y, + 'x_num_col_dims': x_num_col_dims, + 'y_num_col_dims': y_num_col_dims, + }) return out