Skip to content

Commit

Permalink
Add limit support for load_dygraph loading jit.save result (#25935)
Browse files Browse the repository at this point in the history
* add limit support for load_dygraph loading jit.save result

* simplify unittest

* add unittests for coverage

* remove encoding limit of loading extra var info
  • Loading branch information
chenwhql authored Aug 7, 2020
1 parent 12bf9d7 commit 3eee046
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 34 deletions.
94 changes: 78 additions & 16 deletions python/paddle/fluid/dygraph/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import os
import collections
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle
import six
from . import learning_rate_scheduler
import warnings
from .. import core
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars

__all__ = [
'save_dygraph',
Expand Down Expand Up @@ -140,22 +141,83 @@ def load_dygraph(model_path, keep_name_table=False):
elif model_prefix.endswith(".pdopt"):
model_prefix = model_prefix[:-6]

params_file_path = model_prefix + ".pdparams"
if not os.path.exists(params_file_path):
raise RuntimeError("Parameter file [ {} ] not exists".format(
params_file_path))

with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')

if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]
para_dict = None
opti_dict = None
params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt"
if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path):
# Load state dict by `jit.save` save format
# TODO(chenweihang): [Why not support `io.save_infernece_model` save format here]
# The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph.
# Although we reluctantly restore the `state_dict` in some scenarios,
# this may not be complete and there are some limitations, so this function
# will be considered later. The limitations include:
# 1. `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_dict`,
# but this argument is currently not public
# 2. if `save_inference_model` save all persistable variables in a single file,
# user need to give the variable name list to load `state_dict`

# 1. check model path
if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." %
model_prefix)
# 2. load `__variables.info__`
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if not os.path.exists(var_info_path):
raise RuntimeError(
"No target can be loaded. Now only supports loading `state_dict` from "
"the result saved by `imperative.save` and `imperative.jit.save`."
)
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
# 3. load `__variables__`
# TODO(chenweihang): now only supports loading from default save format:
# - all persistable vars saved in one file named `__variables__`
# for other case, we may need to modify the arguments of this API
var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME)
if not os.path.exists(var_file_path):
raise RuntimeError(
"The parameter file to be loaded was not found. "
"Now only supports loading from the default save format, "
"and does not support custom params_filename and "
"save parameters separately.")
# 4. load all persistable vars
load_var_list = []
for name in sorted(extra_var_info):
var = _varbase_creator(name=name, persistable=True)
load_var_list.append(var)
_dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
# 5. construct state_dict
para_dict = dict()
for var in load_var_list:
structured_name = extra_var_info[var.name].get('structured_name',
None)
if structured_name is None:
raise RuntimeError(
"Cannot find saved variable (%s)'s structured name in saved model.",
var.name)
para_dict[structured_name] = var.numpy()
# NOTE: `jit.save` doesn't save optimizer state
else:
# Load state dict by `save_dygraph` save format
if os.path.exists(params_file_path):
with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')

if not keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"]

if os.path.exists(opti_file_path):
with open(opti_file_path, 'rb') as f:
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')

return para_dict, opti_dict
3 changes: 1 addition & 2 deletions python/paddle/fluid/dygraph/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,7 @@ def _load_persistable_vars(model_path,
params_filename=None):
# 1. load extra var info
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
extra_var_info = pickle.load(f)

# 2. construct var dict
load_var_dict = dict()
Expand Down
78 changes: 62 additions & 16 deletions python/paddle/fluid/tests/unittests/test_jit_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import print_function

import os
import unittest
import numpy as np

import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME

BATCH_SIZE = 32
BATCH_NUM = 20
Expand Down Expand Up @@ -77,8 +79,8 @@ def forward(self, x):

def train(layer):
# create optimizer
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.1, parameter_list=layer.parameters())
adam = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=layer.parameters())
# create data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=5)
train_loader.set_batch_generator(random_batch_reader())
Expand Down Expand Up @@ -111,44 +113,63 @@ def setUp(self):
# config seed
fluid.default_main_program().random_seed = SEED

def train_and_save_model(self):
def train_and_save_model(self, model_path=None, configs=None):
layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer)
final_model_path = model_path if model_path else self.model_path
orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save(
layer=layer, model_path=self.model_path, input_spec=example_inputs)
layer=layer,
model_path=final_model_path,
input_spec=example_inputs,
configs=configs)
new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types)
return layer

def test_save(self):
# train and save model
self.train_and_save_model()

def test_load_infernece(self):
def test_save_load(self):
# train and save model
train_layer = self.train_and_save_model()
# load model
infer_layer = fluid.dygraph.jit.load(self.model_path)
program_translator = ProgramTranslator()
program_translator.enable(False)
loaded_layer = fluid.dygraph.jit.load(self.model_path)
self.load_and_inference(train_layer, loaded_layer)
self.load_dygraph_state_dict(train_layer)
self.load_and_finetune(train_layer, loaded_layer)
program_translator.enable(True)

def load_and_inference(self, train_layer, infer_layer):
train_layer.eval()
infer_layer.eval()
# inference & compare
x = fluid.dygraph.to_variable(
np.random.random((1, 784)).astype('float32'))
self.assertTrue(
np.array_equal(train_layer(x).numpy(), infer_layer(x).numpy()))

def test_load_finetune(self):
# train and save model
train_layer = self.train_and_save_model()
# load model
load_train_layer = fluid.dygraph.jit.load(self.model_path)
def load_and_finetune(self, train_layer, load_train_layer):
train_layer.train()
load_train_layer.train()
# train & compare
_, _, train_loss = train(train_layer)
_, _, load_train_loss = train(load_train_layer)
self.assertTrue(
np.array_equal(train_loss.numpy(), load_train_loss.numpy()))

def load_dygraph_state_dict(self, train_layer):
train_layer.eval()
# contruct new model
new_layer = LinearNet(784, 1)
model_dict, _ = fluid.dygraph.load_dygraph(self.model_path)
new_layer.set_dict(model_dict)
new_layer.eval()
# inference & compare
x = fluid.dygraph.to_variable(
np.random.random((1, 784)).astype('float32'))
self.assertTrue(
np.array_equal(train_layer(x).numpy(), new_layer(x).numpy()))

def test_save_get_program_failed(self):
layer = LinearNetNotDeclarative(784, 1)
example_inputs, layer, _ = train(layer)
Expand All @@ -158,6 +179,31 @@ def test_save_get_program_failed(self):
model_path=self.model_path,
input_spec=example_inputs)

def test_load_dygraoh_no_path(self):
model_path = "model.test_jit_save_load.no_path"
new_layer = LinearNet(784, 1)
with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)

def test_load_dygraph_no_var_info(self):
model_path = "model.test_jit_save_load.no_var_info"
self.train_and_save_model(model_path=model_path)
# remove `__variables.info__`
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
os.remove(var_info_path)
new_layer = LinearNet(784, 1)
with self.assertRaises(RuntimeError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)

def test_load_dygraph_not_var_file(self):
model_path = "model.test_jit_save_load.no_var_file"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.params_filename = "__params__"
self.train_and_save_model(model_path=model_path, configs=configs)
new_layer = LinearNet(784, 1)
with self.assertRaises(RuntimeError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)


class TestJitSaveLoadConfig(unittest.TestCase):
def setUp(self):
Expand Down

1 comment on commit 3eee046

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅❤️ Bravo! Your pull request passed all CI. 💙

Please sign in to comment.