-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: add hugging face text-to-speech inference API (#18880)
Description: I implemented a tool to use Hugging Face text-to-speech inference API. Issue: n/a Dependencies: n/a Twitter handle: No Twitter, but do have [LinkedIn](https://www.linkedin.com/in/robby-horvath/) lol. --------- Co-authored-by: Robby <[email protected]> Co-authored-by: Eugene Yurtsev <[email protected]>
- Loading branch information
Showing
4 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from langchain_community.tools.audio.huggingface_text_to_speech_inference import ( | ||
HuggingFaceTextToSpeechModelInference, | ||
) | ||
|
||
__all__ = [ | ||
"HuggingFaceTextToSpeechModelInference", | ||
] |
118 changes: 118 additions & 0 deletions
118
libs/community/langchain_community/tools/audio/huggingface_text_to_speech_inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import logging | ||
import os | ||
import uuid | ||
from datetime import datetime | ||
from typing import Callable, Literal, Optional | ||
|
||
import requests | ||
from langchain_core.callbacks import CallbackManagerForToolRun | ||
from langchain_core.pydantic_v1 import SecretStr | ||
from langchain_core.tools import BaseTool | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HuggingFaceTextToSpeechModelInference(BaseTool): | ||
"""HuggingFace Text-to-Speech Model Inference. | ||
Requirements: | ||
- Environment variable ``HUGGINGFACE_API_KEY`` must be set, | ||
or passed as a named parameter to the constructor. | ||
""" | ||
|
||
name: str = "openai_text_to_speech" | ||
"""Name of the tool.""" | ||
description: str = "A wrapper around OpenAI Text-to-Speech API. " | ||
"""Description of the tool.""" | ||
|
||
model: str | ||
"""Model name.""" | ||
file_extension: str | ||
"""File extension of the output audio file.""" | ||
destination_dir: str | ||
"""Directory to save the output audio file.""" | ||
file_namer: Callable[[], str] | ||
"""Function to generate unique file names.""" | ||
|
||
api_url: str | ||
huggingface_api_key: SecretStr | ||
|
||
_HUGGINGFACE_API_KEY_ENV_NAME = "HUGGINGFACE_API_KEY" | ||
_HUGGINGFACE_API_URL_ROOT = "https://api-inference.huggingface.co/models" | ||
|
||
def __init__( | ||
self, | ||
model: str, | ||
file_extension: str, | ||
*, | ||
destination_dir: str = "./tts", | ||
file_naming_func: Literal["uuid", "timestamp"] = "uuid", | ||
huggingface_api_key: Optional[SecretStr] = None, | ||
) -> None: | ||
if not huggingface_api_key: | ||
huggingface_api_key = SecretStr( | ||
os.getenv(self._HUGGINGFACE_API_KEY_ENV_NAME, "") | ||
) | ||
|
||
if ( | ||
not huggingface_api_key | ||
or not huggingface_api_key.get_secret_value() | ||
or huggingface_api_key.get_secret_value() == "" | ||
): | ||
raise ValueError( | ||
f"'{self._HUGGINGFACE_API_KEY_ENV_NAME}' must be or set or passed" | ||
) | ||
|
||
if file_naming_func == "uuid": | ||
file_namer = lambda: str(uuid.uuid4()) # noqa: E731 | ||
elif file_naming_func == "timestamp": | ||
file_namer = lambda: str(int(datetime.now().timestamp())) # noqa: E731 | ||
else: | ||
raise ValueError( | ||
f"Invalid value for 'file_naming_func': {file_naming_func}" | ||
) | ||
|
||
super().__init__( | ||
model=model, | ||
file_extension=file_extension, | ||
api_url=f"{self._HUGGINGFACE_API_URL_ROOT}/{model}", | ||
destination_dir=destination_dir, | ||
file_namer=file_namer, | ||
huggingface_api_key=huggingface_api_key, | ||
) | ||
|
||
def _run( | ||
self, | ||
query: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
response = requests.post( | ||
self.api_url, | ||
headers={ | ||
"Authorization": f"Bearer {self.huggingface_api_key.get_secret_value()}" | ||
}, | ||
json={"inputs": query}, | ||
) | ||
audio_bytes = response.content | ||
|
||
try: | ||
os.makedirs(self.destination_dir, exist_ok=True) | ||
except Exception as e: | ||
logger.error(f"Error creating directory '{self.destination_dir}': {e}") | ||
raise | ||
|
||
output_file = os.path.join( | ||
self.destination_dir, | ||
f"{str(self.file_namer())}.{self.file_extension}", | ||
) | ||
|
||
try: | ||
with open(output_file, mode="xb") as f: | ||
f.write(audio_bytes) | ||
except FileExistsError: | ||
raise ValueError("Output name must be unique") | ||
except Exception as e: | ||
logger.error(f"Error occurred while creating file: {e}") | ||
raise | ||
|
||
return output_file |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Test Audio Tools.""" | ||
|
||
import os | ||
import tempfile | ||
import uuid | ||
from unittest.mock import Mock, mock_open, patch | ||
|
||
import pytest | ||
from langchain_core.pydantic_v1 import SecretStr | ||
|
||
from langchain_community.tools.audio import HuggingFaceTextToSpeechModelInference | ||
|
||
AUDIO_FORMAT_EXT = "wav" | ||
|
||
|
||
def test_huggingface_tts_constructor() -> None: | ||
with pytest.raises(ValueError): | ||
os.environ.pop("HUGGINGFACE_API_KEY", None) | ||
HuggingFaceTextToSpeechModelInference( | ||
model="test/model", | ||
file_extension=AUDIO_FORMAT_EXT, | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
HuggingFaceTextToSpeechModelInference( | ||
model="test/model", | ||
file_extension=AUDIO_FORMAT_EXT, | ||
huggingface_api_key=SecretStr(""), | ||
) | ||
|
||
HuggingFaceTextToSpeechModelInference( | ||
model="test/model", | ||
file_extension=AUDIO_FORMAT_EXT, | ||
huggingface_api_key=SecretStr("foo"), | ||
) | ||
|
||
os.environ["HUGGINGFACE_API_KEY"] = "foo" | ||
HuggingFaceTextToSpeechModelInference( | ||
model="test/model", | ||
file_extension=AUDIO_FORMAT_EXT, | ||
) | ||
|
||
|
||
def test_huggingface_tts_run_with_requests_mock() -> None: | ||
os.environ["HUGGINGFACE_API_KEY"] = "foo" | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir, patch( | ||
"uuid.uuid4" | ||
) as mock_uuid, patch("requests.post") as mock_inference, patch( | ||
"builtins.open", mock_open() | ||
) as mock_file: | ||
input_query = "Dummy input" | ||
|
||
mock_uuid_value = uuid.UUID("00000000-0000-0000-0000-000000000000") | ||
mock_uuid.return_value = mock_uuid_value | ||
|
||
expected_output_file_base_name = os.path.join(tmp_dir, str(mock_uuid_value)) | ||
expected_output_file = f"{expected_output_file_base_name}.{AUDIO_FORMAT_EXT}" | ||
|
||
test_audio_content = b"test_audio_bytes" | ||
|
||
tts = HuggingFaceTextToSpeechModelInference( | ||
model="test/model", | ||
file_extension=AUDIO_FORMAT_EXT, | ||
destination_dir=tmp_dir, | ||
file_naming_func="uuid", | ||
) | ||
|
||
# Mock the requests.post response | ||
mock_response = Mock() | ||
mock_response.content = test_audio_content | ||
mock_inference.return_value = mock_response | ||
|
||
output_path = tts._run(input_query) | ||
|
||
assert output_path == expected_output_file | ||
|
||
mock_inference.assert_called_once_with( | ||
tts.api_url, | ||
headers={ | ||
"Authorization": f"Bearer {tts.huggingface_api_key.get_secret_value()}" | ||
}, | ||
json={"inputs": input_query}, | ||
) | ||
|
||
mock_file.assert_called_once_with(expected_output_file, mode="xb") | ||
mock_file.return_value.write.assert_called_once_with(test_audio_content) |