From 4b850078a214b234db281b9eb92fafc5363a944c Mon Sep 17 00:00:00 2001 From: Niels Date: Sun, 19 May 2024 11:00:40 +0200 Subject: [PATCH 1/3] Update ignore index --- src/transformers/models/idefics2/modeling_idefics2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 17ed18f6b99d..6acabad0635b 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1857,7 +1857,7 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=self.image_token_id) + loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: From 5f56c1432949429ba1e3e84f6a876e3c221e66d8 Mon Sep 17 00:00:00 2001 From: Niels Date: Mon, 20 May 2024 14:50:02 +0200 Subject: [PATCH 2/3] Update docs --- docs/source/en/model_doc/idefics2.md | 52 ++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/docs/source/en/model_doc/idefics2.md b/docs/source/en/model_doc/idefics2.md index 31a7a1cdeb6e..f7fdf363a6e8 100644 --- a/docs/source/en/model_doc/idefics2.md +++ b/docs/source/en/model_doc/idefics2.md @@ -87,6 +87,58 @@ generated_text = processor.batch_decode(generated_text, skip_special_tokens=True print("Generated text:", generated_text) ``` +- During training, it's important to determine which tokens the model should not learn. For Idefics2, this typically comes down to the image and padding tokens. This means that one can create the labels as follows: + +```python +import requests +from PIL import Image +from transformers import Idefics2Processor, Idefics2ForConditionalGeneration +import torch + +url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" +url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" + +image_1 = Image.open(requests.get(url_1, stream=True).raw) +image_2 = Image.open(requests.get(url_2, stream=True).raw) +images = [image_1, image_2] + +messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], +}, +{ + "role": "assistant", + "content": [ + {"type": "text", "text": "The difference is that one image is about dogs and the other one about cats."}, + ], +}] + +device = "cuda" if torch.cuda.is_available() else "cpu" + +processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b") +model = Idefics2ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics2-8b") +model.to(device) + +text = processor.apply_chat_template(messages, add_generation_prompt=False) +inputs = processor(images=images, text=text, return_tensors="pt").to(device) + +labels = inputs.input_ids.clone() +labels[labels == processor.tokenizer.pad_token_id] = -100 +labels[labels == processor.image_processor.image_token_id] = -100 + +inputs["labels"] = labels + +outputs = model(**inputs) +loss = outputs.loss +loss.backward() +``` + +Do note that when training Idefics2 on multi-turn conversations between a user and an assistant, one typically also sets all the tokens corresponding to the user messages to -100. + ## Model optimizations: Flash Attention The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model. From 910d11bde7116008414ffa81247edbb4c7acdd03 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Mon, 20 May 2024 16:14:50 +0200 Subject: [PATCH 3/3] Update docs --- docs/source/en/model_doc/idefics2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/idefics2.md b/docs/source/en/model_doc/idefics2.md index f7fdf363a6e8..5ad56b7b5c52 100644 --- a/docs/source/en/model_doc/idefics2.md +++ b/docs/source/en/model_doc/idefics2.md @@ -128,7 +128,7 @@ inputs = processor(images=images, text=text, return_tensors="pt").to(device) labels = inputs.input_ids.clone() labels[labels == processor.tokenizer.pad_token_id] = -100 -labels[labels == processor.image_processor.image_token_id] = -100 +labels[labels == model.config.image_token_id] = -100 inputs["labels"] = labels