Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallback to openai #58

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 90 additions & 56 deletions services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ class BaseClient(ABC):
"""Base class for all clients"""

api_type: str = None
system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block."
system_prompt = (
"You are a zsh shell expert, please help me complete the following command, "
"you should only output the completed command, no need to include any other explanation. "
"Do not put completion in quotes."
)

@abstractmethod
def get_completion(self, full_command: str) -> str:
Expand Down Expand Up @@ -44,7 +48,7 @@ def __init__(self, config: dict):
self.config = config
self.config["model"] = self.config.get("model", self.default_model)
self.client = OpenAI(
api_key=self.config["api_key"],
api_key=os.getenv("OPENAI_API_KEY", self.config.get("api_key")),
base_url=self.config.get("base_url", "https://api.openai.com/v1"),
organization=self.config.get("organization"),
)
Expand Down Expand Up @@ -82,7 +86,9 @@ def __init__(self, config: dict):
sys.exit(1)

self.config = config
genai.configure(api_key=self.config["api_key"])
genai.configure(
api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key"))
)
self.config["model"] = config.get("model", self.default_model)
self.model = genai.GenerativeModel(self.config["model"])

Expand All @@ -101,10 +107,10 @@ class GroqClient(BaseClient):
- model (optional): defaults to "llama-3.2-11b-text-preview"
- temperature (optional): defaults to 1.0.
"""

api_type = "groq"
default_model = os.getenv("GROQ_DEFAULT_MODEL", "llama-3.2-11b-text-preview")

def __init__(self, config: dict):
try:
from groq import Groq
Expand All @@ -117,9 +123,9 @@ def __init__(self, config: dict):
self.config = config
self.config["model"] = self.config.get("model", self.default_model)
self.client = Groq(
api_key=self.config["api_key"],
api_key=os.getenv("GROQ_API_KEY", self.config.get("api_key"))
)

def get_completion(self, full_command: str) -> str:
response = self.client.chat.completions.create(
model=self.config["model"],
Expand All @@ -140,10 +146,10 @@ class MistralClient(BaseClient):
- model (optional): defaults to "codestral-latest"
- temperature (optional): defaults to 1.0.
"""

api_type = "mistral"
default_model = os.getenv("MISTRAL_DEFAULT_MODEL", "codestral-latest")

def __init__(self, config: dict):
try:
from mistralai import Mistral
Expand All @@ -152,13 +158,13 @@ def __init__(self, config: dict):
"Mistral library is not installed. Please install it using 'pip install mistralai'"
)
sys.exit(1)

self.config = config
self.config["model"] = self.config.get("model", self.default_model)
self.client = Mistral(
api_key=self.config["api_key"],
api_key=os.getenv("MISTRAL_API_KEY", self.config.get("api_key"))
)

def get_completion(self, full_command: str) -> str:
response = self.client.chat.complete(
model=self.config["model"],
Expand All @@ -170,6 +176,7 @@ def get_completion(self, full_command: str) -> str:
)
return response.choices[0].message.content


class AmazonBedrock(BaseClient):
"""
config keys:
Expand All @@ -183,7 +190,9 @@ class AmazonBedrock(BaseClient):
"""

api_type = "bedrock"
default_model = os.getenv("BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0")
default_model = os.getenv(
"BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0"
)

def __init__(self, config: dict):
try:
Expand All @@ -197,24 +206,20 @@ def __init__(self, config: dict):
self.config = config
self.config["model"] = self.config.get("model", self.default_model)

session_kwargs = {}
if "aws_region" in config:
session_kwargs["region_name"] = config["aws_region"]
if "aws_access_key_id" in config:
session_kwargs["aws_access_key_id"] = config["aws_access_key_id"]
if "aws_secret_access_key" in config:
session_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"]
if "aws_session_token" in config:
session_kwargs["aws_session_token"] = config["aws_session_token"]

self.client = boto3.client("bedrock-runtime", **session_kwargs)
session_kwargs = {
"region_name": config.get("aws_region"),
"aws_access_key_id": config.get("aws_access_key_id"),
"aws_secret_access_key": config.get("aws_secret_access_key"),
"aws_session_token": config.get("aws_session_token"),
}
self.client = boto3.client(
"bedrock-runtime", **{k: v for k, v in session_kwargs.items() if v}
)

def get_completion(self, full_command: str) -> str:
import json

messages = [
{"role": "user", "content": full_command}
]
messages = [{"role": "user", "content": full_command}]

# Format request body based on model type
if "claude" in self.config["model"].lower():
Expand All @@ -223,47 +228,76 @@ def get_completion(self, full_command: str) -> str:
"max_tokens": 1000,
"system": self.system_prompt,
"messages": messages,
"temperature": float(self.config.get("temperature", 1.0))
"temperature": float(self.config.get("temperature", 1.0)),
}
else:
raise ValueError(f"Unsupported model: {self.config['model']}")

response = self.client.invoke_model(
modelId=self.config["model"],
body=json.dumps(body)
modelId=self.config["model"], body=json.dumps(body)
)

response_body = json.loads(response['body'].read())
response_body = json.loads(response["body"].read())
return response_body["content"][0]["text"]



class ClientFactory:
api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type]
api_types = [
OpenAIClient.api_type,
GoogleGenAIClient.api_type,
GroqClient.api_type,
MistralClient.api_type,
AmazonBedrock.api_type,
]

@classmethod
def create(cls):
config_parser = ConfigParser()
config_parser.read(CONFIG_PATH)
service = config_parser["service"]["service"]
try:
config = {k: v for k, v in config_parser[service].items()}
except KeyError:
raise KeyError(f"Config for service {service} is not defined")

api_type = config["api_type"]
match api_type:
case OpenAIClient.api_type:
return OpenAIClient(config)
case GoogleGenAIClient.api_type:
return GoogleGenAIClient(config)
case GroqClient.api_type:
return GroqClient(config)
case MistralClient.api_type:
return MistralClient(config)
case AmazonBedrock.api_type:
return AmazonBedrock(config)
case _:
raise KeyError(
f"Specified API type {api_type} is not one of the supported services {cls.api_types}"

if not os.path.exists(CONFIG_PATH):
# Default to OpenAI if config file is absent
config = {"api_type": "openai", "api_key": os.getenv("OPENAI_API_KEY")}
if not config["api_key"]:
print(
"API key for OpenAI is missing. Please set the OPENAI_API_KEY environment variable."
)
sys.exit(1)
return OpenAIClient(config)

config_parser.read(CONFIG_PATH)

if "service" not in config_parser or "service" not in config_parser["service"]:
print(
"Service section or service key is missing in the configuration file."
)
sys.exit(1)

service = config_parser["service"].get("service")
if not service or service not in config_parser:
print(
f"Config for service {service} is not defined in the configuration file."
)
sys.exit(1)

config = {k: v for k, v in config_parser[service].items()}
if "api_key" not in config:
config["api_key"] = os.getenv(f"{service.upper()}_API_KEY")
if not config["api_key"]:
print(f"API key is missing for the {service} service.")
sys.exit(1)

api_type = config.get("api_type")
if api_type not in cls.api_types:
print(
f"Specified API type {api_type} is not one of the supported services {cls.api_types}."
)
sys.exit(1)

client_classes = {
OpenAIClient.api_type: OpenAIClient,
GoogleGenAIClient.api_type: GoogleGenAIClient,
GroqClient.api_type: GroqClient,
MistralClient.api_type: MistralClient,
AmazonBedrock.api_type: AmazonBedrock,
}

return client_classes[api_type](config)