This page outlines the steps to pretrain a model with T5X on common tasks defined with SeqIO.
Pretraining a model with T5X consists of the following steps:
- Choose the model architecture.
- Choose the SeqIO Task/Mixture to for training.
- Write a Gin file that configures the model, SeqIO Task/Mixture and other details of your pretraining run.
- Launch your experiment locally or on XManager.
- Monitor your experiment.
These steps are explained in detail in the following sections. An example run that trains a T5 1.1 Small checkpoint from scratch on the C4 dataset using the span corruption pretraining objective is also showcased.
To train a model, you need a Gin config file that defines the model params. For your convenience, Gin configs for common models have been made available for use in T5X. Following is a list of these models and their Gin locations.
Model | Gin File Location |
---|---|
T5 Small | t5_1_0/small.gin |
T5 Base | t5_1_0/base.gin |
T5 Large | t5_1_0/large.gin |
T5 3B | t5_1_0/3B.gin |
T5 11B | t5_1_0/11B.gin |
T5 1.1 Small | t5_1_1/small.gin |
T5 1.1 Base | t5_1_1/base.gin |
T5 1.1 Large | t5_1_1/large.gin |
T5 1.1 XL | t5_1_1/xl.gin |
T5 1.1 XXL | t5_1_1/xxl.gin |
MT5 Small | mt5/small.gin |
MT5 Base | mt5/base.gin |
MT5 Large | mt5/large.gin |
MT5 XL | mt5/xl.gin |
MT5 XXL | mt5/xxl.gin |
For the example run, you will use the T5 1.1 Small model. The Gin file for this
model is located at
/t5x/examples/t5/t5_1_1/1_1_small.gin
.
A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks and enables pretraining a model on multiple Tasks simultaneously.
Many common datasets and benchmarks, e.g. GLUE,
SuperGLUE,
WMT,
SQUAD,
CNN/Daily Mail, etc. have been
implemented as SeqIO Tasks/Mixtures and can be used directly. These
Tasks/Mixtures are defined in
third_party/py/t5/data/tasks.py
and
third_party/py/t5/data/mixtures.py
.
For the example run, you will train the model on
c4_v220_span_corruption
Task that implements the span corruption pretraining objective using the C4
dataset. This is the final pretraining Task used in the
T5 paper.
TIP: Want to use a custom Task or Mixture? See section below called "Adding SeqIO Task/Mixture modules and Gin files"
After choosing the model architecture and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. If you're not familiar with Gin, reading the T5X Gin Primer is recommended.
T5X provides a Gin file that configures the T5X trainer for pretraining (located
at
runs/pretrain.gin
),
and expects a few params from you. These params can be specified in a separate
Gin file, or via commandline flags. Following are the required params:
TRAIN_STEPS
: Number of training steps. For the example run, set this to100_000
.MIXTURE_OR_TASK_NAME
: This is the SeqIO Task or Mixture name to run (from Step 2). For the example run, set this to'c4_v220_span_corruption'
.TASK_FEATURE_LENGTHS
: This is a dict mapping feature key to maximum int length for that feature. After preprocessing, features are truncated to the provided value. For the example run, set this to{"inputs": 512, "targets": 114}
, following the original T5 pretraining setup.MODEL_DIR
: A path to write pretrained checkpoints to. When launching using XManager, this path is automatically set and can be accessed from the XManager Artifacts page. When running locally using Blaze, you can explicitly pass a directory using a flag. Launch commands are provided in the next step.
In addition to the above params, you will need to import
pretrain.gin
and the Gin file for the pretrained model, which for the example run is
t5_1_1/small.gin
.
include 't5x/configs/runs/pretrain.gin'
include 't5x/examples/t5/t5_1_1/small.gin'
Note that the include
statements can use relative paths in this example for
which You will pass an appropriate gin_search_paths
flag to locate these files
when launching your run. However, we recommend that you use absolute paths
because it can be more difficult to locate the gin files speicified via relative
paths without inspecting the launch command.
You will also need to import the Python module(s) that register SeqIO Tasks and
Mixtures used in your run. For the example run, we add import t5.data.mixtures
since it is where 'glue_v002_proportional' is registered. Note that this module
must also be included as a dependency in the T5X trainer
binary. Most
common Task/Mixture modules, such as this one, are already included. If your
module is not included, see the Advanced Topics section
at the end of this tutorial for instructions to add it.
Finally, your Gin file should look like this:
include 't5x/examples/t5/t5_1_1/small.gin'
include 't5x/configs/runs/pretrain.gin'
# Register necessary SeqIO Tasks/Mixtures.
import t5.data.mixtures
MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption"
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
TRAIN_STEPS = 10000
DROPOUT_RATE = 0.0
BATCH_SIZE = 256
See
t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin
for this example.
To launch your experiment locally (for debugging only; larger checkpoints may cause issues), run the following on commandline:
MODEL_DIR="/tmp/pretrain-model/"
python -m t5x.train \
--gin_file=t5x/examples/t5/t5_1_1/c4_pretrain_small.gin \
--gin.MODEL_DIR=\"${MODEL_DIR}\" \
--alsologtostderr
Note that multiple comma-separated paths can be passed to the gin_search_paths
flag, and these paths should contain all Gin files used or included in your
experiment.
Now that you have successfully pretrained a model, here are some topics you might want to explore next:
We also touch upon a few advanced topics related to pretraining below that might be useful, especially when customizing your pretraining job.
A
DatasetConfig
object is used to configure loading SeqIO Tasks/Mixtures for training and eval.
If you take a closer look at
runs/pretrain.gin
,
you will see that there are two DatasetConfig
objects defined and passed to
the train function: train_dataset_cfg
and train_eval_dataset_cfg
. Here's a
brief description of these configs:
train
: This configures the Task/Mixture that the model will be pretrained on.train_eval
: This configures the Task/Mixture that is used to compute training metrics on the eval split, e.g. perplexity. These metrics are defined in theModel
class and the eval fn is located here.
A training run may consist of various randomized operations, e.g. dataset
shuffling, dropout, etc. However, it is often useful to have deterministic
training, meaning that the random operations are reproducible and robust to
preemption/restarts. To make your pretraining deterministic, in addition to
the params configured in pretrain.gin
, you need to add the following configs:
- sets the dataset seed to a fixed value:
train/utils.DatasetConfig.seed = 42
. - sets the dropout seed to a fixed value:
train_script.train.random_seed = 42
. - enables dataset checkpointing:
utils.SaveCheckpointConfig.save_dataset = True
. This means that the dataset iterator is checkpointed periodically during training, and in case of preemptions, training resumes from the latest dataset checkpoint to ensure deterministic behavior. The checkpointing frequency is set usingutils.SaveCheckpointConfig.period
(1000
by default), meaning that the dataset is checkpointed after processing1000
batches (batches, not examples; batch size can be overridden usingtrain/DatasetConfig.batch_size
and is set to128
by default).
Refer to SeqIO documentation.