Skip to content

Commit

Permalink
docstrs
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Jul 24, 2024
1 parent 599b36b commit 66e9f12
Showing 1 changed file with 51 additions and 47 deletions.
98 changes: 51 additions & 47 deletions ovos_plugin_manager/templates/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@


class EmbeddingsDB:
"""Base plugin for embeddings database"""
"""Base class for an embeddings database that supports storage, retrieval, and querying of embeddings."""

@abc.abstractmethod
def add_embeddings(self, key: str, embedding: EmbeddingsArray,
metadata: Optional[Dict] = None) -> EmbeddingsArray:
metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray:
"""Store 'embedding' under 'key' with associated metadata.
Args:
key (str): The unique key for the embedding.
embedding (np.ndarray): The embedding vector to store.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the embedding.
Returns:
np.ndarray: The stored embedding.
"""
return NotImplemented
raise NotImplementedError

@abc.abstractmethod
def get_embeddings(self, key: str) -> EmbeddingsArray:
Expand All @@ -37,7 +38,7 @@ def get_embeddings(self, key: str) -> EmbeddingsArray:
Returns:
np.ndarray: The retrieved embedding.
"""
return NotImplemented
raise NotImplementedError

@abc.abstractmethod
def delete_embeddings(self, key: str) -> EmbeddingsArray:
Expand All @@ -49,23 +50,22 @@ def delete_embeddings(self, key: str) -> EmbeddingsArray:
Returns:
np.ndarray: The deleted embedding.
"""
return NotImplemented
raise NotImplementedError

@abc.abstractmethod
def query(self, embeddings: EmbeddingsArray, top_k: int = 5,
return_metadata: bool = False) -> List[EmbeddingsTuple]:
"""Return top_k embeddings closest to the given 'embeddings'.
"""Return the top_k embeddings closest to the given 'embeddings'.
Args:
embeddings (np.ndarray): The embedding vector to query.
top_k (int, optional): The number of top results to return. Defaults to 5.
return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False.
Returns:
List[Tuple[str, float]]: List of tuples containing the key and distance.
if return_metadata is True
List[Tuple[str, float, dict]]: List of tuples containing the key, distance and metadata.
List[EmbeddingsTuple]: List of tuples containing the key and distance, and optionally metadata.
"""
return NotImplemented
raise NotImplementedError

def distance(self, embeddings_a: EmbeddingsArray, embeddings_b: EmbeddingsArray, metric: str = "cosine",
alpha: float = 0.5, # for alpha_divergence and tversky metrics
Expand Down Expand Up @@ -316,7 +316,7 @@ def distance(self, embeddings_a: EmbeddingsArray, embeddings_b: EmbeddingsArray,


class TextEmbeddingsStore:
"""A store for text embeddings interfacing with the embeddings database"""
"""A store for text embeddings interfacing with the embeddings database."""

def __init__(self, db: EmbeddingsDB):
"""Initialize the text embeddings store.
Expand All @@ -336,16 +336,17 @@ def get_text_embeddings(self, text: str) -> EmbeddingsArray:
Returns:
np.ndarray: The resulting embeddings.
"""
return NotImplemented
raise NotImplementedError

def add_document(self, document: str) -> None:
def add_document(self, document: str, metadata: Optional[Dict[str, any]] = None) -> None:
"""Add a document and its embeddings to the database.
Args:
document (str): The document to add.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the document.
"""
embeddings = self.get_text_embeddings(document)
self.db.add_embeddings(document, embeddings)
self.db.add_embeddings(document, embeddings, metadata)

def delete_document(self, document: str) -> None:
"""Delete a document and its embeddings from the database.
Expand Down Expand Up @@ -379,13 +380,13 @@ def distance(self, text_a: str, text_b: str, metric: str = "cosine") -> float:
Returns:
float: The calculated distance.
"""
emb: EmbeddingsArray = self.get_text_embeddings(text_a)
emb2: EmbeddingsArray = self.get_text_embeddings(text_b)
return self.db.distance(emb, emb2, metric)
emb_a = self.get_text_embeddings(text_a)
emb_b = self.get_text_embeddings(text_b)
return self.db.distance(emb_a, emb_b, metric)


class FaceEmbeddingsStore:
"""A store for face embeddings interfacing with the embeddings database"""
"""A store for face embeddings interfacing with the embeddings database."""

def __init__(self, db: EmbeddingsDB):
"""Initialize the face embeddings store.
Expand All @@ -405,22 +406,23 @@ def get_face_embeddings(self, frame: EmbeddingsArray) -> EmbeddingsArray:
Returns:
np.ndarray: The resulting face embeddings.
"""
return NotImplemented
raise NotImplementedError

def add_face(self, user_id: str, frame: EmbeddingsArray):
def add_face(self, user_id: str, frame: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray:
"""Add a face and its embeddings to the database.
Args:
user_id (str): The unique user ID.
frame (np.ndarray): The image frame containing the face.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the face.
Returns:
np.ndarray: The stored face embeddings.
"""
emb: EmbeddingsArray = self.get_face_embeddings(frame)
return self.db.add_embeddings(user_id, emb)
embeddings = self.get_face_embeddings(frame)
return self.db.add_embeddings(user_id, embeddings, metadata)

def delete_face(self, user_id: str):
def delete_face(self, user_id: str) -> EmbeddingsArray:
"""Delete a face and its embeddings from the database.
Args:
Expand All @@ -445,10 +447,10 @@ def predict(self, frame: EmbeddingsArray, top_k: int = 3, thresh: float = 0.15)
matches = self.query(frame, top_k)
if not matches:
return None
best = min(matches, key=lambda k: k[1])
if best[1] > thresh:
best_match = min(matches, key=lambda k: k[1])
if best_match[1] > thresh:
return None
return best[0]
return best_match[0]

def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]:
"""Query the database for the top_k closest face embeddings to the frame.
Expand All @@ -460,8 +462,8 @@ def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float
Returns:
List[Tuple[str, float]]: List of tuples containing the user ID and distance.
"""
emb = self.get_face_embeddings(frame)
return self.db.query(emb, top_k)
embeddings = self.get_face_embeddings(frame)
return self.db.query(embeddings, top_k)

def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str = "cosine") -> float:
"""Calculate the distance between embeddings of two faces.
Expand All @@ -474,13 +476,13 @@ def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str
Returns:
float: The calculated distance.
"""
emb: EmbeddingsArray = self.get_face_embeddings(face_a)
emb2: EmbeddingsArray = self.get_face_embeddings(face_b)
return self.db.distance(emb, emb2, metric)
emb_a = self.get_face_embeddings(face_a)
emb_b = self.get_face_embeddings(face_b)
return self.db.distance(emb_a, emb_b, metric)


class VoiceEmbeddingsStore:
"""A store for voice embeddings interfacing with the embeddings database"""
"""A store for voice embeddings interfacing with the embeddings database."""

def __init__(self, db: EmbeddingsDB):
"""Initialize the voice embeddings store.
Expand All @@ -504,8 +506,7 @@ def audiochunk2array(audio_bytes: bytes) -> EmbeddingsArray:
audio_as_np_float32 = audio_as_np_int16.astype(np.float32)
# Normalise float32 array so that values are between -1.0 and +1.0
max_int16 = 2 ** 15
data = audio_as_np_float32 / max_int16
return data
return audio_as_np_float32 / max_int16

@abc.abstractmethod
def get_voice_embeddings(self, audio_data: EmbeddingsArray) -> EmbeddingsArray:
Expand All @@ -517,22 +518,23 @@ def get_voice_embeddings(self, audio_data: EmbeddingsArray) -> EmbeddingsArray:
Returns:
np.ndarray: The resulting voice embeddings.
"""
return NotImplemented
raise NotImplementedError

def add_voice(self, user_id: str, audio_data: EmbeddingsArray):
def add_voice(self, user_id: str, audio_data: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray:
"""Add a voice and its embeddings to the database.
Args:
user_id (str): The unique user ID.
audio_data (np.ndarray): The input audio data.
metadata (Optional[Dict[str, any]]): Optional metadata associated with the voice.
Returns:
np.ndarray: The stored voice embeddings.
"""
emb: EmbeddingsArray = self.get_voice_embeddings(audio_data)
return self.db.add_embeddings(user_id, emb)
embeddings = self.get_voice_embeddings(audio_data)
return self.db.add_embeddings(user_id, embeddings, metadata)

def delete_voice(self, user_id: str):
def delete_voice(self, user_id: str) -> EmbeddingsArray:
"""Delete a voice and its embeddings from the database.
Args:
Expand All @@ -555,10 +557,12 @@ def predict(self, audio_data: EmbeddingsArray, top_k: int = 3, thresh: float = 0
Optional[str]: The predicted user ID or None if the best match exceeds the threshold.
"""
matches = self.query(audio_data, top_k)
best = min(matches, key=lambda k: k[1])
if best[1] > thresh:
if not matches:
return None
return best[0]
best_match = min(matches, key=lambda k: k[1])
if best_match[1] > thresh:
return None
return best_match[0]

def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]:
"""Query the database for the top_k closest voice embeddings to the audio_data.
Expand All @@ -570,8 +574,8 @@ def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str,
Returns:
List[Tuple[str, float]]: List of tuples containing the user ID and distance.
"""
emb = self.get_voice_embeddings(audio_data)
return self.db.query(emb, top_k)
embeddings = self.get_voice_embeddings(audio_data)
return self.db.query(embeddings, top_k)

def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: str = "cosine") -> float:
"""Calculate the distance between embeddings of two voices.
Expand All @@ -584,6 +588,6 @@ def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: s
Returns:
float: The calculated distance.
"""
emb = self.get_voice_embeddings(voice_a)
emb2 = self.get_voice_embeddings(voice_b)
return self.db.distance(emb, emb2, metric)
emb_a = self.get_voice_embeddings(voice_a)
emb_b = self.get_voice_embeddings(voice_b)
return self.db.distance(emb_a, emb_b, metric)

0 comments on commit 66e9f12

Please sign in to comment.