Skip to content

Commit

Permalink
tf: support checkpoint path (instead of directory) in dp freeze (#3254)
Browse files Browse the repository at this point in the history
To have the same behavior between TF and PT.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 11, 2024
1 parent 0e2304f commit dc63793
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser:
"--checkpoint",
type=str,
default=".",
help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file",
help="Path to checkpoint, either a folder containing checkpoint or the checkpoint prefix",
)
parser_frz.add_argument(
"-o",
Expand Down
12 changes: 9 additions & 3 deletions deepmd/tf/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from os.path import (
abspath,
)
from pathlib import (
Path,
)
from typing import (
List,
Optional,
Expand Down Expand Up @@ -479,7 +482,7 @@ def freeze(
Parameters
----------
checkpoint_folder : str
location of the folder with model
location of either the folder with checkpoint or the checkpoint prefix
output : str
output file name
node_names : Optional[str], optional
Expand All @@ -492,8 +495,11 @@ def freeze(
other arguments
"""
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(checkpoint_folder)
input_checkpoint = checkpoint.model_checkpoint_path
if Path(checkpoint_folder).is_dir():
checkpoint = tf.train.get_checkpoint_state(checkpoint_folder)
input_checkpoint = checkpoint.model_checkpoint_path
else:
input_checkpoint = checkpoint_folder

# expand the output file to full path
output_graph = abspath(output)
Expand Down

0 comments on commit dc63793

Please sign in to comment.