-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtrain_npsn.sh
65 lines (55 loc) · 1.44 KB
/
train_npsn.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/bin/bash
echo "Start training task queues"
# Hyperparameters
dataset_array=("eth" "hotel" "univ" "zara1" "zara2")
device_id_array=(0 1 2 3 4)
baseline="sgcn"
tag="npsn"
# Arguments
while getopts t:b:d:i: flag
do
case "${flag}" in
t) tag=${OPTARG};;
b) baseline=${OPTARG};;
d) dataset_array=(${OPTARG});;
i) device_id_array=(${OPTARG});;
*) echo "usage: $0 [-t TAG] [-b BASELINE] [-d \"eth hotel univ zara1 zara2\"] [-i \"0 1 2 3 4\"]" >&2
exit 1 ;;
esac
done
if [ ${#dataset_array[@]} -ne ${#device_id_array[@]} ]
then
printf "Arrays must all be same length. "
printf "len(dataset_array)=${#dataset_array[@]} and len(device_id_array)=${#device_id_array[@]}\n"
exit 1
fi
# Signal handler
pid_array=()
sighdl ()
{
echo "Kill training processes"
for (( i=0; i<${#dataset_array[@]}; i++ ))
do
kill ${pid_array[$i]}
done
echo "Done."
exit 0
}
trap sighdl SIGINT SIGTERM
# Start training tasks
for (( i=0; i<${#dataset_array[@]}; i++ ))
do
printf "Training ${dataset_array[$i]} "
python3 train_npsn.py --dataset "${dataset_array[$i]}" --tag "${tag}"-"${baseline}" --baseline "${baseline}" \
--use_lrschd --gpu_num ${device_id_array[$i]} &
pid_array[$i]=$!
printf "job ${#pid_array[@]} pid ${pid_array[$i]}\n"
done
for (( i=0; i<${#dataset_array[@]}; i++ ))
do
wait ${pid_array[$i]}
done
echo "Training end."
# Start test
python3 test_npsn.py --tag "${tag}" --baseline "${baseline}"
echo "Done."