From 2f3dc2ddd4b05a320fc492911ba10f146af3c6e8 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 11 Jun 2024 22:04:19 -0400 Subject: [PATCH 1/2] remove torch 113 --- composer/distributed/dist_strategy.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index 1cc1044a02..6f09c67998 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -334,30 +334,6 @@ def sync_hook(*args): keep_low_precision_grads=keep_low_precision_grads, ) - # Note: FSDP does support the use of torch.float32 with sharding. - # They just never expected a user to pass in torch.float32 into mixed_precision as a param_dtype. - # See: https://github.com/pytorch/pytorch/issues/90584 - # The PR fixing this bug is merged into PyTorch, but it hasn't made its way into a release yet. - # Instead a user needs to pass in `None` as param_dtype to have the parameters as torch.float32. - # TODO: remove these checks when PyTorch has a release that includes the fix. - if sharding_map_key != 'NO_SHARD': - if ( - precision == Precision.AMP_FP16 and param_dtype not in [torch.float16, None] or - precision == Precision.AMP_BF16 and param_dtype not in [torch.bfloat16, None] - ): - raise ValueError( - f'FSDP in PyTorch 1.13 does not support precision `{precision}` with sharding strategy `{sharding_strategy}` ' - f'and param_dtype `{param_dtype}.` Consider using one of the predefined mixed_precision strategies ' - "(choose: `'FULL'`, `'DEFAULT'`, `'PURE'`)", - ) - - if param_dtype == torch.float32: - raise ValueError( - f'FSDP in PyTorch 1.13 does not support param_dtype `{param_dtype}` with sharding_strategy `{sharding_map_key}` ' - f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` ' - f'with sharding strategy `{sharding_map_key}.`', - ) - process_group = None if fsdp_config.process_group is not None: process_group_dict = {'process_group': fsdp_config.process_group} From 253cf502aa95bbb580a61bf1fb4aa7f36e00906e Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 12 Jun 2024 00:11:50 -0400 Subject: [PATCH 2/2] lint --- composer/distributed/dist_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index 6f09c67998..be81652881 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -328,7 +328,7 @@ def sync_hook(*args): mixed_precision = fsdp_config.mixed_precision keep_low_precision_grads = fsdp_config.keep_low_precision_grads - mixed_precision, param_dtype, _, _ = get_mixed_precision( + mixed_precision, _, _, _ = get_mixed_precision( precision, mixed_precision=mixed_precision, keep_low_precision_grads=keep_low_precision_grads,