Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llava half precision and autocast issues #29721

Merged
1 change: 1 addition & 0 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def forward(
)

image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch Llava-NeXT model."""

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -306,8 +307,8 @@ def __init__(self, config: LlavaNextConfig):
self.vision_tower = AutoModel.from_config(config.vision_config)

self.multi_modal_projector = LlavaNextMultiModalProjector(config)

self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype))
embed_std = 1 / math.sqrt(config.text_config.hidden_size)
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)

self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
Expand Down Expand Up @@ -543,7 +544,9 @@ def forward(
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
self.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.dtype),
),
dim=-1,
)
Expand All @@ -554,6 +557,7 @@ def forward(
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = inputs_embeds.to(image_features.dtype)

inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
Expand Down
34 changes: 28 additions & 6 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def prepare_config_and_inputs_for_common(self):
}
return config, inputs_dict

def create_and_check_llava_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
model = LlavaForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some concerns about putting this here given I haven't seen autocasting tested in the unit tests elsewhere in the repo. Let me know if you prefer this being tested through an integration test with accelerate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your intuition is correct :) We shouldn't need to do this autocasting here. Without it, I'm assuming it fails?

Could you specify what you're thinking re an accelerate integration test and the need for it? Other fp16 test don't specify a specific accelerate version, so not sure on what it would be addressing here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of the opposite. It works fine without autocasting in .half. The issue is when using a trainer with the fp16 or bf16 flag the model returns a type error. This uses autocasting behind the scenes through accelerate so I wrote these tests as the simplest case to capture this failure. I was not quite sure how to handle this in these tests given we are not testing autocasting behavior elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, just to make sure I've understood, the issue arises when passing the model to Trainer? Does the following test pass?

    def create_and_check_llava_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask):
        model = LlavaForConditionalGeneration(config=config).to(torch_device).half().eval()
        output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
        self.parent.assertFalse(torch.isnan(output).any().item())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the issue arises in passing the model to the trainer with the fp16 or bf16 flags. The test you included would work just fine. As far as I can tell this is due to how the model works when autocast is used within the trainer (indirectly through accelerate). I was able to replicate the same bug in this test using the autocast block as I see when I run the model through the trainer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK. I think we should do this in two steps. Adding an integration test for trainer & autocast for models sounds like a good idea. I suspect it might throw up quite a few things to address and the design of the test is important to make sure it's as lightweight as possible. By splitting up, we can add this fix in quickly and then iterate on the test design / fixing errors it throws.

What I would suggest is:

  • Keep the change to llava next and the create_and_check_llava_next_model_fp16_forward test in this PR
  • Open a new PR for adding an integration test. This would possibly sit under tests/trainer/test_trainer.py - @muellerzr will be able to advise here re where and design :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree with this way forward @frasermince !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also document the test with some comments explaining why it's needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a plan! I'll get to work on that! Would we want this new integration test still namespaced under the model something like tests/llava/trainer/test_trainer.py or are you suggesting we add testing more generally to the trainer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe test/modeling/llava/test_trainer_llava.py

since for now its model specific?

logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())


@require_torch
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -225,7 +238,7 @@ def test_small_model_integration_test(self):

@slow
@require_bitsandbytes
def test_small_model_integration_test_llama(self):
def test_small_model_integration_test_llama_single(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf"

Expand All @@ -238,7 +251,7 @@ def test_small_model_integration_test_llama(self):
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Lastly, be respectful of the environment and other visitors, as the pier is a shared space where people can enjoy the view, relax, or engage in recreational activities." # fmt: skip

self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
Expand Down Expand Up @@ -267,7 +280,10 @@ def test_small_model_integration_test_llama_batched(self):

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, you', 'USER: \nWhat is this? ASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_bitsandbytes
Expand All @@ -287,7 +303,10 @@ def test_small_model_integration_test_batch(self):
output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_bitsandbytes
Expand All @@ -314,7 +333,10 @@ def test_small_model_integration_test_llama_batched_regression(self):

EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip

self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_torch
Expand Down Expand Up @@ -342,7 +364,7 @@ def test_batched_generation(self):
model = model.eval()

EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog in one and a ll",
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
]
Expand Down
73 changes: 66 additions & 7 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
)


if is_torch_available():
Expand Down Expand Up @@ -157,6 +167,39 @@ def prepare_config_and_inputs_for_common(self):
}
return config, inputs_dict

def create_and_check_llava_next_model_fp16_forward(
self, config, input_ids, pixel_values, attention_mask, image_sizes
):
model = LlavaNextForConditionalGeneration(config=config)
model.to(torch_device)
model.half()
model.eval()
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_sizes=image_sizes,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())

def create_and_check_llava_next_model_fp16_autocast_forward(
self, config, input_ids, pixel_values, attention_mask, image_sizes
):
config.torch_dtype = torch.float16
model = LlavaNextForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
image_sizes=image_sizes,
pixel_values=pixel_values.to(torch.bfloat16),
return_dict=True,
)["logits"]
self.parent.assertFalse(torch.isnan(logits).any().item())


@require_torch
class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -239,14 +282,20 @@ def test_small_model_integration_test(self):
inputs = self.processor(self.prompt, self.image, return_tensors="pt")

# verify inputs against original implementation
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset")
filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="llava_1_6_input_ids.pt",
repo_type="dataset",
)
original_input_ids = torch.load(filepath, map_location="cpu")
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
original_input_ids[original_input_ids == -200] = model.config.image_token_index
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()

filepath = hf_hub_download(
repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset"
repo_id="nielsr/test-image",
filename="llava_1_6_pixel_values.pt",
repo_type="dataset",
)
original_pixel_values = torch.load(filepath, map_location="cpu")
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
Expand All @@ -257,7 +306,11 @@ def test_small_model_integration_test(self):
output = model(**inputs)

expected_slice = torch.tensor(
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]],
[
[-4.7695, -4.5664, -0.2786],
[-10.6250, -10.8906, -2.5254],
[-6.7383, -7.2461, -0.6787],
],
dtype=torch.float32,
device=torch_device,
)
Expand All @@ -282,7 +335,10 @@ def test_small_model_integration_test_batch(self):
cats_image = Image.open(requests.get(url, stream=True).raw)

inputs = self.processor(
[self.prompt, self.prompt], images=[self.image, cats_image], return_tensors="pt", padding=True
[self.prompt, self.prompt],
images=[self.image, cats_image],
return_tensors="pt",
padding=True,
).to(torch_device)

# make sure image_sizes are the same
Expand All @@ -292,7 +348,10 @@ def test_small_model_integration_test_batch(self):
output = model.generate(**inputs, max_new_tokens=20)

EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
@require_bitsandbytes
Expand Down
Loading