From bcdfabab076b4f2f8cbb2aa4d473752cd45f7ec8 Mon Sep 17 00:00:00 2001 From: alby11 Date: Fri, 29 Nov 2024 16:34:27 +0100 Subject: [PATCH 1/5] Refactor API key handling using env vars --- services/services.py | 72 +++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/services/services.py b/services/services.py index aa54991..7a4ba6d 100644 --- a/services/services.py +++ b/services/services.py @@ -11,7 +11,7 @@ class BaseClient(ABC): """Base class for all clients""" api_type: str = None - system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block." + system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put comments." @abstractmethod def get_completion(self, full_command: str) -> str: @@ -44,7 +44,7 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = OpenAI( - api_key=self.config["api_key"], + api_key=os.getenv("OPENAI_API_KEY", self.config.get("api_key")), base_url=self.config.get("base_url", "https://api.openai.com/v1"), organization=self.config.get("organization"), ) @@ -82,7 +82,9 @@ def __init__(self, config: dict): sys.exit(1) self.config = config - genai.configure(api_key=self.config["api_key"]) + genai.configure( + api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key")) + ) self.config["model"] = config.get("model", self.default_model) self.model = genai.GenerativeModel(self.config["model"]) @@ -101,10 +103,10 @@ class GroqClient(BaseClient): - model (optional): defaults to "llama-3.2-11b-text-preview" - temperature (optional): defaults to 1.0. """ - + api_type = "groq" default_model = os.getenv("GROQ_DEFAULT_MODEL", "llama-3.2-11b-text-preview") - + def __init__(self, config: dict): try: from groq import Groq @@ -117,9 +119,9 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Groq( - api_key=self.config["api_key"], + api_key=os.getenv("GROQ_API_KEY", self.config.get("api_key")), ) - + def get_completion(self, full_command: str) -> str: response = self.client.chat.completions.create( model=self.config["model"], @@ -140,10 +142,10 @@ class MistralClient(BaseClient): - model (optional): defaults to "codestral-latest" - temperature (optional): defaults to 1.0. """ - + api_type = "mistral" default_model = os.getenv("MISTRAL_DEFAULT_MODEL", "codestral-latest") - + def __init__(self, config: dict): try: from mistralai import Mistral @@ -152,13 +154,13 @@ def __init__(self, config: dict): "Mistral library is not installed. Please install it using 'pip install mistralai'" ) sys.exit(1) - + self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Mistral( - api_key=self.config["api_key"], + api_key=os.getenv("MISTRAL_API_KEY", self.config.get("api_key")), ) - + def get_completion(self, full_command: str) -> str: response = self.client.chat.complete( model=self.config["model"], @@ -170,6 +172,7 @@ def get_completion(self, full_command: str) -> str: ) return response.choices[0].message.content + class AmazonBedrock(BaseClient): """ config keys: @@ -183,7 +186,9 @@ class AmazonBedrock(BaseClient): """ api_type = "bedrock" - default_model = os.getenv("BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0") + default_model = os.getenv( + "BEDROCK_DEFAULT_MODEL", "anthropic.claude-3-5-sonnet-20240620-v1:0" + ) def __init__(self, config: dict): try: @@ -197,24 +202,25 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) - session_kwargs = {} - if "aws_region" in config: - session_kwargs["region_name"] = config["aws_region"] - if "aws_access_key_id" in config: - session_kwargs["aws_access_key_id"] = config["aws_access_key_id"] - if "aws_secret_access_key" in config: - session_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"] - if "aws_session_token" in config: - session_kwargs["aws_session_token"] = config["aws_session_token"] + session_kwargs = { + "region_name": os.getenv("AWS_REGION", config.get("aws_region")), + "aws_access_key_id": os.getenv( + "AWS_ACCESS_KEY_ID", config.get("aws_access_key_id") + ), + "aws_secret_access_key": os.getenv( + "AWS_SECRET_ACCESS_KEY", config.get("aws_secret_access_key") + ), + "aws_session_token": os.getenv( + "AWS_SESSION_TOKEN", config.get("aws_session_token") + ), + } self.client = boto3.client("bedrock-runtime", **session_kwargs) def get_completion(self, full_command: str) -> str: import json - messages = [ - {"role": "user", "content": full_command} - ] + messages = [{"role": "user", "content": full_command}] # Format request body based on model type if "claude" in self.config["model"].lower(): @@ -223,23 +229,27 @@ def get_completion(self, full_command: str) -> str: "max_tokens": 1000, "system": self.system_prompt, "messages": messages, - "temperature": float(self.config.get("temperature", 1.0)) + "temperature": float(self.config.get("temperature", 1.0)), } else: raise ValueError(f"Unsupported model: {self.config['model']}") response = self.client.invoke_model( - modelId=self.config["model"], - body=json.dumps(body) + modelId=self.config["model"], body=json.dumps(body) ) - response_body = json.loads(response['body'].read()) + response_body = json.loads(response["body"].read()) return response_body["content"][0]["text"] - class ClientFactory: - api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type] + api_types = [ + OpenAIClient.api_type, + GoogleGenAIClient.api_type, + GroqClient.api_type, + MistralClient.api_type, + AmazonBedrock.api_type, + ] @classmethod def create(cls): From fcdf6886bb1339a95b10d1da73c4e1c2d340a7b8 Mon Sep 17 00:00:00 2001 From: alby11 Date: Fri, 29 Nov 2024 17:32:49 +0100 Subject: [PATCH 2/5] set openai as default, if config file is present check for service --- services/services.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/services/services.py b/services/services.py index 7a4ba6d..bce1a45 100644 --- a/services/services.py +++ b/services/services.py @@ -253,13 +253,17 @@ class ClientFactory: @classmethod def create(cls): - config_parser = ConfigParser() - config_parser.read(CONFIG_PATH) - service = config_parser["service"]["service"] - try: - config = {k: v for k, v in config_parser[service].items()} - except KeyError: - raise KeyError(f"Config for service {service} is not defined") + + service = "openai_service" + + if os.path.exists(CONFIG_PATH): + config_parser = ConfigParser() + config_parser.read(CONFIG_PATH) + service = config_parser.get("service", "service", fallback="openai_service") + try: + config = {k: v for k, v in config_parser[service].items()} + except KeyError: + raise KeyError(f"Config for service {service} is not defined") api_type = config["api_type"] match api_type: From 2a5f966b4a78a69d22352096549d737c36345bc4 Mon Sep 17 00:00:00 2001 From: alby11 Date: Fri, 29 Nov 2024 17:56:08 +0100 Subject: [PATCH 3/5] Set default API type in ClientFactory and update configuration handling --- services/services.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/services/services.py b/services/services.py index bce1a45..931351d 100644 --- a/services/services.py +++ b/services/services.py @@ -255,6 +255,8 @@ class ClientFactory: def create(cls): service = "openai_service" + api_type = "openai" + config = {"api_type": api_type} # Default configuration if os.path.exists(CONFIG_PATH): config_parser = ConfigParser() @@ -266,6 +268,7 @@ def create(cls): raise KeyError(f"Config for service {service} is not defined") api_type = config["api_type"] + match api_type: case OpenAIClient.api_type: return OpenAIClient(config) From 51bfa42171772e10bcf83566545e19e2c6330f15 Mon Sep 17 00:00:00 2001 From: alby11 Date: Wed, 4 Dec 2024 00:25:14 +0100 Subject: [PATCH 4/5] default to openai if no config file is present --- services/services.py | 85 ++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 31 deletions(-) diff --git a/services/services.py b/services/services.py index 931351d..725bfca 100644 --- a/services/services.py +++ b/services/services.py @@ -11,7 +11,7 @@ class BaseClient(ABC): """Base class for all clients""" api_type: str = None - system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put comments." + system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block." @abstractmethod def get_completion(self, full_command: str) -> str: @@ -82,8 +82,10 @@ def __init__(self, config: dict): sys.exit(1) self.config = config - genai.configure( - api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key")) + ( + genai.configure( + api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key")) + ), ) self.config["model"] = config.get("model", self.default_model) self.model = genai.GenerativeModel(self.config["model"]) @@ -119,7 +121,7 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Groq( - api_key=os.getenv("GROQ_API_KEY", self.config.get("api_key")), + api_key=self.config["api_key"], ) def get_completion(self, full_command: str) -> str: @@ -158,7 +160,7 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Mistral( - api_key=os.getenv("MISTRAL_API_KEY", self.config.get("api_key")), + api_key=self.config["api_key"], ) def get_completion(self, full_command: str) -> str: @@ -202,18 +204,15 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) - session_kwargs = { - "region_name": os.getenv("AWS_REGION", config.get("aws_region")), - "aws_access_key_id": os.getenv( - "AWS_ACCESS_KEY_ID", config.get("aws_access_key_id") - ), - "aws_secret_access_key": os.getenv( - "AWS_SECRET_ACCESS_KEY", config.get("aws_secret_access_key") - ), - "aws_session_token": os.getenv( - "AWS_SESSION_TOKEN", config.get("aws_session_token") - ), - } + session_kwargs = {} + if "aws_region" in config: + session_kwargs["region_name"] = config["aws_region"] + if "aws_access_key_id" in config: + session_kwargs["aws_access_key_id"] = config["aws_access_key_id"] + if "aws_secret_access_key" in config: + session_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"] + if "aws_session_token" in config: + session_kwargs["aws_session_token"] = config["aws_session_token"] self.client = boto3.client("bedrock-runtime", **session_kwargs) @@ -253,21 +252,44 @@ class ClientFactory: @classmethod def create(cls): + config_parser = ConfigParser() + + if not os.path.exists(CONFIG_PATH): + # Default to OpenAI if config file is absent + config = {"api_type": "openai", "api_key": os.getenv("OPENAI_API_KEY")} + if not config["api_key"]: + print( + "API key for OpenAI is missing. Please set the OPENAI_API_KEY environment variable." + ) + sys.exit(1) + return OpenAIClient(config) - service = "openai_service" - api_type = "openai" - config = {"api_type": api_type} # Default configuration + config_parser.read(CONFIG_PATH) - if os.path.exists(CONFIG_PATH): - config_parser = ConfigParser() - config_parser.read(CONFIG_PATH) - service = config_parser.get("service", "service", fallback="openai_service") - try: - config = {k: v for k, v in config_parser[service].items()} - except KeyError: - raise KeyError(f"Config for service {service} is not defined") + if "service" not in config_parser or "service" not in config_parser["service"]: + print( + "Service section or service key is missing in the configuration file." + ) + sys.exit(1) - api_type = config["api_type"] + service = config_parser["service"].get("service") + if not service or service not in config_parser: + print( + f"Config for service {service} is not defined in the configuration file." + ) + sys.exit(1) + + config = {k: v for k, v in config_parser[service].items()} + if "api_key" not in config: + print(f"API key is missing for the {service} service.") + sys.exit(1) + + api_type = config.get("api_type") + if api_type not in cls.api_types: + print( + f"Specified API type {api_type} is not one of the supported services {cls.api_types}." + ) + sys.exit(1) match api_type: case OpenAIClient.api_type: @@ -281,6 +303,7 @@ def create(cls): case AmazonBedrock.api_type: return AmazonBedrock(config) case _: - raise KeyError( - f"Specified API type {api_type} is not one of the supported services {cls.api_types}" + print( + f"Specified API type {api_type} is not one of the supported services {cls.api_types}." ) + sys.exit(1) From a45b4a3e55cc4fedf73d95589b5425adb99e0f86 Mon Sep 17 00:00:00 2001 From: alby11 Date: Wed, 4 Dec 2024 00:46:41 +0100 Subject: [PATCH 5/5] Implement fallback to OpenAI and default values for services --- services/services.py | 68 ++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/services/services.py b/services/services.py index 725bfca..37041c0 100644 --- a/services/services.py +++ b/services/services.py @@ -11,7 +11,11 @@ class BaseClient(ABC): """Base class for all clients""" api_type: str = None - system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block." + system_prompt = ( + "You are a zsh shell expert, please help me complete the following command, " + "you should only output the completed command, no need to include any other explanation. " + "Do not put completion in quotes." + ) @abstractmethod def get_completion(self, full_command: str) -> str: @@ -82,10 +86,8 @@ def __init__(self, config: dict): sys.exit(1) self.config = config - ( - genai.configure( - api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key")) - ), + genai.configure( + api_key=os.getenv("GOOGLE_GENAI_API_KEY", self.config.get("api_key")) ) self.config["model"] = config.get("model", self.default_model) self.model = genai.GenerativeModel(self.config["model"]) @@ -121,7 +123,7 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Groq( - api_key=self.config["api_key"], + api_key=os.getenv("GROQ_API_KEY", self.config.get("api_key")) ) def get_completion(self, full_command: str) -> str: @@ -160,7 +162,7 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) self.client = Mistral( - api_key=self.config["api_key"], + api_key=os.getenv("MISTRAL_API_KEY", self.config.get("api_key")) ) def get_completion(self, full_command: str) -> str: @@ -204,17 +206,15 @@ def __init__(self, config: dict): self.config = config self.config["model"] = self.config.get("model", self.default_model) - session_kwargs = {} - if "aws_region" in config: - session_kwargs["region_name"] = config["aws_region"] - if "aws_access_key_id" in config: - session_kwargs["aws_access_key_id"] = config["aws_access_key_id"] - if "aws_secret_access_key" in config: - session_kwargs["aws_secret_access_key"] = config["aws_secret_access_key"] - if "aws_session_token" in config: - session_kwargs["aws_session_token"] = config["aws_session_token"] - - self.client = boto3.client("bedrock-runtime", **session_kwargs) + session_kwargs = { + "region_name": config.get("aws_region"), + "aws_access_key_id": config.get("aws_access_key_id"), + "aws_secret_access_key": config.get("aws_secret_access_key"), + "aws_session_token": config.get("aws_session_token"), + } + self.client = boto3.client( + "bedrock-runtime", **{k: v for k, v in session_kwargs.items() if v} + ) def get_completion(self, full_command: str) -> str: import json @@ -236,7 +236,6 @@ def get_completion(self, full_command: str) -> str: response = self.client.invoke_model( modelId=self.config["model"], body=json.dumps(body) ) - response_body = json.loads(response["body"].read()) return response_body["content"][0]["text"] @@ -281,8 +280,10 @@ def create(cls): config = {k: v for k, v in config_parser[service].items()} if "api_key" not in config: - print(f"API key is missing for the {service} service.") - sys.exit(1) + config["api_key"] = os.getenv(f"{service.upper()}_API_KEY") + if not config["api_key"]: + print(f"API key is missing for the {service} service.") + sys.exit(1) api_type = config.get("api_type") if api_type not in cls.api_types: @@ -291,19 +292,12 @@ def create(cls): ) sys.exit(1) - match api_type: - case OpenAIClient.api_type: - return OpenAIClient(config) - case GoogleGenAIClient.api_type: - return GoogleGenAIClient(config) - case GroqClient.api_type: - return GroqClient(config) - case MistralClient.api_type: - return MistralClient(config) - case AmazonBedrock.api_type: - return AmazonBedrock(config) - case _: - print( - f"Specified API type {api_type} is not one of the supported services {cls.api_types}." - ) - sys.exit(1) + client_classes = { + OpenAIClient.api_type: OpenAIClient, + GoogleGenAIClient.api_type: GoogleGenAIClient, + GroqClient.api_type: GroqClient, + MistralClient.api_type: MistralClient, + AmazonBedrock.api_type: AmazonBedrock, + } + + return client_classes[api_type](config)