diff --git a/train.py b/train.py index 88a2652c..c95855e0 100644 --- a/train.py +++ b/train.py @@ -54,6 +54,9 @@ def main(args): args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() # number of gpus of each node + #divide the batch_size according to the number of nodes + args.batch_size = int(args.batch_size / args.world_size) + if args.multiprocessing_distributed: # now, args.world_size means num of total processes in all nodes args.world_size = ngpus_per_node * args.world_size