Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#31 from AIPG/wojtuss/bilstm
Browse files Browse the repository at this point in the history
bilstm topology added to text_classification model (Senta)
  • Loading branch information
Uss, Wojciech authored and GitHub Enterprise committed Aug 22, 2018
2 parents 9a4a7ef + 14fee20 commit 8f40794
Show file tree
Hide file tree
Showing 10 changed files with 43,218 additions and 12 deletions.
200 changes: 200 additions & 0 deletions fluid/text_classification/data/test_data/corpus.test

Large diffs are not rendered by default.

32,896 changes: 32,896 additions & 0 deletions fluid/text_classification/data/train.vocab

Large diffs are not rendered by default.

10,000 changes: 10,000 additions & 0 deletions fluid/text_classification/data/train_data/corpus.train

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions fluid/text_classification/infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import sys
import time
import unittest
import contextlib
import numpy as np
import argparse

Expand All @@ -17,7 +14,13 @@ def parse_args():
'--batch_size',
type=int,
default=128,
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn" and "lstm")')
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn", "lstm" and "bilstm").')
parser.add_argument(
"--dataset",
type=str,
default='imdb',
choices=['imdb', 'data'],
help="Dataset to be used: 'imdb' or 'data' (from 'data' subdirectory).")
parser.add_argument(
'--device',
type=str,
Expand Down Expand Up @@ -77,7 +80,7 @@ def infer(args):
wpses = [0] * total_passes
acces = [0] * total_passes
word_dict, train_reader, test_reader = utils.prepare_data(
"imdb", self_dict=False, batch_size=args.batch_size,
args.dataset, self_dict=False, batch_size=args.batch_size,
buf_size=50000)
pass_acc = 0.0
pass_data_len = 0
Expand All @@ -100,6 +103,7 @@ def infer(args):
fetch_list=fetch_targets,
return_numpy=True)
batch_time = time.time() - start
# TODO: add class accuracy measurement as in Senta
word_count = len([w for d in data for w in d[0]])
batch_times[pass_id] += batch_time
word_counts[pass_id] += word_count
Expand Down
49 changes: 49 additions & 0 deletions fluid/text_classification/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,55 @@ def lstm_net(data,
return avg_cost, acc, prediction


def bilstm_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0):
"""
Bi-Lstm net
"""
# embedding layer
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))

# bi-lstm layer
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)

rfc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)

lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)

rlstm_h, c = fluid.layers.dynamic_lstm(
input=rfc0, size=hid_dim * 4, is_reverse=True)

# extract last layer
lstm_last = fluid.layers.sequence_last_step(input=lstm_h)
rlstm_last = fluid.layers.sequence_last_step(input=rlstm_h)

lstm_last_tanh = fluid.layers.tanh(lstm_last)
rlstm_last_tanh = fluid.layers.tanh(rlstm_last)

# concat layer
lstm_concat = fluid.layers.concat(input=[lstm_last, rlstm_last], axis=1)

# full connect layer
fc1 = fluid.layers.fc(input=lstm_concat, size=hid_dim2, act='tanh')
# softmax layer
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)

return avg_cost, acc, prediction


def gru_net(data,
label,
dict_dim,
Expand Down
2 changes: 2 additions & 0 deletions fluid/text_classification/scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ train_bow.sh
train_cnn.sh
train_gru.sh
train_lstm.sh
train_bilstm.sh
```

## Inference
Expand All @@ -35,4 +36,5 @@ infer_bow.sh
infer_cnn.sh
infer_gru.sh
infer_lstm.sh
infer_bilstm.sh
```
7 changes: 7 additions & 0 deletions fluid/text_classification/scripts/infer_bilstm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
time python ../infer.py \
--device CPU \
--model_path bilstm_model/epoch0 \
--num_passes 100 \
--profile

7 changes: 7 additions & 0 deletions fluid/text_classification/scripts/train_bilstm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
time python ../train.py \
--device CPU \
--model_save_dir bilstm_model \
--num_passes 1 \
bilstm

18 changes: 13 additions & 5 deletions fluid/text_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,32 @@
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import bilstm_net
from nets import gru_net

nets = {'bow': bow_net, 'cnn': cnn_net, 'lstm': lstm_net, 'gru': gru_net}
nets = {'bow': bow_net, 'cnn': cnn_net, 'lstm': lstm_net, 'bilstm': bilstm_net,
'gru': gru_net}
# learning rates
lrs = {'bow': 0.002, 'cnn': 0.01, 'lstm': 0.05, 'gru': 0.05}
lrs = {'bow': 0.002, 'cnn': 0.01, 'lstm': 0.05, 'bilstm':0.002, 'gru': 0.05}

def parse_args():
parser = argparse.ArgumentParser("Run inference.")
parser.add_argument(
'topology',
type=str,
choices=['bow', 'cnn', 'lstm', 'gru'],
choices=['bow', 'cnn', 'lstm', 'bilstm', 'gru'],
help='Topology used for the model (bow/cnn/lstm/gru).')
parser.add_argument(
"--dataset",
type=str,
default='imdb',
choices=['imdb', 'data'],
help="Dataset to be used: 'imdb' or 'data' (from 'data' subdirectory).")
parser.add_argument(
'--batch_size',
type=int,
default=128,
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn" and "lstm")')
help='The size of a batch. (default: %(default)d, usually: 128 for "bow" and "gru", 4 for "cnn", "lstm" and "bilstm").')
parser.add_argument(
'--device',
type=str,
Expand Down Expand Up @@ -122,7 +130,7 @@ def train(train_reader,

def train_net(args):
word_dict, train_reader, test_reader = utils.prepare_data(
"imdb", self_dict=False, batch_size=128, buf_size=50000)
args.dataset, self_dict=False, batch_size=128, buf_size=50000)

train(
train_reader,
Expand Down
37 changes: 35 additions & 2 deletions fluid/text_classification/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
import time
import numpy as np
import random
import os

import paddle
import paddle.fluid as fluid
Expand Down Expand Up @@ -48,18 +50,43 @@ def data2tensor(data, place):
return {"words": input_seq, "label": y_data}


def data_reader(file_path, word_dict, is_shuffle=True):
"""
Convert word sequence into slot
"""
unk_id = len(word_dict)
all_data = []
with open(file_path, "r") as fin:
for line in fin:
cols = line.strip().split("\t")
label = int(cols[0])
wids = [word_dict[x] if x in word_dict else unk_id
for x in cols[1].split(" ")]
all_data.append((wids, label))
if is_shuffle:
random.shuffle(all_data)

def reader():
for doc, label in all_data:
yield doc, label
return reader


def prepare_data(data_type="imdb",
self_dict=False,
batch_size=128,
buf_size=50000):
"""
prepare data
"""
script_path = os.path.dirname(__file__)
if self_dict:
word_dict = load_vocab(data_type + ".vocab")
else:
if data_type == "imdb":
word_dict = paddle.dataset.imdb.word_dict()
elif data_type == "data":
word_dict = load_vocab(script_path + "/data/train.vocab")
else:
raise RuntimeError("No such dataset")

Expand All @@ -68,12 +95,18 @@ def prepare_data(data_type="imdb",
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=buf_size),
batch_size=batch_size)

test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.test(word_dict), buf_size=buf_size),
batch_size=batch_size)
elif data_type == "data":
train_reader = paddle.batch(
data_reader(script_path + "/data/train_data/corpus.train", word_dict, True),
batch_size=batch_size)
test_reader = paddle.batch(
data_reader(script_path + "/data/test_data/corpus.test", word_dict, False),
batch_size=batch_size)
else:
raise RuntimeError("no such dataset")
raise RuntimeError("No such dataset")

return word_dict, train_reader, test_reader

0 comments on commit 8f40794

Please sign in to comment.