From da2c3f75d73a13934ebc58317f29a48297d6ef84 Mon Sep 17 00:00:00 2001 From: chufangao Date: Wed, 9 Oct 2024 21:13:01 -0500 Subject: [PATCH] updated_news --- gradio/gradio_app.py | 152 +++++++++++++++++++++ gradio/test.ipynb | 270 +++++++++++++++++++++++++++++++++++++ news_headlines/get_news.py | 96 +++++++------ update_news.sh | 17 +++ 4 files changed, 484 insertions(+), 51 deletions(-) create mode 100644 gradio/gradio_app.py create mode 100644 gradio/test.ipynb create mode 100644 update_news.sh diff --git a/gradio/gradio_app.py b/gradio/gradio_app.py new file mode 100644 index 0000000..a9e218a --- /dev/null +++ b/gradio/gradio_app.py @@ -0,0 +1,152 @@ +import gradio as gr +import pandas as pd +import glob +from rapidfuzz import process, fuzz +import os +import numpy as np +import json + +def process_label_preds(df): + cols_to_drop = ['phase_x','phase_y','hint','hint.1','hint.2'] + df.drop(cols_to_drop, axis=1, inplace=True) + columns = list(df.columns) + inds_to_keep = [] + current_columns = [] + for i, col in enumerate(columns): + if col not in current_columns: + current_columns.append(col) + inds_to_keep.append(i) + df = df.iloc[:, inds_to_keep] + # convert all columns not named nct_id to int + for col in df.columns: + if col != 'nct_id': + df[col] = df[col].astype(int) + # if row == 0, replace with "Failure", -1 wtih Abstain, and 1 with "Success" + df[col] = df[col].replace(0, "Failure") + df[col] = df[col].replace(-1, "Abstain") + df[col] = df[col].replace(1, "Success") + + return df + +# # we always test on supervised TOP labels +# train_df = pd.read_csv('../labeling/pre_post_2020/train_pre2020_dp.csv') +# valid_df = pd.read_csv('../labeling/pre_post_2020/valid_pre2020_dp.csv') +# test_df = pd.read_csv('../labeling/pre_post_2020/test_pre2020_dp.csv') +# all_df = pd.concat([train_df, valid_df, test_df], ignore_index=True) +all_df = pd.read_csv('../CTTI/studies.txt', sep='|') +all_df = all_df[['nct_id', 'brief_title', 'official_title', 'overall_status', 'phase', 'enrollment', 'why_stopped', 'study_type', 'start_date','completion_date']] + +phase_1_label_preds = pd.read_csv('../labeling/weak_preds_by_phase/phase1_dp.csv') +phase_2_label_preds = pd.read_csv('../labeling/weak_preds_by_phase/phase1_dp.csv') +phase_3_label_preds = pd.read_csv('../labeling/weak_preds_by_phase/phase1_dp.csv') +phase_1_label_preds = process_label_preds(phase_1_label_preds) +phase_2_label_preds = process_label_preds(phase_2_label_preds) +phase_3_label_preds = process_label_preds(phase_3_label_preds) + +all_nct_ids = pd.concat([phase_1_label_preds['nct_id'], phase_2_label_preds['nct_id'], phase_3_label_preds['nct_id']]).values +all_nct_ids = set(all_nct_ids) +all_df = all_df[all_df['nct_id'].isin(all_nct_ids)] + +print(all_df.shape) +all_brief_titles = list(all_df['brief_title'].values) + +linkage_path = "/srv/local/data/chufan2/github/CTOD/supplementary/clinical_trial_linkage/Merged_(ALL)_trial_linkage_outcome_df_FDA_updated.csv" +linkage_df = pd.read_csv(linkage_path) + +gpt_decision_path = '../supplementary/llm_prediction_on_pubmed/gpt-35-decisions/' + +def get_gpt_decisions(nct_id): + nct_id = nct_id.strip() + if os.path.exists(gpt_decision_path + nct_id + '_gpt_response.json'): + try: + with open(gpt_decision_path + nct_id + '_gpt_response.json', 'r') as f: + json_dict = json.loads(f.read()) + except json.JSONDecodeError: + return {} + return json.dumps(json_dict, indent=4) + +def get_closest_nctids(title, n=5): + title = title.strip() + if title.startswith('NCT'): + return all_df.loc[all_df['nct_id'] == title] + # fuzzy string match brief_title + # print(title) + closest_titles = process.extract(title, all_brief_titles, scorer=fuzz.WRatio, limit=n) + # print(closest_titles) + closest_inds = [_[2] for _ in closest_titles] # only return the titles + return all_df.iloc[closest_inds,:] + +def get_lf_preds(nct_id): + nct_id = nct_id.strip() + # nct_id = value['nct_id'] + # get the label predictions for the nct_id + # if phase 1 + if nct_id not in all_df['nct_id'].values: + return None + phase = all_df.loc[all_df['nct_id'] == nct_id, 'phase'].values[0] + if 'Phase 1' in phase: + ret = phase_1_label_preds.loc[phase_1_label_preds['nct_id'] == nct_id] + elif 'Phase 2' in phase: + ret = phase_2_label_preds.loc[phase_2_label_preds['nct_id'] == nct_id] + elif 'Phase 3' in phase: + ret = phase_3_label_preds.loc[phase_3_label_preds['nct_id'] == nct_id] + else: + ret = None + print(ret) + return ret + +def get_lf_linkages(nct_id): + nct_id = nct_id.strip() + if nct_id not in linkage_df['nctid'].values: + return None + ret = linkage_df.loc[linkage_df['nctid'] == nct_id] + return ret + +with gr.Blocks(theme=gr.themes.Soft()) as demo: + input_title = gr.Textbox(label="Trial Title Search") + output = gr.DataFrame(label="Trial Search Results", wrap=True) + # greet_btn = gr.Button("Search") + input_nctid = gr.Textbox(label="View Specific NCT ID") + output2 = gr.DataFrame(label="Weakly Supervised Label Predictions", wrap=True) + output3 = gr.DataFrame(label="Next Phase Link Prediction", wrap=True) + output4 = gr.Textbox(label="GPT Decisions",) + + gr.on( + triggers=[input_title.submit], + fn=get_closest_nctids, + inputs=input_title, + outputs=output, + ) + + gr.on( + triggers=[input_nctid.submit], + fn=get_lf_preds, + inputs=input_nctid, + outputs=output2, + ) + gr.on( + triggers=[input_nctid.submit], + fn=get_lf_linkages, + inputs=input_nctid, + outputs=output3, + ) + gr.on( + triggers=[input_nctid.submit], + fn=get_gpt_decisions, + inputs=input_nctid, + outputs=output4, + ) + + # output.select(fn=get_lf_preds, inputs=output, outputs=output2) + + # also if enter key is pressed + + # @gr.render(inputs=output) + # def show_split(text): + # if len(text) == 0: + # gr.Markdown("## No Input Provided") + # else: + # for letter in text: + # gr.Textbox(letter) + +demo.launch(share=True) diff --git a/gradio/test.ipynb b/gradio/test.ipynb new file mode 100644 index 0000000..43889a8 --- /dev/null +++ b/gradio/test.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nctidoutcomephaseconnected next phaseweakly connected next phase
0NCT02201381FailurePhase 3[][]
1NCT02914275SuccessPhase 3[{'trial': 'NCT03965195', 'cross_encoder_score...[]
2NCT01941875SuccessPhase 3[{'trial': 'NCT03619707', 'cross_encoder_score...[]
3NCT00527904SuccessPhase 3[][{'trial': 'NCT01544114', 'cross_encoder_score...
4NCT00489827Not surePhase 3[][{'trial': 'NCT05122780', 'cross_encoder_score...
..................
152045NCT03972280FailurePhase 1[][]
152046NCT04991766FailurePhase 1[][]
152047NCT02299778FailurePhase 1[][]
152048NCT02348307SuccessPhase 1[{'trial': 'NCT02344862', 'cross_encoder_score...[]
152049NCT03656705FailurePhase 1[][]
\n", + "

152050 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " nctid outcome phase \\\n", + "0 NCT02201381 Failure Phase 3 \n", + "1 NCT02914275 Success Phase 3 \n", + "2 NCT01941875 Success Phase 3 \n", + "3 NCT00527904 Success Phase 3 \n", + "4 NCT00489827 Not sure Phase 3 \n", + "... ... ... ... \n", + "152045 NCT03972280 Failure Phase 1 \n", + "152046 NCT04991766 Failure Phase 1 \n", + "152047 NCT02299778 Failure Phase 1 \n", + "152048 NCT02348307 Success Phase 1 \n", + "152049 NCT03656705 Failure Phase 1 \n", + "\n", + " connected next phase \\\n", + "0 [] \n", + "1 [{'trial': 'NCT03965195', 'cross_encoder_score... \n", + "2 [{'trial': 'NCT03619707', 'cross_encoder_score... \n", + "3 [] \n", + "4 [] \n", + "... ... \n", + "152045 [] \n", + "152046 [] \n", + "152047 [] \n", + "152048 [{'trial': 'NCT02344862', 'cross_encoder_score... \n", + "152049 [] \n", + "\n", + " weakly connected next phase \n", + "0 [] \n", + "1 [] \n", + "2 [] \n", + "3 [{'trial': 'NCT01544114', 'cross_encoder_score... \n", + "4 [{'trial': 'NCT05122780', 'cross_encoder_score... \n", + "... ... \n", + "152045 [] \n", + "152046 [] \n", + "152047 [] \n", + "152048 [] \n", + "152049 [] \n", + "\n", + "[152050 rows x 5 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "path = \"/srv/local/data/chufan2/github/CTOD/supplementary/clinical_trial_linkage/Merged_(ALL)_trial_linkage_outcome_df_FDA_updated.csv\"\n", + "df = pd.read_csv(path)\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 12/163992 [00:00<00:12, 13245.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NCT03935217_gpt_response.json\n", + "NCT01703052_gpt_response.json\n", + "NCT02313844_gpt_response.json\n", + "NCT02607982_gpt_response.json\n", + "NCT06132724_gpt_response.json\n", + "NCT05799924_gpt_response.json\n", + "NCT04942938_gpt_response.json\n", + "NCT04210973_gpt_response.json\n", + "NCT01987206_gpt_response.json\n", + "NCT00911716_gpt_response.json\n", + "NCT05230030_gpt_response.json\n", + "NCT01409304_gpt_response.json\n", + "NCT04351204_gpt_response.txt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import json\n", + "from tqdm import trange\n", + "\n", + "files = os.listdir('../supplementary/llm_prediction_on_pubmed/gpt-35-decisions/')\n", + "# # randomly select 10 files\n", + "# np.random.seed(0)\n", + "# inds = np.random.choice(len(files), 10, replace=False)\n", + "# df = {}\n", + "for i in trange(len(files)):\n", + " print(files[i])\n", + " try:\n", + " with open('../supplementary/llm_prediction_on_pubmed/gpt-35-decisions/' + files[i], 'r') as f:\n", + " json_dict = json.loads(f.read())\n", + " except json.decoder.JSONDecodeError:\n", + " # print('Error in ' + files[i])\n", + " break\n", + " # df[files[i].split('_')[0]] = json_dict\n", + " # print(json.dumps(json_dict, indent=4))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/news_headlines/get_news.py b/news_headlines/get_news.py index 716a5a5..1998ec6 100644 --- a/news_headlines/get_news.py +++ b/news_headlines/get_news.py @@ -1,21 +1,20 @@ import os -# os.environ["HF_HOME"] = "/srv/local/data/chufan2/huggingface/" import sys -import os from tqdm.auto import tqdm, trange from datetime import datetime, timedelta import time import pandas as pd import numpy as np import json +import torch import argparse -import random from transformers import pipeline from sentence_transformers import SentenceTransformer, CrossEncoder -sys.path.append('./GNews/') +# append GNews to path, append the path to the GNews folder, in this case, the GNews folder is in the directory of the script +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'GNews/')) from gnews import GNews - +# quit() #convert to datetime def convert_to_datetime(date_str): """ Convert a date string to a datetime object. @@ -49,13 +48,13 @@ def get_date_at_month(start_date, month_to_add): years -= 1 return (start_date[0] + years, month, start_date[2]) -def get_related_news(keyword, start_date, log_dir, num_months=240, last_results=None): +def get_related_news(keyword, start_date, SAVE_NEWS_LOG_PATH, num_months=240, last_results=None): """ Get news related to a keyword for num_months. Log the news to a json file. keyword: str, industry sponsor to search for start_date: tuple of (year, month, day) - log_dir: str, directory to save the news + args.SAVE_NEWS_LOG_PATH: str, directory to save the news num_months: int, number of months to get news for from start_date Returns: dict, news for each month @@ -78,11 +77,11 @@ def get_related_news(keyword, start_date, log_dir, num_months=240, last_results= # Get the news results results = google_news.get_news(keyword) # random sleep to avoid getting blocked, can be adjusted, but this works for me - time.sleep(random.randint(1, 5)) # time.sleep(1) + time.sleep(np.random.randint(1, 5)) # time.sleep(1) lens = len(results) print(f'Got {lens} news for {keyword} in {google_news.start_date} to {google_news.end_date}') all_results[str((start_time, end_time))] = results - with open(log_dir+'news.json', "w") as f: + with open(SAVE_NEWS_LOG_PATH+'news.json', "w") as f: json.dump(all_results, f) # dump the results # sorted_results= sorted(results, key=lambda x: datetime.strptime(x['published date'], "%a, %d %b %Y %H:%M:%S %Z"), reverse=True) @@ -97,8 +96,8 @@ def get_top_sponsors(sponsors, studies): Returns: pd.DataFrame, top 1000 most popular phase 3 industry sponsors """ - # sponsors = pd.read_csv(data_path + './CTTI/sponsors.txt', sep='|') - # studies = pd.read_csv(data_path + './CTTI/studies.txt', sep='|', low_memory=False) + # sponsors = pd.read_csv(args.CTTI_PATH + './CTTI/sponsors.txt', sep='|') + # studies = pd.read_csv(args.CTTI_PATH + './CTTI/studies.txt', sep='|', low_memory=False) studies['study_first_submitted_date'] = pd.to_datetime(studies['study_first_submitted_date']) sponsors = pd.merge(sponsors, studies[['nct_id', 'phase', 'study_first_submitted_date']], on='nct_id', how='left') sponsors = sponsors[sponsors['agency_class']=='INDUSTRY'] @@ -115,34 +114,43 @@ def get_top_sponsors(sponsors, studies): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--mode', type=str, default='get_news', help='get_news, process_news, correspond_news_and_studies') + parser.add_argument('--continue_from_prev_log', type=bool, default=False) + parser.add_argument('--CTTI_PATH', type=str, default='./CITT/') + parser.add_argument('--SENTIMENT_MODEL', type=str, default="yiyanghkust/finbert-tone") + parser.add_argument('--SENTENCE_ENCODER', type=str, default="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext") + parser.add_argument('--SENTENCE_CROSSENCODER', type=str, default="cross-encoder/ms-marco-MiniLM-L-12-v2") + parser.add_argument('--SAVE_NEWS_LOG_PATH', type=str, default='./news_logs/') + parser.add_argument('--SAVE_NEWS_EMBEDDING_PATH', type=str, default='./news_title_embeddings.npy') + parser.add_argument('--SAVE_STUDY_TITLE_EMBEDDING_PATH', type=str, default='./studies_title2_embeddings.npy') + parser.add_argument('--SAVE_NEWS_PATH', type=str, default='./news.csv') + parser.add_argument('--SAVE_STUDY_NEWS_PATH', type=str, default='./studies_with_news.csv') args = parser.parse_args() assert args.mode in ['get_news', 'process_news', 'correspond_news_and_studies'] print(f'args.mode: {args.mode}') - data_path = './CITT/' - log_dir = './news_logs/' - continue_from_prev_log = True - sponsors = pd.read_csv(data_path + 'sponsors.txt', sep='|') - studies = pd.read_csv(data_path + 'studies.txt', sep='|', low_memory=False) + continue_from_prev_log = args.continue_from_prev_log + sponsors = pd.read_csv(args.CTTI_PATH + 'sponsors.txt', sep='|') + studies = pd.read_csv(args.CTTI_PATH + 'studies.txt', sep='|', low_memory=False) combined = get_top_sponsors(sponsors, studies) - cache_folder = '/srv/local/data/chufan2/huggingface/' - sentiment_pipe = pipeline("text-classification", model="yiyanghkust/finbert-tone", device='cuda') - encoder = SentenceTransformer('microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext', cache_folder=cache_folder) - crossencoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2", max_length=512) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + sentiment_model = pipeline("text-classification", model=args.SENTIMENT_MODEL, device=device) + encoder = SentenceTransformer(args.SENTENCE_ENCODER) + crossencoder = CrossEncoder(args.SENTENCE_CROSSENCODER, max_length=512) if args.mode == 'get_news': # warning: this will take a long time (multiple weeks) # get top 1000 most popular phase 3 industry sponsors global_i = 0 for name in tqdm(sorted(combined['name'])): - # if name.lower() in os.listdir(log_dir): - if os.path.exists(os.path.join(log_dir, name.lower()+".json")): + # if name.lower() in os.listdir(args.SAVE_NEWS_LOG_PATH): + last_results = None + if os.path.exists(os.path.join(args.SAVE_NEWS_LOG_PATH, name.lower()+".json")): print(f'{name} already exists') if not continue_from_prev_log: continue else: - last_results = json.load(open(os.path.join(log_dir, name.lower()+".json"))) + last_results = json.load(open(os.path.join(args.SAVE_NEWS_LOG_PATH, name.lower()+".json"))) all_dates = [eval(d)[0][0] for d in last_results.keys()] last_year = sorted(all_dates)[-1] last_results = {k: v for k, v in last_results.items() if eval(k)[0][0] < last_year} @@ -151,23 +159,18 @@ def get_top_sponsors(sponsors, studies): min_date = combined[combined['name']==name]['study_first_submitted_date'].min() # print(date.year, date.month, date.day) start_date = (int(min_date.year), int(min_date.month), 1) - os.makedirs(log_dir+name.lower(), exist_ok=True) - + os.makedirs(args.SAVE_NEWS_LOG_PATH+name.lower(), exist_ok=True) - news = get_related_news(name, start_date, num_months=12*50, log_path=log_dir + name.lower() + '.json', - last_results=last_results) + news = get_related_news(name, start_date, num_months=12*50, log_path=args.SAVE_NEWS_LOG_PATH + name.lower() + '.json', last_results=last_results) global_i += 1 - # break + # ======================== Process the news data ======================== elif args.mode == 'process_news': print('Processing news data') - # studies_df = pd.read_csv('./CITT/studies.txt', sep='|', low_memory=False) - # ticker_dict_df = pd.read_csv('stock_price/ticker_dict_642.csv') - # ticker_dict = {row['name'].lower(): row['ticker'] for _, row in ticker_dict_df.iterrows()} all_company_dfs = [] - for company in sorted(os.listdir(log_dir)): + for company in sorted(os.listdir(args.SAVE_NEWS_LOG_PATH)): with open(os.path.join(company), 'rb') as f: news = json.load(f) # print(company, ticker, news) @@ -188,15 +191,9 @@ def get_top_sponsors(sponsors, studies): all_publishers.append(news[k][i]['publisher']['title']) df = pd.DataFrame({'date': all_dates, 'title': all_titles, 'description': all_descriptions, 'publisher': all_publishers}) - # if company in ticker_dict.keys(): - # ticker = ticker_dict[company] - # else: - # ticker = pd.NA - # df['ticker'] = ticker df['company'] = company all_company_dfs.append(df) all_company_dfs = pd.concat(all_company_dfs) - # all_company_dfs.to_csv('./stock_price/news_tmp.csv', index=False); quit() # update old news.csv with company names print("Processing sentiment") batch_size = 512 @@ -204,7 +201,7 @@ def get_top_sponsors(sponsors, studies): all_label_probs = [] titles = all_company_dfs['title'].tolist() for i in trange(0, len(titles), batch_size): - out = sentiment_pipe(titles[i:i+batch_size], batch_size=batch_size) + out = sentiment_model(titles[i:i+batch_size], batch_size=batch_size) all_label_preds += [o['label'] for o in out] all_label_probs += [o['score'] for o in out] all_company_dfs['sentiment'] = all_label_preds @@ -212,20 +209,17 @@ def get_top_sponsors(sponsors, studies): # process title embeddings using pubmedbert print("Processing embeddings") - encoded_titles = encoder.encode(titles, convert_to_numpy=True, device='cuda', batch_size=batch_size) - np.save('./news_title_embeddings.npy', encoded_titles) - - # encoded_studies = encoder.encode(studies['brief_title'].tolist(), convert_to_numpy=True, device='cuda', batch_size=batch_size) - # np.save('./studies_title_embeddings.npy', encoded_studies) + encoded_titles = encoder.encode(titles, convert_to_numpy=True, device=device, batch_size=batch_size) + np.save(args.SAVE_NEWS_EMBEDDING_PATH, encoded_titles) - all_company_dfs.to_csv('./news.csv', index=False) + all_company_dfs.to_csv(args.SAVE_NEWS_PATH, index=False) elif args.mode == 'correspond_news_and_studies': - news_df = pd.read_csv('./news.csv') - news_title_embedding = np.load('./news_title_embeddings.npy') + news_df = pd.read_csv(args.SAVE_NEWS_PATH) + news_title_embedding = np.load(args.SAVE_NEWS_EMBEDDING_PATH) top_sponsors = combined - interventions = pd.read_csv(data_path+'interventions.txt', sep='|') - conditions = pd.read_csv(data_path+'conditions.txt', sep='|') + interventions = pd.read_csv(args.CTTI_PATH+'interventions.txt', sep='|') + conditions = pd.read_csv(args.CTTI_PATH+'conditions.txt', sep='|') studies = studies[studies['nct_id'].isin(top_sponsors['nct_id'])] studies = studies[studies['nct_id'].isin(interventions['nct_id'])] @@ -246,7 +240,7 @@ def get_top_sponsors(sponsors, studies): studies['title2'] = studies['intervention_name'] + ' ' + studies['condition_name'] studies_title2_embedding = encoder.encode(studies['title2'], convert_to_numpy=True, device='cuda', show_progress_bar=True) - np.save('./studies_title2_embeddings.npy', studies_title2_embedding) + np.save(args.SAVE_STUDY_TITLE_EMBEDDING_PATH, studies_title2_embedding) # print(news_df.shape, news_title_embedding.shape, studies.shape, studies_title2_embedding.shape) # # most relevant news for each study @@ -279,4 +273,4 @@ def get_top_sponsors(sponsors, studies): studies.iloc[i, column_ind:column_ind+len(inds)] = news_df_.iloc[inds].index studies.iloc[i, column_ind+top_k:column_ind+top_k+len(news_df_)] = sims - studies.to_csv('./studies_with_news.csv', index=False) + studies.to_csv(args.SAVE_STUDY_NEWS_PATH, index=False) diff --git a/update_news.sh b/update_news.sh new file mode 100644 index 0000000..7cd7987 --- /dev/null +++ b/update_news.sh @@ -0,0 +1,17 @@ +# get_news paths + +# HF_HOME="/srv/local/data/chufan2/huggingface/" +CTTI_PATH="./CTTI/" +SENTIMENT_MODEL="yiyanghkust/finbert-tone" +SENTENCE_ENCODER="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" +SENTENCE_CROSSENCODER="cross-encoder/ms-marco-MiniLM-L-12-v2" +SAVE_NEWS_LOG_PATH="./news_headlines/news_logs/" +SAVE_NEWS_EMBEDDING_PATH="./news_headlines/news_title_embeddings.npy" +SAVE_STUDY_TITLE_EMBEDDING_PATH="./news_headlines/studies_title_embeddings.npy" +SAVE_NEWS_PATH="./news_headlines/news.csv" +SAVE_STUDY_NEWS_PATH="./news_headlines/studies_with_news.csv" +continue_from_prev_log=True + +python news_headlines/get_news.py --mode=get_news --continue_from_prev_log=$continue_from_prev_log --CTTI_PATH=$CTTI_PATH --SENTIMENT_MODEL=$SENTIMENT_MODEL --SENTENCE_ENCODER=$SENTENCE_ENCODER --SAVE_NEWS_LOG_PATH=$SAVE_NEWS_LOG_PATH --SAVE_NEWS_EMBEDDING_PATH=$SAVE_NEWS_EMBEDDING_PATH --SAVE_NEWS_PATH=$SAVE_NEWS_PATH --SAVE_STUDY_NEWS_PATH=$SAVE_STUDY_NEWS_PATH +# python news_headlines/get_news.py --mode=process_news --continue_from_prev_log=$continue_from_prev_log --CTTI_PATH=$CTTI_PATH --SENTIMENT_MODEL=$SENTIMENT_MODEL --SENTENCE_ENCODER=$SENTENCE_ENCODER --SAVE_NEWS_LOG_PATH=$SAVE_NEWS_LOG_PATH --SAVE_NEWS_EMBEDDING_PATH=$SAVE_NEWS_EMBEDDING_PATH --SAVE_NEWS_PATH=$SAVE_NEWS_PATH --SAVE_STUDY_NEWS_PATH=$SAVE_STUDY_NEWS_PATH +# python news_headlines/get_news.py --mode=correspond_news_and_studies --continue_from_prev_log=$continue_from_prev_log --CTTI_PATH=$CTTI_PATH --SENTIMENT_MODEL=$SENTIMENT_MODEL --SENTENCE_ENCODER=$SENTENCE_ENCODER --SAVE_NEWS_LOG_PATH=$SAVE_NEWS_LOG_PATH --SAVE_NEWS_EMBEDDING_PATH=$SAVE_NEWS_EMBEDDING_PATH --SAVE_NEWS_PATH=$SAVE_NEWS_PATH --SAVE_STUDY_NEWS_PATH=$SAVE_STUDY_NEWS_PATH \ No newline at end of file