Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise python save load api using new load/save op #7995

Merged
merged 10 commits into from
Feb 1, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/v2/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def find_name(var_list, name):
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'parallel_do'
'recv', 'parallel_do', 'save_combine', 'load_combine'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
Expand Down
141 changes: 95 additions & 46 deletions python/paddle/v2/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def is_parameter(var):


def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xreki I have changed the code and go with option 3 using your suggestion. For option 2, there is problem. Because in the python side of the code, the operator op field of var will only be associated with the operator that have this variable as its output. So for feed variable, since it is not the output of any operator. Its op data member will be None.

return var.persistable


Expand All @@ -60,7 +63,12 @@ def _clone_var_in_block_(block, var):
persistable=True)


def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
def save_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
save_file_name=None):
"""
Save variables to directory by executor.

Expand All @@ -69,9 +77,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
:param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program.
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved.
:param vars: variables need to be saved. If specify vars, program & predicate
as a bool. If it returns true, the corresponding input variable will be saved.
:param vars: variables need to be saved. If vars is specified, program & predicate
will be ignored
:param save_file_name: The name of a single file that all vars are saved to.
If it is None, save variables to separate files.

:return: None
"""
if vars is None:
Expand All @@ -83,21 +94,39 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
save_vars(
executor,
dirname=dirname,
vars=filter(predicate, main_program.list_vars()))
vars=filter(predicate, main_program.list_vars()),
save_file_name=save_file_name)
else:
save_program = Program()
save_block = save_program.global_block()

save_var_map = {}
for each_var in vars:
new_var = _clone_var_in_block_(save_block, each_var)
if save_file_name is None:
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
save_var_map[new_var.name] = new_var

if save_file_name is not None:
save_var_list = []
for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name])

save_block.append_op(
type='save',
inputs={'X': [new_var]},
type='save_combine',
inputs={'X': save_var_list},
outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)})
attrs={'file_path': os.path.join(dirname, save_file_name)})

executor.run(save_program)


def save_params(executor, dirname, main_program=None):
def save_params(executor, dirname, main_program=None, save_file_name=None):
"""
Save all parameters to directory with executor.
"""
Expand All @@ -106,10 +135,12 @@ def save_params(executor, dirname, main_program=None):
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_parameter)
predicate=is_parameter,
save_file_name=save_file_name)


def save_persistables(executor, dirname, main_program=None):
def save_persistables(executor, dirname, main_program=None,
save_file_name=None):
"""
Save all persistables to directory with executor.
"""
Expand All @@ -118,21 +149,30 @@ def save_persistables(executor, dirname, main_program=None):
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_persistable)
predicate=is_persistable,
save_file_name=save_file_name)


def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
def load_vars(executor,
dirname,
main_program=None,
vars=None,
predicate=None,
load_file_name=None):
"""
Load variables from directory by executor.

:param executor: executor that save variable
:param executor: executor that load variable
:param dirname: directory path
:param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default default_main_program().
:param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program &
as a bool. If it returns true, the corresponding input variable will be loaded.
:param vars: variables need to be loaded. If vars is specified, program &
predicate will be ignored
:param load_file_name: The name of the single file that all vars are loaded from.
If it is None, load variables from separate files.

:return: None
"""
if vars is None:
Expand All @@ -144,42 +184,62 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
load_vars(
executor,
dirname=dirname,
vars=filter(predicate, main_program.list_vars()))
vars=filter(predicate, main_program.list_vars()),
load_file_name=load_file_name)
else:
load_prog = Program()
load_block = load_prog.global_block()

load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
new_var = _clone_var_in_block_(load_block, each_var)
if load_file_name is None:
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
else:
load_var_map[new_var.name] = new_var

if load_file_name is not None:
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])

load_block.append_op(
type='load',
type='load_combine',
inputs={},
outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)})
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, load_file_name)})

executor.run(load_prog)


def load_params(executor, dirname, main_program=None):
def load_params(executor, dirname, main_program=None, load_file_name=None):
"""
load all parameters from directory by executor.
"""
load_vars(
executor,
dirname=dirname,
main_program=main_program,
predicate=is_parameter)
predicate=is_parameter,
load_file_name=load_file_name)


def load_persistables(executor, dirname, main_program=None):
def load_persistables(executor, dirname, main_program=None,
load_file_name=None):
"""
load all persistables from directory by executor.
"""
load_vars(
executor,
dirname=dirname,
main_program=main_program,
predicate=is_persistable)
predicate=is_persistable,
load_file_name=load_file_name)


def get_inference_program(target_vars, main_program=None):
Expand Down Expand Up @@ -238,7 +298,8 @@ def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
main_program=None):
main_program=None,
save_file_name=None):
"""
Build a model especially for inference,
and save it to directory by the executor.
Expand All @@ -249,6 +310,8 @@ def save_inference_model(dirname,
:param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model.
Default default_main_program().
:param save_file_name: The name of a single file that all parameters are saved to.
If it is None, save parameters to separate files.

:return: None
"""
Expand Down Expand Up @@ -283,25 +346,7 @@ def save_inference_model(dirname,
with open(model_file_name, "wb") as f:
f.write(inference_program.desc.serialize_to_string())

save_params(executor, dirname, main_program)


def load_persistables_if_exist(executor, dirname, main_program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)

def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames

load_vars(
executor,
dirname,
main_program=main_program,
vars=None,
predicate=_is_presistable_and_exist_)
save_persistables(executor, dirname, inference_program, save_file_name)


def get_feed_targets_names(program):
Expand All @@ -322,13 +367,15 @@ def get_fetch_targets_names(program):
return fetch_targets_names


def load_inference_model(dirname, executor):
def load_inference_model(dirname, executor, load_file_name=None):
"""
Load inference model from a directory

:param dirname: directory path
:param executor: executor that load inference model

:param load_file_name: The name of the single file that all parameters are loaded from.
If it is None, load parameters from separate files.

:return: [program, feed_target_names, fetch_targets]
program: program especially for inference.
feed_target_names: Names of variables that need to feed data
Expand All @@ -342,7 +389,7 @@ def load_inference_model(dirname, executor):
program_desc_str = f.read()

program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
load_persistables(executor, dirname, program, load_file_name)

feed_target_names = get_feed_targets_names(program)
fetch_target_names = get_fetch_targets_names(program)
Expand All @@ -359,6 +406,7 @@ def get_parameter_value(para, executor):

:param executor: executor for retrieving the value
:param para: the given parameter

:return: the LoDTensor for the parameter
"""
assert is_parameter(para)
Expand All @@ -377,6 +425,7 @@ def get_parameter_value_by_name(name, executor, program=None):
:param name: the name of the parameter
:param program: the program where the variable is found
Default default_main_program().

:return: the LoDTensor for the variable
"""
if program is None:
Expand Down