From 66e9f12c78eb98ca02e0ecdcc5ca003bf5af8f58 Mon Sep 17 00:00:00 2001 From: miro Date: Thu, 25 Jul 2024 00:55:39 +0100 Subject: [PATCH] docstrs --- ovos_plugin_manager/templates/embeddings.py | 98 +++++++++++---------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/ovos_plugin_manager/templates/embeddings.py b/ovos_plugin_manager/templates/embeddings.py index 407ea6cb..4014f32b 100644 --- a/ovos_plugin_manager/templates/embeddings.py +++ b/ovos_plugin_manager/templates/embeddings.py @@ -11,21 +11,22 @@ class EmbeddingsDB: - """Base plugin for embeddings database""" + """Base class for an embeddings database that supports storage, retrieval, and querying of embeddings.""" @abc.abstractmethod def add_embeddings(self, key: str, embedding: EmbeddingsArray, - metadata: Optional[Dict] = None) -> EmbeddingsArray: + metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray: """Store 'embedding' under 'key' with associated metadata. Args: key (str): The unique key for the embedding. embedding (np.ndarray): The embedding vector to store. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the embedding. Returns: np.ndarray: The stored embedding. """ - return NotImplemented + raise NotImplementedError @abc.abstractmethod def get_embeddings(self, key: str) -> EmbeddingsArray: @@ -37,7 +38,7 @@ def get_embeddings(self, key: str) -> EmbeddingsArray: Returns: np.ndarray: The retrieved embedding. """ - return NotImplemented + raise NotImplementedError @abc.abstractmethod def delete_embeddings(self, key: str) -> EmbeddingsArray: @@ -49,23 +50,22 @@ def delete_embeddings(self, key: str) -> EmbeddingsArray: Returns: np.ndarray: The deleted embedding. """ - return NotImplemented + raise NotImplementedError @abc.abstractmethod def query(self, embeddings: EmbeddingsArray, top_k: int = 5, return_metadata: bool = False) -> List[EmbeddingsTuple]: - """Return top_k embeddings closest to the given 'embeddings'. + """Return the top_k embeddings closest to the given 'embeddings'. Args: embeddings (np.ndarray): The embedding vector to query. top_k (int, optional): The number of top results to return. Defaults to 5. + return_metadata (bool, optional): Whether to include metadata in the results. Defaults to False. Returns: - List[Tuple[str, float]]: List of tuples containing the key and distance. - if return_metadata is True - List[Tuple[str, float, dict]]: List of tuples containing the key, distance and metadata. + List[EmbeddingsTuple]: List of tuples containing the key and distance, and optionally metadata. """ - return NotImplemented + raise NotImplementedError def distance(self, embeddings_a: EmbeddingsArray, embeddings_b: EmbeddingsArray, metric: str = "cosine", alpha: float = 0.5, # for alpha_divergence and tversky metrics @@ -316,7 +316,7 @@ def distance(self, embeddings_a: EmbeddingsArray, embeddings_b: EmbeddingsArray, class TextEmbeddingsStore: - """A store for text embeddings interfacing with the embeddings database""" + """A store for text embeddings interfacing with the embeddings database.""" def __init__(self, db: EmbeddingsDB): """Initialize the text embeddings store. @@ -336,16 +336,17 @@ def get_text_embeddings(self, text: str) -> EmbeddingsArray: Returns: np.ndarray: The resulting embeddings. """ - return NotImplemented + raise NotImplementedError - def add_document(self, document: str) -> None: + def add_document(self, document: str, metadata: Optional[Dict[str, any]] = None) -> None: """Add a document and its embeddings to the database. Args: document (str): The document to add. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the document. """ embeddings = self.get_text_embeddings(document) - self.db.add_embeddings(document, embeddings) + self.db.add_embeddings(document, embeddings, metadata) def delete_document(self, document: str) -> None: """Delete a document and its embeddings from the database. @@ -379,13 +380,13 @@ def distance(self, text_a: str, text_b: str, metric: str = "cosine") -> float: Returns: float: The calculated distance. """ - emb: EmbeddingsArray = self.get_text_embeddings(text_a) - emb2: EmbeddingsArray = self.get_text_embeddings(text_b) - return self.db.distance(emb, emb2, metric) + emb_a = self.get_text_embeddings(text_a) + emb_b = self.get_text_embeddings(text_b) + return self.db.distance(emb_a, emb_b, metric) class FaceEmbeddingsStore: - """A store for face embeddings interfacing with the embeddings database""" + """A store for face embeddings interfacing with the embeddings database.""" def __init__(self, db: EmbeddingsDB): """Initialize the face embeddings store. @@ -405,22 +406,23 @@ def get_face_embeddings(self, frame: EmbeddingsArray) -> EmbeddingsArray: Returns: np.ndarray: The resulting face embeddings. """ - return NotImplemented + raise NotImplementedError - def add_face(self, user_id: str, frame: EmbeddingsArray): + def add_face(self, user_id: str, frame: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray: """Add a face and its embeddings to the database. Args: user_id (str): The unique user ID. frame (np.ndarray): The image frame containing the face. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the face. Returns: np.ndarray: The stored face embeddings. """ - emb: EmbeddingsArray = self.get_face_embeddings(frame) - return self.db.add_embeddings(user_id, emb) + embeddings = self.get_face_embeddings(frame) + return self.db.add_embeddings(user_id, embeddings, metadata) - def delete_face(self, user_id: str): + def delete_face(self, user_id: str) -> EmbeddingsArray: """Delete a face and its embeddings from the database. Args: @@ -445,10 +447,10 @@ def predict(self, frame: EmbeddingsArray, top_k: int = 3, thresh: float = 0.15) matches = self.query(frame, top_k) if not matches: return None - best = min(matches, key=lambda k: k[1]) - if best[1] > thresh: + best_match = min(matches, key=lambda k: k[1]) + if best_match[1] > thresh: return None - return best[0] + return best_match[0] def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]: """Query the database for the top_k closest face embeddings to the frame. @@ -460,8 +462,8 @@ def query(self, frame: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float Returns: List[Tuple[str, float]]: List of tuples containing the user ID and distance. """ - emb = self.get_face_embeddings(frame) - return self.db.query(emb, top_k) + embeddings = self.get_face_embeddings(frame) + return self.db.query(embeddings, top_k) def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str = "cosine") -> float: """Calculate the distance between embeddings of two faces. @@ -474,13 +476,13 @@ def distance(self, face_a: EmbeddingsArray, face_b: EmbeddingsArray, metric: str Returns: float: The calculated distance. """ - emb: EmbeddingsArray = self.get_face_embeddings(face_a) - emb2: EmbeddingsArray = self.get_face_embeddings(face_b) - return self.db.distance(emb, emb2, metric) + emb_a = self.get_face_embeddings(face_a) + emb_b = self.get_face_embeddings(face_b) + return self.db.distance(emb_a, emb_b, metric) class VoiceEmbeddingsStore: - """A store for voice embeddings interfacing with the embeddings database""" + """A store for voice embeddings interfacing with the embeddings database.""" def __init__(self, db: EmbeddingsDB): """Initialize the voice embeddings store. @@ -504,8 +506,7 @@ def audiochunk2array(audio_bytes: bytes) -> EmbeddingsArray: audio_as_np_float32 = audio_as_np_int16.astype(np.float32) # Normalise float32 array so that values are between -1.0 and +1.0 max_int16 = 2 ** 15 - data = audio_as_np_float32 / max_int16 - return data + return audio_as_np_float32 / max_int16 @abc.abstractmethod def get_voice_embeddings(self, audio_data: EmbeddingsArray) -> EmbeddingsArray: @@ -517,22 +518,23 @@ def get_voice_embeddings(self, audio_data: EmbeddingsArray) -> EmbeddingsArray: Returns: np.ndarray: The resulting voice embeddings. """ - return NotImplemented + raise NotImplementedError - def add_voice(self, user_id: str, audio_data: EmbeddingsArray): + def add_voice(self, user_id: str, audio_data: EmbeddingsArray, metadata: Optional[Dict[str, any]] = None) -> EmbeddingsArray: """Add a voice and its embeddings to the database. Args: user_id (str): The unique user ID. audio_data (np.ndarray): The input audio data. + metadata (Optional[Dict[str, any]]): Optional metadata associated with the voice. Returns: np.ndarray: The stored voice embeddings. """ - emb: EmbeddingsArray = self.get_voice_embeddings(audio_data) - return self.db.add_embeddings(user_id, emb) + embeddings = self.get_voice_embeddings(audio_data) + return self.db.add_embeddings(user_id, embeddings, metadata) - def delete_voice(self, user_id: str): + def delete_voice(self, user_id: str) -> EmbeddingsArray: """Delete a voice and its embeddings from the database. Args: @@ -555,10 +557,12 @@ def predict(self, audio_data: EmbeddingsArray, top_k: int = 3, thresh: float = 0 Optional[str]: The predicted user ID or None if the best match exceeds the threshold. """ matches = self.query(audio_data, top_k) - best = min(matches, key=lambda k: k[1]) - if best[1] > thresh: + if not matches: return None - return best[0] + best_match = min(matches, key=lambda k: k[1]) + if best_match[1] > thresh: + return None + return best_match[0] def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, float]]: """Query the database for the top_k closest voice embeddings to the audio_data. @@ -570,8 +574,8 @@ def query(self, audio_data: EmbeddingsArray, top_k: int = 5) -> List[Tuple[str, Returns: List[Tuple[str, float]]: List of tuples containing the user ID and distance. """ - emb = self.get_voice_embeddings(audio_data) - return self.db.query(emb, top_k) + embeddings = self.get_voice_embeddings(audio_data) + return self.db.query(embeddings, top_k) def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: str = "cosine") -> float: """Calculate the distance between embeddings of two voices. @@ -584,6 +588,6 @@ def distance(self, voice_a: EmbeddingsArray, voice_b: EmbeddingsArray, metric: s Returns: float: The calculated distance. """ - emb = self.get_voice_embeddings(voice_a) - emb2 = self.get_voice_embeddings(voice_b) - return self.db.distance(emb, emb2, metric) + emb_a = self.get_voice_embeddings(voice_a) + emb_b = self.get_voice_embeddings(voice_b) + return self.db.distance(emb_a, emb_b, metric)