Skip to content

Commit

Permalink
rename chat to ask (#78)
Browse files Browse the repository at this point in the history
* rename chat to ask

* fix changelog
  • Loading branch information
ebrehault authored May 30, 2024
1 parent 35650cc commit b236f33
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 35 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Changelog

## 2.1.1 (unreleased)
## 3.0.0 (unreleased)

### Breaking change

- Nothing changed yet.

- Rename `chat()` to `ask()`

## 2.1.0 (2024-05-17)

Expand Down
2 changes: 1 addition & 1 deletion docs/03-kb.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ You can get the logs of a Knowledge Box. There are different types of logs:
- `NEW`: resource creation events,
- `PROCESSED`: processing events,
- `MODIFIED`: resource modification events,
- `CHAT`: asked questions and returned answers on the `/chat` endpoint,
- `CHAT`: asked questions and returned answers on the `/ask` endpoint,
- `SEARCH`: queries sent to `/search` or `/find`,
- `FEEDBACK`: user feedbacks on answers

Expand Down
8 changes: 4 additions & 4 deletions docs/05-search.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ Based on a `find` request, Nuclia uses a generative AI to answer the question ba
- CLI:

```bash
nuclia kb search chat --query="My question"
nuclia kb search ask --query="My question"
```

- SDK:

```python
from nuclia import sdk
search = sdk.NucliaSearch()
search.chat(query="My question")
search.ask(query="My question")
```

## Filtering

Any endpoint that involves search (`search`, `find` and `chat`) also support more advanced filtering expressions. Expressions can have one of the following operators:
Any endpoint that involves search (`search`, `find` and `ask`) also support more advanced filtering expressions. Expressions can have one of the following operators:

- `all`: this is the default. Will make search return results containing all specified filter labels.
- `any`: returns results containing at least one of the labels.
Expand All @@ -81,7 +81,7 @@ Here are some examples:
from nucliadb_models.search import Filter

search = sdk.NucliaSearch()
search.chat(
search.ask(
query="My question",
filters=[Filter(any=['/classification.labels/region/Europe','/classification.labels/region/Asia'])],
)
Expand Down
6 changes: 3 additions & 3 deletions nuclia/sdk/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from nucliadb_models.search import ChatRequest
from nucliadb_models.search import AskRequest


class Agent:
Expand All @@ -15,8 +15,8 @@ def __init__(self, prompt: str, filters: List[str]):
self.search = NucliaSearch()

def ask(self, text: str) -> str:
chat_req = ChatRequest(query=text, prompt=self.prompt, filters=self.filters)
answer = self.search.chat(query=chat_req)
ask_req = AskRequest(query=text, prompt=self.prompt, filters=self.filters)
answer = self.search.ask(query=ask_req)
return answer.answer.decode()


Expand Down
38 changes: 17 additions & 21 deletions nuclia/sdk/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from nucliadb_models.search import (
AskRequest,
AskResponseItem,
ChatRequest,
Filter,
FindRequest,
KnowledgeboxFindResults,
Expand All @@ -25,7 +24,7 @@


@dataclass
class ChatAnswer:
class AskAnswer:
answer: bytes
learning_id: str
relations_result: Optional[Relations]
Expand Down Expand Up @@ -122,17 +121,17 @@ def find(
return ndb.ndb.find(req, kbid=ndb.kbid)

@kb
def chat(
def ask(
self,
*,
query: Union[str, dict, ChatRequest],
query: Union[str, dict, AskRequest],
filters: Optional[Union[List[str], List[Filter]]] = None,
**kwargs,
):
"""
Answer a question.
See https://docs.nuclia.dev/docs/api#tag/Search/operation/Chat_Knowledge_Box_kb__kbid__chat_post
See https://docs.nuclia.dev/docs/api#tag/Search/operation/Ask_Knowledge_Box_kb__kbid__ask_post
"""
ndb: NucliaDBClient = kwargs["ndb"]
if isinstance(query, str):
Expand All @@ -146,15 +145,14 @@ def chat(
except ValidationError as exc:
print(exc)
sys.exit(1)
elif isinstance(query, ChatRequest):
# Convert ChatRequest to AskRequest
req = AskRequest.parse_obj(query.dict())
elif isinstance(query, AskRequest):
req = query
else:
raise ValueError("Invalid query type. Must be str, dict or ChatRequest.")
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")

ask_response: SyncAskResponse = ndb.ndb.ask(kbid=ndb.kbid, content=req)
# Convert to ChatAnswer
result = ChatAnswer(

result = AskAnswer(
answer=ask_response.answer.encode(),
learning_id=ask_response.learning_id,
relations_result=ask_response.relations,
Expand Down Expand Up @@ -252,18 +250,18 @@ async def find(
return await ndb.ndb.find(req, kbid=ndb.kbid)

@kb
async def chat(
async def ask(
self,
*,
query: Union[str, dict, ChatRequest],
query: Union[str, dict, AskRequest],
filters: Optional[List[str]] = None,
timeout: int = 100,
**kwargs,
):
"""
Answer a question.
See https://docs.nuclia.dev/docs/api#tag/Search/operation/Chat_Knowledge_Box_kb__kbid__chat_post
See https://docs.nuclia.dev/docs/api#tag/Search/operation/Ask_Knowledge_Box_kb__kbid__ask_post
"""
ndb: NucliaDBClient = kwargs["ndb"]
if isinstance(query, str):
Expand All @@ -277,14 +275,12 @@ async def chat(
except ValidationError as exc:
print(exc)
sys.exit(1)
elif isinstance(query, ChatRequest):
# Convert ChatRequest to AskRequest
req = AskRequest.parse_obj(query.dict())
elif isinstance(query, AskRequest):
req = query
else:
raise ValueError("Invalid query type. Must be str, dict or ChatRequest.")
raise ValueError("Invalid query type. Must be str, dict or AskRequest.")
ask_stream_response = await ndb.ask(req, timeout=timeout)
# Parse the stream response and convert to ChatAnswer
result = ChatAnswer(
result = AskAnswer(
answer=b"",
learning_id=ask_stream_response.headers.get("NUCLIA-LEARNING-ID", ""),
relations_result=None,
Expand Down Expand Up @@ -317,6 +313,6 @@ async def chat(
pass
else: # pragma: no cover
warnings.warn(
f"Unknown chat stream item type: {ask_response_item.type}"
f"Unknown ask stream item type: {ask_response_item.type}"
)
return result
6 changes: 3 additions & 3 deletions nuclia/tests/test_kb/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def test_find_object(testing_config):
assert "Lamarr Lesson plan.pdf" in titles


def test_chat(testing_config):
def test_ask(testing_config):
if IS_PROD:
assert True
return
search = NucliaSearch()
results = search.chat(query="Who is hedy Lamarr?")
results = search.ask(query="Who is hedy Lamarr?")
answer = results.answer.decode()
print("Chat answer: ", answer)
print("Answer: ", answer)
assert "Lamarr" in answer


Expand Down

0 comments on commit b236f33

Please sign in to comment.