Skip to content

Commit

Permalink
feat(python/sdk): allow input_text into lemur (#40)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Loeber <[email protected]>
Co-authored-by: Niels Swimberghe <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2023
1 parent d12a626 commit e4d3379
Show file tree
Hide file tree
Showing 16 changed files with 457 additions and 511 deletions.
37 changes: 34 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
[![AssemblyAI Twitter](https://img.shields.io/twitter/follow/AssemblyAI?label=%40AssemblyAI&style=social)](https://twitter.com/AssemblyAI)
[![AssemblyAI YouTube](https://img.shields.io/youtube/channel/subscribers/UCtatfZMf-8EkIwASXM4ts0A)](https://www.youtube.com/@AssemblyAI)
[![Discord](https://img.shields.io/discord/875120158014853141?logo=discord&label=Discord&link=https%3A%2F%2Fdiscord.com%2Fchannels%2F875120158014853141&style=social)
](https://discord.gg/5aQNZyq3)
](https://assemblyai.com/discord)

# AssemblyAI's Python SDK

Expand Down Expand Up @@ -266,6 +266,37 @@ print(result.response)

</details>


<details>
<summary>Use LeMUR to with Input Text</summary>

```python
import assemblyai as aai

transcriber = aai.Transcriber()
config = aai.TranscriptionConfig(
speaker_labels=True,
)
transcript = transcriber.transcribe("https://example.org/customer.mp3", config=config)

# Example converting speaker label utterances into LeMUR input text
text = ""

for utt in transcript.utterances:
text += f"Speaker {utt.speaker}:\n{utt.text}\n"

result = aai.Lemur().task(
"You are a helpful coach. Provide an analysis of the transcript "
"and offer areas to improve with exact quotes. Include no preamble. "
"Start with an overall summary then get into the examples with feedback.",
input_text=text
)

print(result.response)
```

</details>

<details>
<summary>Delete data previously sent to LeMUR</summary>

Expand Down Expand Up @@ -460,7 +491,7 @@ for sentiment_result in transcript.sentiment_analysis:
print(sentiment_result.text)
print(sentiment_result.sentiment) # POSITIVE, NEUTRAL, or NEGATIVE
print(sentiment_result.confidence)
print(f"Timestamp: {sentiment_result.timestamp.start} - {sentiment_result.timestamp.end}")
print(f"Timestamp: {sentiment_result.start} - {sentiment_result.end}")
```

If `speaker_labels` is also enabled, then each sentiment analysis result will also include a `speaker` field.
Expand Down Expand Up @@ -493,7 +524,7 @@ transcript = transcriber.transcribe(

for entity in transcript.entities:
print(entity.text) # i.e. "Dan Gilbert"
print(entity.type) # i.e. EntityType.person
print(entity.entity_type) # i.e. EntityType.person
print(f"Timestamp: {entity.start} - {entity.end}")
```

Expand Down
24 changes: 22 additions & 2 deletions assemblyai/lemur.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ def __init__(
) -> None:
self._client = client

self._sources = [types.LemurSourceRequest.from_lemur_source(s) for s in sources]
self._sources = (
[types.LemurSourceRequest.from_lemur_source(s) for s in sources]
if sources is not None
else []
)

def question(
self,
Expand All @@ -24,6 +28,7 @@ def question(
final_model: Optional[types.LemurModel],
max_output_size: Optional[int],
temperature: Optional[float],
input_text: Optional[str],
) -> types.LemurQuestionResponse:
response = api.lemur_question(
client=self._client.http_client,
Expand All @@ -34,6 +39,7 @@ def question(
final_model=final_model,
max_output_size=max_output_size,
temperature=temperature,
input_text=input_text,
),
http_timeout=timeout,
)
Expand All @@ -48,6 +54,7 @@ def summarize(
max_output_size: Optional[int],
timeout: Optional[float],
temperature: Optional[float],
input_text: Optional[str],
) -> types.LemurSummaryResponse:
response = api.lemur_summarize(
client=self._client.http_client,
Expand All @@ -58,6 +65,7 @@ def summarize(
final_model=final_model,
max_output_size=max_output_size,
temperature=temperature,
input_text=input_text,
),
http_timeout=timeout,
)
Expand All @@ -72,6 +80,7 @@ def action_items(
max_output_size: Optional[int],
timeout: Optional[float],
temperature: Optional[float],
input_text: Optional[str],
) -> types.LemurActionItemsResponse:
response = api.lemur_action_items(
client=self._client.http_client,
Expand All @@ -82,6 +91,7 @@ def action_items(
final_model=final_model,
max_output_size=max_output_size,
temperature=temperature,
input_text=input_text,
),
http_timeout=timeout,
)
Expand All @@ -95,6 +105,7 @@ def task(
max_output_size: Optional[int],
timeout: Optional[float],
temperature: Optional[float],
input_text: Optional[str],
):
response = api.lemur_task(
client=self._client.http_client,
Expand All @@ -104,6 +115,7 @@ def task(
final_model=final_model,
max_output_size=max_output_size,
temperature=temperature,
input_text=input_text,
),
http_timeout=timeout,
)
Expand All @@ -121,7 +133,7 @@ class Lemur:

def __init__(
self,
sources: List[types.LemurSource],
sources: Optional[List[types.LemurSource]] = None,
client: Optional[_client.Client] = None,
) -> None:
"""
Expand All @@ -147,6 +159,7 @@ def question(
max_output_size: Optional[int] = None,
timeout: Optional[float] = None,
temperature: Optional[float] = None,
input_text: Optional[str] = None,
) -> types.LemurQuestionResponse:
"""
Question & Answer allows you to ask free form questions about one or many transcripts.
Expand Down Expand Up @@ -178,6 +191,7 @@ def question(
max_output_size=max_output_size,
timeout=timeout,
temperature=temperature,
input_text=input_text,
)

def summarize(
Expand All @@ -188,6 +202,7 @@ def summarize(
max_output_size: Optional[int] = None,
timeout: Optional[float] = None,
temperature: Optional[float] = None,
input_text: Optional[str] = None,
) -> types.LemurSummaryResponse:
"""
Summary allows you to distill a piece of audio into a few impactful sentences.
Expand All @@ -214,6 +229,7 @@ def summarize(
max_output_size=max_output_size,
timeout=timeout,
temperature=temperature,
input_text=input_text,
)

def action_items(
Expand All @@ -224,6 +240,7 @@ def action_items(
max_output_size: Optional[int] = None,
timeout: Optional[float] = None,
temperature: Optional[float] = None,
input_text: Optional[str] = None,
) -> types.LemurActionItemsResponse:
"""
Action Items allows you to generate action items from one or many transcripts.
Expand Down Expand Up @@ -251,6 +268,7 @@ def action_items(
max_output_size=max_output_size,
timeout=timeout,
temperature=temperature,
input_text=input_text,
)

def task(
Expand All @@ -260,6 +278,7 @@ def task(
max_output_size: Optional[int] = None,
timeout: Optional[float] = None,
temperature: Optional[float] = None,
input_text: Optional[str] = None,
) -> types.LemurTaskResponse:
"""
Task feature allows you to submit a custom prompt to the model.
Expand All @@ -282,6 +301,7 @@ def task(
max_output_size=max_output_size,
timeout=timeout,
temperature=temperature,
input_text=input_text,
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion assemblyai/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def __init__(
client: _client.Client,
) -> None:
self._client = client
self._websocket: Optional[websockets_client.ClientConnection] = None
self._websocket: Optional[websockets.sync.client.ClientConnection] = None

self._on_open = on_open
self._on_data = on_data
Expand Down
9 changes: 5 additions & 4 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def auto_chapters(self, enable: Optional[bool]) -> None:
"Enable Auto Chapters."

# Validate required params are also set
if enable and self.punctuate == False:
if enable and self.punctuate is False:
raise ValueError(
"If `auto_chapters` is enabled, then `punctuate` must not be disabled"
)
Expand Down Expand Up @@ -1146,11 +1146,11 @@ def set_summarize(
return self

# Validate that required parameters are also set
if self._raw_transcription_config.punctuate == False:
if self._raw_transcription_config.punctuate is False:
raise ValueError(
"If `summarization` is enabled, then `punctuate` must not be disabled"
)
if self._raw_transcription_config.format_text == False:
if self._raw_transcription_config.format_text is False:
raise ValueError(
"If `summarization` is enabled, then `format_text` must not be disabled"
)
Expand Down Expand Up @@ -1666,7 +1666,7 @@ def __init__(
"""
from . import Transcript

if type(transcript) == str:
if isinstance(transcript, str):
transcript = Transcript(transcript_id=transcript)

super().__init__(transcript)
Expand Down Expand Up @@ -1773,6 +1773,7 @@ class BaseLemurRequest(BaseModel):
final_model: Optional[LemurModel]
max_output_size: Optional[int]
temperature: Optional[float]
input_text: Optional[str]


class LemurTaskRequest(BaseLemurRequest):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="assemblyai",
version="0.19.0",
version="0.20.0",
description="AssemblyAI Python SDK",
author="AssemblyAI",
author_email="[email protected]",
Expand Down
65 changes: 5 additions & 60 deletions tests/unit/test_auto_chapters.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import json
from typing import Any, Dict, Tuple

import factory
import httpx
import pytest
from pytest_httpx import HTTPXMock

import tests.unit.unit_test_utils as unit_test_utils
import assemblyai as aai
from assemblyai.api import ENDPOINT_TRANSCRIPT
from tests.unit import factories

aai.settings.api_key = "test"
Expand All @@ -17,65 +13,14 @@ class AutoChaptersResponseFactory(factories.TranscriptCompletedResponseFactory):
chapters = factory.List([factory.SubFactory(factories.ChapterFactory)])


def __submit_mock_request(
httpx_mock: HTTPXMock,
mock_response: Dict[str, Any],
config: aai.TranscriptionConfig,
) -> Tuple[Dict[str, Any], aai.Transcript]:
"""
Helper function to abstract mock transcriber calls with given `TranscriptionConfig`,
and perform some common assertions.
"""

mock_transcript_id = mock_response.get("id", "mock_id")

# Mock initial submission response (transcript is processing)
mock_processing_response = factories.generate_dict_factory(
factories.TranscriptProcessingResponseFactory
)()

httpx_mock.add_response(
url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}",
status_code=httpx.codes.OK,
method="POST",
json={
**mock_processing_response,
"id": mock_transcript_id, # inject ID from main mock response
},
)

# Mock polling-for-completeness response, with completed transcript
httpx_mock.add_response(
url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}",
status_code=httpx.codes.OK,
method="GET",
json=mock_response,
)

# == Make API request via SDK ==
transcript = aai.Transcriber().transcribe(
data="https://example.org/audio.wav",
config=config,
)

# Check that submission and polling requests were made
assert len(httpx_mock.get_requests()) == 2

# Extract body of initial submission request
request = httpx_mock.get_requests()[0]
request_body = json.loads(request.content.decode())

return request_body, transcript


def test_auto_chapters_fails_without_punctuation(httpx_mock: HTTPXMock):
"""
Tests whether the SDK raises an error before making a request
if `auto_chapters` is enabled and `punctuation` is disabled
"""

with pytest.raises(ValueError) as error:
__submit_mock_request(
unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response={}, # response doesn't matter, since it shouldn't occur
config=aai.TranscriptionConfig(
Expand All @@ -98,7 +43,7 @@ def test_auto_chapters_disabled_by_default(httpx_mock: HTTPXMock):
Tests that excluding `auto_chapters` from the `TranscriptionConfig` will
result in the default behavior of it being excluded from the request body
"""
request_body, transcript = __submit_mock_request(
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response=factories.generate_dict_factory(
factories.TranscriptCompletedResponseFactory
Expand All @@ -116,14 +61,14 @@ def test_auto_chapters_enabled(httpx_mock: HTTPXMock):
response is properly parsed into a `Transcript` object
"""
mock_response = factories.generate_dict_factory(AutoChaptersResponseFactory)()
request_body, transcript = __submit_mock_request(
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
httpx_mock,
mock_response=mock_response,
config=aai.TranscriptionConfig(auto_chapters=True),
)

# Check that request body was properly defined
assert request_body.get("auto_chapters") == True
assert request_body.get("auto_chapters") is True

# Check that transcript was properly parsed from JSON response
assert transcript.error is None
Expand Down
Loading

0 comments on commit e4d3379

Please sign in to comment.