Skip to content

Commit

Permalink
Change Paligemma import logging to work with modular (huggingface#34211)
Browse files Browse the repository at this point in the history
* change import logging

* fix CI
  • Loading branch information
yonigozlan authored and BernardZach committed Dec 5, 2024
1 parent 465f465 commit 94c3e1f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 38 deletions.
41 changes: 5 additions & 36 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand Down Expand Up @@ -921,6 +920,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
Expand Down Expand Up @@ -1071,18 +1071,7 @@ def forward(

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits, labels, self.vocab_size)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -1186,27 +1175,8 @@ def forward(

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)

if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down Expand Up @@ -1289,8 +1259,7 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = self.loss_function(logits, labels, self.config)

if not return_dict:
output = (logits,) + outputs[2:]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Processor class for PaliGemma.
"""

import logging
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
Expand All @@ -34,9 +33,10 @@
PreTokenizedInput,
TextInput,
)
from ...utils import logging


logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)

IMAGE_TOKEN = "<image>"
EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]
Expand Down

0 comments on commit 94c3e1f

Please sign in to comment.