diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index c54f372baa9d05..8aaa66e6c4586d 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -171,9 +171,9 @@ def _forward(self, model_inputs): def postprocess(self, model_outputs, function_to_apply=None, top_k=5): if function_to_apply is None: - if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1: + if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1: function_to_apply = ClassificationFunction.SIGMOID - elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1: + elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1: function_to_apply = ClassificationFunction.SOFTMAX elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None: function_to_apply = self.model.config.function_to_apply