From c8861376adee4c0f962918416dd356fdab552189 Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Wed, 29 May 2024 18:57:54 +0800 Subject: [PATCH] Improve `transformers-cli env` reporting (#31003) * Improve `transformers-cli env` reporting * move the line `"Using GPU in script?": ""` to in if conditional statement * same option for npu --- src/transformers/commands/env.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index 8567bbcf5b6..da9ca6660be 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -26,6 +26,7 @@ is_safetensors_available, is_tf_available, is_torch_available, + is_torch_npu_available, ) from . import BaseTransformersCLICommand @@ -88,6 +89,7 @@ def run(self): pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() + pt_npu_available = is_torch_npu_available() tf_version = "not installed" tf_cuda_available = "NA" @@ -129,9 +131,15 @@ def run(self): "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", "Jax version": f"{jax_version}", "JaxLib version": f"{jaxlib_version}", - "Using GPU in script?": "", "Using distributed or parallel set-up in script?": "", } + if pt_cuda_available: + info["Using GPU in script?"] = "" + info["GPU type"] = torch.cuda.get_device_name() + elif pt_npu_available: + info["Using NPU in script?"] = "" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print(self.format_dict(info))