Skip to content

Commit

Permalink
Merge pull request #116 from minimaxir/0.6
Browse files Browse the repository at this point in the history
0.6
  • Loading branch information
minimaxir authored Aug 28, 2019
2 parents ae93a41 + 776bb66 commit e6afb28
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 38 deletions.
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

![gen_demo](docs/gen_demo.png)

A simple Python package that wraps existing model fine-tuning and generation scripts for [OpenAI](https://openai.com)'s [GPT-2 text generation model](https://openai.com/blog/better-language-models/) (specifically the "small" 117M and "medium" 345M hyperparameter versions). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.
A simple Python package that wraps existing model fine-tuning and generation scripts for [OpenAI](https://openai.com)'s [GPT-2 text generation model](https://openai.com/blog/better-language-models/) (specifically the "small" 124M and "medium" 355M hyperparameter versions). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.

This package incorporates and makes minimal low-level changes to:

Expand All @@ -28,13 +28,13 @@ You will also need to install the corresponding TensorFlow for your system (e.g.

An example for downloading the model to the local system, fineturning it on a dataset. and generating some text.

Warning: the pretrained 117M model, and thus any finetuned model, is 500 MB! (the pretrained 345M model is 1.5 GB)
Warning: the pretrained 124M model, and thus any finetuned model, is 500 MB! (the pretrained 355M model is 1.5 GB)

```python
import gpt_2_simple as gpt2

model_name = "117M"
gpt2.download_gpt2(model_name=model_name) # model is saved into current directory under /models/117M/
model_name = "124M"
gpt2.download_gpt2(model_name=model_name) # model is saved into current directory under /models/124M/

sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
Expand Down Expand Up @@ -99,19 +99,18 @@ The method GPT-2 uses to generate text is slightly different than those like oth
* When finetuning GPT-2, it has no sense of the beginning or end of a document within a larger text. You'll need to use a bespoke character sequence to indicate the beginning and end of a document. Then while generating, you can specify a `prefix` targeting the beginning token sequences, and a `truncate` targeting the end token sequence. You can also set `include_prefix=False` to discard the prefix token while generating (e.g. if it's something unwanted like `<|startoftext|>`).
* If you pass a single-column `.csv` file to `finetune()`, it will automatically parse the CSV into a format ideal for training with GPT-2 (including prepending `<|startoftext|>` and suffixing `<|endoftext|>` to every text document, so the `truncate` tricks above are helpful when generating output). This is necessary to handle both quotes and newlines in each text document correctly.
* GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)!
* Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. For the 117M model, if you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80/T4 for only 3x the price, making it price-comparable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU.
* Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. For the 124M model, if you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80/T4 for only 3x the price, making it price-comparable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU.
* If you have a partially-trained GPT-2 model and want to continue finetuning it, you can set `overwrite=True` to finetune, which will continue training and remove the previous iteration of the model without creating a duplicate copy. This can be especially useful for transfer learning (e.g. heavily finetune GPT-2 on one dataset, then finetune on other dataset to get a "merging" of both datasets).
* If your input text dataset is massive (>100 MB), you may want to preencode and compress the dataset using `gpt2.encode_dataset(file_path)`. THe output is a compressed `.npz` file which will load much faster into the GPU for finetuning.
* The 774M "large" model does not currently support finetuning because it will cause modern GPUs to go out-of-memory. However, you can still generate from the default pretrained model using `gpt2.load_gpt2(sess, model_name='774M')` and `gpt2.generate(sess, model_name='774M')`.

## Planned Work
## Interactive Apps Using gpt-2-simple

Note: this project is intended to have a very tight scope unless demand dictates otherwise.
* [gpt2-small](https://minimaxir.com/apps/gpt2-small/) — App using the default GPT-2 124M pretrained model
* [gpt2-reddit](https://minimaxir.com/apps/gpt2-reddit/) — App to generate Reddit titles based on a specified subreddit and/or keyword(s)
* [gpt2-mtg](https://minimaxir.com/apps/gpt2-mtg/) — App to generate Magic: The Gathering cards

* Allow users to generate texts longer than 1024 tokens. ([GitHub Issue](https://github.com/minimaxir/gpt-2-simple/issues/2))
* Allow users to use Colaboratory's TPU for finetuning. ([GitHub Issue](https://github.com/minimaxir/gpt-2-simple/issues/3))
* For Colaboratory, allow model to automatically save checkpoints to Google Drive during training to prevent timeouts.

## Examples Using gpt-2-simple
## Text Generation Examples Using gpt-2-simple

* [ResetEra](https://www.resetera.com/threads/i-trained-an-ai-on-thousands-of-resetera-thread-conversations-and-it-created-hot-gaming-shitposts.112167/) — Generated video game forum discussions ([GitHub w/ dumps](https://github.com/minimaxir/resetera-gpt-2))
* [/r/legaladvice](https://www.reddit.com/r/legaladviceofftopic/comments/bfqf22/i_trained_a_moreadvanced_ai_on_rlegaladvice/) — Title generation ([GitHub w/ dumps](https://github.com/minimaxir/legaladvice-gpt2))
Expand Down
87 changes: 64 additions & 23 deletions gpt_2_simple/gpt_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def download_file_with_progress(url_base, sub_dir, model_name, file_name):
file_name : str
name of file to get e.g. "hparams.json"
sub_dir: str
subdirectory inside which to get and copy locally eg. "models/117M"
subdirectory inside which to get and copy locally eg. "models/124M"
no trailing slash
url_base : str
Start of URL location specifying server and any base directories no
Expand All @@ -56,7 +56,7 @@ def download_file_with_progress(url_base, sub_dir, model_name, file_name):
pbar.update(DOWNLOAD_CHUNK_SIZE)


def download_gpt2(model_dir='models', model_name='117M'):
def download_gpt2(model_dir='models', model_name='124M'):
"""Downloads the GPT-2 model into the current directory
from Google Cloud Storage.
Expand All @@ -67,7 +67,7 @@ def download_gpt2(model_dir='models', model_name='117M'):
model_name : str
name of the GPT-2 model to download.
As of 22 May 2019 one of "117M" or "345M" but may later include other
As of 22 May 2019 one of "124M" or "355M" but may later include other
model sizes
Adapted from https://github.com/openai/gpt-2/blob/master/download_model.py
Expand Down Expand Up @@ -105,10 +105,21 @@ def start_tf_sess(threads=-1, server=None):
return tf.compat.v1.Session(config=config)


def reset_session(sess, threads=-1, server=None):
"""Resets the current TensorFlow session, to clear memory
or load another model.
"""

tf.compat.v1.reset_default_graph()
sess.close()
sess = start_tf_sess(threads, server)
return sess


def finetune(sess,
dataset,
steps=-1,
model_name='117M',
model_name='124M',
model_dir='models',
combine=50000,
batch_size=1,
Expand All @@ -125,12 +136,16 @@ def finetune(sess,
max_checkpoints=1,
use_memory_saving_gradients=False,
only_train_transformer_layers=False,
optimizer='adam',
overwrite=False):
"""Finetunes the model on the given dataset.
Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
See that file for parameter definitions.
"""

assert model_name not in ['774M', '1558M'], "Currently, modern GPUs cannot finetune the 774M GPT-2 model or larger."

SAMPLE_DIR = 'samples'

checkpoint_path = os.path.join(checkpoint_dir, run_name)
Expand All @@ -144,13 +159,12 @@ def maketree(path):
maketree(checkpoint_path)
files = [f for f in os.listdir(checkpoint_path)]
for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
if file not in files:
try:
shutil.copyfile(os.path.join(model_dir, model_name, file),
os.path.join(checkpoint_path, file))
except FileNotFoundError as fnf_error:
print("You need to download the GPT-2 model first via download_gpt2()")
raise(fnf_error)
try:
shutil.copyfile(os.path.join(model_dir, model_name, file),
os.path.join(checkpoint_path, file))
except FileNotFoundError as fnf_error:
print("You need to download the GPT-2 model first via download_gpt2()")
raise(fnf_error)

enc = encoder.get_encoder(checkpoint_path)
hparams = model.default_hparams()
Expand All @@ -161,7 +175,7 @@ def maketree(path):
raise ValueError(
"Can't get samples longer than window size: %s" % hparams.n_ctx)

if model_name != '117M':
if model_name not in ['117M', '124M']:
use_memory_saving_gradients = True
only_train_transformer_layers = True
accumulate_gradients = 1
Expand All @@ -182,18 +196,23 @@ def maketree(path):

all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
train_vars = [v for v in all_vars if '/h' in v.name] if only_train_transformer_layers else all_vars

if optimizer == 'adam':
opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
elif optimizer == 'sgd':
opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=learning_rate)

if accumulate_gradients > 1:
if use_memory_saving_gradients:
exit("Memory saving gradients are not implemented for gradient accumulation yet.")
opt = AccumulatingOptimizer(
opt=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
opt=opt,
var_list=train_vars)
opt_reset = opt.reset()
opt_compute = opt.compute_gradients(loss)
opt_apply = opt.apply_gradients()
summary_loss = tf.compat.v1.summary.scalar('loss', opt_apply)
else:
opt = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
if use_memory_saving_gradients:
opt_grads = memory_saving_gradients.gradients(loss, train_vars)
else:
Expand Down Expand Up @@ -330,12 +349,17 @@ def sample_batch():

def load_gpt2(sess,
run_name="run1",
checkpoint_dir="checkpoint"):
"""Loads the model checkpoint into a TensorFlow session
checkpoint_dir="checkpoint",
model_name=None,
model_dir='models'):
"""Loads the model checkpoint or existing model into a TensorFlow session
for repeated predictions.
"""

checkpoint_path = os.path.join(checkpoint_dir, run_name)
if model_name:
checkpoint_path = os.path.join(model_dir, model_name)
else:
checkpoint_path = os.path.join(checkpoint_dir, run_name)

hparams = model.default_hparams()
with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
Expand All @@ -348,13 +372,18 @@ def load_gpt2(sess,
saver = tf.compat.v1.train.Saver(allow_empty=True)
sess.run(tf.compat.v1.global_variables_initializer())

print('Loading checkpoint', ckpt)
if model_name:
print('Loading pretrained model', ckpt)
else:
print('Loading checkpoint', ckpt)
saver.restore(sess, ckpt)


def generate(sess,
run_name='run1',
checkpoint_dir='checkpoint',
model_name=None,
model_dir='models',
sample_dir='samples',
return_as_list=False,
truncate=None,
Expand Down Expand Up @@ -384,7 +413,10 @@ def generate(sess,
if prefix == '':
prefix = None

checkpoint_path = os.path.join(checkpoint_dir, run_name)
if model_name:
checkpoint_path = os.path.join(model_dir, model_name)
else:
checkpoint_path = os.path.join(checkpoint_dir, run_name)

enc = encoder.get_encoder(checkpoint_path)
hparams = model.default_hparams()
Expand Down Expand Up @@ -452,6 +484,8 @@ def generate(sess,
def generate_to_file(sess,
run_name='run1',
checkpoint_dir='checkpoint',
model_name=None,
model_dir='models',
truncate=None,
destination_path='gpt_2_gen_texts.txt',
sample_delim='=' * 20 + '\n',
Expand All @@ -474,6 +508,8 @@ def generate_to_file(sess,
generate(sess=sess,
run_name=run_name,
checkpoint_dir=checkpoint_dir,
model_name=model_name,
model_dir=model_dir,
return_as_list=False,
truncate=truncate,
destination_path=destination_path,
Expand Down Expand Up @@ -557,7 +593,7 @@ def copy_file_from_gdrive(file_path):
shutil.copyfile("/content/drive/My Drive/" + file_path, file_path)


def is_gpt2_downloaded(model_dir='models', model_name='117M'):
def is_gpt2_downloaded(model_dir='models', model_name='124M'):
"""Checks if the original model + associated files are present in folder."""

for filename in ['checkpoint', 'encoder.json', 'hparams.json',
Expand Down Expand Up @@ -585,7 +621,7 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True,


def encode_dataset(file_path, model_dir='models', out_path='text_encoded.npz',
model_name="117M",
model_name="124M",
combine=50000):
"""Preencodes a text document into chunks and compresses it,
saving time when generated.
Expand Down Expand Up @@ -620,7 +656,7 @@ def cmd():
nargs='?', default='checkpoint')
parser.add_argument(
'--model_name', help="[finetune] Name of the GPT-2 model to finetune",
nargs='?', default='117M')
nargs='?', default='124M')
parser.add_argument(
'--model_dir', help="[finetune] Path of directory of the GPT-2 model to finetune",
nargs='?', default='models')
Expand All @@ -642,6 +678,9 @@ def cmd():
parser.add_argument(
'--print_every', help="[finetune] After how many steps to print progress",
nargs='?', default=10, type=int)
parser.add_argument(
'--optimizer', help="[finetune] Optimizer to use for finetuning (adam or sgd)",
nargs='?', default='adam')
parser.add_argument(
'--overwrite', help="[finetune] Overwrite existing model when continuing training",
nargs='?', default=False, type=lambda x: (str(x).lower() == 'true'))
Expand Down Expand Up @@ -701,6 +740,7 @@ def cmd():
sample_every=args.sample_every,
save_every=args.save_every,
print_every=args.print_every,
optimizer=args.optimizer,
overwrite=args.overwrite)
if args.mode == "generate":
cmd_generate(nfiles=args.nfiles, nsamples=args.nsamples,
Expand All @@ -715,7 +755,7 @@ def cmd():

def cmd_finetune(dataset, run_name, checkpoint_dir, model_name, model_dir, steps,
restore_from, sample_every,
save_every, print_every, overwrite):
save_every, print_every, optimizer, overwrite):
"""Wrapper script for finetuning the model via the CLI."""

if not is_gpt2_downloaded(model_dir=model_dir, model_name=model_name):
Expand All @@ -729,6 +769,7 @@ def cmd_finetune(dataset, run_name, checkpoint_dir, model_name, model_dir, steps
steps=steps, restore_from=restore_from,
sample_every=sample_every, save_every=save_every,
print_every=print_every,
optimizer=optimizer,
overwrite=overwrite)


Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages

long_description = '''
A simple Python package that wraps existing model fine-tuning and generation scripts for OpenAI GPT-2 text generation model (specifically the "small", 117M hyperparameter version). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.
A simple Python package that wraps existing model fine-tuning and generation scripts for OpenAI GPT-2 text generation model (specifically the "small", 124M hyperparameter version). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.
## Usage
Expand All @@ -12,7 +12,7 @@
```python
import gpt_2_simple as gpt2
gpt2.download_gpt2() # model is saved into current directory under /models/117M/
gpt2.download_gpt2() # model is saved into current directory under /models/124M/
sess = gpt2.start_tf_sess()
gpt2.finetune(sess, 'shakespeare.txt', steps=1000) # steps is max number of training steps
Expand Down Expand Up @@ -47,7 +47,7 @@
setup(
name='gpt_2_simple',
packages=['gpt_2_simple'], # this must be the same as the name above
version='0.5.4',
version='0.6',
description="Python package to easily retrain OpenAI's GPT-2 " \
"text-generating model on new texts.",
long_description=long_description,
Expand Down

0 comments on commit e6afb28

Please sign in to comment.