Skip to content

Commit

Permalink
Added/fixed reset methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dscripka committed Feb 11, 2024
1 parent c633844 commit 68e88c1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion openwakeword/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ def get_parent_model_from_label(self, label):
return parent_model

def reset(self):
"""Reset the prediction buffer"""
"""Reset the prediction and audio feature buffers. Useful for re-initializing the model, though may not be efficient
when called too frequently."""
self.prediction_buffer = defaultdict(partial(deque, maxlen=30))
self.preprocessor.reset()

def predict(self, x: np.ndarray, patience: dict = {},
threshold: dict = {}, debounce_time: float = 0.0, timing: bool = False):
Expand Down
10 changes: 9 additions & 1 deletion openwakeword/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def tflite_embedding_predict(x):

self.embedding_model_predict = tflite_embedding_predict

# Create databuffers
# Create databuffers with empty/random data
self.raw_data_buffer: Deque = deque(maxlen=sr*10)
self.melspectrogram_buffer = np.ones((76, 32)) # n_frames x num_features
self.melspectrogram_max_len = 10*97 # 97 is the number of frames in 1 second of 16hz audio
Expand All @@ -169,6 +169,14 @@ def tflite_embedding_predict(x):
self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16))
self.feature_buffer_max_len = 120 # ~10 seconds of feature buffer history

def reset(self):
"""Reset the internal buffers"""
self.raw_data_buffer.clear()
self.melspectrogram_buffer = np.ones((76, 32))
self.accumulated_samples = 0
self.raw_data_remainder = np.empty(0)
self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16))

def _get_melspectrogram(self, x: Union[np.ndarray, List], melspec_transform: Callable = lambda x: x/10 + 2):
"""
Function to compute the mel-spectrogram of the provided audio samples.
Expand Down

0 comments on commit 68e88c1

Please sign in to comment.