From 8e4c6c24913d99510532315244477669b700508a Mon Sep 17 00:00:00 2001 From: "yuxin.wang" Date: Tue, 23 Jan 2024 10:38:45 +0800 Subject: [PATCH] Refactor calculate_cost function to include different cost calculations based on model name (#17) Co-authored-by: wangyuxin --- generate/chat_completion/models/minimax_pro.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py index 5b5b5c2..bb02dd9 100644 --- a/generate/chat_completion/models/minimax_pro.py +++ b/generate/chat_completion/models/minimax_pro.py @@ -199,7 +199,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput: model_info=self.model_info, message=self.message, finish_reason=response['choices'][0]['finish_reason'], - cost=calculate_cost(response['usage']), + cost=calculate_cost(model_name=self.model_info.name , usage=response['usage']), extra={ 'input_sensitive': response['input_sensitive'], 'output_sensitive': response['output_sensitive'], @@ -244,8 +244,16 @@ def update_existing_message(self, response: ResponseValue) -> str: return delta -def calculate_cost(usage: dict[str, int], num_web_search: int = 0) -> float: - return 0.015 * (usage['total_tokens'] / 1000) + (0.03 * num_web_search) +def calculate_cost(model_name: str, usage: dict[str, int], num_web_search: int = 0) -> float: + if model_name == 'abab6-chat': + model_cost = 0.1 * (usage['total_tokens'] / 1000) + elif model_name == 'abab5.5-chat': + model_cost = 0.015 * (usage['total_tokens'] / 1000) + elif model_name == 'abab5.5s-chat': + model_cost = 0.005 * (usage['total_tokens'] / 1000) + else: + return 0.0 + return model_cost + (0.03 * num_web_search) class MinimaxProChat(ChatCompletionModel): @@ -316,7 +324,7 @@ def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: model_info=self.model_info, message=message, finish_reason=finish_reason, - cost=calculate_cost(response['usage'], num_web_search), + cost=calculate_cost(model_name=self.name, usage=response['usage'], num_web_search=num_web_search), extra={ 'input_sensitive': response['input_sensitive'], 'output_sensitive': response['output_sensitive'],