Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Add AWDRNN Pratrained model test #20018

Merged
merged 8 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/contrib/onnx/mx2onnx/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def export_model(sym, params, in_shapes=None, in_types=np.float32,
if not isinstance(in_types, list):
in_types = [in_types for _ in range(len(in_shapes))]
in_types_t = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(i_t)] for i_t in in_types]
assert len(in_types) == len(in_shapes), "The lengths of in_types and in_shapes must equal"
# if input parameters are strings(file paths), load files and create symbol parameter objects
if isinstance(sym, string_types) and isinstance(params, string_types):
logging.info("Converting json and weight file to sym and params")
Expand Down
61 changes: 61 additions & 0 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,3 +868,64 @@ def test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
finally:
shutil.rmtree(tmp_path)


@with_seed()
@pytest.mark.parametrize('model_name', [('awd_lstm_lm_600', 600), ('awd_lstm_lm_1150', 1150)])
@pytest.mark.parametrize('seq_length', [16, 128, 256])
def test_awd_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name, seq_length):
try:
import gluonnlp as nlp
ctx = mx.cpu()
dataset= 'wikitext-2'
model, _ = nlp.model.get_model(
name=model_name[0],
ctx=ctx,
pretrained=True,
dataset_name=dataset,
dropout=0)
model.hybridize()

batch = 2
num_hidden = model_name[1]
num_layers = 2
inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch),
ctx=ctx).astype('float32')
begin_state = model.begin_state(func=mx.nd.random.uniform, low=0, high=1,
batch_size=batch, dtype='float32', ctx=ctx)
out, out_state= model(inputs, begin_state)

prefix = "%s/awd_lstm" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix

input_shapes = [(seq_length, batch),
np.shape(begin_state[0][0]), np.shape(begin_state[0][1]),
np.shape(begin_state[1][0]), np.shape(begin_state[1][1]),
np.shape(begin_state[2][0]), np.shape(begin_state[2][1])]
input_types = [np.float32, np.float32, np.float32, np.float32, np.float32, np.float32,
np.float32]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes,
input_types, onnx_file, verbose=True)

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = onnxruntime.InferenceSession(onnx_file, sess_options)

in_tensors = [inputs, begin_state[0][0], begin_state[0][1],
begin_state[1][0], begin_state[1][1],
begin_state[2][0], begin_state[2][1]]
input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy()) for i in range(len(in_tensors)))
pred = sess.run(None, input_dict)

assert_almost_equal(out, pred[6])
assert_almost_equal(out_state[0][0], pred[0])
assert_almost_equal(out_state[0][1], pred[1])
assert_almost_equal(out_state[1][0], pred[2])
assert_almost_equal(out_state[1][1], pred[3])
assert_almost_equal(out_state[2][0], pred[4])
assert_almost_equal(out_state[2][1], pred[5])

finally:
shutil.rmtree(tmp_path)