diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index f62da5bc9..ab85e3f42 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -6,7 +6,6 @@ import yaml from contextlib import nullcontext from pathlib import Path -from pkg_resources import packaging from datetime import datetime import contextlib @@ -474,7 +473,7 @@ def get_policies(cfg, rank): verify_bfloat_support = (( torch.version.cuda and torch.cuda.is_bf16_supported() - and packaging.version.parse(torch.version.cuda).release >= (11, 0) + and torch.version.cuda >= "11.0" and dist.is_nccl_available() and nccl.version() >= (2, 10) ) or