Skip to content

Commit

Permalink
support checkpoint_path dir (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
dawn310826 authored Jun 30, 2022
1 parent a0fd524 commit 619046e
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def _check_model_dir(model_dir, continue_train):

def _get_ckpt_path(pipeline_config, checkpoint_path):
if checkpoint_path != '' and checkpoint_path is not None:
ckpt_path = checkpoint_path
if gfile.IsDirectory(checkpoint_path):
ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path)
else:
ckpt_path = checkpoint_path
elif gfile.IsDirectory(pipeline_config.model_dir):
ckpt_path = tf.train.latest_checkpoint(pipeline_config.model_dir)
logging.info('checkpoint_path is not specified, '
Expand Down Expand Up @@ -710,24 +713,21 @@ def export(export_dir,
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
export_config, **input_fn_kwargs)
ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
if 'oss_path' in extra_params:
return export_big_model_to_oss(export_dir, pipeline_config, extra_params,
serving_input_fn, estimator, checkpoint_path,
serving_input_fn, estimator, ckpt_path,
verbose)

if 'redis_url' in extra_params:
return export_big_model(export_dir, pipeline_config, extra_params,
serving_input_fn, estimator, checkpoint_path,
serving_input_fn, estimator, ckpt_path,
verbose)

if not checkpoint_path:
checkpoint_path = estimator_utils.latest_checkpoint(
pipeline_config.model_dir)

final_export_dir = estimator.export_savedmodel(
export_dir_base=export_dir,
serving_input_receiver_fn=serving_input_fn,
checkpoint_path=checkpoint_path,
checkpoint_path=ckpt_path,
strip_default_attrs=True)

# add export ts as version info
Expand Down Expand Up @@ -777,10 +777,11 @@ def export_checkpoint(pipeline_config=None,
export_config = pipeline_config.export_config
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
export_config, **input_fn_kwargs)
ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
estimator.export_checkpoint(
export_path=export_path,
serving_input_receiver_fn=serving_input_fn,
checkpoint_path=checkpoint_path,
checkpoint_path=ckpt_path,
mode=mode)

logging.info('model checkpoint has been exported successfully')

0 comments on commit 619046e

Please sign in to comment.