Skip to content

Commit

Permalink
Override hparams from command line, save accum_gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
lopuhin committed Mar 20, 2019
1 parent 93a1a16 commit ab2c05f
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions lm/gpt_2_tf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def train(
config='default',
accum_gradients=1, # accumulate gradients N times
clean=False,
# override hparams from config
n_ctx=None,
n_embd=None,
n_head=None,
n_layer=None,
):

sp_model = spm.SentencePieceProcessor()
Expand All @@ -59,16 +64,23 @@ def train(

hparams = model.HPARAMS[config]
hparams.n_vocab = len(sp_model)
(run_path / 'params.json').write_text(json.dumps(dict(
if n_ctx is not None: hparams.n_ctx = n_ctx
if n_embd is not None: hparams.n_embd = n_embd
if n_head is not None: hparams.n_head = n_head
if n_layer is not None: hparams.n_layer = n_layer
params_text = json.dumps(dict(
hparams=hparams.values(),
dataset_path=str(dataset_path),
sp_model_path=sp_model_path,
batch_size=batch_size,
accum_gradients=accum_gradients,
lr=lr,
epochs=epochs,
restore_from=str(restore_from),
argv=sys.argv,
), indent=4, sort_keys=True))
), indent=4, sort_keys=True)
print(params_text)
(run_path / 'params.json').write_text(params_text)

if sample_length is None:
sample_length = hparams.n_ctx - 1
Expand Down

0 comments on commit ab2c05f

Please sign in to comment.