From 540e25b72c4ff16775d3a46864b69e5abe2eec00 Mon Sep 17 00:00:00 2001 From: "yuxin.wang" Date: Mon, 11 Mar 2024 11:08:40 +0800 Subject: [PATCH] Azure bug (#50) * Add model parameter to AzureChat constructor --------- Co-authored-by: wangyuxin --- generate/chat_completion/models/azure.py | 10 +++++++++- generate/chat_completion/models/minimax.py | 1 - generate/chat_completion/models/openai_like.py | 2 +- generate/platforms/azure.py | 3 +++ generate/text_to_speech/models/minimax.py | 4 ++-- tests/test_chat_completion_model.py | 4 ---- tests/test_modifiers.py | 4 ++-- 7 files changed, 17 insertions(+), 11 deletions(-) diff --git a/generate/chat_completion/models/azure.py b/generate/chat_completion/models/azure.py index 700aa4f..4062866 100644 --- a/generate/chat_completion/models/azure.py +++ b/generate/chat_completion/models/azure.py @@ -9,6 +9,7 @@ from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput from generate.chat_completion.models.openai import OpenAIChatParameters, OpenAIChatParametersDict from generate.chat_completion.models.openai_like import convert_to_openai_message, process_openai_like_model_reponse +from generate.chat_completion.stream_manager import StreamManager from generate.http import HttpClient, HttpxPostKwargs from generate.platforms.azure import AzureSettings @@ -21,7 +22,7 @@ class AzureChat(RemoteChatCompletionModel): def __init__( self, - model: str, + model: str | None = None, parameters: OpenAIChatParameters | None = None, settings: AzureSettings | None = None, http_client: HttpClient | None = None, @@ -29,6 +30,9 @@ def __init__( parameters = parameters or OpenAIChatParameters() settings = settings or AzureSettings() # type: ignore http_client = http_client or HttpClient() + model = model or settings.chat_api_engine + if model is None: + raise ValueError('model must be provided or set in settings.chat_api_engine') super().__init__(model, parameters=parameters, settings=settings, http_client=http_client) @override @@ -75,3 +79,7 @@ def _get_request_parameters(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatPar @override def _process_reponse(self, response: dict[str, Any]) -> ChatCompletionOutput: return process_openai_like_model_reponse(response, model_type=self.model_type) + + @override + def _process_stream_line(self, line: str, stream_manager: StreamManager) -> ChatCompletionStreamOutput | None: + raise NotImplementedError diff --git a/generate/chat_completion/models/minimax.py b/generate/chat_completion/models/minimax.py index 39eafc4..32c466a 100644 --- a/generate/chat_completion/models/minimax.py +++ b/generate/chat_completion/models/minimax.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import uuid from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional from pydantic import PositiveInt, field_validator diff --git a/generate/chat_completion/models/openai_like.py b/generate/chat_completion/models/openai_like.py index 773ffff..d2c9ad1 100644 --- a/generate/chat_completion/models/openai_like.py +++ b/generate/chat_completion/models/openai_like.py @@ -2,12 +2,12 @@ import base64 import json +import uuid from abc import ABC from functools import partial from typing import Any, Callable, Dict, List, Literal, Type, Union, cast from typing_extensions import NotRequired, TypedDict, override -import uuid from generate.chat_completion.base import RemoteChatCompletionModel from generate.chat_completion.cost_caculator import GeneralCostCalculator diff --git a/generate/platforms/azure.py b/generate/platforms/azure.py index 9db7988..5362c31 100644 --- a/generate/platforms/azure.py +++ b/generate/platforms/azure.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import SecretStr from pydantic_settings import SettingsConfigDict @@ -10,4 +12,5 @@ class AzureSettings(PlatformSettings): api_key: SecretStr api_base: str api_version: str + chat_api_engine: Optional[str] = None platform_url: str = 'https://learn.microsoft.com/en-us/azure/ai-services/openai/' diff --git a/generate/text_to_speech/models/minimax.py b/generate/text_to_speech/models/minimax.py index 92f577b..f30d2a9 100644 --- a/generate/text_to_speech/models/minimax.py +++ b/generate/text_to_speech/models/minimax.py @@ -79,7 +79,7 @@ def _get_request_parameters(self, text: str, parameters: MinimaxSpeechParameters 'Content-Type': 'application/json', } return { - 'url': self.settings.api_base + 'text_to_speech', + 'url': self.settings.api_base + '/text_to_speech', 'json': json_data, 'headers': headers, 'params': {'GroupId': self.settings.group_id}, @@ -154,7 +154,7 @@ def _get_request_parameters(self, text: str, parameters: MinimaxProSpeechParamet 'Content-Type': 'application/json', } return { - 'url': self.settings.api_base + 't2a_pro', + 'url': self.settings.api_base + '/t2a_pro', 'json': json_data, 'headers': headers, 'params': {'GroupId': self.settings.group_id}, diff --git a/tests/test_chat_completion_model.py b/tests/test_chat_completion_model.py index 4ab4e02..7a88c63 100644 --- a/tests/test_chat_completion_model.py +++ b/tests/test_chat_completion_model.py @@ -13,7 +13,6 @@ RemoteChatCompletionModel, ) from generate.chat_completion.message import Prompt -from generate.chat_completion.models.azure import AzureChat from generate.test import get_pytest_params @@ -33,9 +32,6 @@ def test_model_type_is_unique() -> None: ], ) def test_http_chat_model(model_cls: Type[ChatCompletionModel], parameters: dict[str, Any]) -> None: - if issubclass(model_cls, AzureChat): - return - model = model_cls() prompt = '这是测试,只回复你好' sync_output = model.generate(prompt, **parameters) diff --git a/tests/test_modifiers.py b/tests/test_modifiers.py index 7829501..0e7ae1b 100644 --- a/tests/test_modifiers.py +++ b/tests/test_modifiers.py @@ -43,7 +43,7 @@ class Country(BaseModel): def test_session() -> None: model = OpenAIChat().session() - model.generate('I am bob') - reply = model.generate('who am i?').reply + model.generate('call me BOB') + reply = model.generate('TEST: my name is ?').reply assert 'bob' in reply.lower()