From 0ad38889a68cd491836a13b0d2f8dbe2b6346cd6 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 6 Apr 2021 09:49:51 +0800 Subject: [PATCH] [ErnieSage] Change the spawn to distributed.launch. (#226) * [ErnieSage] Change the spawn to launch. * use paddle distributed. * optimize get_rank call times. --- examples/text_graph/erniesage/README.md | 6 ++--- .../config/erniesage_link_prediction.yaml | 2 +- examples/text_graph/erniesage/data/dataset.py | 9 ++++--- .../text_graph/erniesage/link_prediction.py | 25 +++++++++---------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/text_graph/erniesage/README.md b/examples/text_graph/erniesage/README.md index 72c3ac3e3fbc1f..d7a7962c852b88 100644 --- a/examples/text_graph/erniesage/README.md +++ b/examples/text_graph/erniesage/README.md @@ -39,9 +39,9 @@ NLPCC2016-DBQA 是由国际自然语言处理和中文计算会议 NLPCC 于 201 # 数据预处理,建图 python ./preprocessing/dump_graph.py --conf ./config/erniesage_link_prediction.yaml # GPU多卡或单卡模式ErnieSage -python link_prediction.py --conf ./config/erniesage_link_prediction.yaml -# 对图节点的的embeding进行预测 -python link_prediction.py --conf ./config/erniesage_link_prediction.yaml --do_predict +python -m paddle.distributed.launch --gpus "0" link_prediction.py --conf ./config/erniesage_link_prediction.yaml +# 对图节点的的embeding进行预测, 单卡或多卡 +python -m paddle.distributed.launch --gpus "0" link_prediction.py --conf ./config/erniesage_link_prediction.yaml --do_predict ``` ## 超参数设置 diff --git a/examples/text_graph/erniesage/config/erniesage_link_prediction.yaml b/examples/text_graph/erniesage/config/erniesage_link_prediction.yaml index 3beb069111b093..509fe3159d1adf 100644 --- a/examples/text_graph/erniesage/config/erniesage_link_prediction.yaml +++ b/examples/text_graph/erniesage/config/erniesage_link_prediction.yaml @@ -1,7 +1,7 @@ # Global Enviroment Settings # trainer config ------ -n_gpu: 1 # number of gpus used to train, delete it, if use cpu +device: "gpu" # use cpu or gpu devices to train. seed: 2020 task: "link_prediction" diff --git a/examples/text_graph/erniesage/data/dataset.py b/examples/text_graph/erniesage/data/dataset.py index 91eeed1af90303..f2e2003ad552d5 100644 --- a/examples/text_graph/erniesage/data/dataset.py +++ b/examples/text_graph/erniesage/data/dataset.py @@ -15,6 +15,7 @@ import os import numpy as np +import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.io import Dataset, IterableDataset @@ -32,8 +33,8 @@ class TrainData(Dataset): def __init__(self, graph_work_path): - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) - trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + trainer_id = paddle.distributed.get_rank() + trainer_count = paddle.distributed.get_world_size() print("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count)) @@ -63,8 +64,8 @@ def __len__(self): class PredictData(Dataset): def __init__(self, num_nodes): - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) - trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + trainer_id = paddle.distributed.get_rank() + trainer_count = paddle.distributed.get_world_size() self.data = np.arange(trainer_id, num_nodes, trainer_count) def __getitem__(self, index): diff --git a/examples/text_graph/erniesage/link_prediction.py b/examples/text_graph/erniesage/link_prediction.py index 0c008ec86c7a11..6a08a525347d40 100644 --- a/examples/text_graph/erniesage/link_prediction.py +++ b/examples/text_graph/erniesage/link_prediction.py @@ -44,7 +44,7 @@ def load_data(graph_data_path): def do_train(config): - paddle.set_device("gpu" if config.n_gpu else "cpu") + paddle.set_device(config.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(config) @@ -72,6 +72,7 @@ def do_train(config): optimizer = paddle.optimizer.Adam( learning_rate=config.lr, parameters=model.parameters()) + rank = paddle.distributed.get_rank() global_step = 0 tic_train = time.time() for epoch in range(config.epoch): @@ -88,13 +89,13 @@ def do_train(config): optimizer.step() optimizer.clear_grad() if global_step % config.save_per_step == 0: - if (not config.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if rank == 0: output_dir = os.path.join(config.output_path, "model_%d" % global_step) if not os.path.exists(output_dir): os.makedirs(output_dir) model._layers.save_pretrained(output_dir) - if (not config.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if rank == 0: output_dir = os.path.join(config.output_path, "last") if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -107,7 +108,7 @@ def tostr(data_array): @paddle.no_grad() def do_predict(config): - paddle.set_device("gpu" if config.n_gpu else "cpu") + paddle.set_device(config.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() set_seed(config) @@ -136,7 +137,7 @@ def do_predict(config): num_workers=config.sample_workers, collate_fn=collate_fn) - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_id = paddle.distributed.get_rank() id2str = io.open( os.path.join(config.graph_work_path, "terms.txt"), encoding=config.encoding).readlines() @@ -172,14 +173,12 @@ def do_predict(config): parser.add_argument("--do_predict", action='store_true', default=False) args = parser.parse_args() config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader)) + + assert config.device in [ + "gpu", "cpu" + ], "Device should be gpu/cpu, but got %s." % config.device logger.info(config) if args.do_predict: - do_func = do_predict - else: - do_func = do_train - - if config.n_gpu > 1 and paddle.fluid.core.is_compiled_with_cuda( - ) and paddle.fluid.core.get_cuda_device_count() > 1: - paddle.distributed.spawn(do_func, args=(config, ), nprocs=config.n_gpu) + do_predict(config) else: - do_func(config) + do_train(config)