Skip to content

Commit

Permalink
Added updates to feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
digbose92 committed Mar 16, 2023
1 parent 4757220 commit 76430a9
Show file tree
Hide file tree
Showing 20 changed files with 772 additions and 9 deletions.
44 changes: 44 additions & 0 deletions configs/config_LSTM_social_message.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
data:
csv_file: '/data/digbose92/ads_complete_repo/ads_codes/SAIM-ADS/data/SAIM_ads_data_message_tone_train_test_val_clip_features.csv'

parameters:
batch_size: 16
train_shuffle: True
val_shuffle: False
epochs: 50
early_stop: 5
max_length: 333
fps: 4
base_fps: 24
num_workers: 4

device:
is_cuda: True

loss:
loss_option: 'bce_cross_entropy_loss'

optimizer:
choice: 'Adam'
lr: 1e-4
gamma: 0.5
step_size: 15
scheduler: 'step_lr'
mode: 'max'
decay: 0.001
patience: 5
factor: 0.5
verbose: True

model:
option: 'LSTM_multi_layer_social_message_model'
model_type: 'LSTM'
embedding_dim: 512
n_hidden: 128
n_layers: 2
n_classes: 2
batch_first: True

output:
model_dir: '/data/digbose92/ads_complete_repo/ads_codes/model_files/recent_models/model_dir'
log_dir: '/data/digbose92/ads_complete_repo/ads_codes/model_files/recent_models/log_dir'
Binary file added datasets/__pycache__/dataset.cpython-37.pyc
Binary file not shown.
64 changes: 64 additions & 0 deletions datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,70 @@ def __getitem__(self,idx):
return(clip_feature_array_padded,transition_label,feat_len)


class SAIM_social_message_clip_features_dataset(Dataset):
def __init__(self,csv_data,label_map,num_classes,max_length,fps,base_fps):

self.csv_data=csv_data
self.num_classes=num_classes
self.max_length=max_length
self.fps=fps
self.base_fps=base_fps
self.division_factor=self.base_fps//self.fps # 24/4=6
self.clip_feature_list=self.csv_data['clip_feature_path'].tolist()
self.label_map=label_map

def __len__(self):
return(len(self.clip_feature_list))

def subsample_feature(self,feature_array):

#feature_array - subsample the array by extracting the features at every frame at self.division_factor
feature_array_subsampled=feature_array[::self.division_factor]
#print(feature_array_subsampled.shape,feature_array.shape)
return(feature_array_subsampled)

def pad_data(self,feat_data):
padded=np.zeros((self.max_length,feat_data.shape[1]))
if(feat_data.shape[0]>self.max_length):
padded=feat_data[:self.max_length,:]
else:
padded[:feat_data.shape[0],:]=feat_data
return(padded)

def __getitem__(self,idx):

#get path of the feature fiile
clip_feature_file=self.clip_feature_list[idx]

#load the feature file
with open(clip_feature_file, 'rb') as f:
clip_features = pickle.load(f)

#get the features
clip_feature_array=clip_features['Features']

#subsample the features
clip_feature_array_subsampled=self.subsample_feature(clip_feature_array)

#return the length of the features
#print(clip_feature_array_subsampled.shape)
if(clip_feature_array_subsampled.shape[0]>=self.max_length):
feat_len=self.max_length
else:
feat_len=clip_feature_array_subsampled.shape[0]

#pad the features
clip_feature_array_padded=self.pad_data(clip_feature_array_subsampled)

#get the label
social_msg=self.label_map[self.csv_data['social_message'].iloc[idx]]

social_msg_label=np.zeros((self.num_classes))
social_msg_label[social_msg]=1


return(clip_feature_array_padded,social_msg_label,feat_len)

#test the dataset
# csv_file="/data/digbose92/ads_complete_repo/ads_codes/SAIM-ADS/data/SAIM_ads_data_message_tone_train_test_val_clip_features.csv"
# csv_data=pd.read_csv(csv_file)
Expand Down
119 changes: 119 additions & 0 deletions feature_extraction/extract_ast_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#use transformers library to extract features from ASTs
from transformers import AutoProcessor, ASTModel, AutoFeatureExtractor
import torch
from datasets import load_dataset
import torchaudio
import json
from tqdm import tqdm
import os
import pickle

def generate_file_list(json_data,folder):

wav_file_names=[folder+"/"+i.split("/")[-1] for i in json_data]
return wav_file_names

#load the model
model_option="MIT/ast-finetuned-audioset-10-10-0.4593"
wav_file_list="/data/digbose92/ads_complete_repo/ads_codes/SAIM-ADS/data/jwt_ads_of_world_wav_files.json"
folder="/data/digbose92/ads_complete_repo/ads_wav_files/cvpr_wav_files"
option="cvpr_ads"
save_folder="/data/digbose92/ads_complete_repo/ads_features/audio_embeddings/ast_embeddings/cvpr_ads"
#save_folder="/data/digbose92/ads_complete_repo/ads_features/audio_embeddings/ast_embeddings/jwt_ads_of_world"

if(option=="jwt_ads_of_world"):
with open(wav_file_list) as f:
wav_file_list = json.load(f)

wav_file_names_json_data=[wav_file_list["data"][i]["wav"] for i in range(len(wav_file_list['data']))]
wav_file_names=generate_file_list(wav_file_names_json_data,folder)

elif(option=="cvpr_ads"):

wav_file_names=os.listdir(folder)
wav_file_names=[os.path.join(folder,i) for i in wav_file_names]
#print(len(wav_file_names))
#print(wav_file_names)

# wav_file="/data/digbose92/ads_complete_repo/ads_wav_files/jwt_ads_of_world_wav_files/2k_sports_never_say_never_1.wav"

#define feature extractor and model
feature_extractor = AutoFeatureExtractor.from_pretrained(model_option)
#print(feature_extractor.max_length)
model = ASTModel.from_pretrained(model_option)
device=torch.device("cuda:0")
model.to(device)
sampling_rate=16000
file_list_failure=[]

for wav_file in tqdm(wav_file_names):
try:
waveform, sampling_rate = torchaudio.load(wav_file) #read the audio using torchaudio
#print(wav_file)
waveform=waveform[0].cpu().numpy()
inputs=feature_extractor(waveform, sampling_rate=sampling_rate, return_tensors="pt") #extract features using transformers
inputs['input_values']=inputs['input_values'].to(device)
#print(inputs.keys())
with torch.no_grad():
outputs=model(**inputs)

last_hidden_state=outputs.last_hidden_state
pooler_output=outputs.pooler_output

#create dictionary to save
save_dict={'last_hidden_state':last_hidden_state.cpu().numpy(),'pooler_output':pooler_output.cpu().numpy()}

file_name_id=os.path.splitext(wav_file.split("/")[-1])[0]+".pkl"
destination_filename=os.path.join(save_folder,file_name_id)

with open(destination_filename, 'wb') as f:
pickle.dump(save_dict, f)
#dict_filename=os.path.join(save_folder,wav_file.split("/")[-1]+".npy")
except:
file_list_failure.append(wav_file)
pass



#create the save file name

#save_file_name=wav_file.split("/")[-1].split(".")[0]+".npy"



#the sequence length is 1214
#because the spectrogram is 128*1024 which is broken down as follows: (128-16)//10+1=12 and (1024-16)//10+1=101 and 101*12=1212
#adding two more tokens will make it 1214 which is two CLS tokens
#print(last_hidden_state.shape)


# waveform, sampling_rate = torchaudio.load(wav_file)
# waveform=waveform[0].cpu().numpy()
# #model option and model
# feature_extractor = AutoFeatureExtractor.from_pretrained(model_option)
# model = ASTModel.from_pretrained(model_option)
# sampling_rate=16000





# # # #generate the datasets
# # dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# # dataset = dataset.sort("id")
# # print(type(dataset[0]["audio"]["array"]))
# # #read the audio file
# inputs = feature_extractor(waveform, sampling_rate=sampling_rate, return_tensors="pt")


# #inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

# #generate outputs
# with torch.no_grad():
# outputs=model(**inputs)

# print(outputs.keys())
# # #last hidden state
# last_hidden_state=outputs.last_hidden_state
# print(last_hidden_state.shape)

34 changes: 34 additions & 0 deletions feature_extraction/torchaudio_examples_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torchaudio.compliance.kaldi as ta_kaldi
import torchaudio
import torch

file="/data/digbose92/ads_complete_repo/ads_wav_files/jwt_ads_of_world_wav_files/459528.wav"
waveform,sampling_rate=torchaudio.load(file)
waveform=waveform[0].cpu().numpy()

waveform = torch.from_numpy(waveform).unsqueeze(0)
num_mel_bins=128
max_length=1024
fbank = ta_kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sampling_rate,
use_energy=False,
window_type="hanning",
num_mel_bins=num_mel_bins,
dither=0.0,
frame_shift=10
)

n_frames = fbank.shape[0]
difference = max_length - n_frames

# pad or truncate, depending on difference
if difference > 0:
pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
fbank = pad_module(fbank)
elif difference < 0:
fbank = fbank[0:max_length, :]
fbank = fbank.numpy()

print(fbank.shape)
Binary file added figures/LSTM_social_message_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/LSTM_tone_transition_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added losses/__pycache__/loss_functions.cpython-37.pyc
Binary file not shown.
Binary file added models/__pycache__/LSTM_models.cpython-37.pyc
Binary file not shown.
Binary file added optimizers/__pycache__/optimizer.cpython-37.pyc
Binary file not shown.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ Pillow
ipython
jupyter
bertopic
importlib-resources
importlib-resources
transformers
Binary file added scripts/__pycache__/evaluate_model.cpython-37.pyc
Binary file not shown.
Binary file modified scripts/__pycache__/evaluate_model.cpython-38.pyc
Binary file not shown.
66 changes: 62 additions & 4 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch
from statistics import mean
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
from scipy.stats.stats import pearsonr
import sys
import os
from collections import Counter

def sort_batch(X, y, lengths):
lengths, indx = lengths.sort(dim=0, descending=True)
Expand Down Expand Up @@ -46,12 +47,69 @@ def gen_validate_score_LSTM_tone_transition_model(model,loader,device,criterion)
pred_label_val=torch.cat(pred_labels).numpy()
#print(target_label_val.shape,pred_label_val.shape)
pred_labels_discrete=np.where(pred_label_val>=0.5,1,0)

pred_labels_array=np.argmax(pred_labels_discrete,axis=1)
target_labels_array=np.argmax(target_label_val,axis=1)

#print(len(target_labels),len(pred_labels_discrete))
val_acc=accuracy_score(target_label_val,pred_labels_discrete)
val_f1=f1_score(target_label_val,pred_labels_discrete,average='macro')
val_acc=accuracy_score(target_labels_array,pred_labels_array)
val_f1=f1_score(target_labels_array,pred_labels_array,average='macro')

#classification_rep=classification_report(target_label_val,pred_labels_discrete)

cm=confusion_matrix(list(target_labels_array),list(pred_labels_array),labels=[0,1])


return(mean(val_loss_list),val_acc,val_f1,cm)


#same as previous... neeed to merge later
def gen_validate_score_LSTM_social_message_model(model,loader,device,criterion):

print("starting validation")
Sig = nn.Sigmoid()
model.eval()
target_labels=[]
pred_labels=[]
step=0
val_loss_list=[]

with torch.no_grad():
for i, (vid_feat,label,lens) in enumerate(tqdm(loader)):

vid_feat=vid_feat.float()
label=label.float()
vid_feat=vid_feat.to(device)
label=label.to(device)

vid_feat,label,lens = sort_batch(vid_feat,label,lens)
logits=model(vid_feat,lens.cpu().numpy())
logits_sig=Sig(logits)

loss=criterion(logits,label)
val_loss_list.append(loss.item())
target_labels.append(label.cpu())
pred_labels.append(logits_sig.cpu())
step=step+1

target_label_val=torch.cat(target_labels).numpy()
pred_label_val=torch.cat(pred_labels).numpy()
#print(target_label_val.shape,pred_label_val.shape)
pred_labels_discrete=np.where(pred_label_val>=0.5,1,0)

return(mean(val_loss_list),val_acc,val_f1)
#convert pred_labels_discrete to 0 and 1 using argmax
pred_labels_array=np.argmax(pred_labels_discrete,axis=1)
target_labels_array=np.argmax(target_label_val,axis=1)

#target_label_discrete=np.where(target_label_val>=0.5,1,0)
#print(target_labels_array.shape,pred_labels_array.shape)
#print(len(target_labels),len(pred_labels_discrete))
val_acc=accuracy_score(target_labels_array,pred_labels_array)
val_f1=f1_score(target_labels_array,pred_labels_array,average='macro')

#classification_rep=classification_report(target_label_val,pred_labels_discrete)
cm=confusion_matrix(list(target_labels_array),list(pred_labels_array),labels=[0,1])
print(Counter(list(pred_labels_array)))

return(mean(val_loss_list),val_acc,val_f1,cm)

Loading

0 comments on commit 76430a9

Please sign in to comment.