diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 14da9e697af9e9..9cc485e4601c6e 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -502,7 +502,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class StoppingCriteriaList(list): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: - is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device) + is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool) for criteria in self: is_done = is_done | criteria(input_ids, scores, **kwargs) return is_done