diff --git a/caikit_nlp/modules/text_classification/sequence_classification.py b/caikit_nlp/modules/text_classification/sequence_classification.py index a485c59b..dadc8dcc 100644 --- a/caikit_nlp/modules/text_classification/sequence_classification.py +++ b/caikit_nlp/modules/text_classification/sequence_classification.py @@ -185,7 +185,7 @@ def _get_scores(self, text: Union[str, List[str]]): softmax = torch.nn.Softmax(dim=1) raw_scores = softmax(logits) - scores = raw_scores.numpy() + scores = raw_scores.double().numpy() num_labels = self.model.num_labels num_texts = 1 # str if isinstance(text, List):