diff --git a/src/Dataset/multi_dataset_test.py b/src/Dataset/multi_dataset_test.py index 6b9b1bb..c2e68df 100644 --- a/src/Dataset/multi_dataset_test.py +++ b/src/Dataset/multi_dataset_test.py @@ -154,7 +154,7 @@ def __init__(self, text_tokenizer, test_split = 'close', max_seq = 2048, max_img # self.data_whole_2D = self.data_whole_2D + [{'medpix_qa_dataset':i} for i in range(len(medpix_qa_dataset))] # print('medpix_qa_dataset loaded') - pmcvqa_dataset = PMCVQA_Dataset(csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/New_Dataset/data_csv/pmcvqa_test.csv') + pmcvqa_dataset = VQA_Dataset(csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/New_Dataset/data_csv/pmcvqa_test.csv') self.dataset_reflect['pmcvqa_dataset'] = pmcvqa_dataset self.data_whole_2D = self.data_whole_2D + [{'pmcvqa_dataset':i} for i in range(len(pmcvqa_dataset))] print('pmcvqa_dataset loaded') @@ -165,12 +165,12 @@ def __init__(self, text_tokenizer, test_split = 'close', max_seq = 2048, max_img self.data_whole_2D = self.data_whole_2D + [{'casereport_dataset':i} for i in range(len(casereport_dataset))] print('casereport_dataset loaded') - vqarad_dataset = PMCVQA_Dataset(csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/New_Dataset/data_csv/vqarad_test.csv') + vqarad_dataset = VQA_Dataset(csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/New_Dataset/data_csv/vqarad_test.csv') self.dataset_reflect['vqarad_dataset'] = vqarad_dataset self.data_whole_2D = self.data_whole_2D + [{'vqarad_dataset':i} for i in range(len(vqarad_dataset))] print('vqarad_dataset loaded') - slake_dataset = PMCVQA_Dataset(csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/New_Dataset/data_csv/slakevqa_test.csv') + slake_dataset = VQA_Dataset(csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/New_Dataset/data_csv/slakevqa_test.csv') self.dataset_reflect['slake_dataset'] = slake_dataset self.data_whole_2D = self.data_whole_2D + [{'slake_dataset':i} for i in range(len(slake_dataset))] print('slake_dataset loaded') @@ -313,4 +313,4 @@ def text_add_image(self,images,question,answer): # print(len(dataset)) # for i in range(10): # dataset[i] -# input() \ No newline at end of file +# input()