forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sky.yaml
114 lines (95 loc) · 3.84 KB
/
sky.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Example: a distributed DeepSpeed job (DeepSpeed-Chat) on 2 VMs.
#
# This takes care constructing a "hostfile" to pass to DeepSpeed.
#
# Usage:
#
# $ sky launch sky.yaml -r --down -c ds
#
# # Optional: After the job starts running, you can log into the two nodes and
# # check gpustat:
# $ ssh ds
# $ gpustat -i
# $ ssh ds-worker1
# $ gpustat -i
resources:
accelerators: A100:1 # GCP, Lambda
# accelerators: A100-80GB:1 # Azure, GCP, SCP
# accelerators: A10G:1 # AWS. Will OOM for (1) single_node/run_1.3b_lora.sh (2) multi_node/run_66b.sh.
# accelerators: T4:1 # AWS, Azure, GCP. Will OOM for (1) single_node/run_1.3b_lora.sh (2) multi_node/run_66b.sh.
num_nodes: 2
envs:
MY_VAR_1: "hello"
MY_VAR_2: "world"
# List of env vars to propagate to all nodes in deepspeed. If you add an env above, add it to this list.
DEEPSPEED_ENVS: "MY_VAR_1,MY_VAR_2,SKYPILOT_NODE_RANK"
setup: |
git clone https://github.com/microsoft/DeepSpeedExamples.git || true
cd DeepSpeedExamples
git checkout d7c42b4f34df91035e7ed3e0c51500bb53d0bc71
conda activate deepspeed
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n deepspeed python=3.8 -y
conda activate deepspeed
pip install deepspeed
cd applications/DeepSpeed-Chat
pip install -r requirements.txt
# Required by DeepSpeed in multi-node settings.
#
# NOTE(skypilot): DeepSpeed uses `pdsh` to log into each node and calls
# `ninja --version`; so it has to be installed system-wide rather than in
# the above 'deepspeed' conda env.
sudo apt-get -y install pdsh ninja-build
fi
file_mounts:
# Required for DeepSpeed's passwordless SSH (run commands on nodes).
~/.ssh/id_rsa: ~/.ssh/sky-key
run: |
cd DeepSpeedExamples
conda activate deepspeed
# Launch on the first node only
if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then
# Prepare a hostfile.
HOSTFILE_PATH=/tmp/hostfile.${SKYPILOT_TASK_ID}
python -c "import os;n_gpus=os.environ['SKYPILOT_NUM_GPUS_PER_NODE'];print('\n'.join([f'{ip} slots={n_gpus}' for ip in os.environ['SKYPILOT_NODE_IPS'].splitlines()]))" > ${HOSTFILE_PATH}
# Generate .deepspeed_env to propagate env vars to all workers spawned by DeepSpeed.
echo "Generating .deepspeed_env"
python3 -c 'import os; f = open(".deepspeed_env", "w"); f.write("\n".join(["{}=\"{}\"".format(var, os.getenv(var, "")) for var in os.getenv("DEEPSPEED_ENVS").split(",")])); f.write("\n"); f.close()'
echo "*******************************************"
echo "Hostfile: ${HOSTFILE_PATH}"
cat ${HOSTFILE_PATH}
echo "*******************************************"
################ Your launch command goes here ################
cd applications/DeepSpeed-Chat/training/step1_supervised_finetuning/
# Adapted from: training_scripts/single_node/run_1.3b_lora.sh
# Note the additional argument: --hostfile $HOSTFILE_PATH
# Alternatively, you can move HOSTFILE_PATH to /job/hostfile:
# sudo mkdir -p /job; sudo chmod 777 /job; mv ${HOSTFILE_PATH} /job/hostfile
OUTPUT_PATH=./output
mkdir -p $OUTPUT_PATH
deepspeed \
--hostfile $HOSTFILE_PATH \
main.py \
--data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets \
--data_split 2,4,4 \
--model_name_or_path facebook/opt-1.3b \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--max_seq_len 512 \
--learning_rate 1e-3 \
--weight_decay 0.1 \
--num_train_epochs 16 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--num_warmup_steps 0 \
--seed 1234 \
--zero_stage 0 \
--lora_dim 128 \
--lora_module_name decoder.layers. \
--only_optimize_lora \
--deepspeed \
--output_dir $OUTPUT_PATH \
| tee $OUTPUT_PATH/training.log
fi