-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Re-write Chapter 1 in Book to use new Fluid API #524
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,8 +54,9 @@ After setting up our model, there are several major steps to go through to train | |
Our program starts with importing necessary packages: | ||
|
||
```python | ||
import paddle.v2 as paddle | ||
import paddle.v2.dataset.uci_housing as uci_housing | ||
import paddle | ||
import paddle.fluid as fluid | ||
import numpy | ||
``` | ||
|
||
We encapsulated the [UCI Housing Data Set](https://archive.ics.uci.edu/ml/datasets/Housing) in our Python module `uci_housing`. This module can | ||
|
@@ -116,49 +117,58 @@ When training complex models, we usually have one more split: the validation set | |
|
||
`fit_a_line/trainer.py` demonstrates the training using [PaddlePaddle](http://paddlepaddle.org). | ||
|
||
### Initialize PaddlePaddle | ||
### Datafeeder Configuration | ||
|
||
```python | ||
paddle.init(use_gpu=False, trainer_count=1) | ||
``` | ||
We first define data feeders for test and train. The feeder reads a `BATCH_SIZE` of data each time and feed them to the training/testing process. Users can shuffle a batch out of a `buf_size` in order to make the data random. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users can shuffle a batch out of a |
||
|
||
### Model Configuration | ||
```python | ||
BATCH_SIZE = 20 | ||
|
||
Linear regression is essentially a fully-connected layer with linear activation: | ||
train_reader = paddle.batch( | ||
paddle.reader.shuffle( | ||
paddle.dataset.uci_housing.train(), buf_size=500), | ||
batch_size=BATCH_SIZE) | ||
|
||
```python | ||
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) | ||
y_predict = paddle.layer.fc(input=x, | ||
size=1, | ||
act=paddle.activation.Linear()) | ||
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) | ||
cost = paddle.layer.square_error_cost(input=y_predict, label=y) | ||
test_reader = paddle.batch( | ||
paddle.reader.shuffle( | ||
paddle.dataset.uci_housing.test(), buf_size=500), | ||
batch_size=BATCH_SIZE) | ||
``` | ||
|
||
### Save Topology | ||
### Train Program Configuration | ||
The train_program must return the avg_loss as its first returned parameter and then use the inference_program to setup the train_program | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might need to explain a bit more on the Train program topic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel "use the inference_program to setup the train_program" little bit confusing because we did not show inference_program until later |
||
|
||
```python | ||
# Save the inference topology to protobuf. | ||
inference_topology = paddle.topology.Topology(layers=y_predict) | ||
with open("inference_topology.pkl", 'wb') as f: | ||
inference_topology.serialize_for_inference(f) | ||
def train_program(): | ||
y = fluid.layers.data(name='y', shape=[1], dtype='float32') | ||
|
||
# feature vector of length 13 | ||
x = fluid.layers.data(name='x', shape=[13], dtype='float32') | ||
y_predict = fluid.layers.fc(input=x, size=1, act=None) | ||
|
||
loss = fluid.layers.square_error_cost(input=y_predict, label=y) | ||
avg_loss = fluid.layers.mean(loss) | ||
|
||
return avg_loss | ||
``` | ||
|
||
|
||
### Create Parameters | ||
### Specify Place | ||
Specify your training environment, you should specify if the training is on CPU or GPU. | ||
|
||
```python | ||
parameters = paddle.parameters.create(cost) | ||
use_cuda = False | ||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() | ||
``` | ||
|
||
### Create Trainer | ||
The trainer will take the train_program. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... train_program as input. |
||
|
||
```python | ||
optimizer = paddle.optimizer.Momentum(momentum=0) | ||
|
||
trainer = paddle.trainer.SGD(cost=cost, | ||
parameters=parameters, | ||
update_equation=optimizer) | ||
trainer = fluid.Trainer( | ||
train_func=train_program, | ||
place=place, | ||
optimizer=fluid.optimizer.SGD(learning_rate=0.001)) | ||
``` | ||
|
||
### Feeding Data | ||
|
@@ -168,105 +178,92 @@ PaddlePaddle provides the | |
for loading the training data. A reader may return multiple columns, and we need a Python dictionary to specify the mapping from column index to data layers. | ||
|
||
```python | ||
feeding={'x': 0, 'y': 1} | ||
feed_order=['x', 'y'] | ||
``` | ||
|
||
Moreover, an event handler is provided to print the training progress: | ||
|
||
```python | ||
# event_handler to print training and testing info | ||
def event_handler(event): | ||
if isinstance(event, paddle.event.EndIteration): | ||
if event.batch_id % 100 == 0: | ||
print "Pass %d, Batch %d, Cost %f" % ( | ||
event.pass_id, event.batch_id, event.cost) | ||
|
||
if isinstance(event, paddle.event.EndPass): | ||
result = trainer.test( | ||
reader=paddle.batch( | ||
uci_housing.test(), batch_size=2), | ||
feeding=feeding) | ||
print "Test %d, Cost %f" % (event.pass_id, result.cost) | ||
``` | ||
# Specify the directory path to save the parameters | ||
params_folder = "fit_a_line.inference.model" | ||
|
||
```python | ||
# event_handler to plot training and testing info | ||
# Plot data | ||
from paddle.v2.plot import Ploter | ||
|
||
train_title = "Train cost" | ||
test_title = "Test cost" | ||
plot_cost = Ploter(train_title, test_title) | ||
|
||
step = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EndStepEvent provide |
||
|
||
def event_handler_plot(event): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you are plotting the graph, why not keep using the |
||
# event_handler to print training and testing info | ||
def event_handler(event): | ||
global step | ||
if isinstance(event, paddle.event.EndIteration): | ||
if step % 10 == 0: # every 10 batches, record a train cost | ||
plot_cost.append(train_title, step, event.cost) | ||
|
||
if isinstance(event, fluid.EndStepEvent): | ||
if step % 100 == 0: # every 100 batches, record a test cost | ||
result = trainer.test( | ||
reader=paddle.batch( | ||
uci_housing.test(), batch_size=2), | ||
feeding=feeding) | ||
plot_cost.append(test_title, step, result.cost) | ||
test_metrics = trainer.test( | ||
reader=test_reader, feed_order=feed_order) | ||
|
||
if step % 100 == 0: # every 100 batches, update cost plot | ||
print(test_metrics[0]) | ||
|
||
plot_cost.append(test_title, step, test_metrics[0]) | ||
plot_cost.plot() | ||
|
||
step += 1 | ||
if test_metrics[0] < 10.0: | ||
# If the accuracy is good enough, we can stop the training. | ||
print('loss is less than 10.0, stop') | ||
trainer.stop() | ||
|
||
if step >= 2000: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if we need to terminate early. Fit a line is pretty small. |
||
# Or if it has been running for enough steps | ||
print('has been running for 2000 steps, stop') | ||
trainer.stop() | ||
|
||
if isinstance(event, paddle.event.EndPass): | ||
if event.pass_id % 10 == 0: | ||
with open('params_pass_%d.tar' % event.pass_id, 'w') as f: | ||
trainer.save_parameter_to_tar(f) | ||
# We can save the trained parameters for the inferences later | ||
if params_folder is not None: | ||
trainer.save_params(params_folder) | ||
|
||
step += 1 | ||
``` | ||
|
||
### Start Training | ||
|
||
```python | ||
%matplotlib inline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this line do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will show the image inline in Jupyter notebook, otherwise it shows something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It helps with plotting the graphics inline. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and it needs to be inside of a python block? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
|
||
# The training could take up to a few minutes. | ||
trainer.train( | ||
reader=paddle.batch( | ||
paddle.reader.shuffle( | ||
uci_housing.train(), buf_size=500), | ||
batch_size=2), | ||
feeding=feeding, | ||
event_handler=event_handler_plot, | ||
num_passes=30) | ||
reader=train_reader, | ||
num_epochs=100, | ||
event_handler=event_handler, | ||
feed_order=feed_order) | ||
|
||
``` | ||
|
||
![png](./image/train_and_test.png) | ||
|
||
### Apply model | ||
### Inference | ||
|
||
Initialize the Inferencer with the inference_program and the params_folder, which is where we saved our params | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it should be |
||
|
||
#### 1. generate testing data | ||
#### Setup the Inference Program. | ||
Similar to the trainer.train, the Inferencer needs to take an inference_program to do inferring. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to do "inference"? |
||
Prune the train_program to only have the y_predict. | ||
|
||
```python | ||
test_data_creator = paddle.dataset.uci_housing.test() | ||
test_data = [] | ||
test_label = [] | ||
|
||
for item in test_data_creator(): | ||
test_data.append((item[0],)) | ||
test_label.append(item[1]) | ||
if len(test_data) == 5: | ||
break | ||
def inference_program(): | ||
x = fluid.layers.data(name='x', shape=[13], dtype='float32') | ||
y_predict = fluid.layers.fc(input=x, size=1, act=None) | ||
return y_predict | ||
``` | ||
|
||
#### 2. inference | ||
|
||
```python | ||
# load parameters from tar file. | ||
# users can remove the comments and change the model name | ||
# with open('params_pass_20.tar', 'r') as f: | ||
# parameters = paddle.parameters.Parameters.from_tar(f) | ||
inferencer = fluid.Inferencer( | ||
infer_func=inference_program, param_path=params_folder, place=place) | ||
|
||
probs = paddle.infer( | ||
output_layer=y_predict, parameters=parameters, input=test_data) | ||
batch_size = 10 | ||
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32") | ||
|
||
for i in xrange(len(probs)): | ||
print "label=" + str(test_label[i][0]) + ", predict=" + str(probs[i][0]) | ||
results = inferencer.infer({'x': tensor_x}) | ||
print("infer results: ", results[0]) | ||
``` | ||
|
||
## Summary | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All these sections including training and inference are under ##Dataset section, should we bring these 1 level up?