From 68bbbd906cb24d412c4512a33ad0a9ce58b82d49 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Thu, 14 Sep 2023 01:33:52 +0530 Subject: [PATCH] Flex xpu bug fix (#26135) flex gpu bug fix --- src/transformers/training_args.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1a03a3448a91..7aae329f3d68 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1425,12 +1425,13 @@ def __post_init__(self): and is_torch_available() and (self.device.type != "cuda") and (self.device.type != "npu") + and (self.device.type != "xpu") and (get_xla_device_type(self.device) != "GPU") and (self.fp16 or self.fp16_full_eval) ): raise ValueError( "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" - " (`--fp16_full_eval`) can only be used on CUDA or NPU devices." + " (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)." ) if (