From 8467d56610b9d04c7faf4c5cb2f593127cff92dd Mon Sep 17 00:00:00 2001 From: Vishal Agarwal Date: Wed, 15 May 2024 14:21:33 +0530 Subject: [PATCH] fix dtype of config.head_dim to int for directml llm example (#1159) fix `config.head_dim` to `int` type. The changes in #1138 made the head_dim param to type `float`. --- examples/directml/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/directml/llm/llm.py b/examples/directml/llm/llm.py index 1aa745a41..f96690c56 100644 --- a/examples/directml/llm/llm.py +++ b/examples/directml/llm/llm.py @@ -52,7 +52,7 @@ def set_config_parameters(tokenizer: transformers.AutoTokenizer, repo_id: str, n config.hidden_size = llm_model.config.hidden_size config.num_heads = llm_model.config.num_attention_heads - config.head_dim = getattr(llm_model.config, "head_dim", config.hidden_size / config.num_heads) + config.head_dim = getattr(llm_model.config, "head_dim", config.hidden_size // config.num_heads) config.num_layers = num_layers or llm_model.config.num_hidden_layers config.vocab_size = llm_model.config.vocab_size config.model_type = main_model.config.model_type