-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.sh
80 lines (69 loc) · 3.03 KB
/
train.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/bin/bash
export PYTHONPATH=./:$PYTHONPATH
source activate /home/tione/notebook/envs/tf_1
# 生成Map ID文件
python -u src/prepare/map_id_process.py \
--action_csv data/wedata/wechat_algo_data1/user_action.csv,data/wedata/wechat_algo_data2/user_action.csv \
--feed_csv data/wedata/wechat_algo_data2/feed_info.csv \
--save_path data/preprocess/map.obj
# DeepWalk
python -u src/prepare/deepwalk.py \
--action_csv data/wedata/wechat_algo_data1/user_action.csv,data/wedata/wechat_algo_data2/user_action.csv \
--test_csv data/wedata/wechat_algo_data1/test_a.csv,data/wedata/wechat_algo_data1/test_b.csv,data/wedata/wechat_algo_data2/test_a.csv \
--feed_csv data/wedata/wechat_algo_data2/feed_info.csv \
--save_dir data/deepwalk \
--map_file data/preprocess/map.obj \
--config_file config/deepwalk.yaml
# 文本Word2Vec
python -u src/prepare/nlp_process.py \
--feed_csv data/wedata/wechat_algo_data2/feed_info.csv \
--word_dim 512 \
--workers 50 \
--map_path data/preprocess/map.obj \
--save_path data/deepwalk/word.npy
# 生成tensorflow hashtable
python -u src/prepare/feed_hashtable.py \
--feed_csv data/wedata/wechat_algo_data2/feed_info.csv \
--map_file data/preprocess/map.obj \
--save_path data/preprocess/tf_hash/feed.ckpt
# 生成训练集tfrecord,带验证集
python -u src/prepare/multiprocessing_gene_dataset.py \
--csv_file data/wedata/wechat_algo_data1/user_action.csv,data/wedata/wechat_algo_data2/user_action.csv \
--feed_csv_file data/wedata/wechat_algo_data2/feed_info.csv \
--maps_file data/preprocess/map.obj \
--dataset_dir data/preprocess \
--pool_num 20 \
--do_eval true
# 生成训练集tfrecord,全量训练
python -u src/prepare/multiprocessing_gene_dataset.py \
--csv_file data/wedata/wechat_algo_data1/user_action.csv,data/wedata/wechat_algo_data2/user_action.csv \
--feed_csv_file data/wedata/wechat_algo_data2/feed_info.csv \
--maps_file data/preprocess/map.obj \
--dataset_dir data/preprocess \
--pool_num 20 \
--do_eval false
# 生成测试集的tensorflow hashtable
python -u src/prepare/gene_test_hashtable.py \
--feed_csv_file data/wedata/wechat_algo_data2/feed_info.csv \
--features_file data/preprocess/data_and_count.pkl \
--maps_file data/preprocess/map.obj \
--pool_num 40 \
--train_sess_path data/preprocess/tf_hash/feed.ckpt \
--save_path data/preprocess/tf_hash_test/user_feed.ckpt
# 训练Match Tower
python -u src/train/match_tower_run.py \
--action_csv data/wedata/wechat_algo_data1/user_action.csv,data/wedata/wechat_algo_data2/user_action.csv \
--feed_csv data/wedata/wechat_algo_data2/feed_info.csv \
--maps_file data/preprocess/map.obj \
--hashtable_ckpt data/preprocess/tf_hash/feed.ckpt \
--config_file config/match_tower.yaml \
--tfrecord_dir data/match_tower \
--full_run True
# 训练多任务模型
python -u src/train/train_run.py \
--config_file config/config.yaml \
--match_config_file config/match_tower.yaml \
--train_input_file data/preprocess/tfrecord/all_in/no_kfold/1 \
--eval_input_file data/preprocess/tfrecord/eval.tfrecord \
--hashtable_ckpt data/preprocess/tf_hash/feed.ckpt \
--match_tower_ckpt data/model/match_tower