-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revise python save load api using new load/save op #7995
Changes from 7 commits
16f62cc
5084ff3
32d7551
93ecb79
a9082bf
57ede7f
e6fba90
b7565ea
7e200ec
7db51df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,7 +60,12 @@ def _clone_var_in_block_(block, var): | |
persistable=True) | ||
|
||
|
||
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): | ||
def save_vars(executor, | ||
dirname, | ||
main_program=None, | ||
vars=None, | ||
predicate=None, | ||
save_file_name=None): | ||
""" | ||
Save variables to directory by executor. | ||
|
||
|
@@ -69,9 +74,12 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): | |
:param main_program: program. If vars is None, then filter all variables in this | ||
program which fit `predicate`. Default default_main_program. | ||
:param predicate: The Predicate describes a callable that returns a variable | ||
as a bool. If it returns true, the variables will be saved. | ||
:param vars: variables need to be saved. If specify vars, program & predicate | ||
as a bool. If it returns true, the corresponding input variable will be saved. | ||
:param vars: variables need to be saved. If vars is specified, program & predicate | ||
will be ignored | ||
:param save_file_name: The name of a single file that all vars are saved to. | ||
If it is None, save variables to separate files. | ||
|
||
:return: None | ||
""" | ||
if vars is None: | ||
|
@@ -83,21 +91,40 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): | |
save_vars( | ||
executor, | ||
dirname=dirname, | ||
vars=filter(predicate, main_program.list_vars())) | ||
vars=filter(predicate, main_program.list_vars()), | ||
save_file_name=save_file_name) | ||
else: | ||
save_program = Program() | ||
save_block = save_program.global_block() | ||
for each_var in vars: | ||
new_var = _clone_var_in_block_(save_block, each_var) | ||
|
||
if save_file_name is None: | ||
for each_var in vars: | ||
new_var = _clone_var_in_block_(save_block, each_var) | ||
save_block.append_op( | ||
type='save', | ||
inputs={'X': [new_var]}, | ||
outputs={}, | ||
attrs={'file_path': os.path.join(dirname, new_var.name)}) | ||
else: | ||
save_var_map = {} | ||
for each_var in vars: | ||
new_var = _clone_var_in_block_(save_block, each_var) | ||
save_var_map[new_var.name] = new_var | ||
|
||
save_var_list = [] | ||
for name in sorted(save_var_map.keys()): | ||
save_var_list.append(save_var_map[name]) | ||
|
||
save_block.append_op( | ||
type='save', | ||
inputs={'X': [new_var]}, | ||
type='save_combine', | ||
inputs={'X': save_var_list}, | ||
outputs={}, | ||
attrs={'file_path': os.path.join(dirname, new_var.name)}) | ||
attrs={'file_path': os.path.join(dirname, save_file_name)}) | ||
|
||
executor.run(save_program) | ||
|
||
|
||
def save_params(executor, dirname, main_program=None): | ||
def save_params(executor, dirname, main_program=None, save_file_name=None): | ||
""" | ||
Save all parameters to directory with executor. | ||
""" | ||
|
@@ -106,10 +133,12 @@ def save_params(executor, dirname, main_program=None): | |
dirname=dirname, | ||
main_program=main_program, | ||
vars=None, | ||
predicate=is_parameter) | ||
predicate=is_parameter, | ||
save_file_name=save_file_name) | ||
|
||
|
||
def save_persistables(executor, dirname, main_program=None): | ||
def save_persistables(executor, dirname, main_program=None, | ||
save_file_name=None): | ||
""" | ||
Save all persistables to directory with executor. | ||
""" | ||
|
@@ -118,21 +147,30 @@ def save_persistables(executor, dirname, main_program=None): | |
dirname=dirname, | ||
main_program=main_program, | ||
vars=None, | ||
predicate=is_persistable) | ||
predicate=is_persistable, | ||
save_file_name=save_file_name) | ||
|
||
|
||
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): | ||
def load_vars(executor, | ||
dirname, | ||
main_program=None, | ||
vars=None, | ||
predicate=None, | ||
load_file_name=None): | ||
""" | ||
Load variables from directory by executor. | ||
|
||
:param executor: executor that save variable | ||
:param executor: executor that load variable | ||
:param dirname: directory path | ||
:param main_program: program. If vars is None, then filter all variables in this | ||
program which fit `predicate`. Default default_main_program(). | ||
:param predicate: The Predicate describes a callable that returns a variable | ||
as a bool. If it returns true, the variables will be loaded. | ||
:param vars: variables need to be loaded. If specify vars, program & | ||
as a bool. If it returns true, the corresponding input variable will be loaded. | ||
:param vars: variables need to be loaded. If vars is specified, program & | ||
predicate will be ignored | ||
:param load_file_name: The name of the single file that all vars are loaded from. | ||
If it is None, load variables from separate files. | ||
|
||
:return: None | ||
""" | ||
if vars is None: | ||
|
@@ -144,42 +182,64 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): | |
load_vars( | ||
executor, | ||
dirname=dirname, | ||
vars=filter(predicate, main_program.list_vars())) | ||
vars=filter(predicate, main_program.list_vars()), | ||
load_file_name=load_file_name) | ||
else: | ||
load_prog = Program() | ||
load_block = load_prog.global_block() | ||
for each_var in vars: | ||
assert isinstance(each_var, Variable) | ||
new_var = _clone_var_in_block_(load_block, each_var) | ||
|
||
if load_file_name is None: | ||
for each_var in vars: | ||
assert isinstance(each_var, Variable) | ||
new_var = _clone_var_in_block_(load_block, each_var) | ||
load_block.append_op( | ||
type='load', | ||
inputs={}, | ||
outputs={'Out': [new_var]}, | ||
attrs={'file_path': os.path.join(dirname, new_var.name)}) | ||
else: | ||
load_var_map = {} | ||
for each_var in vars: | ||
assert isinstance(each_var, Variable) | ||
new_var = _clone_var_in_block_(load_block, each_var) | ||
load_var_map[new_var.name] = new_var | ||
|
||
load_var_list = [] | ||
for name in sorted(load_var_map.keys()): | ||
load_var_list.append(load_var_map[name]) | ||
|
||
load_block.append_op( | ||
type='load', | ||
type='load_combine', | ||
inputs={}, | ||
outputs={"Out": [new_var]}, | ||
attrs={'file_path': os.path.join(dirname, new_var.name)}) | ||
outputs={"Out": load_var_list}, | ||
attrs={'file_path': os.path.join(dirname, load_file_name)}) | ||
|
||
executor.run(load_prog) | ||
|
||
|
||
def load_params(executor, dirname, main_program=None): | ||
def load_params(executor, dirname, main_program=None, load_file_name=None): | ||
""" | ||
load all parameters from directory by executor. | ||
""" | ||
load_vars( | ||
executor, | ||
dirname=dirname, | ||
main_program=main_program, | ||
predicate=is_parameter) | ||
predicate=is_parameter, | ||
load_file_name=load_file_name) | ||
|
||
|
||
def load_persistables(executor, dirname, main_program=None): | ||
def load_persistables(executor, dirname, main_program=None, | ||
load_file_name=None): | ||
""" | ||
load all persistables from directory by executor. | ||
""" | ||
load_vars( | ||
executor, | ||
dirname=dirname, | ||
main_program=main_program, | ||
predicate=is_persistable) | ||
predicate=is_persistable, | ||
load_file_name=load_file_name) | ||
|
||
|
||
def get_inference_program(target_vars, main_program=None): | ||
|
@@ -234,11 +294,27 @@ def append_fetch_ops(inference_program, | |
attrs={'col': i}) | ||
|
||
|
||
def get_parameters(program): | ||
parameter_list = [] | ||
input_args = set() | ||
for block in program.blocks: | ||
for op in block.ops: | ||
if op.desc.type() != 'feed': | ||
input_args.update(op.desc.input_arg_names()) | ||
|
||
for var in program.list_vars(): | ||
if is_persistable(var) and var.name in input_args: | ||
parameter_list.append(var) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is actually not a parameter list but a persistable variable list. Normally, the program should not contain unreferenced variables, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I think we still need to exclude 'feed' and 'fetch' variables right (because they have been added to the program desc)? They are also persistable and we don't want to store them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this implementation can potentially solve the problem described in PR #8020 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will try including these changes to actually verify. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@kexinzhao Can we change the function
@sidgoyal78 I think the problem in #8020 is that, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think of three ways to redefine load/save_persistables():
def load_persitables(xxx):
parameter_list = get_parameters(program)
load_var(xxx, parameter_list) which basically moved the usage of
So we cannot use code like below to exclude 'feed' and 'fetch' def is_persistable(var):
if var.op.desc.type() == 'feed' or var.op.desc.type() == 'fetch':
return false
return var.persistable
def is_persistable(var):
if var.desc.name() == 'feed' or var.desc.name() == 'fetch':
return false
return var.persistable If we want to go with this option, we can firstly do a quick fix in this pr using the code above. Then fix the feed/fetch var name, modify API accordingly, set some global const kFeedVarName in C++ and pybind it to python, etc in the future PR. @Xreki @luotao1 @sidgoyal78, which option do your prefer or do you have other suggestions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like the 2nd method. But I am not sure whether it is suitable to use For the 3rd method. I think it is not suitable to use the name, but may be we can use the type, which should be |
||
|
||
return parameter_list | ||
|
||
|
||
def save_inference_model(dirname, | ||
feeded_var_names, | ||
target_vars, | ||
executor, | ||
main_program=None): | ||
main_program=None, | ||
save_file_name=None): | ||
""" | ||
Build a model especially for inference, | ||
and save it to directory by the executor. | ||
|
@@ -249,6 +325,8 @@ def save_inference_model(dirname, | |
:param executor: executor that save inference model | ||
:param main_program: original program, which will be pruned to build the inference model. | ||
Default default_main_program(). | ||
:param save_file_name: The name of a single file that all parameters are saved to. | ||
If it is None, save parameters to separate files. | ||
|
||
:return: None | ||
""" | ||
|
@@ -283,25 +361,13 @@ def save_inference_model(dirname, | |
with open(model_file_name, "wb") as f: | ||
f.write(inference_program.desc.serialize_to_string()) | ||
|
||
save_params(executor, dirname, main_program) | ||
|
||
|
||
def load_persistables_if_exist(executor, dirname, main_program=None): | ||
filenames = next(os.walk(dirname))[2] | ||
filenames = set(filenames) | ||
|
||
def _is_presistable_and_exist_(var): | ||
if not is_persistable(var): | ||
return False | ||
else: | ||
return var.name in filenames | ||
|
||
load_vars( | ||
parameter_list = get_parameters(inference_program) | ||
save_vars( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can try to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
executor, | ||
dirname, | ||
main_program=main_program, | ||
vars=None, | ||
predicate=_is_presistable_and_exist_) | ||
inference_program, | ||
parameter_list, | ||
save_file_name=save_file_name) | ||
|
||
|
||
def get_feed_targets_names(program): | ||
|
@@ -322,13 +388,15 @@ def get_fetch_targets_names(program): | |
return fetch_targets_names | ||
|
||
|
||
def load_inference_model(dirname, executor): | ||
def load_inference_model(dirname, executor, load_file_name=None): | ||
""" | ||
Load inference model from a directory | ||
|
||
:param dirname: directory path | ||
:param executor: executor that load inference model | ||
|
||
:param load_file_name: The name of the single file that all parameters are loaded from. | ||
If it is None, load parameters from separate files. | ||
|
||
:return: [program, feed_target_names, fetch_targets] | ||
program: program especially for inference. | ||
feed_target_names: Names of variables that need to feed data | ||
|
@@ -342,7 +410,13 @@ def load_inference_model(dirname, executor): | |
program_desc_str = f.read() | ||
|
||
program = Program.parse_from_string(program_desc_str) | ||
load_persistables_if_exist(executor, dirname, program) | ||
parameter_list = get_parameters(program) | ||
load_vars( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also try to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
executor, | ||
dirname, | ||
program, | ||
parameter_list, | ||
load_file_name=load_file_name) | ||
|
||
feed_target_names = get_feed_targets_names(program) | ||
fetch_target_names = get_fetch_targets_names(program) | ||
|
@@ -359,6 +433,7 @@ def get_parameter_value(para, executor): | |
|
||
:param executor: executor for retrieving the value | ||
:param para: the given parameter | ||
|
||
:return: the LoDTensor for the parameter | ||
""" | ||
assert is_parameter(para) | ||
|
@@ -377,6 +452,7 @@ def get_parameter_value_by_name(name, executor, program=None): | |
:param name: the name of the parameter | ||
:param program: the program where the variable is found | ||
Default default_main_program(). | ||
|
||
:return: the LoDTensor for the variable | ||
""" | ||
if program is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the common codes
line 202 - 204
out of theif
statement?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done