From 6023fe58858a982bbe82416843e1fd9b87db0537 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 17 Jun 2024 10:57:49 -0700 Subject: [PATCH] simplify launcher (#3398) --- composer/cli/launcher.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/composer/cli/launcher.py b/composer/cli/launcher.py index 08dd7b3921..91110c2add 100755 --- a/composer/cli/launcher.py +++ b/composer/cli/launcher.py @@ -197,8 +197,13 @@ def _parse_args(): if args.nproc < 1: raise ValueError('The nproc must be 1 or greater') - if args.world_size is None and 'WORLD_SIZE' in os.environ: - args.world_size = int(os.environ['WORLD_SIZE']) + if args.world_size is None: + if 'WORLD_SIZE' in os.environ and os.environ.get('LOCAL_WORLD_SIZE') != os.environ['WORLD_SIZE']: + # Use WORLD_SIZE env var if set and running multinode. Otherwise, default to nproc + # to enable easy overriding of number of processes when on a single node. + args.world_size = int(os.environ['WORLD_SIZE']) + else: + args.world_size = args.nproc if args.base_rank is None and 'BASE_RANK' in os.environ: args.base_rank = int(os.environ['BASE_RANK']) @@ -212,9 +217,6 @@ def _parse_args(): if args.master_port is None and 'MASTER_PORT' in os.environ: args.master_port = int(os.environ['MASTER_PORT']) - if args.world_size is None: - args.world_size = args.nproc - if args.world_size < args.nproc: raise ValueError(f'world_size({args.world_size}) cannot be less than nproc({args.nproc})')