From a4bc2e5f91d93269944c9bfdd08beb5ee6a8cc32 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Tue, 15 Oct 2024 15:33:32 +0200 Subject: [PATCH] [Inference Client] Factorize inference payload build (#2601) * Factorize inference payload build and add test * Add comments * Add method description * fix style * fix style again * fix prepare payload helper * experiment: try old version of workflow * revert experiment: try old version of workflow * Add docstring * update docstring * simplify json payload construction when inputs is a dict * ignore mypy str bytes warning * fix encoding condition * remove unnecessary checks for parameters --- src/huggingface_hub/inference/_client.py | 196 ++++++------------ src/huggingface_hub/inference/_common.py | 46 +++- .../inference/_generated/_async_client.py | 196 ++++++------------ tests/test_inference_client.py | 145 ++++++++++++- 4 files changed, 306 insertions(+), 277 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 38b37b71e3..8d5a8e6a38 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -57,6 +57,7 @@ _get_unsupported_text_generation_kwargs, _import_numpy, _open_as_binary, + _prepare_payload, _set_unsupported_text_generation_kwargs, _stream_chat_completion_response, _stream_text_generation_response, @@ -364,18 +365,8 @@ def audio_classification( ``` """ parameters = {"function_to_apply": function_to_apply, "top_k": top_k} - if all(parameter is None for parameter in parameters.values()): - # if no parameters are provided, send audio as raw data - data = audio - payload: Optional[Dict[str, Any]] = None - else: - # Or some parameters are provided -> send audio as base64 encoded string - data = None - payload = {"inputs": _b64_encode(audio)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, data=data, model=model, task="audio-classification") + payload = _prepare_payload(audio, parameters=parameters, expect_binary=True) + response = self.post(**payload, model=model, task="audio-classification") return AudioClassificationOutputElement.parse_obj_as_list(response) def audio_to_audio( @@ -988,7 +979,7 @@ def document_question_answering( [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16, words=None)] ``` """ - payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} parameters = { "doc_stride": doc_stride, "handle_impossible_answer": handle_impossible_answer, @@ -999,10 +990,8 @@ def document_question_answering( "top_k": top_k, "word_boxes": word_boxes, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="document-question-answering") + payload = _prepare_payload(inputs, parameters=parameters) + response = self.post(**payload, model=model, task="document-question-answering") return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) def feature_extraction( @@ -1060,17 +1049,14 @@ def feature_extraction( [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) ``` """ - payload: Dict = {"inputs": text} parameters = { "normalize": normalize, "prompt_name": prompt_name, "truncate": truncate, "truncation_direction": truncation_direction, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="feature-extraction") + payload = _prepare_payload(text, parameters=parameters) + response = self.post(**payload, model=model, task="feature-extraction") np = _import_numpy() return np.array(_bytes_to_dict(response), dtype="float32") @@ -1119,12 +1105,9 @@ def fill_mask( ] ``` """ - payload: Dict = {"inputs": text} parameters = {"targets": targets, "top_k": top_k} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="fill-mask") + payload = _prepare_payload(text, parameters=parameters) + response = self.post(**payload, model=model, task="fill-mask") return FillMaskOutputElement.parse_obj_as_list(response) def image_classification( @@ -1166,19 +1149,8 @@ def image_classification( ``` """ parameters = {"function_to_apply": function_to_apply, "top_k": top_k} - - if all(parameter is None for parameter in parameters.values()): - data = image - payload: Optional[Dict[str, Any]] = None - - else: - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - - response = self.post(json=payload, data=data, model=model, task="image-classification") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = self.post(**payload, model=model, task="image-classification") return ImageClassificationOutputElement.parse_obj_as_list(response) def image_segmentation( @@ -1237,18 +1209,8 @@ def image_segmentation( "subtask": subtask, "threshold": threshold, } - if all(parameter is None for parameter in parameters.values()): - # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image - data = image - payload: Optional[Dict[str, Any]] = None - else: - # if parameters are provided, the image needs to be a base64-encoded string - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, data=data, model=model, task="image-segmentation") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = self.post(**payload, model=model, task="image-segmentation") output = ImageSegmentationOutputElement.parse_obj_as_list(response) for item in output: item.mask = _b64_to_image(item.mask) # type: ignore [assignment] @@ -1323,19 +1285,8 @@ def image_to_image( "guidance_scale": guidance_scale, **kwargs, } - if all(parameter is None for parameter in parameters.values()): - # Either only an image to send => send as raw bytes - data = image - payload: Optional[Dict[str, Any]] = None - else: - # if parameters are provided, the image needs to be a base64-encoded string - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - - response = self.post(json=payload, data=data, model=model, task="image-to-image") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = self.post(**payload, model=model, task="image-to-image") return _bytes_to_image(response) def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: @@ -1493,25 +1444,15 @@ def object_detection( ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() - >>> client.object_detection("people.jpg"): + >>> client.object_detection("people.jpg") [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] ``` """ parameters = { "threshold": threshold, } - if all(parameter is None for parameter in parameters.values()): - # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image - data = image - payload: Optional[Dict[str, Any]] = None - else: - # if parameters are provided, the image needs to be a base64-encoded string - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, data=data, model=model, task="object-detection") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = self.post(**payload, model=model, task="object-detection") return ObjectDetectionOutputElement.parse_obj_as_list(response) def question_answering( @@ -1587,12 +1528,10 @@ def question_answering( "max_seq_len": max_seq_len, "top_k": top_k, } - payload: Dict[str, Any] = {"question": question, "context": context} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value + inputs: Dict[str, Any] = {"question": question, "context": context} + payload = _prepare_payload(inputs, parameters=parameters) response = self.post( - json=payload, + **payload, model=model, task="question-answering", ) @@ -1700,19 +1639,14 @@ def summarization( SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") ``` """ - payload: Dict[str, Any] = {"inputs": text} - if parameters is not None: - payload["parameters"] = parameters - else: + if parameters is None: parameters = { "clean_up_tokenization_spaces": clean_up_tokenization_spaces, "generate_parameters": generate_parameters, "truncation": truncation, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="summarization") + payload = _prepare_payload(text, parameters=parameters) + response = self.post(**payload, model=model, task="summarization") return SummarizationOutput.parse_obj_as_list(response)[0] def table_question_answering( @@ -1757,15 +1691,13 @@ def table_question_answering( TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') ``` """ - payload: Dict[str, Any] = { + inputs = { "query": query, "table": table, } - - if parameters is not None: - payload["parameters"] = parameters + payload = _prepare_payload(inputs, parameters=parameters) response = self.post( - json=payload, + **payload, model=model, task="table-question-answering", ) @@ -1813,7 +1745,11 @@ def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] ["5", "5", "5"] ``` """ - response = self.post(json={"table": table}, model=model, task="tabular-classification") + response = self.post( + json={"table": table}, + model=model, + task="tabular-classification", + ) return _bytes_to_list(response) def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: @@ -1899,15 +1835,16 @@ def text_classification( ] ``` """ - payload: Dict[str, Any] = {"inputs": text} parameters = { "function_to_apply": function_to_apply, "top_k": top_k, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="text-classification") + payload = _prepare_payload(text, parameters=parameters) + response = self.post( + **payload, + model=model, + task="text-classification", + ) return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] @overload @@ -2481,7 +2418,7 @@ def text_to_image( >>> image.save("better_astronaut.png") ``` """ - payload = {"inputs": prompt} + parameters = { "negative_prompt": negative_prompt, "height": height, @@ -2493,10 +2430,8 @@ def text_to_image( "seed": seed, **kwargs, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value # type: ignore - response = self.post(json=payload, model=model, task="text-to-image") + payload = _prepare_payload(prompt, parameters=parameters) + response = self.post(**payload, model=model, task="text-to-image") return _bytes_to_image(response) def text_to_speech( @@ -2599,7 +2534,6 @@ def text_to_speech( >>> Path("hello_world.flac").write_bytes(audio) ``` """ - payload: Dict[str, Any] = {"inputs": text} parameters = { "do_sample": do_sample, "early_stopping": early_stopping, @@ -2618,10 +2552,8 @@ def text_to_speech( "typical_p": typical_p, "use_cache": use_cache, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="text-to-speech") + payload = _prepare_payload(text, parameters=parameters) + response = self.post(**payload, model=model, task="text-to-speech") return response def token_classification( @@ -2683,17 +2615,15 @@ def token_classification( ] ``` """ - payload: Dict[str, Any] = {"inputs": text} + parameters = { "aggregation_strategy": aggregation_strategy, "ignore_labels": ignore_labels, "stride": stride, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value + payload = _prepare_payload(text, parameters=parameters) response = self.post( - json=payload, + **payload, model=model, task="token-classification", ) @@ -2769,7 +2699,6 @@ def translation( if src_lang is None and tgt_lang is not None: raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") - payload: Dict[str, Any] = {"inputs": text} parameters = { "src_lang": src_lang, "tgt_lang": tgt_lang, @@ -2777,10 +2706,8 @@ def translation( "truncation": truncation, "generate_parameters": generate_parameters, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = self.post(json=payload, model=model, task="translation") + payload = _prepare_payload(text, parameters=parameters) + response = self.post(**payload, model=model, task="translation") return TranslationOutput.parse_obj_as_list(response)[0] def visual_question_answering( @@ -2921,15 +2848,14 @@ def zero_shot_classification( ``` """ - parameters = {"candidate_labels": labels, "multi_label": multi_label} - if hypothesis_template is not None: - parameters["hypothesis_template"] = hypothesis_template - + parameters = { + "candidate_labels": labels, + "multi_label": multi_label, + "hypothesis_template": hypothesis_template, + } + payload = _prepare_payload(text, parameters=parameters) response = self.post( - json={ - "inputs": text, - "parameters": parameters, - }, + **payload, task="zero-shot-classification", model=model, ) @@ -2986,13 +2912,11 @@ def zero_shot_image_classification( if len(labels) < 2: raise ValueError("You must specify at least 2 classes to compare.") - payload = { - "inputs": {"image": _b64_encode(image), "candidateLabels": ",".join(labels)}, - } - if hypothesis_template is not None: - payload.setdefault("parameters", {})["hypothesis_template"] = hypothesis_template + inputs = {"image": _b64_encode(image), "candidateLabels": ",".join(labels)} + parameters = {"hypothesis_template": hypothesis_template} + payload = _prepare_payload(inputs, parameters=parameters) response = self.post( - json=payload, + **payload, model=model, task="zero-shot-image-classification", ) diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index a92d8fad4a..a19636a506 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -58,10 +58,7 @@ is_numpy_available, is_pillow_available, ) -from ._generated.types import ( - ChatCompletionStreamOutput, - TextGenerationStreamOutput, -) +from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput if TYPE_CHECKING: @@ -259,6 +256,47 @@ def _bytes_to_image(content: bytes) -> "Image": return Image.open(io.BytesIO(content)) +## PAYLOAD UTILS + + +def _prepare_payload( + inputs: Union[str, Dict[str, Any], ContentT], + parameters: Optional[Dict[str, Any]], + expect_binary: bool = False, +) -> Dict[str, Any]: + """ + Used in `InferenceClient` and `AsyncInferenceClient` to prepare the payload for an API request, handling various input types and parameters. + `expect_binary` is set to `True` when the inputs are a binary object or a local path or URL. This is the case for image and audio inputs. + """ + if parameters is None: + parameters = {} + parameters = {k: v for k, v in parameters.items() if v is not None} + has_parameters = len(parameters) > 0 + + is_binary = isinstance(inputs, (bytes, Path)) + # If expect_binary is True, inputs must be a binary object or a local path or a URL. + if expect_binary and not is_binary and not isinstance(inputs, str): + raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}") # type: ignore + # Send inputs as raw content when no parameters are provided + if expect_binary and not has_parameters: + return {"data": inputs} + # If expect_binary is False, inputs must not be a binary object. + if not expect_binary and is_binary: + raise ValueError(f"Unexpected binary inputs. Got {inputs}") # type: ignore + + json: Dict[str, Any] = {} + # If inputs is a bytes-like object, encode it to base64 + if expect_binary: + json["inputs"] = _b64_encode(inputs) # type: ignore + # Otherwise (string, dict, list) send it as is + else: + json["inputs"] = inputs + # Add parameters to the json payload if any + if has_parameters: + json["parameters"] = parameters + return {"json": json} + + ## STREAMING UTILS diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 8a1384a671..5c3a8044fc 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -45,6 +45,7 @@ _get_unsupported_text_generation_kwargs, _import_numpy, _open_as_binary, + _prepare_payload, _set_unsupported_text_generation_kwargs, raise_text_generation_error, ) @@ -398,18 +399,8 @@ async def audio_classification( ``` """ parameters = {"function_to_apply": function_to_apply, "top_k": top_k} - if all(parameter is None for parameter in parameters.values()): - # if no parameters are provided, send audio as raw data - data = audio - payload: Optional[Dict[str, Any]] = None - else: - # Or some parameters are provided -> send audio as base64 encoded string - data = None - payload = {"inputs": _b64_encode(audio)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, data=data, model=model, task="audio-classification") + payload = _prepare_payload(audio, parameters=parameters, expect_binary=True) + response = await self.post(**payload, model=model, task="audio-classification") return AudioClassificationOutputElement.parse_obj_as_list(response) async def audio_to_audio( @@ -1031,7 +1022,7 @@ async def document_question_answering( [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16, words=None)] ``` """ - payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} parameters = { "doc_stride": doc_stride, "handle_impossible_answer": handle_impossible_answer, @@ -1042,10 +1033,8 @@ async def document_question_answering( "top_k": top_k, "word_boxes": word_boxes, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="document-question-answering") + payload = _prepare_payload(inputs, parameters=parameters) + response = await self.post(**payload, model=model, task="document-question-answering") return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) async def feature_extraction( @@ -1104,17 +1093,14 @@ async def feature_extraction( [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) ``` """ - payload: Dict = {"inputs": text} parameters = { "normalize": normalize, "prompt_name": prompt_name, "truncate": truncate, "truncation_direction": truncation_direction, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="feature-extraction") + payload = _prepare_payload(text, parameters=parameters) + response = await self.post(**payload, model=model, task="feature-extraction") np = _import_numpy() return np.array(_bytes_to_dict(response), dtype="float32") @@ -1164,12 +1150,9 @@ async def fill_mask( ] ``` """ - payload: Dict = {"inputs": text} parameters = {"targets": targets, "top_k": top_k} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="fill-mask") + payload = _prepare_payload(text, parameters=parameters) + response = await self.post(**payload, model=model, task="fill-mask") return FillMaskOutputElement.parse_obj_as_list(response) async def image_classification( @@ -1212,19 +1195,8 @@ async def image_classification( ``` """ parameters = {"function_to_apply": function_to_apply, "top_k": top_k} - - if all(parameter is None for parameter in parameters.values()): - data = image - payload: Optional[Dict[str, Any]] = None - - else: - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - - response = await self.post(json=payload, data=data, model=model, task="image-classification") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = await self.post(**payload, model=model, task="image-classification") return ImageClassificationOutputElement.parse_obj_as_list(response) async def image_segmentation( @@ -1284,18 +1256,8 @@ async def image_segmentation( "subtask": subtask, "threshold": threshold, } - if all(parameter is None for parameter in parameters.values()): - # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image - data = image - payload: Optional[Dict[str, Any]] = None - else: - # if parameters are provided, the image needs to be a base64-encoded string - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, data=data, model=model, task="image-segmentation") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = await self.post(**payload, model=model, task="image-segmentation") output = ImageSegmentationOutputElement.parse_obj_as_list(response) for item in output: item.mask = _b64_to_image(item.mask) # type: ignore [assignment] @@ -1371,19 +1333,8 @@ async def image_to_image( "guidance_scale": guidance_scale, **kwargs, } - if all(parameter is None for parameter in parameters.values()): - # Either only an image to send => send as raw bytes - data = image - payload: Optional[Dict[str, Any]] = None - else: - # if parameters are provided, the image needs to be a base64-encoded string - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - - response = await self.post(json=payload, data=data, model=model, task="image-to-image") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = await self.post(**payload, model=model, task="image-to-image") return _bytes_to_image(response) async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: @@ -1549,25 +1500,15 @@ async def object_detection( # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() - >>> await client.object_detection("people.jpg"): + >>> await client.object_detection("people.jpg") [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] ``` """ parameters = { "threshold": threshold, } - if all(parameter is None for parameter in parameters.values()): - # if no parameters are provided, the image can be raw bytes, an image file, or URL to an online image - data = image - payload: Optional[Dict[str, Any]] = None - else: - # if parameters are provided, the image needs to be a base64-encoded string - data = None - payload = {"inputs": _b64_encode(image)} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, data=data, model=model, task="object-detection") + payload = _prepare_payload(image, parameters=parameters, expect_binary=True) + response = await self.post(**payload, model=model, task="object-detection") return ObjectDetectionOutputElement.parse_obj_as_list(response) async def question_answering( @@ -1644,12 +1585,10 @@ async def question_answering( "max_seq_len": max_seq_len, "top_k": top_k, } - payload: Dict[str, Any] = {"question": question, "context": context} - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value + inputs: Dict[str, Any] = {"question": question, "context": context} + payload = _prepare_payload(inputs, parameters=parameters) response = await self.post( - json=payload, + **payload, model=model, task="question-answering", ) @@ -1759,19 +1698,14 @@ async def summarization( SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") ``` """ - payload: Dict[str, Any] = {"inputs": text} - if parameters is not None: - payload["parameters"] = parameters - else: + if parameters is None: parameters = { "clean_up_tokenization_spaces": clean_up_tokenization_spaces, "generate_parameters": generate_parameters, "truncation": truncation, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="summarization") + payload = _prepare_payload(text, parameters=parameters) + response = await self.post(**payload, model=model, task="summarization") return SummarizationOutput.parse_obj_as_list(response)[0] async def table_question_answering( @@ -1817,15 +1751,13 @@ async def table_question_answering( TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') ``` """ - payload: Dict[str, Any] = { + inputs = { "query": query, "table": table, } - - if parameters is not None: - payload["parameters"] = parameters + payload = _prepare_payload(inputs, parameters=parameters) response = await self.post( - json=payload, + **payload, model=model, task="table-question-answering", ) @@ -1874,7 +1806,11 @@ async def tabular_classification(self, table: Dict[str, Any], *, model: Optional ["5", "5", "5"] ``` """ - response = await self.post(json={"table": table}, model=model, task="tabular-classification") + response = await self.post( + json={"table": table}, + model=model, + task="tabular-classification", + ) return _bytes_to_list(response) async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: @@ -1962,15 +1898,16 @@ async def text_classification( ] ``` """ - payload: Dict[str, Any] = {"inputs": text} parameters = { "function_to_apply": function_to_apply, "top_k": top_k, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="text-classification") + payload = _prepare_payload(text, parameters=parameters) + response = await self.post( + **payload, + model=model, + task="text-classification", + ) return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] @overload @@ -2546,7 +2483,7 @@ async def text_to_image( >>> image.save("better_astronaut.png") ``` """ - payload = {"inputs": prompt} + parameters = { "negative_prompt": negative_prompt, "height": height, @@ -2558,10 +2495,8 @@ async def text_to_image( "seed": seed, **kwargs, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value # type: ignore - response = await self.post(json=payload, model=model, task="text-to-image") + payload = _prepare_payload(prompt, parameters=parameters) + response = await self.post(**payload, model=model, task="text-to-image") return _bytes_to_image(response) async def text_to_speech( @@ -2665,7 +2600,6 @@ async def text_to_speech( >>> Path("hello_world.flac").write_bytes(audio) ``` """ - payload: Dict[str, Any] = {"inputs": text} parameters = { "do_sample": do_sample, "early_stopping": early_stopping, @@ -2684,10 +2618,8 @@ async def text_to_speech( "typical_p": typical_p, "use_cache": use_cache, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="text-to-speech") + payload = _prepare_payload(text, parameters=parameters) + response = await self.post(**payload, model=model, task="text-to-speech") return response async def token_classification( @@ -2750,17 +2682,15 @@ async def token_classification( ] ``` """ - payload: Dict[str, Any] = {"inputs": text} + parameters = { "aggregation_strategy": aggregation_strategy, "ignore_labels": ignore_labels, "stride": stride, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value + payload = _prepare_payload(text, parameters=parameters) response = await self.post( - json=payload, + **payload, model=model, task="token-classification", ) @@ -2837,7 +2767,6 @@ async def translation( if src_lang is None and tgt_lang is not None: raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") - payload: Dict[str, Any] = {"inputs": text} parameters = { "src_lang": src_lang, "tgt_lang": tgt_lang, @@ -2845,10 +2774,8 @@ async def translation( "truncation": truncation, "generate_parameters": generate_parameters, } - for key, value in parameters.items(): - if value is not None: - payload.setdefault("parameters", {})[key] = value - response = await self.post(json=payload, model=model, task="translation") + payload = _prepare_payload(text, parameters=parameters) + response = await self.post(**payload, model=model, task="translation") return TranslationOutput.parse_obj_as_list(response)[0] async def visual_question_answering( @@ -2992,15 +2919,14 @@ async def zero_shot_classification( ``` """ - parameters = {"candidate_labels": labels, "multi_label": multi_label} - if hypothesis_template is not None: - parameters["hypothesis_template"] = hypothesis_template - + parameters = { + "candidate_labels": labels, + "multi_label": multi_label, + "hypothesis_template": hypothesis_template, + } + payload = _prepare_payload(text, parameters=parameters) response = await self.post( - json={ - "inputs": text, - "parameters": parameters, - }, + **payload, task="zero-shot-classification", model=model, ) @@ -3058,13 +2984,11 @@ async def zero_shot_image_classification( if len(labels) < 2: raise ValueError("You must specify at least 2 classes to compare.") - payload = { - "inputs": {"image": _b64_encode(image), "candidateLabels": ",".join(labels)}, - } - if hypothesis_template is not None: - payload.setdefault("parameters", {})["hypothesis_template"] = hypothesis_template + inputs = {"image": _b64_encode(image), "candidateLabels": ",".join(labels)} + parameters = {"hypothesis_template": hypothesis_template} + payload = _prepare_payload(inputs, parameters=parameters) response = await self.post( - json=payload, + **payload, model=model, task="zero-shot-image-classification", ) diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index a4fa971a2b..b97c62d165 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -49,7 +49,11 @@ from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, MAIN_INFERENCE_API_FRAMEWORKS from huggingface_hub.errors import HfHubHTTPError, ValidationError from huggingface_hub.inference._client import _open_as_binary -from huggingface_hub.inference._common import _stream_chat_completion_response, _stream_text_generation_response +from huggingface_hub.inference._common import ( + _prepare_payload, + _stream_chat_completion_response, + _stream_text_generation_response, +) from huggingface_hub.utils import build_hf_headers from .testing_utils import with_production_testing @@ -1080,3 +1084,142 @@ def test_resolve_chat_completion_url( client = InferenceClient(model=client_model, base_url=client_base_url) url = client._resolve_chat_completion_url(model) assert url == expected_url + + +@pytest.mark.parametrize( + "inputs, parameters, expect_binary, expected_json, expected_data", + [ + # Case 1: inputs is a simple string without parameters + ( + "simple text", + None, + False, + {"inputs": "simple text"}, + None, + ), + # Case 2: inputs is a simple string with parameters + ( + "simple text", + {"param1": "value1"}, + False, + { + "inputs": "simple text", + "parameters": {"param1": "value1"}, + }, + None, + ), + # Case 3: inputs is a dict without parameters + ( + {"input_key": "input_value"}, + None, + False, + {"inputs": {"input_key": "input_value"}}, + None, + ), + # Case 4: inputs is a dict with parameters + ( + {"input_key": "input_value", "input_key2": "input_value2"}, + {"param1": "value1"}, + False, + { + "inputs": {"input_key": "input_value", "input_key2": "input_value2"}, + "parameters": {"param1": "value1"}, + }, + None, + ), + # Case 5: inputs is bytes without parameters + ( + b"binary data", + None, + True, + None, + b"binary data", + ), + # Case 6: inputs is bytes with parameters + ( + b"binary data", + {"param1": "value1"}, + True, + { + "inputs": "encoded_data", + "parameters": {"param1": "value1"}, + }, + None, + ), + # Case 7: inputs is a Path object without parameters + ( + Path("test_file.txt"), + None, + True, + None, + Path("test_file.txt"), + ), + # Case 8: inputs is a Path object with parameters + ( + Path("test_file.txt"), + {"param1": "value1"}, + True, + { + "inputs": "encoded_data", + "parameters": {"param1": "value1"}, + }, + None, + ), + # Case 9: inputs is a URL string without parameters + ( + "http://example.com", + None, + True, + None, + "http://example.com", + ), + # Case 10: inputs is a URL string without parameters but expect_binary is False + ( + "http://example.com", + None, + False, + { + "inputs": "http://example.com", + }, + None, + ), + # Case 11: inputs is a URL string with parameters + ( + "http://example.com", + {"param1": "value1"}, + True, + { + "inputs": "encoded_data", + "parameters": {"param1": "value1"}, + }, + None, + ), + # Case 12: inputs is a URL string with parameters but expect_binary is False + ( + "http://example.com", + {"param1": "value1"}, + False, + { + "inputs": "http://example.com", + "parameters": {"param1": "value1"}, + }, + None, + ), + # Case 13: parameters contain None values + ( + "simple text", + {"param1": None, "param2": "value2"}, + False, + { + "inputs": "simple text", + "parameters": {"param2": "value2"}, + }, + None, + ), + ], +) +def test_prepare_payload(inputs, parameters, expect_binary, expected_json, expected_data): + with patch("huggingface_hub.inference._common._b64_encode", return_value="encoded_data"): + payload = _prepare_payload(inputs, parameters, expect_binary=expect_binary) + assert payload.get("json") == expected_json + assert payload.get("data") == expected_data