Skip to content

Commit

Permalink
Update output format of MosaicML endpoint to be more flexible (langch…
Browse files Browse the repository at this point in the history
…ain-ai#6060)

There will likely be another change or two coming over the next couple
weeks as we stabilize the API, but putting this one in now which just
makes the integration a bit more flexible with the response output
format.

```
(langchain) danielking@MML-1B940F4333E2 langchain % pytest tests/integration_tests/llms/test_mosaicml.py tests/integration_tests/embeddings/test_mosaicml.py 
=================================================================================== test session starts ===================================================================================
platform darwin -- Python 3.10.11, pytest-7.3.1, pluggy-1.0.0
rootdir: /Users/danielking/github/langchain
configfile: pyproject.toml
plugins: asyncio-0.20.3, mock-3.10.0, dotenv-0.5.2, cov-4.0.0, anyio-3.6.2
asyncio: mode=strict
collected 12 items                                                                                                                                                                        

tests/integration_tests/llms/test_mosaicml.py ......                                                                                                                                [ 50%]
tests/integration_tests/embeddings/test_mosaicml.py ......                                                                                                                          [100%]

=================================================================================== slowest 5 durations ===================================================================================
4.76s call     tests/integration_tests/llms/test_mosaicml.py::test_retry_logic
4.74s call     tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_llm_call
4.13s call     tests/integration_tests/llms/test_mosaicml.py::test_instruct_prompt
0.91s call     tests/integration_tests/llms/test_mosaicml.py::test_short_retry_does_not_loop
0.66s call     tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_extra_kwargs
=================================================================================== 12 passed in 19.70s ===================================================================================
```

#### Who can review?

  @hwchase17
  @dev2049
  • Loading branch information
dakinggg authored Jun 16, 2023
1 parent 50d9c7d commit a9b97aa
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 13 deletions.
39 changes: 33 additions & 6 deletions langchain/embeddings/mosaicml.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
"""

endpoint_url: str = (
"https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict"
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
)
"""Endpoint URL to use."""
embed_instruction: str = "Represent the document for retrieval: "
Expand Down Expand Up @@ -98,11 +98,38 @@ def _embed(
f"Error raised by inference API: {parsed_response['error']}"
)

if "data" not in parsed_response:
raise ValueError(
f"Error raised by inference API, no key data: {parsed_response}"
)
embeddings = parsed_response["data"]
# The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats.
if isinstance(parsed_response, dict):
if "data" in parsed_response:
output_item = parsed_response["data"]
elif "output" in parsed_response:
output_item = parsed_response["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)

if isinstance(output_item, list) and isinstance(output_item[0], list):
embeddings = output_item
else:
embeddings = [output_item]
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, list):
embeddings = parsed_response
elif isinstance(first_item, dict):
if "output" in first_item:
embeddings = [item["output"] for item in parsed_response]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")

except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {response.text}"
Expand Down
41 changes: 34 additions & 7 deletions langchain/llms/mosaicml.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,45 @@ def _call(
f"Error raised by inference API: {parsed_response['error']}"
)

if "data" not in parsed_response:
raise ValueError(
f"Error raised by inference API, no key data: {parsed_response}"
)
generated_text = parsed_response["data"]
# The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats.
if isinstance(parsed_response, dict):
if "data" in parsed_response:
output_item = parsed_response["data"]
elif "output" in parsed_response:
output_item = parsed_response["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)

if isinstance(output_item, list):
text = output_item[0]
else:
text = output_item
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, str):
text = first_item
elif isinstance(first_item, dict):
if "output" in parsed_response:
text = first_item["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")

text = text[len(prompt) :]

except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {response.text}"
)

text = generated_text[0][len(prompt) :]

# TODO: replace when MosaicML supports custom stop tokens natively
if stop is not None:
text = enforce_stop_tokens(text, stop)
Expand Down

0 comments on commit a9b97aa

Please sign in to comment.