Skip to content

Commit

Permalink
Python trainer api (#193)
Browse files Browse the repository at this point in the history
* Python trainer API and demo

* Adding missing PaddleAPIPrivate.h

* Adding api_train.sh

* More comments

* Bump up patch version to 0b3
  • Loading branch information
emailweixu authored and reyoung committed Oct 27, 2016
1 parent 46bd5f5 commit cbe734b
Show file tree
Hide file tree
Showing 28 changed files with 707 additions and 312 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
project(paddle CXX C)
set(PADDLE_MAJOR_VERSION 0)
set(PADDLE_MINOR_VERSION 8)
set(PADDLE_PATCH_VERSION 0b2)
set(PADDLE_PATCH_VERSION 0b3)
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
Expand Down
1 change: 1 addition & 0 deletions cmake/swig.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ function(generate_python_api target_name)
COMMAND swig -python -c++ -outcurrentdir -I../ api/Paddle.swig
&& mv ${PROJ_ROOT}/paddle/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py
DEPENDS ${PROJ_ROOT}/paddle/api/Paddle.swig
${PROJ_ROOT}/paddle/api/PaddleAPI.h
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
COMMENT "Generate Python API from swig")
add_custom_target(${target_name} ALL DEPENDS
Expand Down
114 changes: 114 additions & 0 deletions demo/quick_start/api_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import itertools
import random

from paddle.trainer.config_parser import parse_config
from py_paddle import swig_paddle as api
from py_paddle import DataProviderConverter
from paddle.trainer.PyDataProvider2 \
import integer_value, integer_value_sequence, sparse_binary_vector

def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--train_data",
type=str, required=False, help="train data file")
parser.add_argument("--test_data", type=str, help="test data file")
parser.add_argument("--config",
type=str, required=True, help="config file name")
parser.add_argument("--dict_file", required=True, help="dictionary file")
parser.add_argument("--seq",
default=1, type=int,
help="whether use sequence training")
parser.add_argument("--use_gpu", default=0, type=int,
help="whether use GPU for training")
parser.add_argument("--trainer_count", default=1, type=int,
help="Number of threads for training")
parser.add_argument("--num_passes", default=5, type=int,
help="Number of training passes")
return parser.parse_args()

UNK_IDX = 0

def load_data(file_name, word_dict):
with open(file_name, 'r') as f:
for line in f:
label, comment = line.strip().split('\t')
words = comment.split()
word_slot = [word_dict.get(w, UNK_IDX) for w in words]
yield word_slot, int(label)

def load_dict(dict_file):
word_dict = dict()
with open(dict_file, 'r') as f:
for i, line in enumerate(f):
w = line.strip().split()[0]
word_dict[w] = i
return word_dict

def main():
options = parse_arguments()
api.initPaddle("--use_gpu=%s" % options.use_gpu,
"--trainer_count=%s" % options.trainer_count)

word_dict = load_dict(options.dict_file)
train_dataset = list(load_data(options.train_data, word_dict))
if options.test_data:
test_dataset = list(load_data(options.test_data, word_dict))
else:
test_dataset = None

trainer_config = parse_config(options.config,
"dict_file=%s" % options.dict_file)
# No need to have data provider for trainer
trainer_config.ClearField('data_config')
trainer_config.ClearField('test_data_config')

# create a GradientMachine from the model configuratin
model = api.GradientMachine.createFromConfigProto(
trainer_config.model_config)
# create a trainer for the gradient machine
trainer = api.Trainer.create(trainer_config, model)

# create a data converter which converts data to PaddlePaddle
# internal format
input_types = [
integer_value_sequence(len(word_dict)) if options.seq
else sparse_binary_vector(len(word_dict)),
integer_value(2)]
converter = DataProviderConverter(input_types)

batch_size = trainer_config.opt_config.batch_size
trainer.startTrain()
for train_pass in xrange(options.num_passes):
trainer.startTrainPass()
random.shuffle(train_dataset)
for pos in xrange(0, len(train_dataset), batch_size):
batch = itertools.islice(train_dataset, pos, pos + batch_size)
size = min(batch_size, len(train_dataset) - pos)
trainer.trainOneDataBatch(size, converter(batch))
trainer.finishTrainPass()
if test_dataset:
trainer.startTestPeriod();
for pos in xrange(0, len(test_dataset), batch_size):
batch = itertools.islice(test_dataset, pos, pos + batch_size)
size = min(batch_size, len(test_dataset) - pos)
trainer.testOneDataBatch(size, converter(batch))
trainer.finishTestPeriod()
trainer.finishTrain()

if __name__ == '__main__':
main()
29 changes: 29 additions & 0 deletions demo/quick_start/api_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e

# Note: if using trainer_config.emb.py, trainer_config.cnn.py
# or trainer_config.lstm.py, you need to change --seq to --seq=1
# because they are sequence models.
python api_train.py \
--config=trainer_config.lr.py \
--trainer_count=2 \
--num_passes=15 \
--use_gpu=0 \
--seq=0 \
--train_data=data/train.txt \
--test_data=data/test.txt \
--dict_file=data/dict.txt \
2>&1 | tee 'train.log'
2 changes: 1 addition & 1 deletion demo/quick_start/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ paddle train \
--config=$cfg \
--save_dir=./output \
--trainer_count=4 \
--log_period=20 \
--log_period=100 \
--num_passes=15 \
--use_gpu=false \
--show_parameter_stats_period=100 \
Expand Down
3 changes: 1 addition & 2 deletions demo/quick_start/trainer_config.lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from paddle.trainer_config_helpers import *

dict_file = "./data/dict.txt"
dict_file = get_config_arg('dict_file', str, "./data/dict.txt")
word_dict = dict()
with open(dict_file, 'r') as f:
for i, line in enumerate(f):
Expand Down Expand Up @@ -63,7 +63,6 @@
label = data_layer(name="label", size=2)

# Define cross-entropy classification loss and error.
classification_cost(input=output, label=label)
cls = classification_cost(input=output, label=label)
outputs(cls)
else:
Expand Down
4 changes: 2 additions & 2 deletions demo/sentiment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, train_conf, dict_file, model_dir=None, label_file = None):
conf = parse_config(train_conf, "is_predict=1")
self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
self.network.loadParameters(self.model_dir)
slots = [integer_value_sequence(self.dict_dim)]
self.converter = DataProviderConverter(slots)
input_types = [integer_value_sequence(self.dict_dim)]
self.converter = DataProviderConverter(input_types)

def load_dict(self):
"""
Expand Down
19 changes: 1 addition & 18 deletions paddle/api/Arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,10 @@ limitations under the License. */


#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"

#include "paddle/parameter/Argument.h"

struct ArgumentsPrivate {
std::vector<paddle::Argument> outputs;

inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
if (idx < outputs.size()) {
return outputs[idx];
} else {
RangeError e;
throw e;
}
}

template <typename T>
std::shared_ptr<T>& cast(void* rawPtr) const {
return *(std::shared_ptr<T>*)(rawPtr);
}
};

size_t Arguments::getSlotNum() const { return m->outputs.size(); }

Arguments* Arguments::createArguments(size_t slotNum) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ configure_file(

generate_python_api(python_swig_sources)

file(GLOB PY_PADDLE_PYTHON_FILES ${PROJ_ROOT}/paddle/py_paddle/*.py)

# TODO(yuyang18) : make wheel name calculated by cmake
add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel
Expand All @@ -55,6 +57,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
paddle_trainer
paddle_api
paddle_cuda
${PY_PADDLE_PYTHON_FILES}
)

install(DIRECTORY ${PROJ_ROOT}/paddle/dist/
Expand Down
44 changes: 13 additions & 31 deletions paddle/api/ConfigParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,9 @@ limitations under the License. */


#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/Trainer.h"

struct TrainerConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> conf;
TrainerConfigPrivate() : conf(std::make_shared<paddle::TrainerConfig>()) {}
};

struct ModelConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> conf;
};

struct ParameterConfigPrivate {
paddle::ParameterPtr parameter;
paddle::ParameterConfig config;
Expand All @@ -39,30 +31,26 @@ struct ParameterConfigPrivate {
}
};

struct OptimizationConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> trainer_config;
paddle::OptimizationConfig config;

paddle::OptimizationConfig& getConfig() {
if (trainer_config != nullptr) {
return *trainer_config->mutable_opt_config();
} else {
return config;
}
}
};

TrainerConfig::TrainerConfig() : m(new TrainerConfigPrivate()) {}

TrainerConfig::~TrainerConfig() { delete m; }

TrainerConfig* TrainerConfig::createFromTrainerConfigFile(
const std::string& confPath) {
LOG(INFO) << "load trainer config from " << confPath;
paddle::TrainerConfigHelper helper(confPath);
//! TODO(yuyang18): Make TrainerConfigPrivate to TrainerConfigHelper
auto conf = std::make_shared<paddle::TrainerConfigHelper>(confPath);
auto retv = new TrainerConfig();
*retv->m->conf = helper.getConfig();
retv->m->conf = conf;
return retv;
}

TrainerConfig* TrainerConfig::createFromProtoString(
const std::string& str) {
auto retv = new TrainerConfig();
paddle::TrainerConfig trainerConfigProto;
auto conf = std::make_shared<paddle::TrainerConfigHelper>(trainerConfigProto);
CHECK(conf->getMutableConfig().ParseFromString(str));
retv->m->conf = conf;
return retv;
}

Expand All @@ -76,10 +64,6 @@ ModelConfig* TrainerConfig::getModelConfig() const {
return retv;
}

void* ModelConfig::getPaddleModelConfig() const {
return m->conf->mutable_model_config();
}

ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {}

ParameterConfig::~ParameterConfig() {
Expand Down Expand Up @@ -132,8 +116,6 @@ OptimizationConfig* TrainerConfig::getOptimizationConfig() const {
return opt_config;
}

void* OptimizationConfig::getRawPtr() { return &m->getConfig(); }

OptimizationConfig* OptimizationConfig::createFromProtoString(
const std::string& str) {
auto conf = new OptimizationConfig();
Expand Down
18 changes: 5 additions & 13 deletions paddle/api/GradientMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,22 @@ limitations under the License. */


#include "PaddleAPI.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "PaddleAPIPrivate.h"

#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "Internal.h"

std::vector<int> GradientMachine::defaultParamTypes = {
PARAMETER_VALUE, PARAMETER_GRADIENT, PARAMETER_MOMENTUM};

struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;

template <typename T>
inline T& cast(void* ptr) {
return *(T*)(ptr);
}
};

GradientMachine::GradientMachine() : m(new GradientMachinePrivate()) {}

GradientMachine::~GradientMachine() { delete m; }

GradientMachine* GradientMachine::createFromPaddleModelPtr(
void* confPtr, GradientMatchineCreateMode mode,
const void* confPtr, GradientMatchineCreateMode mode,
const std::vector<int>& types) {
auto& conf = *(paddle::ModelConfig*)(confPtr);
auto& conf = *(const paddle::ModelConfig*)(confPtr);
std::vector<ParameterType> realTypes;
staticCastVector(&realTypes, types);
auto machineRawPtr = paddle::GradientMachine::create(conf, mode, realTypes);
Expand Down Expand Up @@ -66,7 +58,7 @@ GradientMachine* GradientMachine::createByConfigProtoStr(
GradientMachine* GradientMachine::createByModelConfig(
ModelConfig* conf, GradientMatchineCreateMode mode,
const std::vector<int>& types) {
auto confPtr = (paddle::ModelConfig*)conf->getPaddleModelConfig();
auto confPtr = &conf->m->conf->getModelConfig();
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
}

Expand Down
Loading

0 comments on commit cbe734b

Please sign in to comment.