Skip to content

Commit

Permalink
Cap gen length if prefix to prevent OOB (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Jun 16, 2019
1 parent e715c8d commit d9b673e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ def generate(sess,
if prefix == '':
prefix = None

if prefix:
context = tf.placeholder(tf.int32, [batch_size, None])

CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'

Expand All @@ -391,11 +388,17 @@ def generate(sess,
with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

if prefix:
context = tf.placeholder(tf.int32, [batch_size, None])
context_tokens = enc.encode(prefix)
assert len(context_tokens) < length

np.random.seed(seed)
tf.set_random_seed(seed)

output = sample.sample_sequence(
hparams=hparams, length=length,
hparams=hparams,
length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
context=context if prefix else None,
batch_size=batch_size,
Expand All @@ -404,8 +407,6 @@ def generate(sess,

if destination_path:
f = open(destination_path, 'w')
if prefix:
context_tokens = enc.encode(prefix)
generated = 0
gen_texts = []
while generated < nsamples:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
setup(
name='gpt_2_simple',
packages=['gpt_2_simple'], # this must be the same as the name above
version='0.5.1',
version='0.5.2',
description="Python package to easily retrain OpenAI's GPT-2 " \
"text-generating model on new texts.",
long_description=long_description,
Expand Down

0 comments on commit d9b673e

Please sign in to comment.