Skip to content

Commit

Permalink
【Hackathon 5th No.26】为 Paddle 新增 diagonal_scatter API (#6289)
Browse files Browse the repository at this point in the history
* add diagonal scatter docs

* update

* fix: name

* Update docs/api/paddle/diagonal_scatter_cn.rst

Co-authored-by: ooo oo <[email protected]>

* fix review suggestions

* update

* update

* add difference compare

---------

Co-authored-by: ooo oo <[email protected]>
  • Loading branch information
DanGuge and ooooo-create authored Nov 20, 2023
1 parent ab84306 commit f03f205
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api/paddle/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ tensor 元素操作相关(如:转置,reshape 等)
" :ref:`paddle.view_as <cn_api_paddle_view_as>` ", "使用 other 的 shape,返回 x 的一个 view Tensor"
" :ref:`paddle.unfold <cn_api_paddle_unfold>` ", "返回 x 的一个 view Tensor。以滑动窗口式提取 x 的值"
" :ref:`paddle.masked_fill <cn_api_paddle_masked_fill>` ", "根据 mask 信息,将 value 中的值填充到 x 中 mask 对应为 True 的位置。"
" :ref:`paddle.diagonal_scatter <cn_api_paddle_diagonal_scatter>` ", "根据给定的轴 axis 和偏移量 offset,将张量 y 的值填充到张量 x 中"
" :ref:`paddle.index_fill <cn_api_paddle_index_fill>` ", "沿着指定轴 axis 将 index 中指定位置的 x 的值填充为 value"

.. _tensor_manipulation_inplace:
Expand Down
1 change: 1 addition & 0 deletions docs/api/paddle/Tensor/Overview_en.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,4 @@ Methods
vander
hypot
hypot_
diagonal_scatter
10 changes: 10 additions & 0 deletions docs/api/paddle/Tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3081,3 +3081,13 @@ masked_fill_(x, mask, value, name=None)
:::::::::

Inplace 版本的 :ref:`cn_api_paddle_masked_fill` API,对输入 `x` 采用 Inplace 策略。

diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
:::::::::
根据给定的轴 axis 和偏移量 offset,将张量 y 的值填充到张量 x 中。

返回:张量 y 填充到张量 x 中的结果。

返回类型:Tensor

请参考 :ref:`cn_api_paddle_diagonal_scatter`
37 changes: 37 additions & 0 deletions docs/api/paddle/diagonal_scatter_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
.. _cn_api_paddle_diagonal_scatter:

diagonal_scatter
-------------------------------

.. py:function:: paddle.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
根据参数 ``offset``、``axis1``、``axis2``,将张量 ``y`` 填充到张量 ``x`` 的对应位置。

这个函数将会返回一个新的 ``Tensor``。

参数 ``offset`` 确定从指定的二维平面中获取对角线的位置:

- 如果 offset = 0,则嵌入主对角线。
- 如果 offset > 0,则嵌入主对角线右上的对角线。
- 如果 offset < 0,则嵌入主对角线左下的对角线。

参数
::::::::::::

- **x** (Tensor) - 输入张量,张量的维度至少为 2 维,支持 float16、float32、float64、bfloat16、uint8、int8、int16、int32、int64、bool、complex64、complex128 数据类型。
- **y** (Tensor) - 嵌入张量,将会被嵌入到输入张量中,支持 float16、float32、float64、bfloat16、uint8、int8、int16、int32、int64、bool、complex64、complex128 数据类型。
- **offset** (int, 可选) - 从指定的二维平面嵌入对角线的位置,默认值为 0,即主对角线。
- **axis1** (int, 可选) - 对角线的第一个维度,默认值为 0。
- **axis2** (int, 可选) - 对角线的第二个维度,默认值为 1。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
::::::::::::
``Tensor``,返回一个根据给定的轴 ``axis`` 和偏移量 ``offset``,将张量 ``y`` 填充到张量 ``x`` 对应位置的新 ``Tensor``。


代码示例
::::::::::::

COPY-FROM: paddle.diagonal_scatter
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## [ 仅参数名不一致 ] torch.Tensor.diagonal_scatter

### [torch.Tensor.diagonal_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.diagonal_scatter.html?highlight=diagonal_scatter#torch.Tensor.diagonal_scatter)

```python
torch.Tensor.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1)
```

### [paddle.Tensor.diagonal_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#diagonal-scatter-x-y-offset-0-axis1-0-axis2-1-name-none)

```python
paddle.Tensor.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1)
```

两者功能一致且参数用法一致,仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| input | x | 输入张量,被嵌入的张量,仅参数名不一致。 |
| src | y | 用于嵌入的张量,仅参数名不一致。 |
| offset | offset | 从指定的二维平面嵌入对角线的位置,默认值为 0,即主对角线。 |
| dim1 | axis1 | 对角线的第一个维度,默认值为 0,仅参数名不一致。 |
| dim2 | axis2 | 对角线的第二个维度,默认值为 1,仅参数名不一致。 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## [ 仅参数名不一致 ] torch.diagonal_scatter

### [torch.diagonal_scatter](https://pytorch.org/docs/stable/generated/torch.diagonal_scatter.html?highlight=diagonal_scatter#torch.diagonal_scatter)

```python
torch.diagonal_scatter(input,
src,
offset=0,
dim1=0,
dim2=1)
```

### [paddle.diagonal_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/diagonal_scatter_cn.html)

```python
paddle.diagonal_scatter(x,
y,
offset=0,
axis1=0,
axis2=1)
```

两者功能一致且参数用法一致,仅参数名不一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| <font color='red'> input </font> | <font color='red'> x </font> | 输入张量,被嵌入的张量,仅参数名不一致。 |
| <font color='red'> src </font> | <font color='red'> y </font> | 用于嵌入的张量,仅参数名不一致。 |
| <font color='red'> offset </font> | <font color='red'> offset </font> | 从指定的二维平面嵌入对角线的位置,默认值为 0,即主对角线。 |
| <font color='red'> dim1 </font> | <font color='red'> axis1 </font> | 对角线的第一个维度,默认值为 0,仅参数名不一致。 |
| <font color='red'> dim2 </font> | <font color='red'> axis2 </font> | 对角线的第二个维度,默认值为 1,仅参数名不一致。 |
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,10 @@
| REFERENCE-MAPPING-ITEM(`torch.frexp`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.frexp.md) |
| REFERENCE-MAPPING-ITEM(`torch.nanmean`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.nanmean.md) |
| REFERENCE-MAPPING-ITEM(`torch.take_along_dim`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.take_along_dim.md) |
| REFERENCE-MAPPING-ITEM(`torch.diagonal_scatter`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/ops/torch.diagonal_scatter.md) |
| NOT-IMPLEMENTED-ITEM(`torch.geqrf`, https://pytorch.org/docs/stable/generated/torch.geqrf.html?highlight=geqrf#torch.geqrf) |
| NOT-IMPLEMENTED-ITEM(`torch.bitwise_right_shift`, https://pytorch.org/docs/stable/generated/torch.bitwise_right_shift.html#torch.bitwise_right_shift) |
| NOT-IMPLEMENTED-ITEM(`torch.is_conj`, https://pytorch.org/docs/stable/generated/torch.is_conj.html#torch.is_conj) |
| NOT-IMPLEMENTED-ITEM(`torch.diagonal_scatter`, https://pytorch.org/docs/stable/generated/torch.diagonal_scatter.html#torch.diagonal_scatter) |
| NOT-IMPLEMENTED-ITEM(`torch.select_scatter`, https://pytorch.org/docs/stable/generated/torch.select_scatter.html#torch.select_scatter) |
| NOT-IMPLEMENTED-ITEM(`torch.slice_scatter`, https://pytorch.org/docs/stable/generated/torch.slice_scatter.html#torch.slice_scatter) |
| NOT-IMPLEMENTED-ITEM(`torch.scatter_reduce`, https://pytorch.org/docs/stable/generated/torch.scatter_reduce.html#torch.scatter_reduce) |
Expand Down Expand Up @@ -790,7 +790,7 @@
| 167 | [torch.Tensor.any](https://pytorch.org/docs/stable/generated/torch.Tensor.any.html?highlight=torch+tensor+any#torch.Tensor.any) | [paddle.Tensor.any](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#any-axis-none-keepdim-false-name-none) | 功能一致, 参数不一致 , [差异对比](https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.any.md) |
| 168 | [torch.Tensor.bitwise_right_shift](https://pytorch.org/docs/stable/generated/torch.Tensor.bitwise_right_shift.html#torch.Tensor.bitwise_right_shift) | | 功能缺失 |
| 169 | [torch.Tensor.is_conj](https://pytorch.org/docs/stable/generated/torch.Tensor.is_conj.html#torch.Tensor.is_conj) | | 功能缺失 |
| 170 | [torch.Tensor.diagonal_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.diagonal_scatter.html#torch.Tensor.diagonal_scatter) | | 功能缺失 |
| 170 | [torch.Tensor.diagonal_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.diagonal_scatter.html#torch.Tensor.diagonal_scatter) | [paddle.Tensor.diagonal_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#diagonal-scatter-x-y-offset-0-axis1-0-axis2-1-name-none) | 功能完全一致,仅参数名不一致 [差异对比](https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.diagonal_scatter.md) |
| 171 | [torch.Tensor.scatter_reduce](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce.html#torch.Tensor.scatter_reduce) | | 功能缺失 |
| 172 | [torch.Tensor.select_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.select_scatter.html#torch.Tensor.select_scatter) | | 功能缺失 |
| 173 | [torch.Tensor.slice_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.slice_scatter.html#torch.Tensor.slice_scatter) | | 功能缺失 |
Expand Down Expand Up @@ -1097,9 +1097,9 @@
| REFERENCE-MAPPING-ITEM(`torch.Tensor.less`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.less.md) |
| REFERENCE-MAPPING-ITEM(`torch.Tensor.all`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.all.md) |
| REFERENCE-MAPPING-ITEM(`torch.Tensor.any`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.any.md) |
| REFERENCE-MAPPING-ITEM(`torch.Tensor.diagonal_scatter`, https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.diagonal_scatter.md) |
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.bitwise_right_shift`, https://pytorch.org/docs/stable/generated/torch.Tensor.bitwise_right_shift.html#torch.Tensor.bitwise_right_shift) |
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.is_conj`, https://pytorch.org/docs/stable/generated/torch.Tensor.is_conj.html#torch.Tensor.is_conj) |
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.diagonal_scatter`, https://pytorch.org/docs/stable/generated/torch.Tensor.diagonal_scatter.html#torch.Tensor.diagonal_scatter) |
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.scatter_reduce`, https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce.html#torch.Tensor.scatter_reduce) |
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.select_scatter`, https://pytorch.org/docs/stable/generated/torch.Tensor.select_scatter.html#torch.Tensor.select_scatter) |
| NOT-IMPLEMENTED-ITEM(`torch.Tensor.slice_scatter`, https://pytorch.org/docs/stable/generated/torch.Tensor.slice_scatter.html#torch.Tensor.slice_scatter) |
Expand Down

0 comments on commit f03f205

Please sign in to comment.