Skip to content

Commit

Permalink
refactor(forge/llm): Create BaseOpenAIProvider -> deduplicate `Groq…
Browse files Browse the repository at this point in the history
…Provider` & `OpenAIProvider` implementation (#7178)

- Add `_BaseOpenAIProvider`, `BaseOpenAIChatProvider`, and `BaseOpenAIEmbeddingProvider`, which implement the shared functionality of OpenAI-like providers, e.g. `GroqProvider` and `OpenAIProvider`
- (Re)move as much code as possible from `GroqProvider` and `OpenAIProvider` by rebasing them on `BaseOpenAI(Chat|Embedding)Provider`

Also:
- Rename `get_available_models()` to `get_available_chat_models()` on `BaseChatModelProvider`
- Add `get_available_models()` to `BaseModelProvider`
- Add `get_available_embedding_models()` to `BaseEmbeddingModelProvider`
- Move common `fix_failed_parse_tries` config attribute into base `ModelProviderConfiguration`
  • Loading branch information
Pwuts authored Jun 2, 2024
1 parent cb9ad6f commit 4e76768
Show file tree
Hide file tree
Showing 8 changed files with 640 additions and 732 deletions.
2 changes: 1 addition & 1 deletion autogpt/autogpt/app/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def check_model(
) -> ModelName:
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
multi_provider = MultiProvider()
models = await multi_provider.get_available_models()
models = await multi_provider.get_available_chat_models()

if any(model_name == m.name for m in models):
return model_name
Expand Down
4 changes: 2 additions & 2 deletions autogpt/tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_azure_config(config_with_azure: Config) -> None:
@pytest.mark.asyncio
async def test_create_config_gpt4only(config: Config) -> None:
with mock.patch(
"forge.llm.providers.multi.MultiProvider.get_available_models"
"forge.llm.providers.multi.MultiProvider.get_available_chat_models"
) as mock_get_models:
mock_get_models.return_value = [
ChatModelInfo(
Expand All @@ -164,7 +164,7 @@ async def test_create_config_gpt4only(config: Config) -> None:
@pytest.mark.asyncio
async def test_create_config_gpt3only(config: Config) -> None:
with mock.patch(
"forge.llm.providers.multi.MultiProvider.get_available_models"
"forge.llm.providers.multi.MultiProvider.get_available_chat_models"
) as mock_get_models:
mock_get_models.return_value = [
ChatModelInfo(
Expand Down
Loading

0 comments on commit 4e76768

Please sign in to comment.