Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix dropout and doc (#20124)
Browse files Browse the repository at this point in the history
  • Loading branch information
barry-jin authored Apr 7, 2021
1 parent 798cfe1 commit 4745e0d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
4 changes: 1 addition & 3 deletions python/mxnet/ndarray/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def pooling(data=None, kernel=None, stride=None, pad=None, pool_type="max",

# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.ndarray.numpy_extension')
def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=True, **kwargs):
def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=False, **kwargs):
r"""Applies dropout operation to input array.
- During training, each element of the input is set to zero with probability p.
Expand Down Expand Up @@ -869,10 +869,8 @@ def one_hot(data, depth=None, on_value=1.0, off_value=0.0, dtype="float32"):
>>> npx.one_hot(data, 3)
array([[[0., 1., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[1., 0., 0.]],
[[0., 0., 1.],
[1., 0., 0.]]], dtype=float64)
"""
Expand Down
4 changes: 1 addition & 3 deletions python/mxnet/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def pooling(data=None, kernel=None, stride=None, pad=None, pool_type="max",

# pylint: disable=too-many-arguments, unused-argument
@set_module('mxnet.numpy_extension')
def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=True, **kwargs):
def dropout(data, p=0.5, mode="training", axes=None, cudnn_off=False, **kwargs):
r"""Applies dropout operation to input array.
- During training, each element of the input is set to zero with probability p.
Expand Down Expand Up @@ -829,10 +829,8 @@ def one_hot(data, depth=None, on_value=1.0, off_value=0.0, dtype="float32"):
>>> npx.one_hot(data, 3)
array([[[0., 1., 0.],
[1., 0., 0.]],
[[0., 1., 0.],
[1., 0., 0.]],
[[0., 0., 1.],
[1., 0., 0.]]], dtype=float64)
"""
Expand Down

0 comments on commit 4745e0d

Please sign in to comment.