Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Skip flaky test for unidirectional rnn_relu
Browse files Browse the repository at this point in the history
Skip `test_rnnrelu_sym`, and add some issue tracking message

Add return

Revert test_rnnrelu_sym to origin
  • Loading branch information
xziya committed Oct 22, 2019
1 parent d109033 commit 192f7a9
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4):
if default_context().device_type == 'cpu':
# NOTE(zixuanweeei): Currently, we don't add `add` requests support on fused mkl-dnn rnn operator.
# We tracked this issue by https://github.com/apache/incubator-mxnet/issues/16578
if isinstance(grad_req, dict) and 'add' in grad_req.values():
print("Skip the test when requiring `add` operation against gradients on CPU context.")
return
Expand Down Expand Up @@ -257,20 +258,17 @@ def test_rnntanh_bidirectional():
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_rnnrelu_sym():
Ts = [1, 5]
Ns = [1, 32]
Is = [32, 128, 512]
Hs = [32, 128, 512]
for T, N, I, H in itertools.product(Ts, Ns, Is, Hs):
fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
stack = mx.rnn.SequentialRNNCell()
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_'))
T, N, I, H = 5, 32, 200, 200

check_rnn_consistency(fused, stack, T, N, I, H, 'write')
check_rnn_consistency(fused, stack, T, N, I, H, 'add')
check_rnn_consistency(fused, stack, T, N, I, H, 'null')
fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
stack = mx.rnn.SequentialRNNCell()
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_'))

check_rnn_consistency(fused, stack, T, N, I, H, 'write')
check_rnn_consistency(fused, stack, T, N, I, H, 'add')
check_rnn_consistency(fused, stack, T, N, I, H, 'null')

@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
Expand Down

0 comments on commit 192f7a9

Please sign in to comment.