diff --git a/torchelastic/distributed/launch.py b/torchelastic/distributed/launch.py index bb77e08e..f07b06ba 100644 --- a/torchelastic/distributed/launch.py +++ b/torchelastic/distributed/launch.py @@ -8,14 +8,16 @@ import os -from torch.distributed.run import parse_args, run +os.environ["LOGLEVEL"] = "INFO" + +# Since logger initialized during imoprt statement +# the log level should be set first +from torch.distributed.run import main as run_main def main(args=None): - args = parse_args(args) - run(args) + run_main(args) if __name__ == "__main__": - os.environ["LOGLEVEL"] = "INFO" main()