Skip to content

Commit

Permalink
Update benchmark.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yuh-zha authored Jun 20, 2023
1 parent 819fdd0 commit 7e18faa
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,36 @@ def eval_align_nlg(ckpt_path, comment='', base_model='roberta-large', batch_size
evaluator.evaluate()
timer.finish(name)

def eval_gptscore(api_key, gpt_model='davinci003', tasks=ALL_TASKS):
gptscore = GPTScoreScorer(api_key=api_key, gpt_model=gpt_model)
evaluator = Evaluator(eval_tasks=tasks, align_func=gptscore.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
evaluator.result_save_name = f"nlg_eval_fact/baselines/GPTScore-{gpt_model}"
evaluator.evaluate()

def eval_chatgptluo2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
chatgpt = ChatGPTLuo2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTLuo2023-{chat_model}"
evaluator.evaluate()

def eval_chatgptgao2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
chatgpt = ChatGPTGao2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTGao2023-{chat_model}"
evaluator.evaluate()

def eval_chatgptyichen2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
chatgpt = ChatGPTYiChen2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTYiChen2023-{chat_model}"
evaluator.evaluate()

def eval_chatgptshiqichen2023(api_key, chat_model='gpt-3.5-turbo', tasks=['qags_cnndm']):
chatgpt = ChatGPTShiqiChen2023Scorer(task=tasks, api_key=api_key, chat_model=chat_model)
evaluator = Evaluator(eval_tasks=tasks, align_func=chatgpt.scorer, is_save_all_tables=IS_SAVE_ALL_TABLES)
evaluator.result_save_name = f"nlg_eval_fact/baselines/ChatGPTShiqiChen2023-{chat_model}"
evaluator.evaluate()

def run_benchmarks(args, argugment_error):
os.makedirs('exp_results/baselines', exist_ok=True)
os.makedirs('exp_results/align_eval', exist_ok=True)
Expand Down

0 comments on commit 7e18faa

Please sign in to comment.