Skip to content

Commit

Permalink
Check that all arguments are valid
Browse files Browse the repository at this point in the history
  • Loading branch information
lopuhin committed Mar 20, 2019
1 parent ab2c05f commit 694bab6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
35 changes: 35 additions & 0 deletions lm/fire_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Helpers methods for interacting with python fire.
https://gist.github.com/trhodeos/5a20b438480c880f7e15f08987bd9c0f
adjusted for keyword only arguments support
"""
import functools
import inspect


def only_allow_defined_args(function_to_decorate):
"""
Decorator which only allows arguments defined to be used.
Note, we need to specify this, as Fire allows method chaining. This means
that extra kwargs are kept around and passed to future methods that are
called. We don't need this, and should fail early if this happens.
Args:
function_to_decorate: Function which to decorate.
Returns:
Wrapped function.
"""

@functools.wraps(function_to_decorate)
def _return_wrapped(*args, **kwargs):
"""Internal wrapper function."""
argspec = inspect.getfullargspec(function_to_decorate)
valid_names = set(argspec.args + argspec.kwonlyargs)
if "self" in valid_names:
valid_names.remove("self")
for arg_name in kwargs:
if arg_name not in valid_names:
raise ValueError("Unknown argument seen '%s', expected: [%s]" %
(arg_name, ", ".join(valid_names)))
return function_to_decorate(*args, **kwargs)

return _return_wrapped
2 changes: 2 additions & 0 deletions lm/gpt_2_tf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from . import model, sample
from lm.data import END_OF_TEXT
from lm.fire_utils import only_allow_defined_args


def main():
return fire.Fire(train)


@only_allow_defined_args
def train(
run_path,
dataset_path,
Expand Down

0 comments on commit 694bab6

Please sign in to comment.