-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #689 from PaddlePaddle/jzy2
Add NER demo for Fluid.
- Loading branch information
Showing
6 changed files
with
468 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# 命名实体识别 | ||
|
||
以下是本例的简要目录结构及说明: | ||
|
||
```text | ||
. | ||
├── data # 存储运行本例所依赖的数据,从外部获取 | ||
├── network_conf.py # 模型定义 | ||
├── reader.py # 数据读取接口, 从外部获取 | ||
├── README.md # 文档 | ||
├── train.py # 训练脚本 | ||
├── infer.py # 预测脚本 | ||
├── utils.py # 定义通用的函数, 从外部获取 | ||
└── utils_extend.py # 对utils.py的拓展 | ||
``` | ||
|
||
|
||
## 简介,模型详解 | ||
|
||
在PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中对于命名实体识别任务有较详细的介绍,在本例中不再重复介绍。 | ||
在模型上,我们沿用了v2版本的模型结构,唯一区别是我们使用LSTM代替原始的RNN。 | ||
|
||
## 数据获取 | ||
|
||
请参考PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md) 一节中数据获取方式,将该例中的data文件夹拷贝至本例目录下,运行其中的download.sh脚本获取训练和测试数据。 | ||
|
||
## 通用脚本获取 | ||
|
||
请将PaddlePaddle v2版本[命名实体识别](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/README.md)中提供的用于数据读取的文件[reader.py](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/reader.py)以及包含字典导入等通用功能的文件[utils.py](https://github.com/PaddlePaddle/models/blob/develop/sequence_tagging_for_ner/utils.py)复制到本目录下。本例将会使用到这两个脚本。 | ||
|
||
## 训练 | ||
|
||
1. 运行 `sh data/download.sh` | ||
2. 修改 `train.py` 的 `main` 函数,指定数据路径 | ||
|
||
```python | ||
main( | ||
train_data_file="data/train", | ||
test_data_file="data/test", | ||
vocab_file="data/vocab.txt", | ||
target_file="data/target.txt", | ||
emb_file="data/wordVectors.txt", | ||
model_save_dir="models", | ||
num_passes=1000, | ||
use_gpu=False, | ||
parallel=False) | ||
``` | ||
|
||
3. 运行命令 `python train.py` ,**需要注意:直接运行使用的是示例数据,请替换真实的标记数据。** | ||
|
||
```text | ||
Pass 127, Batch 9525, Cost 4.0867705, Precision 0.3954984, Recall 0.37846154, F1_score0.38679245 | ||
Pass 127, Batch 9530, Cost 3.137265, Precision 0.42971888, Recall 0.38351256, F1_score0.405303 | ||
Pass 127, Batch 9535, Cost 3.6240938, Precision 0.4272152, Recall 0.41795665, F1_score0.4225352 | ||
Pass 127, Batch 9540, Cost 3.5352352, Precision 0.48464164, Recall 0.4536741, F1_score0.46864685 | ||
Pass 127, Batch 9545, Cost 4.1130385, Precision 0.40131578, Recall 0.3836478, F1_score0.39228293 | ||
Pass 127, Batch 9550, Cost 3.6826708, Precision 0.43333334, Recall 0.43730888, F1_score0.43531203 | ||
Pass 127, Batch 9555, Cost 3.6363933, Precision 0.42424244, Recall 0.3962264, F1_score0.4097561 | ||
Pass 127, Batch 9560, Cost 3.6101768, Precision 0.51363635, Recall 0.353125, F1_score0.41851854 | ||
Pass 127, Batch 9565, Cost 3.5935276, Precision 0.5152439, Recall 0.5, F1_score0.5075075 | ||
Pass 127, Batch 9570, Cost 3.4987144, Precision 0.5, Recall 0.4330218, F1_score0.46410686 | ||
Pass 127, Batch 9575, Cost 3.4659843, Precision 0.39864865, Recall 0.38064516, F1_score0.38943896 | ||
Pass 127, Batch 9580, Cost 3.1702557, Precision 0.5, Recall 0.4490446, F1_score0.47315437 | ||
Pass 127, Batch 9585, Cost 3.1587276, Precision 0.49377593, Recall 0.4089347, F1_score0.4473684 | ||
Pass 127, Batch 9590, Cost 3.5043538, Precision 0.4556962, Recall 0.4600639, F1_score0.45786962 | ||
Pass 127, Batch 9595, Cost 2.981989, Precision 0.44981414, Recall 0.45149255, F1_score0.4506518 | ||
[TrainSet] pass_id:127 pass_precision:[0.46023396] pass_recall:[0.43197003] pass_f1_score:[0.44565433] | ||
[TestSet] pass_id:127 pass_precision:[0.4708409] pass_recall:[0.47971722] pass_f1_score:[0.4752376] | ||
``` | ||
## 预测 | ||
1. 修改 [infer.py](./infer.py) 的 `infer` 函数,指定:需要测试的模型的路径、测试数据、字典文件,预测标记文件的路径,默认参数如下: | ||
|
||
```python | ||
infer( | ||
model_path="models/params_pass_0", | ||
batch_size=6, | ||
test_data_file="data/test", | ||
vocab_file="data/vocab.txt", | ||
target_file="data/target.txt", | ||
use_gpu=False | ||
) | ||
``` | ||
|
||
2. 在终端运行 `python infer.py`,开始测试,会看到如下预测结果(以下为训练70个pass所得模型的部分预测结果): | ||
|
||
```text | ||
leicestershire B-ORG B-LOC | ||
extended O O | ||
their O O | ||
first O O | ||
innings O O | ||
by O O | ||
DGDG O O | ||
runs O O | ||
before O O | ||
being O O | ||
bowled O O | ||
out O O | ||
for O O | ||
296 O O | ||
with O O | ||
england B-LOC B-LOC | ||
discard O O | ||
andy B-PER B-PER | ||
caddick I-PER I-PER | ||
taking O O | ||
three O O | ||
for O O | ||
DGDG O O | ||
. O O | ||
``` | ||
|
||
输出分为三列,以“\t” 分隔,第一列是输入的词语,第二列是标准结果,第三列为生成的标记结果。多条输入序列之间以空行分隔。 | ||
|
||
## 结果示例 | ||
|
||
<p align="center"> | ||
<img src="imgs/convergence_curve.png" width="80%" align="center"/><br/> | ||
图1. 学习曲线, 横轴表示训练轮数,纵轴表示F1值 | ||
</p> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
|
||
import paddle.fluid as fluid | ||
import paddle.v2 as paddle | ||
|
||
from network_conf import ner_net | ||
import reader | ||
from utils import load_dict, load_reverse_dict | ||
from utils_extend import to_lodtensor | ||
|
||
|
||
def infer(model_path, batch_size, test_data_file, vocab_file, target_file, | ||
use_gpu): | ||
""" | ||
use the model under model_path to predict the test data, the result will be printed on the screen | ||
return nothing | ||
""" | ||
word_dict = load_dict(vocab_file) | ||
word_reverse_dict = load_reverse_dict(vocab_file) | ||
|
||
label_dict = load_dict(target_file) | ||
label_reverse_dict = load_reverse_dict(target_file) | ||
|
||
test_data = paddle.batch( | ||
reader.data_reader(test_data_file, word_dict, label_dict), | ||
batch_size=batch_size) | ||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
|
||
inference_scope = fluid.core.Scope() | ||
with fluid.scope_guard(inference_scope): | ||
[inference_program, feed_target_names, | ||
fetch_targets] = fluid.io.load_inference_model(model_path, exe) | ||
for data in test_data(): | ||
word = to_lodtensor(map(lambda x: x[0], data), place) | ||
mark = to_lodtensor(map(lambda x: x[1], data), place) | ||
target = to_lodtensor(map(lambda x: x[2], data), place) | ||
crf_decode = exe.run( | ||
inference_program, | ||
feed={"word": word, | ||
"mark": mark, | ||
"target": target}, | ||
fetch_list=fetch_targets, | ||
return_numpy=False) | ||
lod_info = (crf_decode[0].lod())[0] | ||
np_data = np.array(crf_decode[0]) | ||
assert len(data) == len(lod_info) - 1 | ||
for sen_index in xrange(len(data)): | ||
assert len(data[sen_index][0]) == lod_info[ | ||
sen_index + 1] - lod_info[sen_index] | ||
word_index = 0 | ||
for tag_index in xrange(lod_info[sen_index], | ||
lod_info[sen_index + 1]): | ||
word = word_reverse_dict[data[sen_index][0][word_index]] | ||
gold_tag = label_reverse_dict[data[sen_index][2][ | ||
word_index]] | ||
tag = label_reverse_dict[np_data[tag_index][0]] | ||
print word + "\t" + gold_tag + "\t" + tag | ||
word_index += 1 | ||
print "" | ||
|
||
|
||
if __name__ == "__main__": | ||
infer( | ||
model_path="models/params_pass_0", | ||
batch_size=6, | ||
test_data_file="data/test", | ||
vocab_file="data/vocab.txt", | ||
target_file="data/target.txt", | ||
use_gpu=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import math | ||
|
||
import paddle.fluid as fluid | ||
from paddle.fluid.initializer import NormalInitializer | ||
|
||
from utils import logger, load_dict, get_embedding | ||
|
||
|
||
def ner_net(word_dict_len, label_dict_len, parallel, stack_num=2): | ||
mark_dict_len = 2 | ||
word_dim = 50 | ||
mark_dim = 5 | ||
hidden_dim = 300 | ||
IS_SPARSE = True | ||
embedding_name = 'emb' | ||
|
||
def _net_conf(word, mark, target): | ||
word_embedding = fluid.layers.embedding( | ||
input=word, | ||
size=[word_dict_len, word_dim], | ||
dtype='float32', | ||
is_sparse=IS_SPARSE, | ||
param_attr=fluid.ParamAttr( | ||
name=embedding_name, trainable=False)) | ||
|
||
mark_embedding = fluid.layers.embedding( | ||
input=mark, | ||
size=[mark_dict_len, mark_dim], | ||
dtype='float32', | ||
is_sparse=IS_SPARSE) | ||
|
||
word_caps_vector = fluid.layers.concat( | ||
input=[word_embedding, mark_embedding], axis=1) | ||
mix_hidden_lr = 1 | ||
|
||
rnn_para_attr = fluid.ParamAttr( | ||
initializer=NormalInitializer( | ||
loc=0.0, scale=0.0), | ||
learning_rate=mix_hidden_lr) | ||
hidden_para_attr = fluid.ParamAttr( | ||
initializer=NormalInitializer( | ||
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)), | ||
learning_rate=mix_hidden_lr) | ||
|
||
hidden = fluid.layers.fc( | ||
input=word_caps_vector, | ||
name="__hidden00__", | ||
size=hidden_dim, | ||
act="tanh", | ||
bias_attr=fluid.ParamAttr(initializer=NormalInitializer( | ||
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))), | ||
param_attr=fluid.ParamAttr(initializer=NormalInitializer( | ||
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)))) | ||
fea = [] | ||
for direction in ["fwd", "bwd"]: | ||
for i in range(stack_num): | ||
if i != 0: | ||
hidden = fluid.layers.fc( | ||
name="__hidden%02d_%s__" % (i, direction), | ||
size=hidden_dim, | ||
act="stanh", | ||
bias_attr=fluid.ParamAttr(initializer=NormalInitializer( | ||
loc=0.0, scale=1.0)), | ||
input=[hidden, rnn[0], rnn[1]], | ||
param_attr=[ | ||
hidden_para_attr, rnn_para_attr, rnn_para_attr | ||
]) | ||
rnn = fluid.layers.dynamic_lstm( | ||
name="__rnn%02d_%s__" % (i, direction), | ||
input=hidden, | ||
size=hidden_dim, | ||
candidate_activation='relu', | ||
gate_activation='sigmoid', | ||
cell_activation='sigmoid', | ||
bias_attr=fluid.ParamAttr(initializer=NormalInitializer( | ||
loc=0.0, scale=1.0)), | ||
is_reverse=(i % 2) if direction == "fwd" else not i % 2, | ||
param_attr=rnn_para_attr) | ||
fea += [hidden, rnn[0], rnn[1]] | ||
|
||
rnn_fea = fluid.layers.fc( | ||
size=hidden_dim, | ||
bias_attr=fluid.ParamAttr(initializer=NormalInitializer( | ||
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3))), | ||
act="stanh", | ||
input=fea, | ||
param_attr=[hidden_para_attr, rnn_para_attr, rnn_para_attr] * 2) | ||
|
||
emission = fluid.layers.fc( | ||
size=label_dict_len, | ||
input=rnn_fea, | ||
param_attr=fluid.ParamAttr(initializer=NormalInitializer( | ||
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)))) | ||
|
||
crf_cost = fluid.layers.linear_chain_crf( | ||
input=emission, | ||
label=target, | ||
param_attr=fluid.ParamAttr( | ||
name='crfw', | ||
initializer=NormalInitializer( | ||
loc=0.0, scale=(1. / math.sqrt(hidden_dim) / 3)), | ||
learning_rate=mix_hidden_lr)) | ||
avg_cost = fluid.layers.mean(x=crf_cost) | ||
return avg_cost, emission | ||
|
||
word = fluid.layers.data(name='word', shape=[1], dtype='int64', lod_level=1) | ||
mark = fluid.layers.data(name='mark', shape=[1], dtype='int64', lod_level=1) | ||
target = fluid.layers.data( | ||
name="target", shape=[1], dtype='int64', lod_level=1) | ||
|
||
if parallel: | ||
places = fluid.layers.get_places() | ||
pd = fluid.layers.ParallelDo(places) | ||
with pd.do(): | ||
word_ = pd.read_input(word) | ||
mark_ = pd.read_input(mark) | ||
target_ = pd.read_input(target) | ||
avg_cost, emission_base = _net_conf(word_, mark_, target_) | ||
pd.write_output(avg_cost) | ||
pd.write_output(emission_base) | ||
avg_cost_list, emission = pd() | ||
avg_cost = fluid.layers.mean(x=avg_cost_list) | ||
emission.stop_gradient = True | ||
else: | ||
avg_cost, emission = _net_conf(word, mark, target) | ||
|
||
return avg_cost, emission, word, mark, target |
Oops, something went wrong.