Skip to content

Commit

Permalink
feat: add optional metadata arg to get_utterances
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Jan 9, 2025
1 parent f8e26f1 commit c63c1ac
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 136 deletions.
28 changes: 19 additions & 9 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,44 @@ async def aadd(
**kwargs,
)

def get_utterances(self) -> List[Utterance]:
def get_utterances(self, include_metadata: bool = False) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
:param include_metadata: Whether to include function schemas and metadata in
the returned Utterance objects.
:type include_metadata: bool
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if self.index is None:
logger.warning("Index is None, could not retrieve utterances.")
return []
_, metadata = self._get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata)
if not include_metadata:
# we remove the metadata from the tuples (ie only keep 0, 1 items)
route_tuples = [x[:2] for x in route_tuples]
return [Utterance.from_tuple(x) for x in route_tuples]

async def aget_utterances(self) -> List[Utterance]:
async def aget_utterances(self, include_metadata: bool = False) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
:param include_metadata: Whether to include function schemas and metadata in
the returned Utterance objects.
:type include_metadata: bool
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if self.index is None:
logger.warning("Index is None, could not retrieve utterances.")
return []
_, metadata = await self._async_get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata)
if not include_metadata:
# we remove the metadata from the tuples (ie only keep 0, 1 items)
route_tuples = [x[:2] for x in route_tuples]
return [Utterance.from_tuple(x) for x in route_tuples]

def get_routes(self) -> List[Route]:
Expand All @@ -109,7 +119,7 @@ def get_routes(self) -> List[Route]:
:return: A list of Route objects.
:rtype: List[Route]
"""
utterances = self.get_utterances()
utterances = self.get_utterances(include_metadata=True)
routes_dict: Dict[str, Route] = {}
# first create a dictionary of route names to Route objects
for utt in utterances:
Expand Down
11 changes: 8 additions & 3 deletions semantic_router/index/hybrid_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,17 @@ def add(
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])

def get_utterances(self) -> List[Utterance]:
def get_utterances(self, include_metadata: bool = False) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
:param include_metadata: Whether to include function schemas and metadata in
the returned Utterance objects - HybridLocalIndex only supports False.
:type include_metadata: bool
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if include_metadata:
raise ValueError("include_metadata is not supported for HybridLocalIndex.")
if self.routes is None or self.utterances is None:
return []
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
Expand Down
16 changes: 10 additions & 6 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,17 @@ def _remove_and_sync(self, routes_to_delete: dict) -> np.ndarray:
# return what was removed
return route_utterances[~mask]

def get_utterances(self) -> List[Utterance]:
"""
Gets a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
def get_utterances(self, include_metadata: bool = False) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the index.
:param include_metadata: Whether to include function schemas and metadata in
the returned Utterance objects - HybridLocalIndex only supports False.
:type include_metadata: bool
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if include_metadata:
raise ValueError("include_metadata is not supported for HybridLocalIndex.")
if self.routes is None or self.utterances is None:
return []
return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)]
Expand Down
19 changes: 13 additions & 6 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,20 @@ def add(
batch_size=batch_size,
)

def get_utterances(self) -> List[Utterance]:
"""
Gets a list of route and utterance objects currently stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance, function_schema, metadata) objects.
def get_utterances(self, include_metadata: bool = False) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the index.
:param include_metadata: Whether to include function schemas and metadata in
the returned Utterance objects - QdrantIndex only supports False.
:type include_metadata: bool
:return: A list of Utterance objects.
:rtype: List[Utterance]
"""
if include_metadata:
raise NotImplementedError(
"include_metadata is not supported for QdrantIndex. If required please "
"reach out to maintainers on GitHub via an issue or PR."
)

# Check if collection exists first
if not self.client.collection_exists(self.index_name):
Expand Down
18 changes: 11 additions & 7 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_index(
:param encoder_name: The name of the encoder to use, defaults to None.
:type encoder_name: Optional[str], optional
"""
remote_routes = index.get_utterances()
remote_routes = index.get_utterances(include_metadata=True)
return cls.from_tuples(
route_tuples=[utt.to_tuple() for utt in remote_routes],
encoder_type=encoder_type,
Expand Down Expand Up @@ -380,7 +380,7 @@ def _init_index_state(self):
# run auto sync if active
if self.auto_sync:
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
remote_utterances = self.index.get_utterances(include_metadata=True)
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=remote_utterances,
Expand Down Expand Up @@ -576,7 +576,7 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
try:
# first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
remote_utterances = self.index.get_utterances(include_metadata=True)
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=remote_utterances,
Expand Down Expand Up @@ -632,7 +632,9 @@ async def async_sync(
try:
# first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = await self.index.aget_utterances()
remote_utterances = await self.index.aget_utterances(
include_metadata=True
)
diff = UtteranceDiff.from_utterances(
local_utterances=local_utterances,
remote_utterances=remote_utterances,
Expand Down Expand Up @@ -1016,7 +1018,7 @@ def get_utterance_diff(self, include_metadata: bool = False) -> List[str]:
"route2: utterance4", which do not exist locally.
"""
# first we get remote and local utterances
remote_utterances = self.index.get_utterances()
remote_utterances = self.index.get_utterances(include_metadata=include_metadata)
local_utterances = self.to_config().to_utterances()

diff_obj = UtteranceDiff.from_utterances(
Expand Down Expand Up @@ -1046,7 +1048,9 @@ async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]
"route2: utterance4", which do not exist locally.
"""
# first we get remote and local utterances
remote_utterances = await self.index.aget_utterances()
remote_utterances = await self.index.aget_utterances(
include_metadata=include_metadata
)
local_utterances = self.to_config().to_utterances()

diff_obj = UtteranceDiff.from_utterances(
Expand Down Expand Up @@ -1318,7 +1322,7 @@ def fit(
# Switch to a local index for fitting
from semantic_router.index.local import LocalIndex

remote_routes = self.index.get_utterances()
remote_routes = self.index.get_utterances(include_metadata=True)
# TODO Enhance by retrieving directly the vectors instead of embedding all utterances again
routes, utterances, function_schemas, metadata = map(
list, zip(*remote_routes)
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/routers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def fit(
# Switch to a local index for fitting
from semantic_router.index.hybrid_local import HybridLocalIndex

remote_routes = self.index.get_utterances()
remote_routes = self.index.get_utterances(include_metadata=True)
# TODO Enhance by retrieving directly the vectors instead of embedding all utterances again
routes, utterances, function_schemas, metadata = map(
list, zip(*remote_routes)
Expand Down
Loading

0 comments on commit c63c1ac

Please sign in to comment.