Skip to content

Commit

Permalink
Fix typo in classification function selection logic to improve code c…
Browse files Browse the repository at this point in the history
…onsistency (#32031)

Make problem_type condition consistent with num_labels condition

The latter condition generally overrides the former, so this is more of a code reading issue. I'm not sure the bug would ever actually get triggered under normal use.
  • Loading branch information
moses authored and itazap committed Jul 25, 2024
1 parent 75918d5 commit a99fa54
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/pipelines/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def _forward(self, model_inputs):

def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1:
if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
function_to_apply = ClassificationFunction.SIGMOID
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1:
elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
function_to_apply = ClassificationFunction.SOFTMAX
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
function_to_apply = self.model.config.function_to_apply
Expand Down

0 comments on commit a99fa54

Please sign in to comment.