diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index a3fc8e50fdc67a..23e47640d37614 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2199,7 +2199,7 @@ def tensor_split(x, indices_or_sections, axis=0, name=None): Split the input tensor into multiple sub-Tensors along ``axis``, allowing not being of equal size. Args: - x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, complex64, complex128 or int64. + x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64. indices_or_sections (int|list|tuple): If ``indices_or_sections`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``. If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first ``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).