From b87b4905c5eff5db8b11491f341f25a52dc0e6a5 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 4 Jun 2018 13:57:07 +0800 Subject: [PATCH] mem_leak --- benchmark/fluid/fluid_benchmark.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index c1d458970a58b..f2b356765c8e2 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -16,6 +16,7 @@ import cProfile import time import os +import sys import numpy as np @@ -201,6 +202,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, exe.run(train_prog) return + sys.stderr.write('train with Executor\n') if args.use_fake_data: raise Exception( "fake data is not supported in single GPU test for now.") @@ -231,6 +233,8 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, train_losses.append(loss) print("Pass: %d, Iter: %d, Loss: %f\n" % (pass_id, iters, np.mean(train_losses))) + if batch_id == 2: + break train_elapsed = time.time() - start_time examples_per_sec = num_samples / train_elapsed print('\nTotal examples: %d, total time: %.5f, %.5f examples/sec\n' % @@ -243,7 +247,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, print(", Test Accuracy: %f" % pass_test_acc) print("\n") # TODO(wuyi): add warmup passes to get better perf data. - exit(0) + # exit(0) # TODO(wuyi): replace train, train_parallel, test functions with new trainer @@ -361,10 +365,18 @@ def main(): raise Exception( "Must configure correct environments to run dist train.") train_args.extend([train_prog, startup_prog]) - if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": - train_args.extend([nccl_id_var, num_trainers, trainer_id]) - train_parallel(*train_args) - train(*train_args) + + role = os.getenv("PADDLE_TRAINING_ROLE") + if role == "TRAINER": + if args.gpus > 1: + train_args.extend([nccl_id_var, num_trainers, trainer_id]) + train_parallel(*train_args) + else: + train(*train_args) + elif role == "PSERVER": + train(*train_args) + else: + raise Exception("Unknown PADDLE_TRAINING_ROLE: %s" % role) exit(0) # for other update methods, use default programs