Skip to content

Commit

Permalink
fixed bug in train.py when using longer training examples than normal…
Browse files Browse the repository at this point in the history
… [skip ci]
  • Loading branch information
dscripka committed Feb 23, 2024
1 parent fe57deb commit c40fe92
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions openwakeword/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,13 +813,13 @@ def convert_onnx_to_tflite(onnx_model_path, output_path):
# Create openwakeword model
if args.train_model is True:
F = openwakeword.utils.AudioFeatures(device='cpu')
input_shape = F.get_embedding_shape(config["total_length"]//16000) # training data is always 16 khz
input_shape = np.load(os.path.join(feature_save_dir, "positive_features_test.npy")).shape[1:]

oww = Model(n_classes=1, input_shape=input_shape, model_type=config["model_type"],
layer_dim=config["layer_size"], seconds_per_example=1280*input_shape[0]/16000)

# Create data transform function for batch generation to handle differ clip lengths (todo: write tests for this)
def f(x, n=16):
def f(x, n=input_shape[0]):
"""Simple transformation function to ensure negative data is the appropriate shape for the model size"""
if n > x.shape[1] or n < x.shape[1]:
x = np.vstack(x)
Expand Down

0 comments on commit c40fe92

Please sign in to comment.