Skip to content

Commit

Permalink
[refactor] Remove "list of audio paths" input type for get_embedding …
Browse files Browse the repository at this point in the history
…function (#10)
  • Loading branch information
gudgud96 authored Apr 17, 2023
1 parent b000273 commit eba2d8e
Showing 1 changed file with 15 additions and 30 deletions.
45 changes: 15 additions & 30 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,39 +73,24 @@ def get_embeddings(self, x, sr=SAMPLE_RATE):
"""
Get embeddings using VGGish model.
Params:
-- x : Either
(i) a string which is the directory of a set of audio files, or
(ii) a list of np.ndarray audio samples
-- x : a list of np.ndarray audio samples
-- sr : Sampling rate, if x is a list of audio samples. Default value is 16000.
"""
embd_lst = []
if isinstance(x, list):
try:
for audio in tqdm(x, disable=(not self.verbose)):
if self.model_name == "vggish":
embd = self.model.forward(audio, sr)
elif self.model_name == "pann":
with torch.no_grad():
out = self.model(torch.tensor(audio).float().unsqueeze(0), None)
embd = out['embedding'].data[0]
if self.device == torch.device('cuda'):
embd = embd.cpu()
embd = embd.detach().numpy()
embd_lst.append(embd)
except Exception as e:
print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e)))
elif isinstance(x, str):
try:
for fname in tqdm(os.listdir(x), disable=(not self.verbose)):
embd = self.model.forward(os.path.join(x, fname))
if self.device == torch.device('cuda'):
embd = embd.cpu()
embd = embd.detach().numpy()
embd_lst.append(embd)
except Exception as e:
print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e)))
else:
raise AttributeError
try:
for audio in tqdm(x, disable=(not self.verbose)):
if self.model_name == "vggish":
embd = self.model.forward(audio, sr)
elif self.model_name == "pann":
with torch.no_grad():
out = self.model(torch.tensor(audio).float().unsqueeze(0), None)
embd = out['embedding'].data[0]
if self.device == torch.device('cuda'):
embd = embd.cpu()
embd = embd.detach().numpy()
embd_lst.append(embd)
except Exception as e:
print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e)))

return np.concatenate(embd_lst, axis=0)

Expand Down

0 comments on commit eba2d8e

Please sign in to comment.