From 619046ef61d3a1856a386638fc9993c90b9f74dd Mon Sep 17 00:00:00 2001 From: dawn <34618110+dawn310826@users.noreply.github.com> Date: Thu, 30 Jun 2022 12:08:59 +0800 Subject: [PATCH] support checkpoint_path dir (#225) --- easy_rec/python/main.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/easy_rec/python/main.py b/easy_rec/python/main.py index caf4478ae..817f08387 100644 --- a/easy_rec/python/main.py +++ b/easy_rec/python/main.py @@ -222,7 +222,10 @@ def _check_model_dir(model_dir, continue_train): def _get_ckpt_path(pipeline_config, checkpoint_path): if checkpoint_path != '' and checkpoint_path is not None: - ckpt_path = checkpoint_path + if gfile.IsDirectory(checkpoint_path): + ckpt_path = estimator_utils.latest_checkpoint(checkpoint_path) + else: + ckpt_path = checkpoint_path elif gfile.IsDirectory(pipeline_config.model_dir): ckpt_path = tf.train.latest_checkpoint(pipeline_config.model_dir) logging.info('checkpoint_path is not specified, ' @@ -710,24 +713,21 @@ def export(export_dir, input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path serving_input_fn = _get_input_fn(data_config, feature_configs, None, export_config, **input_fn_kwargs) + ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path) if 'oss_path' in extra_params: return export_big_model_to_oss(export_dir, pipeline_config, extra_params, - serving_input_fn, estimator, checkpoint_path, + serving_input_fn, estimator, ckpt_path, verbose) if 'redis_url' in extra_params: return export_big_model(export_dir, pipeline_config, extra_params, - serving_input_fn, estimator, checkpoint_path, + serving_input_fn, estimator, ckpt_path, verbose) - if not checkpoint_path: - checkpoint_path = estimator_utils.latest_checkpoint( - pipeline_config.model_dir) - final_export_dir = estimator.export_savedmodel( export_dir_base=export_dir, serving_input_receiver_fn=serving_input_fn, - checkpoint_path=checkpoint_path, + checkpoint_path=ckpt_path, strip_default_attrs=True) # add export ts as version info @@ -777,10 +777,11 @@ def export_checkpoint(pipeline_config=None, export_config = pipeline_config.export_config serving_input_fn = _get_input_fn(data_config, feature_configs, None, export_config, **input_fn_kwargs) + ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path) estimator.export_checkpoint( export_path=export_path, serving_input_receiver_fn=serving_input_fn, - checkpoint_path=checkpoint_path, + checkpoint_path=ckpt_path, mode=mode) logging.info('model checkpoint has been exported successfully')