Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean train #201

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions paddle/trainer/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,8 @@ P_DEFINE_string(model_list, "",

namespace paddle {

void Trainer::init(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);

auto config = TrainerConfigHelper::createFromFlagConfig();
feenableexcept(FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW);

init(config);
}

void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
bool testing,
const std::shared_ptr<GradientMachine> &gradientMachine,
const std::shared_ptr<DataProvider> &dataProvider,
const std::shared_ptr<DataProvider> &testDataProvider) {
bool testing) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interface is used by other projects. Also used by recent pull request: https://github.com/baidu/Paddle/pull/193/files#diff-71eb11a1ef9cbe77f5e207dc915cc8b6

this->stats_ = std::make_shared<TrainerStats>();

config_ = config;
Expand Down Expand Up @@ -171,7 +158,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
}

// initialize trainer internal
trainerInternal_.init(config_, gradientMachine,
trainerInternal_.init(config_,
TrainerInternalConfig::createFromMode(mode_),
stats_, testing);
std::unique_ptr<ParameterUtilConfig> paramConfig(
Expand All @@ -192,8 +179,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
(!IGradientMachineMode::dataMustInCpu(mode_,
FLAGS_trainer_count));

dataProvider_ = dataProvider;
if (!dataProvider_ && config_->hasDataConfig()) {
if (config_->hasDataConfig()) {
dataProvider_.reset(DataProvider::create(*config_, *config_, gpuData));
}
if (dataProvider_) {
Expand All @@ -209,8 +195,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
}
}

testDataProvider_ = testDataProvider;
if (!testDataProvider_ && config_->hasTestDataConfig()) {
if (config_->hasTestDataConfig()) {
testDataProvider_.reset(
DataProvider::create(config_->getTestDataConfig(), *config_, gpuData));
}
Expand Down
15 changes: 1 addition & 14 deletions paddle/trainer/Trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,10 @@ class Trainer {
*
* @param config TrainerConfig.
* @param testing true if only for testing
* @param gradientMachine GradientMachine that will be trained.
* nullptr if create from config.
* @param dataProvider Train Data Provider. null if create from config.
* @param testDataProvider Test Data Provider. null if create from config.
*/
virtual void init(
const std::shared_ptr<TrainerConfigHelper> &config,
bool testing = false,
const std::shared_ptr<GradientMachine> &gradientMachine = nullptr,
const std::shared_ptr<DataProvider> &dataProvider = nullptr,
const std::shared_ptr<DataProvider> &testDataProvider = nullptr);

/**
* Initialize Trainer from command line flags.
*/
void init(int argc, char** argv);

bool testing = false);

/**
* Train until num_passes reached.
Expand Down
10 changes: 3 additions & 7 deletions paddle/trainer/TrainerInternal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ limitations under the License. */
namespace paddle {

void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config,
const GradientMachinePtr &gradientMachine,
std::unique_ptr<TrainerInternalConfig> &&intconfig,
const std::shared_ptr<TrainerStats> &stats,
bool testing) {
Expand All @@ -53,12 +52,9 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config,
createParameterUpdater(testing);
}

gradientMachine_ = gradientMachine;
if (!gradientMachine) {
gradientMachine_.reset(GradientMachine::create(
config_->getConfig().model_config(), intconfig_->mode,
parameterUpdater_->getParameterTypes()));
}
gradientMachine_.reset(GradientMachine::create(
config_->getConfig().model_config(), intconfig_->mode,
parameterUpdater_->getParameterTypes()));
}

void TrainerInternal::trainOneBatch(int64_t batchId,
Expand Down
2 changes: 0 additions & 2 deletions paddle/trainer/TrainerInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ class TrainerInternal {
/**
* Intializes trainer internal class
* @param config network config
* @param machine gradient machine
* @param intconfig training config
* @param stats training stats
* @param testing if it is in testing phase
*/
void init(const std::shared_ptr<TrainerConfigHelper> &config,
const GradientMachinePtr &machine,
std::unique_ptr<TrainerInternalConfig> &&intconfig,
const std::shared_ptr<TrainerStats> &stats,
bool testing);
Expand Down