Skip to content

Commit

Permalink
[tokenizers] Support import zero-shot-classification to model zoo
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Dec 31, 2024
1 parent 94fbc93 commit ea8c0b2
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import onnx
from torch import nn
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.modeling_outputs import TokenClassifierOutput, Seq2SeqSequenceClassifierOutput

from djl_converter.safetensors_convert import convert_file
import torch
Expand Down Expand Up @@ -58,8 +58,10 @@ def forward(self,
output = self.model(input_ids, attention_mask)
else:
output = self.model(input_ids, attention_mask, token_type_ids)
if isinstance(output, TokenClassifierOutput):
# TokenClassifierOutput may contains mix of Tensor and Tuple(Tensor)
if isinstance(output, TokenClassifierOutput) or isinstance(
output, Seq2SeqSequenceClassifierOutput):
# TokenClassifierOutput/Seq2SeqSequenceClassifierOutput
# may contains mix of Tensor and Tuple(Tensor)
return {"logits": output["logits"]}

return output
Expand Down Expand Up @@ -303,14 +305,13 @@ def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str,

# noinspection PyBroadException
try:
wrapper = ModelWrapper(hf_pipeline.model, include_types)
if include_types:
script_module = torch.jit.trace(
ModelWrapper(hf_pipeline.model, include_types),
(input_ids, attention_mask, token_type_ids),
wrapper, (input_ids, attention_mask, token_type_ids),
strict=False)
else:
script_module = torch.jit.trace(ModelWrapper(
hf_pipeline.model, include_types),
script_module = torch.jit.trace(wrapper,
(input_ids, attention_mask),
strict=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from djl_converter.sentence_similarity_converter import SentenceSimilarityConverter
from djl_converter.text_classification_converter import TextClassificationConverter
from djl_converter.token_classification_converter import TokenClassificationConverter
from djl_converter.zero_shot_classification_converter import ZeroShotClassificationConverter

ARCHITECTURES_2_TASK = {
"ForQuestionAnswering": "question-answering",
Expand All @@ -42,6 +43,7 @@
"sentence-similarity": SentenceSimilarityConverter(),
"text-classification": TextClassificationConverter(),
"token-classification": TokenClassificationConverter(),
"zero-shot-classification": ZeroShotClassificationConverter(),
}


Expand Down Expand Up @@ -127,6 +129,9 @@ def list_models(self, args: Namespace) -> List[dict]:
if not task:
if "sentence-similarity" in model_info.tags:
task = "sentence-similarity"
else:
if "zero-shot-classification" in model_info.tags:
task = "zero-shot-classification"

if not task:
logging.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def main():
if not task:
if "sentence-similarity" in model_info.tags:
task = "sentence-similarity"
else:
if "zero-shot-classification" in model_info.tags:
task = "zero-shot-classification"

if not task:
logging.error(
f"Unsupported model architecture: {arch} for {args.model_id}.")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import math

from djl_converter.huggingface_converter import HuggingfaceConverter


class ZeroShotClassificationConverter(HuggingfaceConverter):

def __init__(self):
super().__init__()
self.task = "zero-shot-classification"
self.application = "nlp/zero_shot_classification"
self.translator = "ai.djl.huggingface.translator.ZeroShotClassificationTranslatorFactory"
self.inputs = "one day I will see the world"
self.labels = ['travel']

def encode_inputs(self, tokenizer):
return tokenizer(self.inputs,
f"This example is {self.labels[0]}.",
return_tensors='pt')

def verify_jit_output(self, hf_pipeline, encoding, out):
logits = out['logits']
entail_contradiction_logits = logits[:, [0, 2]]
probs = entail_contradiction_logits.softmax(dim=1)
score = probs[:, 1].item()

pipeline_output = hf_pipeline(self.inputs, self.labels)
expected = pipeline_output["scores"][0]

if math.isclose(expected, score, abs_tol=1e-3):
return True, None

return False, f"Unexpected inference result"

def get_extra_arguments(self, hf_pipeline, model_id: str,
temp_dir: str) -> dict:
return {
"padding": "true",
"truncation": "only_first",
}

0 comments on commit ea8c0b2

Please sign in to comment.