Skip to content

Commit

Permalink
Allow client to send steering requests to /chat/completions
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloiyu committed Jan 13, 2025
1 parent 15979ea commit 69125db
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions aleph_alpha_client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ChatRequest:
top_k: int = 0
top_p: float = 0.0
stream_options: Optional[StreamOptions] = None
steering_concepts: Optional[List[str]] = None

def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in asdict(self).items() if v is not None}
Expand Down
23 changes: 23 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,26 @@ async def test_stream_options(async_client: AsyncClient, chat_model_name: str):
# Then the last chunks has information about usage
assert all(isinstance(item, ChatStreamChunk) for item in stream_items[:-1])
assert isinstance(stream_items[-1], Usage)


def test_steering_chat(sync_client: Client, chat_model_name: str):
base_request = ChatRequest(
messages=[Message(role=Role.User, content="Hello, how are you?")],
model="llama-3.1-8b-instruct",
)

steered_request = ChatRequest(
messages=[Message(role=Role.User, content="Hello, how are you?")],
model="llama-3.1-8b-instruct",
steering_concepts=["slang"],
)

base_response = sync_client.chat(base_request, model=chat_model_name)
steered_response = sync_client.chat(steered_request, model=chat_model_name)

base_completion_result = base_response.message.content
steered_completion_result = steered_response.message.content

assert base_completion_result
assert steered_completion_result
assert base_completion_result != steered_completion_result

0 comments on commit 69125db

Please sign in to comment.