-
Notifications
You must be signed in to change notification settings - Fork 310
/
Copy pathtest_llama2_7b.sh
80 lines (56 loc) · 8.91 KB
/
test_llama2_7b.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/bin/bash
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Llama2-7b
# The flow of this file is as follows:
# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText.
# 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b.
# 3. Run decoding from the finetuned weights
# 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding.
set -ex
idx=$(date +%Y-%m-%d-%H-%M)
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset
export ASYNC_CHECKPOINTING=false
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
pip install torch --index-url https://download.pytorch.org/whl/cpu
# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt
# In the following command, we are copying Meta's checkpoint into a local directory `tmp`.
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py`
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/
# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own
export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext
# Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path`
python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx}
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true
# Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5
# We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1
# Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert
# the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5
# Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding
export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx}
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true
# Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly
export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
# We run decoding on the fine-tuned parameter checkpoint
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false
# We also test whether the forward pass logits match the golden logits for Llama2-7b
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false
# Converting MaxText orbax checkpoint to HF
JAX_PLATFORMS=cpu python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CONVERTED_CHECKPOINT} run_name=convert_to_hf model_name=llama2-7b hf_model_path=/tmp/hf_llama2
# Test whether the forward pass logits match the golden logits for Huggingface checkpoint converted from MaxText orbax checkpoint
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_hf ici_tensor_parallelism=4 model_name=llama2-7b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --hf_model_path=/tmp/hf_llama2 --max_kl_div=1e-4