Skip to content

Commit

Permalink
add monot5 tpu train doc (#108)
Browse files Browse the repository at this point in the history
* add monot5 tpu train doc

* Update pygaggle/data/create_msmarco_t5_training_pairs.py

Co-authored-by: Rodrigo Frassetto Nogueira <[email protected]>

* update tpu experiment link to README and add trainning time info in tpu doc

Co-authored-by: Rodrigo Frassetto Nogueira <[email protected]>
  • Loading branch information
MXueguang and rodrigonogueira4 authored Nov 11, 2020
1 parent 5e1e0dd commit da0fb61
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,4 @@ The following documents describe how to use Pygaggle on various IR test collecti
+ [Experiments on MS MARCO Document Retrieval](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-document.md)
+ [Experiments on MS MARCO Passage Retrieval - Dev Subset](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-subset.md)
+ [Experiments on MS MARCO Passage Retrieval - Entire Dev Set](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-entire.md)
+ [Experiments on MS MARCO Passage Retrieval - with TPU](https://github.com/castorini/pygaggle/blob/master/docs/experiments-monot5-tpu.md)
59 changes: 59 additions & 0 deletions docs/experiments-monot5-tpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,63 @@ You should see the same result.

If you were able to replicate these results, please submit a PR adding to the replication log! Please mention in your PR if you find any difference!

## Train monoT5

First, download the MS MARCO train triples:
```
wget https://storage.googleapis.com/duobert_git/triples.train.small.tar.gz
tar -xvf triples.train.small.tar.gz
rm triples.train.small.tar.gz
```

Then convert the train triples to t5 input format:
```
python -m pygaggle.data.create_msmarco_t5_training_pairs --triples_train triples.train.small.tsv --output_to_t5 query_doc_pairs.train.tsv
```

Next, copy the input file to Google Storage. TPU training will read data directly from `gs`
```
gsutil cp query_doc_pairs.train.tsv ${GS_FOLDER}/query_doc_pairs.train.tsv
```

Recall the environment variables
```
export MODEL=<t5 pretrain model, e.g. base, large, 3B>
export GS_FOLDER=<gs folder to store checkpoints>
export PROJECT_NAME=<gcloud project name>
export TPU_NAME=<name of tpu to create>
export BASE_CKPT=<initial model checkpoint, e.g. 999900>
```

Copy pre-trained checkpoint to our target model
```
echo "model_checkpoint_path: \"model.ckpt-${BASE_CKPT}\"" > checkpoint
gsutil cp checkpoint ${GS_FOLDER}
gsutil cp gs://t5-data/pretrained_models/${MODEL}/model.ckpt-${BASE_CKPT}* ${GS_FOLDER}
```

```
nohup t5_mesh_transformer \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT_NAME}" \
--tpu_zone="europe-west4-a" \
--model_dir="${GS_FOLDER}" \
--gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/${MODEL}/model.ckpt-${BASE_CKPT}'" \
--gin_file="dataset.gin" \
--gin_file="models/bi_v1.gin" \
--gin_file="gs://t5-data/pretrained_models/${MODEL}/operative_config.gin" \
--gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \
--gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
--gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \
--gin_param="tsv_dataset_fn.filename = '${GS_FOLDER}/query_doc_pairs.train.tsv'" \
--gin_file="learning_rate_schedules/constant_0_001.gin" \
--gin_param="run.train_steps = 1100000" \
--gin_param="tokens_per_batch = 65536" \
>> out.log_exp 2>&1 &
tail -100f out.log_exp
```

Training T5 base, large, and 3B take approximately 12, 48, and 160 hours overall, respectively, on a single TPU.

## Replication Log
21 changes: 21 additions & 0 deletions pygaggle/data/create_msmarco_t5_training_pairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
This script creates monoT5 input files for training,
Each line in the monoT5 input file follows the format:
f'Query: {query} Document: {document} Relevant:\t{label}\n')
"""
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--triples_train", type=str, required=True,
help="tsv file <query>, <positive_document>, <negative_document>")
parser.add_argument("--output_to_t5", type=str, required=True,
help="t5 train input file")
args = parser.parse_args()

with open(args.output_to_t5, 'w') as fout_t5:
for line_num, line in enumerate(tqdm(open(args.triples_train))):
query, positive_document, negative_document = line.strip().split('\t')
fout_t5.write(f'Query: {query} Document: {positive_document} Relevant:\ttrue\n')
fout_t5.write(f'Query: {query} Document: {negative_document} Relevant:\tfalse\n')
print('Done!')

0 comments on commit da0fb61

Please sign in to comment.