Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

大神,好,训练模型后如何使用模型呢? #116

Open
wangwisdom opened this issue Apr 8, 2018 · 9 comments
Open

大神,好,训练模型后如何使用模型呢? #116

wangwisdom opened this issue Apr 8, 2018 · 9 comments

Comments

@wangwisdom
Copy link

wangwisdom commented Apr 8, 2018

大神,好,训练模型后如何使用模型呢?意思是如何输入一个句子进行测试呢?

@forever1dream
Copy link

forever1dream commented Apr 10, 2018

项目使用的是bazel编译,比较麻烦,修改一下seg_backend_api.cc 这个文件,如下:
/*

  • Copyright 2016- 2018 Koth. All Rights Reserved.
  • =====================================================================================
  • Filename: seg_backend_api.cc
  • Author: Koth
  • Create Time: 2016-11-20 20:43:26
  • Description:

*/
#include
#include
#include
#include
#include

#include "base/base.h"
#include "jsonxx.h"
#include "basic_string_util.h"
#include "tf_seg_model.h"
#include "pos_tagger.h"
#include "third_party/crow/include/crow.h"
#include "tensorflow/core/platform/init_main.h"

DEFINE_int32(port, 9090, "the api serving binding port");
DEFINE_string(model_path, "tensor_decode/models/seg_model.pbtxt", "the model path");
DEFINE_string(vocab_path, "tensor_decode/models/basic_vocab.txt", "char vocab path");
DEFINE_string(pos_model_path, "tensor_decode/models/pos_model.pbtxt", "the pos tagging model path");
DEFINE_string(word_vocab_path, "tensor_decode/models/word_vocab.txt", "word vocab path");
DEFINE_string(pos_vocab_path, "tensor_decode/models/pos_vocab.txt", "pos vocab path");
DEFINE_int32(max_sentence_len, 80, "max sentence len ");
DEFINE_string(user_dict_path, "", "user dict path");
DEFINE_int32(max_word_num, 50, "max num of word per sentence ");

class SegMiddleware
{
public:
struct context {};
SegMiddleware() {}
~SegMiddleware() {}
void before_handle(crow::request& req, crow::response& res, context& ctx) {}
void after_handle(crow::request& req, crow::response& res, context& ctx) {}
private:
};

int main(int argc, char* argv[])
{
tensorflow::port::InitMain(argv[0], &argc, &argv);
google::ParseCommandLineFlags(&argc, &argv, true);
crow::App app;
kcws::TfSegModel model;
CHECK(model.LoadModel(FLAGS_model_path,
FLAGS_vocab_path,
FLAGS_max_sentence_len,
FLAGS_user_dict_path))
<< "Load model error";

if (!FLAGS_pos_model_path.empty()) {
	kcws::PosTagger* tagger = new kcws::PosTagger;
	CHECK(tagger->LoadModel(FLAGS_pos_model_path,
				FLAGS_word_vocab_path,
				FLAGS_vocab_path,
				FLAGS_pos_vocab_path,
				FLAGS_max_word_num)) << "load pos model error";
	model.SetPosTagger(tagger);
}
std::ifstream in("/home/work/test/nlu-ner/train_and_dict/model/origin_corpus/origin_corpus");
std::string sentence = "";
while (std::getline(in, sentence)) {
//while (1) {
//	std::cout << "please input query:";
//	std::cin >> sentence;
//	if (sentence == "q") {
//		return 0;
//	}

	std::vector<std::string> result;
	std::vector<std::string> tags;
	std::string desc = "";
	//std::cout << "input sentence is:" << sentence << std::endl;
	if (model.Segment(sentence, &result, &tags)) {
		int status = 0;
		//std::cout << "result size:" << result.size() << std::endl;
		//std::cout << "tags size:" << tags.size() << std::endl;
		if (result.size() == tags.size()) {
			int nl = result.size();
			for (int i = 0; i < nl; i++) {
				std::cout << result[i] << "/" << tags[i] << " ";
			}
			std::cout << std::endl;
		} else {
			for (std::string str : result) {
				std::cout << str <<  " ";
			}
			std::cout << std::endl;
		}
		//std::cout << "segments" << std::endl;
	} else {
		desc = "Parse request error";
	}

	//std::cout << "status" << std::endl;;
	//std::cout << "msg" << desc << std::endl;
}

return 0;

}
然后 自己写个编译命令编译,就可以直接运行,如下:
#!/bin/bash

set -e -x

g++ -std=c++11 -o seg_backend_api ./kcws/cc/seg_backend_api.cc
./kcws/cc/pos_tagger.cc ./kcws/cc/sentence_breaker.cc ./kcws/cc/tf_seg_model.cc ./kcws/cc/viterbi_decode.cc
./utils/basic_vocab.cc ./utils/jsonxx.cc ./utils/py_word2vec_vob.cc ./utils/word2vec_vob.cc
./tfmodel/tfmodel.cc
-g -Wall -D_DEBUG -Wshadow -Wno-sign-compare -w -Xlinker -export-dynamic
-I../tensorflow/
-I./kcws/cc/
-I./utils/
-I./tfmodel/
-I./third_party/gflags/include/
-I./third_party/glog/include/
-I/home/soft/boost/include/
-I/usr/include/python2.7/
-I../tensorflow/tensorflow/contrib/makefile/gen/proto
-I../tensorflow/tensorflow/contrib/makefile/downloads/eigen
-I../tensorflow/tensorflow/contrib/makefile/gen/protobuf/include
-I../tensorflow/tensorflow/contrib/makefile/downloads/nsync/public/
-L../tensorflow/bazel-bin/tensorflow -ltensorflow_cc
-L../tensorflow/bazel-bin/tensorflow -ltensorflow_framework
-L./third_party/gflags/lib -lgflags
-L./third_party/gflags/lib -lgflags_nothreads
-L./third_party/glog/lib -lglog
-L/home/soft/boost/lib -lboost_system
-L/usr/lib64 -lpython2.7
-lm
-ldl
-lpthread
注意修改下 自己的代码库路径

@AlleyEli
Copy link

@forever1dream 能再解释下上面的方法吗?为何不用bazel编译,然后具体使用依然不懂啊?

@forever1dream
Copy link

forever1dream commented Apr 20, 2018

@AlleyEli 我是觉得用bazel编译比较麻烦,而且集成到自己的项目中比较麻烦,上面的方法主要是弄清楚解码所需要的依赖 -I是依赖的头文件,-L是所需的.a或者.so 可以直接根据这个些makefile或者其他的,修改的代码主要是将源码中的网络服务去掉,改成,本地输入测试,或者文件测试
image
image

@AlleyEli
Copy link

@forever1dream 非常感谢!

@AlleyEli
Copy link

@forever1dream 能告诉下训练和使用时候环境吗?
我的Ubuntu16.04 + python2.7 + bzel0.45 + tf1.7.0
我对比下

@forever1dream
Copy link

forever1dream commented May 18, 2018 via email

@AlleyEli
Copy link

AlleyEli commented May 18, 2018

@forever1dream thanks
我自己用python封装了一套训练加使用的接口(https://github.com/AlleyEli/kcws), 测试也通过了;
主要是对比下环境,之前因为版本兼容问题,困扰好久,所以想看看你的运行环境

@forever1dream
Copy link

@AlleyEli 我把整个编译过程也放到github(https://github.com/forever1dream/cplus-kcws) 修改下自己的tensorflow安装路径(把该项目和tensorflow安装路径放在同级目录,就可以了) 和 boost路径,就可以运行了。感谢你的Python训练方法,我再去看看,多谢啦。

@AlleyEli
Copy link

@forever1dream OK

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants