Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make together embeddings.create() into OpenAI compatible format and allow providing a safety_model to Complete.create() #63

Merged
merged 12 commits into from
Dec 6, 2023
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,6 @@ print(output_text)
Space Robots are a great way to get your kids interested in science. After all, they are the future!
```

## Embeddings API

Embeddings are vector representations of sequences. You can use these vectors for measuring the overall similarity between texts. Embeddings are useful for tasks such as search and retrieval.

```python
resp = together.Embeddings.create("embed this sentence into a single vector", model="togethercomputer/bert-base-uncased")

print(resp['data'][0]['embedding']) # [0.06659205, 0.07896972, 0.007910785 ........]
```

## Colab Tutorial

Follow along in our Colab (Google Colaboratory) Notebook Tutorial [Example Finetuning Project](https://colab.research.google.com/drive/11DwtftycpDSgp3Z1vnV-Cy68zvkGZL4K?usp=sharing).
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"

[tool.poetry]
name = "together"
version = "0.2.8"
version = "0.2.9"
authors = [
"Together AI <[email protected]>"
]
Expand Down
23 changes: 23 additions & 0 deletions src/together/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import urllib.parse
from typing import Type

from .version import VERSION

Expand Down Expand Up @@ -41,6 +42,27 @@
from .models import Models


class Together:
complete: Type[Complete]
completion: Type[Completion]
embeddings: Type[Embeddings]
files: Type[Files]
finetune: Type[Finetune]
image: Type[Image]
models: Type[Models]

def __init__(
self,
) -> None:
self.complete = Complete
self.completion = Completion
self.embeddings = Embeddings
self.files = Files
self.finetune = Finetune
self.image = Image
self.models = Models


__all__ = [
"api_key",
"api_base",
Expand All @@ -63,4 +85,5 @@
"MISSING_API_KEY_MESSAGE",
"BACKOFF_FACTOR",
"min_samples",
"Together",
]
9 changes: 9 additions & 0 deletions src/together/commands/complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser])
action="store_true",
help="temperature for the LM",
)
subparser.add_argument(
"--safety-model",
"-sm",
default=None,
type=str,
help="The name of the safety model to use for moderation.",
)
subparser.set_defaults(func=_run_complete)


Expand Down Expand Up @@ -142,6 +149,7 @@ def _run_complete(args: argparse.Namespace) -> None:
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
logprobs=args.logprobs,
safety_model=args.safety_model,
)
except together.AuthenticationError:
logger.critical(together.MISSING_API_KEY_MESSAGE)
Expand All @@ -159,6 +167,7 @@ def _run_complete(args: argparse.Namespace) -> None:
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
safety_model=args.safety_model,
raw=args.raw,
):
if not args.raw:
Expand Down
3 changes: 1 addition & 2 deletions src/together/commands/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import argparse
import json

import together
from together import Embeddings
Expand Down Expand Up @@ -42,7 +41,7 @@ def _run_complete(args: argparse.Namespace) -> None:
model=args.model,
)

print(json.dumps(response, indent=4))
print([e.embedding for e in response.data])
except together.AuthenticationError:
logger.critical(together.MISSING_API_KEY_MESSAGE)
exit(0)
4 changes: 4 additions & 0 deletions src/together/complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def create(
logprobs: Optional[int] = None,
api_key: Optional[str] = None,
cast: bool = False,
safety_model: Optional[str] = None,
) -> Union[Dict[str, Any], TogetherResponse]:
if model == "":
model = together.default_text_model
Expand All @@ -38,6 +39,7 @@ def create(
"stop": stop,
"repetition_penalty": repetition_penalty,
"logprobs": logprobs,
"safety_model": safety_model,
}

# send request
Expand Down Expand Up @@ -70,6 +72,7 @@ def create_streaming(
raw: Optional[bool] = False,
api_key: Optional[str] = None,
cast: Optional[bool] = False,
safety_model: Optional[str] = None,
) -> Union[Iterator[str], Iterator[TogetherResponse]]:
"""
Prints streaming responses and returns the completed text.
Expand All @@ -88,6 +91,7 @@ def create_streaming(
"stop": stop,
"repetition_penalty": repetition_penalty,
"stream_tokens": True,
"safety_model": safety_model,
}

# send request
Expand Down
47 changes: 38 additions & 9 deletions src/together/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional
import concurrent.futures
from typing import Any, Dict, List, Optional, Union

import together
from together.utils import create_post_request, get_logger
Expand All @@ -7,29 +8,57 @@
logger = get_logger(str(__name__))


class DataItem:
def __init__(self, embedding: List[float]):
self.embedding = embedding


class EmbeddingsOutput:
def __init__(self, data: List[DataItem]):
self.data = data


class Embeddings:
@classmethod
def create(
self,
input: str,
cls,
input: Union[str, List[str]],
model: Optional[str] = "",
) -> Dict[str, Any]:
) -> EmbeddingsOutput:
if model == "":
model = together.default_embedding_model

parameter_payload = {
"input": input,
"model": model,
}
if isinstance(input, str):
parameter_payload = {
"input": input,
"model": model,
}

response = cls._process_input(parameter_payload)

return EmbeddingsOutput([DataItem(response["data"][0]["embedding"])])

elif isinstance(input, list):
# If input is a list, process each string concurrently
with concurrent.futures.ThreadPoolExecutor() as executor:
parameter_payloads = [{"input": item, "model": model} for item in input]
results = list(executor.map(cls._process_input, parameter_payloads))

return EmbeddingsOutput(
[DataItem(item["data"][0]["embedding"]) for item in results]
)

@classmethod
def _process_input(cls, parameter_payload: Dict[str, Any]) -> Dict[str, Any]:
# send request
response = create_post_request(
url=together.api_base_embeddings, json=parameter_payload
)

# return the json as a DotDict
try:
response_json = dict(response.json())

except Exception as e:
raise together.JSONError(e, http_status=response.status_code)

return response_json