Skip to content

Commit

Permalink
Merge pull request #10878 from seiriosPlus/new_api_about_cpkt
Browse files Browse the repository at this point in the history
New api about checkpoint and models
  • Loading branch information
seiriosPlus authored Jun 10, 2018
2 parents 7bcc980 + bf2c53a commit d896134
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 83 deletions.
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from trainer import EndEpochEvent
from trainer import BeginStepEvent
from trainer import EndStepEvent
from trainer import CheckpointConfig

import inferencer
from inferencer import Inferencer
Expand Down
246 changes: 174 additions & 72 deletions python/paddle/fluid/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint'
'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
]


Expand Down Expand Up @@ -457,95 +458,161 @@ def get_parameter_value_by_name(name, executor, program=None):

SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_"


def save_checkpoint(executor,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
checkpoint_dir,
trainer_id,
trainer_args=None,
main_program=None,
max_num_checkpoints=3):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval between two saved checkpoints must greater than save_interval_secs.
:param executor
:param checkpoint_dir
:param max_num_checkpoints
:param save_interval_secs
:param main_program
:param executor executor for save the value
:param checkpoint_dir the checkpoint directory
:param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
:param main_program will save all variables in program
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
raise ValueError("'checkpoint_dir' should not be None")

if trainer_args:
assert isinstance(trainer_args, dict)

if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)

serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial >= 0 and not _interval_secs_exceed(
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)

serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_trainer_args(cur_dir, trainer_id, trainer_args)

save_vars(
executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints)
if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program)

_scroll_delete(checkpoint_dir, max_num_checkpoints)


def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
"""
Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto.
:param executor
:param checkpoint_dir
:param main_program
:param executor executor for load the value
:param checkpoint_dir the checkpoint directory
:param serial the serial folder in checkpoint directory will be load
:param main_program will load all variables in program
"""

if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
raise ValueError("'checkpoint_dir' should not be None")

serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ")

if serial < 0:
return
if main_program is None:
raise ValueError('main_program should not be None.')

cur_dir = _get_serial_dir(serial, checkpoint_dir)

load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program, True)


def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
:param checkpoint_dir
:param delete_dir
"""

if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
raise ValueError("'checkpoint_dir' should not be None")
_scroll_delete(checkpoint_dir, max_num_checkpoints=0)

if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)


def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def load_persist_vars_without_grad(executor,
dirname,
program,
has_model_dir=False):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
"""

if has_model_dir:
dirname = _get_model_dir(dirname)

load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)


def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
"""
cur_dir = _get_model_dir(dirname)
save_vars(
executor,
dirname=cur_dir,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)


def save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)

cur_dir = _get_trainer_dir(dirname, trainer_id)

for name, value in trainer_args.iteritems():
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
_write_success(cur_dir)


def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
assert isinstance(trainer_args, list)

cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)

ret_values = []

for arg in trainer_args:
cur_file = os.path.join(cur_dir, arg)
with open(cur_file, 'r') as f:
contents = f.read()
ret_values.append(contents.strip())
return ret_values


def _is_checkpoint_var(var):
Expand All @@ -559,36 +626,74 @@ def _is_checkpoint_var(var):
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False
# @GRAD are named for gradient variables, checkpoint will not save it.
if "@GRAD" in var.name:
return False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if ".trainer_" in var.name:
return False

if var.name.endswith("@GRAD"):
# .block is named for distribute train variables, checkpoint will not save it.
if ".block" in var.name:
return False

return var.persistable


def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)

try:
serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num


def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)

if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)

return serial_dir


def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)

def _lru_delete(dirname, max_num_checkpoints=3):
if not os.path.isdir(model_dir):
os.makedirs(model_dir)

return model_dir


def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)

if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)

return trainer_dir


def _scroll_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serials = []
serial_map = {}
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue
serial_num = _get_dir_serial(serial)
serial_map[serial_num] = serial

if len(serials) <= max_num_checkpoints:
if len(serial_map.keys()) <= max_num_checkpoints:
return

serials = serial_map.keys()
serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir)


Expand All @@ -604,33 +709,30 @@ def _write_success(dirname):
f.write(now)


def _get_lastest_checkpoint_dir(checkpoint_dir):
def get_latest_checkpoint_serial(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
:param checkpoint_dir
"""
if not checkpoint_dir.strip():
if not checkpoint_dir:
return -1

def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)

try:
int(serial)
except ValueError:
return -1

if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
return -1

success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(serial)
return serial

if not os.path.isdir(checkpoint_dir):
return -1
Expand Down
Loading

0 comments on commit d896134

Please sign in to comment.