Skip to content

Commit

Permalink
community[minor]: add hugging face text-to-speech inference API (#18880)
Browse files Browse the repository at this point in the history
Description: I implemented a tool to use Hugging Face text-to-speech
inference API.

Issue: n/a

Dependencies: n/a

Twitter handle: No Twitter, but do have
[LinkedIn](https://www.linkedin.com/in/robby-horvath/) lol.

---------

Co-authored-by: Robby <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
3 people authored and hinthornw committed Apr 26, 2024
1 parent 11e1c6a commit da010f9
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 0 deletions.
7 changes: 7 additions & 0 deletions libs/community/langchain_community/tools/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from langchain_community.tools.audio.huggingface_text_to_speech_inference import (
HuggingFaceTextToSpeechModelInference,
)

__all__ = [
"HuggingFaceTextToSpeechModelInference",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import logging
import os
import uuid
from datetime import datetime
from typing import Callable, Literal, Optional

import requests
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.tools import BaseTool

logger = logging.getLogger(__name__)


class HuggingFaceTextToSpeechModelInference(BaseTool):
"""HuggingFace Text-to-Speech Model Inference.
Requirements:
- Environment variable ``HUGGINGFACE_API_KEY`` must be set,
or passed as a named parameter to the constructor.
"""

name: str = "openai_text_to_speech"
"""Name of the tool."""
description: str = "A wrapper around OpenAI Text-to-Speech API. "
"""Description of the tool."""

model: str
"""Model name."""
file_extension: str
"""File extension of the output audio file."""
destination_dir: str
"""Directory to save the output audio file."""
file_namer: Callable[[], str]
"""Function to generate unique file names."""

api_url: str
huggingface_api_key: SecretStr

_HUGGINGFACE_API_KEY_ENV_NAME = "HUGGINGFACE_API_KEY"
_HUGGINGFACE_API_URL_ROOT = "https://api-inference.huggingface.co/models"

def __init__(
self,
model: str,
file_extension: str,
*,
destination_dir: str = "./tts",
file_naming_func: Literal["uuid", "timestamp"] = "uuid",
huggingface_api_key: Optional[SecretStr] = None,
) -> None:
if not huggingface_api_key:
huggingface_api_key = SecretStr(
os.getenv(self._HUGGINGFACE_API_KEY_ENV_NAME, "")
)

if (
not huggingface_api_key
or not huggingface_api_key.get_secret_value()
or huggingface_api_key.get_secret_value() == ""
):
raise ValueError(
f"'{self._HUGGINGFACE_API_KEY_ENV_NAME}' must be or set or passed"
)

if file_naming_func == "uuid":
file_namer = lambda: str(uuid.uuid4()) # noqa: E731
elif file_naming_func == "timestamp":
file_namer = lambda: str(int(datetime.now().timestamp())) # noqa: E731
else:
raise ValueError(
f"Invalid value for 'file_naming_func': {file_naming_func}"
)

super().__init__(
model=model,
file_extension=file_extension,
api_url=f"{self._HUGGINGFACE_API_URL_ROOT}/{model}",
destination_dir=destination_dir,
file_namer=file_namer,
huggingface_api_key=huggingface_api_key,
)

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
response = requests.post(
self.api_url,
headers={
"Authorization": f"Bearer {self.huggingface_api_key.get_secret_value()}"
},
json={"inputs": query},
)
audio_bytes = response.content

try:
os.makedirs(self.destination_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating directory '{self.destination_dir}': {e}")
raise

output_file = os.path.join(
self.destination_dir,
f"{str(self.file_namer())}.{self.file_extension}",
)

try:
with open(output_file, mode="xb") as f:
f.write(audio_bytes)
except FileExistsError:
raise ValueError("Output name must be unique")
except Exception as e:
logger.error(f"Error occurred while creating file: {e}")
raise

return output_file
Empty file.
87 changes: 87 additions & 0 deletions libs/community/tests/unit_tests/tools/audio/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Test Audio Tools."""

import os
import tempfile
import uuid
from unittest.mock import Mock, mock_open, patch

import pytest
from langchain_core.pydantic_v1 import SecretStr

from langchain_community.tools.audio import HuggingFaceTextToSpeechModelInference

AUDIO_FORMAT_EXT = "wav"


def test_huggingface_tts_constructor() -> None:
with pytest.raises(ValueError):
os.environ.pop("HUGGINGFACE_API_KEY", None)
HuggingFaceTextToSpeechModelInference(
model="test/model",
file_extension=AUDIO_FORMAT_EXT,
)

with pytest.raises(ValueError):
HuggingFaceTextToSpeechModelInference(
model="test/model",
file_extension=AUDIO_FORMAT_EXT,
huggingface_api_key=SecretStr(""),
)

HuggingFaceTextToSpeechModelInference(
model="test/model",
file_extension=AUDIO_FORMAT_EXT,
huggingface_api_key=SecretStr("foo"),
)

os.environ["HUGGINGFACE_API_KEY"] = "foo"
HuggingFaceTextToSpeechModelInference(
model="test/model",
file_extension=AUDIO_FORMAT_EXT,
)


def test_huggingface_tts_run_with_requests_mock() -> None:
os.environ["HUGGINGFACE_API_KEY"] = "foo"

with tempfile.TemporaryDirectory() as tmp_dir, patch(
"uuid.uuid4"
) as mock_uuid, patch("requests.post") as mock_inference, patch(
"builtins.open", mock_open()
) as mock_file:
input_query = "Dummy input"

mock_uuid_value = uuid.UUID("00000000-0000-0000-0000-000000000000")
mock_uuid.return_value = mock_uuid_value

expected_output_file_base_name = os.path.join(tmp_dir, str(mock_uuid_value))
expected_output_file = f"{expected_output_file_base_name}.{AUDIO_FORMAT_EXT}"

test_audio_content = b"test_audio_bytes"

tts = HuggingFaceTextToSpeechModelInference(
model="test/model",
file_extension=AUDIO_FORMAT_EXT,
destination_dir=tmp_dir,
file_naming_func="uuid",
)

# Mock the requests.post response
mock_response = Mock()
mock_response.content = test_audio_content
mock_inference.return_value = mock_response

output_path = tts._run(input_query)

assert output_path == expected_output_file

mock_inference.assert_called_once_with(
tts.api_url,
headers={
"Authorization": f"Bearer {tts.huggingface_api_key.get_secret_value()}"
},
json={"inputs": input_query},
)

mock_file.assert_called_once_with(expected_output_file, mode="xb")
mock_file.return_value.write.assert_called_once_with(test_audio_content)

0 comments on commit da010f9

Please sign in to comment.