diff --git a/services/services.py b/services/services.py index aa54991..37041c0 100644 --- a/services/services.py +++ b/services/services.py @@ -11,7 +11,11 @@ class BaseClient(ABC): """Base class for all clients""" api_type: str = None - system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block." + system_prompt = ( + "You are a zsh shell expert, please help me complete the following command, " + "you should only output the completed command, no need to include any other explanation. " + "Do not put completion in quotes." + ) @abstractmethod def get_completion(self, full_command: str) -> str: @@ -44,7 +48,7 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = OpenAI( - api_key=self.config["api_key"], + api_key=os.getenv("OPENAI_API_KEY", self.config.get("api_key")), base_url=self.config.get("base_url", "https://api.openai.com/v1"), organization=self.config.get("organization"), ) @@ -82,7 +86,9 @@ def __init__(self, config: dict): sys.exit(1) self.config = config - genai.configure(api_key=self.config["api_key"]) + genai.configure( + api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key")) + ) self.config["model"] = config.get("model", self.default_model) self.model = genai.GenerativeModel(self.config["model"]) @@ -101,10 +107,10 @@ class GroqClient(BaseClient): - model (optional): defaults to "llama-3.2-11b-text-preview" - temperature (optional): defaults to 1.0. """ - + api_type = "groq" default_model = os.getenv("GROQ_DEFAULT_MODEL", "llama-3.2-11b-text-preview") - + def __init__(self, config: dict): try: from groq import Groq @@ -117,9 +123,9 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Groq( - api_key=self.config["api_key"], + api_key=os.getenv("GROQ_API_KEY", self.config.get("api_key")) ) - + def get_completion(self, full_command: str) -> str: response = self.client.chat.completions.create( model=self.config["model"], @@ -140,10 +146,10 @@ class MistralClient(BaseClient): - model (optional): defaults to "codestral-latest" - temperature (optional): defaults to 1.0. """ - + api_type = "mistral" default_model = os.getenv("MISTRAL_DEFAULT_MODEL", "codestral-latest") - + def __init__(self, config: dict): try: from mistralai import Mistral @@ -152,13 +158,13 @@ def __init__(self, config: dict): "Mistral library is not installed. Please install it using 'pip install mistralai'" ) sys.exit(1) - + self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Mistral( - api_key=self.config["api_key"], + api_key=os.getenv("MISTRAL_API_KEY", self.config.get("api_key")) ) - + def get_completion(self, full_command: str) -> str: response = self.client.chat.complete( model=self.config["model"], @@ -170,6 +176,7 @@ def get_completion(self, full_command: str) -> str: ) return response.choices[0].message.content + class AmazonBedrock(BaseClient): """ config keys: @@ -183,7 +190,9 @@ class AmazonBedrock(BaseClient): """ api_type = "bedrock" - default_model = os.getenv("BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0") + default_model = os.getenv( + "BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0" + ) def __init__(self, config: dict): try: @@ -197,24 +206,20 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) - session_kwargs = {} - if "aws_region" in config: - session_kwargs["region_name"] = config["aws_region"] - if "aws_access_key_id" in config: - session_kwargs["aws_access_key_id"] = config["aws_access_key_id"] - if "aws_secret_access_key" in config: - session_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"] - if "aws_session_token" in config: - session_kwargs["aws_session_token"] = config["aws_session_token"] - - self.client = boto3.client("bedrock-runtime", **session_kwargs) + session_kwargs = { + "region_name": config.get("aws_region"), + "aws_access_key_id": config.get("aws_access_key_id"), + "aws_secret_access_key": config.get("aws_secret_access_key"), + "aws_session_token": config.get("aws_session_token"), + } + self.client = boto3.client( + "bedrock-runtime", **{k: v for k, v in session_kwargs.items() if v} + ) def get_completion(self, full_command: str) -> str: import json - messages = [ - {"role": "user", "content": full_command} - ] + messages = [{"role": "user", "content": full_command}] # Format request body based on model type if "claude" in self.config["model"].lower(): @@ -223,47 +228,76 @@ def get_completion(self, full_command: str) -> str: "max_tokens": 1000, "system": self.system_prompt, "messages": messages, - "temperature": float(self.config.get("temperature", 1.0)) + "temperature": float(self.config.get("temperature", 1.0)), } else: raise ValueError(f"Unsupported model: {self.config['model']}") response = self.client.invoke_model( - modelId=self.config["model"], - body=json.dumps(body) + modelId=self.config["model"], body=json.dumps(body) ) - - response_body = json.loads(response['body'].read()) + response_body = json.loads(response["body"].read()) return response_body["content"][0]["text"] - class ClientFactory: - api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type] + api_types = [ + OpenAIClient.api_type, + GoogleGenAIClient.api_type, + GroqClient.api_type, + MistralClient.api_type, + AmazonBedrock.api_type, + ] @classmethod def create(cls): config_parser = ConfigParser() - config_parser.read(CONFIG_PATH) - service = config_parser["service"]["service"] - try: - config = {k: v for k, v in config_parser[service].items()} - except KeyError: - raise KeyError(f"Config for service {service} is not defined") - - api_type = config["api_type"] - match api_type: - case OpenAIClient.api_type: - return OpenAIClient(config) - case GoogleGenAIClient.api_type: - return GoogleGenAIClient(config) - case GroqClient.api_type: - return GroqClient(config) - case MistralClient.api_type: - return MistralClient(config) - case AmazonBedrock.api_type: - return AmazonBedrock(config) - case _: - raise KeyError( - f"Specified API type {api_type} is not one of the supported services {cls.api_types}" + + if not os.path.exists(CONFIG_PATH): + # Default to OpenAI if config file is absent + config = {"api_type": "openai", "api_key": os.getenv("OPENAI_API_KEY")} + if not config["api_key"]: + print( + "API key for OpenAI is missing. Please set the OPENAI_API_KEY environment variable." ) + sys.exit(1) + return OpenAIClient(config) + + config_parser.read(CONFIG_PATH) + + if "service" not in config_parser or "service" not in config_parser["service"]: + print( + "Service section or service key is missing in the configuration file." + ) + sys.exit(1) + + service = config_parser["service"].get("service") + if not service or service not in config_parser: + print( + f"Config for service {service} is not defined in the configuration file." + ) + sys.exit(1) + + config = {k: v for k, v in config_parser[service].items()} + if "api_key" not in config: + config["api_key"] = os.getenv(f"{service.upper()}_API_KEY") + if not config["api_key"]: + print(f"API key is missing for the {service} service.") + sys.exit(1) + + api_type = config.get("api_type") + if api_type not in cls.api_types: + print( + f"Specified API type {api_type} is not one of the supported services {cls.api_types}." + ) + sys.exit(1) + + client_classes = { + OpenAIClient.api_type: OpenAIClient, + GoogleGenAIClient.api_type: GoogleGenAIClient, + GroqClient.api_type: GroqClient, + MistralClient.api_type: MistralClient, + AmazonBedrock.api_type: AmazonBedrock, + } + + return client_classes[api_type](config)