diff --git a/krs/krs.py b/krs/krs.py index e85abca..ad2b46d 100755 --- a/krs/krs.py +++ b/krs/krs.py @@ -4,176 +4,66 @@ from krs.main import KrsMain from krs.utils.constants import KRSSTATE_PICKLE_FILEPATH, KRS_DATA_DIRECTORY -app = typer.Typer() # Main typer app -krs = KrsMain() # Main class object +app = typer.Typer() +krs = KrsMain() +def check_initialized(): + if not os.path.exists(KRSSTATE_PICKLE_FILEPATH): + typer.echo("KRS is not initialized. Please run 'krs init' first.") + raise typer.Exit() -def check_initialized() -> None: - - """ - Checks if KRS is initialized or not. - """ - try: - if not os.path.exists(KRSSTATE_PICKLE_FILEPATH): - typer.echo("KRS is not initialized. Please run 'krs init' first.") - raise typer.Exit() - except Exception as e: - typer.echo("Error: ", e) - raise typer.Abort() - except KeyboardInterrupt: - typer.echo("\nExiting...") - raise typer.Abort() - except typer.Exit: - raise typer.Abort() - except typer.Abort: - raise typer.Abort() - except: - typer.echo("An error occured. Please try again.") - raise typer.Abort() - -try: - if not os.path.exists(KRS_DATA_DIRECTORY): # Create data directory if not exists - os.mkdir(KRS_DATA_DIRECTORY) -except Exception as e: - typer.echo("Error: ", e) - raise typer.Abort() -except KeyboardInterrupt as e: - typer.echo("\nInterruption: ", e) - raise typer.Abort() -except typer.Exit as e: - typer.echo("\nExiting: ", e) - raise typer.Abort() -except typer.Abort as e: - typer.echo("\nAborting: ", e) - raise typer.Abort() -except: - typer.echo("An error occured. Please try again.") - raise typer.Abort() - +if not os.path.exists(KRS_DATA_DIRECTORY): + os.mkdir(KRS_DATA_DIRECTORY) -@app.command() # Command to initialize the services -def init() -> None: # Function to initialize the services +@app.command() +def init(): """ Initializes the services and loads the scanner. """ - try: - krs.initialize() - typer.echo("Services initialized and scanner loaded.") - except Exception as e: - typer.echo("Error: ", e) - raise typer.Abort() - except KeyboardInterrupt as e: - typer.echo("\nInterruption: ", e) - raise typer.Abort() - except typer.Exit as e: - typer.echo("\nExiting: ", e) - raise typer.Abort() - except typer.Abort as e: - typer.echo("\nAborting: ", e) - raise typer.Abort() - except: - typer.echo("An error occured. Please try again.") - raise typer.Abort() + krs.initialize() + typer.echo("Services initialized and scanner loaded.") -# Command to scan the cluster @app.command() -def scan() -> None: +def scan(): """ Scans the cluster and extracts a list of tools that are currently used. """ - try: - check_initialized() - krs.scan_cluster() - except Exception as e: - typer.echo("Error: ", e) - raise typer.Abort() - except KeyboardInterrupt as e: - typer.echo("\nInterruption: ", e) - raise typer.Abort() - except typer.Exit as e: - typer.echo("\nExiting: ", e) - raise typer.Abort() - except typer.Abort as e: - typer.echo("\nAborting: ", e) - raise typer.Abort() - except: - typer.echo("An error occured. Please try again.") - raise typer.Abort() + check_initialized() + krs.scan_cluster() + -# Command to list all the namespaces @app.command() -def namespaces() -> None: +def namespaces(): """ Lists all the namespaces. """ - - try: - check_initialized() - namespaces = krs.list_namespaces() - typer.echo("Namespaces in your cluster are: \n") - for i, namespace in enumerate(namespaces): - typer.echo(str(i+1)+ ". "+ namespace) - except Exception as e: - typer.echo("Error: ", e) - raise typer.Abort() - except KeyboardInterrupt as e: - typer.echo("\nInterruption: ", e) - raise typer.Abort() - except typer.Exit as e: - typer.echo("\nExiting: ", e) - raise typer.Abort() - except typer.Abort as e: - typer.echo("\nAborting: ", e) - raise typer.Abort() - except: - typer.echo("An error occured. Please try again.") - raise typer.Abort() + check_initialized() + namespaces = krs.list_namespaces() + typer.echo("Namespaces in your cluster are: \n") + for i, namespace in enumerate(namespaces): + typer.echo(str(i+1)+ ". "+ namespace) -# Command to list all the pods @app.command() -def pods(namespace: str = typer.Option(None, help="Specify namespace to list pods from")) -> None: - +def pods(namespace: str = typer.Option(None, help="Specify namespace to list pods from")): """ Lists all the pods with namespaces, or lists pods under a specified namespace. + """ + check_initialized() + if namespace: + pods = krs.list_pods(namespace) + if pods == 'wrong namespace name': + typer.echo(f"\nWrong namespace name entered, try again!\n") + raise typer.Abort() + typer.echo(f"\nPods in namespace '{namespace}': \n") + else: + pods = krs.list_pods_all() + typer.echo("\nAll pods in the cluster: \n") - Args: - namespace: str: Namespace name to list pods from. - Returns: - None - """ - try: - check_initialized() - if namespace: - pods = krs.list_pods(namespace) - if pods == 'wrong namespace name': - typer.echo(f"\nWrong namespace name entered, try again!\n") - raise typer.Abort() - typer.echo(f"\nPods in namespace '{namespace}': \n") - else: - pods = krs.list_pods_all() - typer.echo("\nAll pods in the cluster: \n") - - for i, pod in enumerate(pods): - typer.echo(str(i+1)+ '. '+ pod) - except Exception as e: - typer.echo("Error: ", e) - raise typer.Abort() - except KeyboardInterrupt as e: - typer.echo("\nInterruption: ", e) - raise typer.Abort() - except typer.Exit as e: - typer.echo("\nExiting: ", e) - raise typer.Abort() - except typer.Abort as e: - typer.echo("\nAborting: ", e) - raise typer.Abort() - except: - typer.echo("An error occured. Please try again.") - raise typer.Abort() - + for i, pod in enumerate(pods): + typer.echo(str(i+1)+ '. '+ pod) @app.command() -def recommend(): # Command to recommend tools +def recommend(): """ Generates a table of recommended tools from our ranking database and their CNCF project status. """ diff --git a/krs/main.py b/krs/main.py index 35e88c7..d116aeb 100644 --- a/krs/main.py +++ b/krs/main.py @@ -1,5 +1,4 @@ -from math import e -from krs.utils.fetch_tools_krs import krs_tool_ranking_info +from krs.utils.fetch_tools_krs import krs_tool_ranking_info from krs.utils.cluster_scanner import KubetoolsScanner from krs.utils.llm_client import KrsGPTClient from krs.utils.functional import extract_log_entries, CustomJSONEncoder @@ -8,530 +7,240 @@ from tabulate import tabulate from krs.utils.constants import (KRSSTATE_PICKLE_FILEPATH, LLMSTATE_PICKLE_FILEPATH, POD_INFO_FILEPATH, KRS_DATA_DIRECTORY) - class KrsMain: - # Class to handle the main functionality of the KRS tool - def __init__(self) -> None: - - """ - Initialize the KrsMain class. - - """ - - try: - self.pod_info = None - self.pod_list = None - self.namespaces = None - self.deployments = None - self.state_file = KRSSTATE_PICKLE_FILEPATH - self.isClusterScanned = False - self.continue_chat = False - self.logs_extracted = [] - self.scanner = None - self.get_events = True - self.get_logs = True - self.cluster_tool_list = None - self.detailed_cluster_tool_list = None - self.category_cluster_tools_dict = None - - self.load_state() - - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise e - except: - print("An error occurred during initialization.") - raise Exception("An error occurred during initialization.") - - - def initialize(self, config_file : str = '~/.kube/config'): - - """ - Initialize the KrsMain class. - - Args: - config_file (str): The path to the kubeconfig file. - Returns: - None - - """ - - try: - self.config_file = config_file - self.tools_dict, self.category_dict, cncf_status_dict = krs_tool_ranking_info() # Get the tools and their rankings - self.cncf_status = cncf_status_dict['cncftools'] # Get the CNCF status of the tools - self.scanner = KubetoolsScanner(self.get_events, self.get_logs, self.config_file) # Initialize the scanner - self.save_state() # Save the state - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise e - except: - print("An error occurred during initialization.") - raise Exception("An error occurred during initialization.") - - - def save_state(self) -> None: - - """ - - Save the state of the KrsMain class. - - Args: - None - Returns: - None - - """ - - try: - state = { - 'pod_info': self.pod_info, - 'pod_list': self.pod_list, - 'namespaces': self.namespaces, - 'deployments': self.deployments, - 'cncf_status': self.cncf_status, - 'tools_dict': self.tools_dict, - 'category_tools_dict': self.category_dict, - 'extracted_logs': self.logs_extracted, - 'kubeconfig': self.config_file, - 'isScanned': self.isClusterScanned, - 'cluster_tool_list': self.cluster_tool_list, - 'detailed_tool_list': self.detailed_cluster_tool_list, - 'category_tool_list': self.category_cluster_tools_dict - } - except Exception as e: - print(f"An error occurred during saving the state: {e}") - raise e - except: - print("An error occurred during saving the state.") - raise Exception("An error occurred during saving the state.") - - try: - os.makedirs(os.path.dirname(self.state_file), exist_ok=True) # Create the directory if it doesn't exist - with open(self.state_file, 'wb') as f: # Open the file in write mode - pickle.dump(state, f) # Dump the state to the file - except Exception as e: - print(f"An error occurred during saving the state: {e}") - raise e - except: - print("An error occurred during saving the state.") - raise Exception("An error occurred during saving the state.") - - - def load_state(self)-> None: - - """ - Load the state of the KrsMain class. - - Args: - None - Returns: - None - """ - - try: - if os.path.exists(self.state_file): - with open(self.state_file, 'rb') as f: - state = pickle.load(f) - self.pod_info = state.get('pod_info') - self.pod_list = state.get('pod_list') - self.namespaces = state.get('namespaces') - self.deployments = state.get('deployments') - self.cncf_status = state.get('cncf_status') - self.tools_dict = state.get('tools_dict') - self.category_dict = state.get('category_tools_dict') - self.logs_extracted = state.get('extracted_logs') - self.config_file = state.get('kubeconfig') - self.isClusterScanned = state.get('isScanned') - self.cluster_tool_list = state.get('cluster_tool_list') - self.detailed_cluster_tool_list = state.get('detailed_tool_list') - self.category_cluster_tools_dict = state.get('category_tool_list') - self.scanner = KubetoolsScanner(self.get_events, self.get_logs, self.config_file) # Reinitialize the scanner - except Exception as e: - print(f"An error occurred during loading the state: {e}") - raise e - except: - print("An error occurred during loading the state.") - raise Exception("An error occurred during loading the state.") - - - def check_scanned(self) -> None: - - """ - Check if the cluster has been scanned. - - Args: - None - Returns: - None - """ - - try: - if not self.isClusterScanned: - self.pod_list, self.pod_info, self.deployments, self.namespaces = self.scanner.scan_kubernetes_deployment() # Scan the cluster - self.save_state() - except Exception as e: - print(f"An error occurred during scanning the cluster: {e}") - raise e - except: - print("An error occurred during scanning the cluster.") - raise Exception("An error occurred during scanning the cluster.") - + def __init__(self): + + self.pod_info = None + self.pod_list = None + self.namespaces = None + self.deployments = None + self.state_file = KRSSTATE_PICKLE_FILEPATH + self.isClusterScanned = False + self.continue_chat = False + self.logs_extracted = [] + self.scanner = None + self.get_events = True + self.get_logs = True + self.cluster_tool_list = None + self.detailed_cluster_tool_list = None + self.category_cluster_tools_dict = None + + self.load_state() + + def initialize(self, config_file='~/.kube/config'): + self.config_file = config_file + self.tools_dict, self.category_dict, cncf_status_dict = krs_tool_ranking_info() + self.cncf_status = cncf_status_dict['cncftools'] + self.scanner = KubetoolsScanner(self.get_events, self.get_logs, self.config_file) + self.save_state() + + def save_state(self): + state = { + 'pod_info': self.pod_info, + 'pod_list': self.pod_list, + 'namespaces': self.namespaces, + 'deployments': self.deployments, + 'cncf_status': self.cncf_status, + 'tools_dict': self.tools_dict, + 'category_tools_dict': self.category_dict, + 'extracted_logs': self.logs_extracted, + 'kubeconfig': self.config_file, + 'isScanned': self.isClusterScanned, + 'cluster_tool_list': self.cluster_tool_list, + 'detailed_tool_list': self.detailed_cluster_tool_list, + 'category_tool_list': self.category_cluster_tools_dict + } + os.makedirs(os.path.dirname(self.state_file), exist_ok=True) + with open(self.state_file, 'wb') as f: + pickle.dump(state, f) + + def load_state(self): + if os.path.exists(self.state_file): + with open(self.state_file, 'rb') as f: + state = pickle.load(f) + self.pod_info = state.get('pod_info') + self.pod_list = state.get('pod_list') + self.namespaces = state.get('namespaces') + self.deployments = state.get('deployments') + self.cncf_status = state.get('cncf_status') + self.tools_dict = state.get('tools_dict') + self.category_dict = state.get('category_tools_dict') + self.logs_extracted = state.get('extracted_logs') + self.config_file = state.get('kubeconfig') + self.isClusterScanned = state.get('isScanned') + self.cluster_tool_list = state.get('cluster_tool_list') + self.detailed_cluster_tool_list = state.get('detailed_tool_list') + self.category_cluster_tools_dict = state.get('category_tool_list') + self.scanner = KubetoolsScanner(self.get_events, self.get_logs, self.config_file) + + def check_scanned(self): + if not self.isClusterScanned: + self.pod_list, self.pod_info, self.deployments, self.namespaces = self.scanner.scan_kubernetes_deployment() + self.save_state() - def list_namespaces(self) -> list: - - """ - List all the namespaces in the cluster. - - Args: - None - Returns: - list: List of namespaces in the cluster. - """ - try: - self.check_scanned() - return self.scanner.list_namespaces() - except Exception as e: - print(f"An error occurred during listing the namespaces: {e}") - raise e - except: - print("An error occurred during listing the namespaces.") - raise Exception("An error occurred during listing the namespaces.") - + def list_namespaces(self): + self.check_scanned() + return self.scanner.list_namespaces() - def list_pods(self, namespace: str) -> list: - - """ - List all the pods in a given namespace. - - Args: - namespace (str): The namespace to list the pods from. - Returns: - list: List of pods in the given namespace. - """ - - try: - self.check_scanned() - if namespace not in self.list_namespaces(): - return "wrong namespace name" - return self.scanner.list_pods(namespace) - except Exception as e: - print(f"An error occurred during listing the pods: {e}") - raise e - except: - print("An error occurred during listing the pods.") - raise Exception("An error occurred during listing the pods.") - + def list_pods(self, namespace): + self.check_scanned() + if namespace not in self.list_namespaces(): + return "wrong namespace name" + return self.scanner.list_pods(namespace) - def list_pods_all(self) -> list: - - """ - list all the pods in the cluster. - - Args: - None - Returns: - list: List of all the pods in the cluster. - """ - - try: - self.check_scanned() - return self.scanner.list_pods_all() # List all the pods in the cluster. - except Exception as e: - print(f"An error occurred during listing all the pods: {e}") - raise e - except: - print("An error occurred during listing all the pods.") - raise Exception("An error occurred during listing all the pods.") + def list_pods_all(self): + self.check_scanned() + return self.scanner.list_pods_all() - def detect_tools_from_repo(self) -> list: - - """ - Detect the tools used in the cluster from the repository. - - Args: - None - Returns: - list: List of tools used in the cluster. - """ - - try: - tool_set = set() - - # Detect tools from the pods - for pod in self.pod_list: - for service_name in pod.split('-'): - if service_name in self.tools_dict.keys(): - tool_set.add(service_name) - - # Detect tools from the deployments - for dep in self.deployments: - for service_name in dep.split('-'): - if service_name in self.tools_dict.keys(): - tool_set.add(service_name) - - return list(tool_set) - except Exception as e: - print(f"An error occurred during detecting the tools: {e}") - raise e - except: - print("An error occurred during detecting the tools.") - raise Exception("An error occurred during detecting the tools.") - + def detect_tools_from_repo(self): + tool_set = set() + for pod in self.pod_list: + for service_name in pod.split('-'): + if service_name in self.tools_dict.keys(): + tool_set.add(service_name) + + for dep in self.deployments: + for service_name in dep.split('-'): + if service_name in self.tools_dict.keys(): + tool_set.add(service_name) + + return list(tool_set) - def extract_rankings(self) -> tuple: - - """ - Extract the rankings of the tools used in the cluster. - - Args: - None - Returns: - tuple: A tuple containing the detailed tool list and the category tool list. - """ - - try: - tool_dict = {} - category_tools_dict = {} - - # Extract the rankings of the tools used in the cluster - for tool in self.cluster_tool_list: - tool_details = self.tools_dict[tool] - for detail in tool_details: - rank = detail['rank'] - category = detail['category'] - if category not in category_tools_dict: - category_tools_dict[category] = [] - category_tools_dict[category].append(rank) - - tool_dict[tool] = tool_details - - return tool_dict, category_tools_dict - except Exception as e: - print(f"An error occurred during extracting the rankings: {e}") - raise e - except: - print("An error occurred during extracting the rankings.") - raise Exception("An error occurred during extracting the rankings.") + def extract_rankings(self): + tool_dict = {} + category_tools_dict = {} + for tool in self.cluster_tool_list: + tool_details = self.tools_dict[tool] + for detail in tool_details: + rank = detail['rank'] + category = detail['category'] + if category not in category_tools_dict: + category_tools_dict[category] = [] + category_tools_dict[category].append(rank) + + tool_dict[tool] = tool_details + + return tool_dict, category_tools_dict - def generate_recommendations(self) -> None: + def generate_recommendations(self): - """ - Generate recommendations for the tools used in the cluster. - - Args: - None - Returns: - None - """ - - try: - if not self.isClusterScanned: - self.scan_cluster() + if not self.isClusterScanned: + self.scan_cluster() - self.print_recommendations() - except Exception as e: - print(f"An error occurred during generating recommendations: {e}") - raise e - except: - print("An error occurred during generating recommendations.") - raise Exception("An error occurred during generating recommendations.") - + self.print_recommendations() - def scan_cluster(self) -> None: - - """ - Scan the cluster and extract the tools used in the cluster. - - Args: - None - Returns: - None - """ - - try: - print("\nScanning your cluster...\n") - self.pod_list, self.pod_info, self.deployments, self.namespaces = self.scanner.scan_kubernetes_deployment() # Scan the cluster - self.isClusterScanned = True - print("Cluster scanned successfully...\n") - self.cluster_tool_list = self.detect_tools_from_repo() - print("Extracted tools used in cluster...\n") - self.detailed_cluster_tool_list, self.category_cluster_tools_dict = self.extract_rankings() # Extract the rankings of the tools used in the cluster - - self.print_scan_results() - self.save_state() - except Exception as e: - print(f"An error occurred during scanning the cluster: {e}") - raise e - except: - print("An error occurred during scanning the cluster.") - raise Exception("An error occurred during scanning the cluster.") - - def print_scan_results(self) -> None: - - """ - Print the scan results of the cluster. - - Args: - None - Returns: - None - """ - - try: - scan_results = [] # List to store the scan results - - for tool, details in self.detailed_cluster_tool_list.items(): # Iterate over the tools and their details - first_entry = True - for detail in details: - row = [tool if first_entry else "", detail['rank'], detail['category'], self.cncf_status.get(tool, 'unlisted')] # Create a row with the tool name, rank, category, and CNCF status - scan_results.append(row) - first_entry = False - - print("\nThe cluster is using the following tools:\n") - print(tabulate(scan_results, headers=["Tool Name", "Rank", "Category", "CNCF Status"], tablefmt="grid")) - except Exception as e: - print(f"An error occurred during printing the scan results: {e}") - raise e - except: - print("An error occurred during printing the scan results.") - raise Exception("An error occurred during printing the scan results.") + def scan_cluster(self): + + print("\nScanning your cluster...\n") + self.pod_list, self.pod_info, self.deployments, self.namespaces = self.scanner.scan_kubernetes_deployment() + self.isClusterScanned = True + print("Cluster scanned successfully...\n") + self.cluster_tool_list = self.detect_tools_from_repo() + print("Extracted tools used in cluster...\n") + self.detailed_cluster_tool_list, self.category_cluster_tools_dict = self.extract_rankings() + + self.print_scan_results() + self.save_state() + + def print_scan_results(self): + scan_results = [] + + for tool, details in self.detailed_cluster_tool_list.items(): + first_entry = True + for detail in details: + row = [tool if first_entry else "", detail['rank'], detail['category'], self.cncf_status.get(tool, 'unlisted')] + scan_results.append(row) + first_entry = False + + print("\nThe cluster is using the following tools:\n") + print(tabulate(scan_results, headers=["Tool Name", "Rank", "Category", "CNCF Status"], tablefmt="grid")) + + def print_recommendations(self): + recommendations = [] + + for category, ranks in self.category_cluster_tools_dict.items(): + rank = ranks[0] + recommended_tool = self.category_dict[category][1]['name'] + status = self.cncf_status.get(recommended_tool, 'unlisted') + if rank == 1: + row = [category, "Already using the best", recommended_tool, status] + else: + row = [category, "Recommended tool", recommended_tool, status] + recommendations.append(row) - def print_recommendations(self) -> None: - - """ - Print the recommendations for the tools used in the cluster. - - Args: - None - Returns: - None - """ - - try: - - recommendations = [] - - for category, ranks in self.category_cluster_tools_dict.items(): - rank = ranks[0] - recommended_tool = self.category_dict[category][1]['name'] # Get the recommended tool for the category - status = self.cncf_status.get(recommended_tool, 'unlisted') # Get the CNCF status of the recommended tool - if rank == 1: - row = [category, "Already using the best", recommended_tool, status] - else: - row = [category, "Recommended tool", recommended_tool, status] - recommendations.append(row) + print("\nOur recommended tools for this deployment are:\n") + print(tabulate(recommendations, headers=["Category", "Recommendation", "Tool Name", "CNCF Status"], tablefmt="grid")) - print("\nOur recommended tools for this deployment are:\n") - # Print the recommendations - print(tabulate(recommendations, headers=["Category", "Recommendation", "Tool Name", "CNCF Status"], tablefmt="grid")) - except Exception as e: - print(f"An error occurred during printing the recommendations: {e}") - raise e - except: - print("An error occurred during printing the recommendations.") - raise Exception("An error occurred during printing the recommendations.") - def health_check(self, change_model: bool = False) -> None: - + def health_check(self, change_model=False): - """ - Check the health of the cluster, and also start an interactive terminal to chat with the user. - - Args: - change_model (bool): Option to reinitialize/change the LLM. - Returns: - None - """ - - try: - - if os.path.exists(LLMSTATE_PICKLE_FILEPATH) and not change_model: # Check if the LLM state file exists and the model is not to be changed - continue_previous_chat = input("\nDo you want to continue fixing the previously selected pod ? (y/n): >> ") - while True: - if continue_previous_chat not in ['y', 'n']: - continue_previous_chat = input("\nPlease enter one of the given options ? (y/n): >> ") - else: - break - - if continue_previous_chat=='y': - krsllmclient = KrsGPTClient() # Initialize the LLM client - self.continue_chat = True # Set the continue chat flag to True + if os.path.exists(LLMSTATE_PICKLE_FILEPATH) and not change_model: + continue_previous_chat = input("\nDo you want to continue fixing the previously selected pod ? (y/n): >> ") + while True: + if continue_previous_chat not in ['y', 'n']: + continue_previous_chat = input("\nPlease enter one of the given options ? (y/n): >> ") else: - krsllmclient = KrsGPTClient(reset_history=True) # Initialize the LLM client - - else: - krsllmclient = KrsGPTClient(reinitialize=True) # Initialize the LLM client - self.continue_chat = False # Set the continue chat flag to False + break - if not self.continue_chat: - - self.check_scanned() # Check if the cluster has been scanned + if continue_previous_chat=='y': + krsllmclient = KrsGPTClient() + self.continue_chat = True + else: + krsllmclient = KrsGPTClient(reset_history=True) + + else: + krsllmclient = KrsGPTClient(reinitialize=True) + self.continue_chat = False - print("\nNamespaces in the cluster:\n") - namespaces = self.list_namespaces() - namespace_len = len(namespaces) - for i, namespace in enumerate(namespaces, start=1): - print(f"{i}. {namespace}") + if not self.continue_chat: - # Select a namespace - self.selected_namespace_index = int(input("\nWhich namespace do you want to check the health for? Select a namespace by entering its number: >> ")) - while True: - if self.selected_namespace_index not in list(range(1, namespace_len+1)): - self.selected_namespace_index = int(input(f"\nWrong input! Select a namespace number between {1} to {namespace_len}: >> ")) - else: - break + self.check_scanned() - self.selected_namespace = namespaces[self.selected_namespace_index - 1] - pod_list = self.list_pods(self.selected_namespace) - pod_len = len(pod_list) - print(f"\nPods in the namespace {self.selected_namespace}:\n") - for i, pod in enumerate(pod_list, start=1): - print(f"{i}. {pod}") - self.selected_pod_index = int(input(f"\nWhich pod from {self.selected_namespace} do you want to check the health for? Select a pod by entering its number: >> ")) + print("\nNamespaces in the cluster:\n") + namespaces = self.list_namespaces() + namespace_len = len(namespaces) + for i, namespace in enumerate(namespaces, start=1): + print(f"{i}. {namespace}") - while True: - if self.selected_pod_index not in list(range(1, pod_len+1)): - self.selected_pod_index = int(input(f"\nWrong input! Select a pod number between {1} to {pod_len}: >> ")) - else: - break + self.selected_namespace_index = int(input("\nWhich namespace do you want to check the health for? Select a namespace by entering its number: >> ")) + while True: + if self.selected_namespace_index not in list(range(1, namespace_len+1)): + self.selected_namespace_index = int(input(f"\nWrong input! Select a namespace number between {1} to {namespace_len}: >> ")) + else: + break + + self.selected_namespace = namespaces[self.selected_namespace_index - 1] + pod_list = self.list_pods(self.selected_namespace) + pod_len = len(pod_list) + print(f"\nPods in the namespace {self.selected_namespace}:\n") + for i, pod in enumerate(pod_list, start=1): + print(f"{i}. {pod}") + self.selected_pod_index = int(input(f"\nWhich pod from {self.selected_namespace} do you want to check the health for? Select a pod by entering its number: >> ")) + + while True: + if self.selected_pod_index not in list(range(1, pod_len+1)): + self.selected_pod_index = int(input(f"\nWrong input! Select a pod number between {1} to {pod_len}: >> ")) + else: + break - print("\nChecking status of the pod...") + print("\nChecking status of the pod...") - print("\nExtracting logs and events from the pod...") + print("\nExtracting logs and events from the pod...") - logs_from_pod = self.get_logs_from_pod(self.selected_namespace_index, self.selected_pod_index) + logs_from_pod = self.get_logs_from_pod(self.selected_namespace_index, self.selected_pod_index) - self.logs_extracted = extract_log_entries(logs_from_pod) + self.logs_extracted = extract_log_entries(logs_from_pod) - print("\nLogs and events from the pod extracted successfully!\n") + print("\nLogs and events from the pod extracted successfully!\n") - prompt_to_llm = self.create_prompt(self.logs_extracted) + prompt_to_llm = self.create_prompt(self.logs_extracted) - krsllmclient.interactive_session(prompt_to_llm) + krsllmclient.interactive_session(prompt_to_llm) - self.save_state() - except Exception as e: - print(f"An error occurred during the health check: {e}") - raise e - except: - print("An error occurred during the health check.") - raise Exception("An error occurred during the health check.") - + self.save_state() - def get_logs_from_pod(self, namespace_index: int, pod_index: int) -> list: - - """ - Get the logs from a pod. - - Args: - namespace_index (int): The index of the namespace. - pod_index (int): The index of the pod. - Returns: - list: List of logs from the pod. - """ - + def get_logs_from_pod(self, namespace_index, pod_index): try: namespace_index -= 1 pod_index -= 1 @@ -539,80 +248,30 @@ def get_logs_from_pod(self, namespace_index: int, pod_index: int) -> list: return list(self.pod_info[namespace][pod_index]['info']['Logs'].values())[0] except KeyError as e: print("\nKindly enter a value from the available namespaces and pods") - return [] - except Exception as e: - print(f"An error occurred during getting logs from the pod: {e}") - raise e - except: - print("An error occurred during getting logs from the pod.") - raise Exception("An error occurred during getting logs from the pod.") - - - def create_prompt(self, log_entries: list) -> str: - - """ - Create a prompt for the LLM. - - Args: - log_entries (list): List of log entries. - Returns: - str: The prompt for the LLM. - """ - - try: - prompt = "You are a DevOps expert with experience in Kubernetes. Analyze the following log entries:\n{\n" - for entry in sorted(log_entries): # Sort to maintain consistent order - prompt += f"{entry}\n" - prompt += "}\nIf there is nothing of concern in between { }, return a message stating that 'Everything looks good!'. Explain the warnings and errors and the steps that should be taken to resolve the issues, only if they exist." - return prompt - except Exception as e: - print(f"An error occurred during creating the prompt: {e}") - raise e - except: - print("An error occurred during creating the prompt.") - raise Exception("An error occurred during creating the prompt.") - + return None + + def create_prompt(self, log_entries): + prompt = "You are a DevOps expert with experience in Kubernetes. Analyze the following log entries:\n{\n" + for entry in sorted(log_entries): # Sort to maintain consistent order + prompt += f"{entry}\n" + prompt += "}\nIf there is nothing of concern in between { }, return a message stating that 'Everything looks good!'. Explain the warnings and errors and the steps that should be taken to resolve the issues, only if they exist." + return prompt - def export_pod_info(self) -> None: + def export_pod_info(self): - """ - Export the pod info with logs and events. - - Args: - None - Returns: - None - """ - - try: - self.check_scanned() # Check if the cluster has been scanned + self.check_scanned() - with open(POD_INFO_FILEPATH, 'w') as f: # Open the file in write mode - json.dump(self.pod_info, f, cls=CustomJSONEncoder) # Dump the pod info to the file - except Exception as e: - print(f"An error occurred during exporting the pod info: {e}") - raise e - except: - print("An error occurred during exporting the pod info.") - raise Exception("An error occurred during exporting the pod info.") + with open(POD_INFO_FILEPATH, 'w') as f: + json.dump(self.pod_info, f, cls=CustomJSONEncoder) - def exit(self)-> None: - - """ - Exit the tool. - - Args: - None - Returns: - None - """ + def exit(self): try: # List all files and directories in the given directory - files = os.listdir(KRS_DATA_DIRECTORY) # Get all the files in the directory + files = os.listdir(KRS_DATA_DIRECTORY) for file in files: - file_path = os.path.join(KRS_DATA_DIRECTORY, file) + file_path = os.path.join(KRS_DATA_DIRECTORY, file) # Check if it's a file and not a directory if os.path.isfile(file_path): os.remove(file_path) # Delete the file @@ -620,20 +279,16 @@ def exit(self)-> None: except Exception as e: print(f"Error occurred: {e}") - raise e - except: - print("An error occurred during deleting the files.") - raise Exception("An error occurred during deleting the files.") def main(self): - self.scan_cluster() # Scan the cluster - self.generate_recommendations() # Generate recommendations - self.health_check() # Check the health of the cluster + self.scan_cluster() + self.generate_recommendations() + self.health_check() if __name__=='__main__': - recommender = KrsMain() # Initialize the KrsMain class - recommender.main() # Run the main function + recommender = KrsMain() + recommender.main() # logs_info = recommender.get_logs_from_pod(4,2) # print(logs_info) # logs = recommender.extract_log_entries(logs_info) diff --git a/krs/utils/cluster_scanner.py b/krs/utils/cluster_scanner.py index 372dc4b..922836b 100644 --- a/krs/utils/cluster_scanner.py +++ b/krs/utils/cluster_scanner.py @@ -1,25 +1,8 @@ -from math import e from kubernetes import client, config import logging - -# Define the KubetoolsScanner class class KubetoolsScanner: - - - def __init__(self, get_events: bool = True, get_logs: bool =True, config_file: str ='~/.kube/config') -> None: - """ - __init__ method for the KubetoolsScanner class. - - Args: - get_events (bool): Flag indicating whether to fetch events associated with the pod. - get_logs (bool): Flag indicating whether to fetch logs of the pod. - config_file (str): The path to the Kubernetes configuration file. - - Returns: - None - - """ + def __init__(self, get_events=True, get_logs=True, config_file='~/.kube/config'): self.get_events = get_events self.get_logs = get_logs self.config_file = config_file @@ -27,128 +10,46 @@ def __init__(self, get_events: bool = True, get_logs: bool =True, config_file: s self.v2 = None self.setup_kubernetes_client() - def setup_kubernetes_client(self) -> None: - - """ - Sets up the Kubernetes client using the configuration file provided. - - Args: - None - - Returns: - None - - """ - + def setup_kubernetes_client(self): try: - # Load the Kubernetes configuration - config.load_kube_config(config_file=self.config_file) # Load the Kubernetes configuration - self.v1 = client.AppsV1Api() # Create an instance of the Kubernetes AppsV1 API - self.v2 = client.CoreV1Api() # Create an instance of the Kubernetes CoreV1 API + config.load_kube_config(config_file=self.config_file) + self.v1 = client.AppsV1Api() + self.v2 = client.CoreV1Api() except Exception as e: logging.error("Failed to load Kubernetes configuration: %s", e) raise - except: - logging.error("An error occurred while setting up the Kubernetes client.") - raise - def scan_kubernetes_deployment(self) -> tuple: - - """ - Scans the Kubernetes deployment for pods, deployments, and namespaces. - - Args: - None - Returns: - tuple: A tuple containing the list of pods, pod information, deployments, and namespaces. - - """ - + def scan_kubernetes_deployment(self): try: - - # Fetch the list of pods, pod information, deployments, and namespaces - deployments = self.v1.list_deployment_for_all_namespaces() # Fetch deployments + deployments = self.v1.list_deployment_for_all_namespaces() namespaces = self.list_namespaces() - - pod_dict = {} - pod_list = [] - for name in namespaces: - pods = self.list_pods(name) - pod_list += pods - pod_dict[name] = [{'name': pod, 'info': self.get_pod_info(name, pod)} for pod in pods] # Fetch pod info - - # Extract the names of the pods and deployments - deployment_list = [dep.metadata.name for dep in deployments.items] # List deployment names - return pod_list, pod_dict, deployment_list, namespaces # Return the list of pods, pod info, deployments, and namespaces - - except client.rest.ApiException as e: - logging.error("Error fetching data from Kubernetes API: %s", e) - return [], {}, [], [] except Exception as e: logging.error("Error fetching data from Kubernetes API: %s", e) return {}, {}, [] - except: - logging.error("An error occurred while fetching data from the Kubernetes API.") - return {}, {}, [] + pod_dict = {} + pod_list = [] + for name in namespaces: + pods = self.list_pods(name) + pod_list += pods + pod_dict[name] = [{'name': pod, 'info': self.get_pod_info(name, pod)} for pod in pods] - def list_namespaces(self) -> list: - - """ - Lists all the namespaces in the Kubernetes cluster. + deployment_list = [dep.metadata.name for dep in deployments.items] + return pod_list, pod_dict, deployment_list, namespaces - Args: - None - - Returns: - list: A list of namespace names. - """ - - try: - namespaces = self.v2.list_namespace() # List all namespaces - return [namespace.metadata.name for namespace in namespaces.items] # Extract namespace names - except Exception as e: - logging.error("Failed to list namespaces: %s", e) - return [] - except: - logging.error("An error occurred while listing namespaces.") - return [] + def list_namespaces(self): + namespaces = self.v2.list_namespace() + return [namespace.metadata.name for namespace in namespaces.items] - def list_pods_all(self) -> list: - - """ - Lists all the pods in all namespaces in the Kubernetes cluster. - - Args: - None - - Returns: - list: A list of pod names. - """ - - try: - pods = self.v2.list_pod_for_all_namespaces() # List all pods - return [pod.metadata.name for pod in pods.items] # Extract pod names - except Exception as e: - logging.error("Failed to list pods: %s", e) - return [] - except: - logging.error("An error occurred while listing pods.") - return [] + def list_pods_all(self): + pods = self.v2.list_pod_for_all_namespaces() + return [pod.metadata.name for pod in pods.items] - def list_pods(self, namespace : str) -> list: - - try: - pods = self.v2.list_namespaced_pod(namespace) # List pods in a specific namespace - return [pod.metadata.name for pod in pods.items] # Extract pod names - except Exception as e: - logging.error("Failed to list pods in namespace %s: %s", namespace, e) - return [] - except: - logging.error("An error occurred while listing pods in namespace %s.", namespace) - return [] + def list_pods(self, namespace): + pods = self.v2.list_namespaced_pod(namespace) + return [pod.metadata.name for pod in pods.items] - def get_pod_info(self, namespace : str, pod: str, include_events: bool =True, include_logs: bool = True) -> dict: + def get_pod_info(self, namespace, pod, include_events=True, include_logs=True): """ Retrieves information about a specific pod in a given namespace. @@ -161,80 +62,42 @@ def get_pod_info(self, namespace : str, pod: str, include_events: bool =True, in Returns: dict: A dictionary containing the pod information, events (if include_events is True), and logs (if include_logs is True). """ - - try: - - # Fetch pod information - pod_info = self.v2.read_namespaced_pod(pod, namespace) # Read pod information - pod_info_map = pod_info.to_dict() # Convert to dictionary - pod_info_map["metadata"]["managed_fields"] = None # Clean up metadata - - info = {'PodInfo': pod_info_map} - - if include_events: - info['Events'] = self.fetch_pod_events(namespace, pod) - - if include_logs: - # Retrieve logs for all containers within the pod - container_logs = {} - - # Fetch logs for each container in the pod - for container in pod_info.spec.containers: - try: - logs = self.v2.read_namespaced_pod_log(name=pod, namespace=namespace, container=container.name) # Read logs - container_logs[container.name] = logs # Store logs in a dictionary - except Exception as e: - logging.error("Failed to fetch logs for container %s in pod %s: %s", container.name, pod, e) - container_logs[container.name] = "Error fetching logs: " + str(e) - info['Logs'] = container_logs - - return info - except client.rest.ApiException as e: - logging.error("Error fetching pod info: %s", e) - return {} - except Exception as e: - logging.error("Error fetching pod info: %s", e) - return {} - except: - logging.error("An error occurred while fetching pod info.") - return {} - - def fetch_pod_events(self, namespace: str, pod: str) -> list: - - """ - Fetches the events associated with a specific pod in a given namespace. - - Args: - namespace (str): The namespace of the pod. - pod (str): The name of the pod. - - Returns: - list: A list of event details. - - """ - - try: - events = self.v2.list_namespaced_event(namespace) - return [{ - 'Name': event.metadata.name, - 'Message': event.message, - 'Reason': event.reason - } for event in events.items if event.involved_object.name == pod] - except client.rest.ApiException as e: - logging.error("Error fetching events for pod %s in namespace %s: %s", pod, namespace, e) - return [] - except Exception as e: - logging.error("Error fetching events for pod %s in namespace %s: %s", pod, namespace, e) - return [] - except: - logging.error("An error occurred while fetching events for pod %s in namespace %s.", pod, namespace) - return [] + pod_info = self.v2.read_namespaced_pod(pod, namespace) + pod_info_map = pod_info.to_dict() + pod_info_map["metadata"]["managed_fields"] = None # Clean up metadata + + info = {'PodInfo': pod_info_map} + + if include_events: + info['Events'] = self.fetch_pod_events(namespace, pod) + + if include_logs: + # Retrieve logs for all containers within the pod + container_logs = {} + for container in pod_info.spec.containers: + try: + logs = self.v2.read_namespaced_pod_log(name=pod, namespace=namespace, container=container.name) + container_logs[container.name] = logs + except Exception as e: + logging.error("Failed to fetch logs for container %s in pod %s: %s", container.name, pod, e) + container_logs[container.name] = "Error fetching logs: " + str(e) + info['Logs'] = container_logs + + return info + + def fetch_pod_events(self, namespace, pod): + events = self.v2.list_namespaced_event(namespace) + return [{ + 'Name': event.metadata.name, + 'Message': event.message, + 'Reason': event.reason + } for event in events.items if event.involved_object.name == pod] if __name__ == '__main__': - scanner = KubetoolsScanner() # Initialize the KubetoolsScanner - pod_list, pod_info, deployments, namespaces = scanner.scan_kubernetes_deployment() # Scan the Kubernetes deployment + scanner = KubetoolsScanner() + pod_list, pod_info, deployments, namespaces = scanner.scan_kubernetes_deployment() print("POD List: \n\n", pod_list) print("\n\nPOD Info: \n\n", pod_info.keys()) print("\n\nNamespaces: \n\n", namespaces) diff --git a/krs/utils/fetch_tools_krs.py b/krs/utils/fetch_tools_krs.py index 8d28d29..af20e27 100644 --- a/krs/utils/fetch_tools_krs.py +++ b/krs/utils/fetch_tools_krs.py @@ -3,227 +3,79 @@ import yaml from krs.utils.constants import (KUBETOOLS_DATA_JSONURL, KUBETOOLS_JSONPATH, CNCF_YMLPATH, CNCF_YMLURL, CNCF_TOOLS_JSONPATH, TOOLS_RANK_JSONPATH, CATEGORY_RANK_JSONPATH) - - # Function to convert 'githubStars' to a float, or return 0 if it cannot be converted -def get_github_stars(tool: dict) -> float: - - """ - get_github_stars checks for the tool’s star rating, if one exists, - it returns the star rating as a float data type, else, it returns 0. - - Args: - -tool(dict) - Arg that holds info about a tool. - - Returns: - -float: A tool’s star rating - - """ +def get_github_stars(tool): + stars = tool.get('githubStars', 0) try: - stars = tool.get('githubStars', 0) return float(stars) - except ValueError as e: - print(f"Error: {e}") - return {} - except TypeError as e: - print(f"Error: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while converting the star rating.") - return {} - + except ValueError: + return 0.0 # Function to download and save a file -def download_file(url: str, filename: str) -> None: - - """ - Downloads a file from the specified URL and saves it to the given filename. - - Args: - url (str): The URL of the file to download. - filename (str): The path to save the downloaded file. - - Returns: - None - """ - try: - response = requests.get(url) - response.raise_for_status() # Ensure we notice bad responses - with open(filename, 'wb') as file: - file.write(response.content) - except requests.exceptions.RequestException as e: - print(f"Error: {e}") - return {} - except FileNotFoundError as e: - print(f"Error: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while downloading the file.") - return {} - -def parse_yaml_to_dict(yaml_file_path: str) -> dict: - - """ - Parses a YAML file and returns a dictionary with tool names as keys and their statuses as values. - - This function specifically targets the structure of the CNCF landscape YAML file, extracting tool names and their project statuses. - - Args: - yaml_file_path (str): The file path of the YAML file to parse. - - Returns: - dict: A dictionary with tool names as keys and their project statuses as values. - """ - - try: - with open(yaml_file_path, 'r') as file: - data = yaml.safe_load(file) - except FileNotFoundError: - print(f"Error: The file {yaml_file_path} was not found.") - return {} - except yaml.YAMLError as e: - print(f"Error parsing the YAML file: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while parsing the YAML file.") - return {} +def download_file(url, filename): + response = requests.get(url) + response.raise_for_status() # Ensure we notice bad responses + with open(filename, 'wb') as file: + file.write(response.content) + +def parse_yaml_to_dict(yaml_file_path): + with open(yaml_file_path, 'r') as file: + data = yaml.safe_load(file) cncftools = {} - try: - - # Extract tool names and project statuses from the CNCF landscape YAML file - for category in data.get('landscape', []): - for subcategory in category.get('subcategories', []): - for item in subcategory.get('items', []): - item_name = item.get('name').lower() - project_status = item.get('project', 'listed') - cncftools[item_name] = project_status - except AttributeError as e: - print(f"Error processing the YAML file: {e}") - return {} - except ValueError as e: - print(f"Error processing data: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while processing the YAML file.") - return {} + for category in data.get('landscape', []): + for subcategory in category.get('subcategories', []): + for item in subcategory.get('items', []): + item_name = item.get('name').lower() + project_status = item.get('project', 'listed') + cncftools[item_name] = project_status return {'cncftools': cncftools} -def save_json_file(jsondict: dict, jsonpath: str) -> None: - - """ - Saves a dictionary to a JSON file at the specified path. - - This function takes a dictionary and writes it to a JSON file, formatting the output for readability. +def save_json_file(jsondict, jsonpath): - Args: - jsondict (dict): The dictionary to save. - jsonpath (str): The file path where the JSON file will be saved. + # Write the category dictionary to a new JSON file + with open(jsonpath, 'w') as f: + json.dump(jsondict, f, indent=4) - Returns: - None - """ - try: - # Write the category dictionary to a new JSON file - with open(jsonpath, 'w') as f: - json.dump(jsondict, f, indent=4) - except FileNotFoundError as e: - print(f"Error: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while saving the JSON file.") - return {} - - -def krs_tool_ranking_info()-> tuple: - - """ - krs_tool_ranking_info fetches the tool ranking data from the KubeTools API and the CNCF landscape YAML file, - processes this data to rank tools within categories based on GitHub stars, and saves the processed data to JSON files. - - Returns: - Tuple containing three dictionaries: - - tools_dict: Dictionary mapping tool names to their rankings, categories, and URLs. - - category_tools_dict: Dictionary mapping category names to dictionaries of tools ranked within the category. - - cncf_tools_dict: Dictionary representing the parsed CNCF landscape YAML file. - """ - - +def krs_tool_ranking_info(): # New dictionaries tools_dict = {} category_tools_dict = {} - try: - download_file(KUBETOOLS_DATA_JSONURL, KUBETOOLS_JSONPATH) # Download the KubeTools JSON file - download_file(CNCF_YMLURL, CNCF_YMLPATH) # Download the CNCF landscape YAML file - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while downloading the files.") - return {} - - try: - with open(KUBETOOLS_JSONPATH) as f: - data = json.load(f) + download_file(KUBETOOLS_DATA_JSONURL, KUBETOOLS_JSONPATH) + download_file(CNCF_YMLURL, CNCF_YMLPATH) + + with open(KUBETOOLS_JSONPATH) as f: + data = json.load(f) + + for category in data: + # Sort the tools in the current category by the number of GitHub stars + sorted_tools = sorted(category['tools'], key=get_github_stars, reverse=True) + + for i, tool in enumerate(sorted_tools, start=1): + tool["name"] = tool['name'].replace("\t", "").lower() + tool['ranking'] = i + + # Update tools_dict + tools_dict.setdefault(tool['name'], []).append({ + 'rank': i, + 'category': category['category']['name'], + 'url': tool['link'] + }) + + # Update ranked_tools_dict + category_tools_dict.setdefault(category['category']['name'], {}).update({i: {'name': tool['name'], 'url': tool['link']}}) - # Process the KubeTools data to rank tools within categories based on GitHub stars - for category in data: - # Sort the tools in the current category by the number of GitHub stars - sorted_tools = sorted(category['tools'], key=get_github_stars, reverse=True) - - for i, tool in enumerate(sorted_tools, start=1): - tool["name"] = tool['name'].replace("\t", "").lower() - tool['ranking'] = i - - # Update tools_dict - tools_dict.setdefault(tool['name'], []).append({ - 'rank': i, - 'category': category['category']['name'], - 'url': tool['link'] - }) - - # Update category_tools_dict - category_tools_dict.setdefault(category['category']['name'], {}).update({i: {'name': tool['name'], 'url': tool['link']}}) - - - cncf_tools_dict = parse_yaml_to_dict(CNCF_YMLPATH) # Parse the CNCF landscape YAML file - save_json_file(cncf_tools_dict, CNCF_TOOLS_JSONPATH) # Save the CNCF landscape dictionary to a JSON file - save_json_file(tools_dict, TOOLS_RANK_JSONPATH) # Save the tools dictionary to a JSON file - save_json_file(category_tools_dict, CATEGORY_RANK_JSONPATH) # Save the category dictionary to a JSON file - - return tools_dict, category_tools_dict, cncf_tools_dict - - except FileNotFoundError as e: - print(f"Error: {e}") - return {} - except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - except: - print("An error occurred while processing the data.") - return {} + + cncf_tools_dict = parse_yaml_to_dict(CNCF_YMLPATH) + save_json_file(cncf_tools_dict, CNCF_TOOLS_JSONPATH) + save_json_file(tools_dict, TOOLS_RANK_JSONPATH) + save_json_file(category_tools_dict, CATEGORY_RANK_JSONPATH) + + return tools_dict, category_tools_dict, cncf_tools_dict if __name__=='__main__': tools_dict, category_tools_dict, cncf_tools_dict = krs_tool_ranking_info() diff --git a/krs/utils/functional.py b/krs/utils/functional.py index e5d2c0d..b4e9b17 100644 --- a/krs/utils/functional.py +++ b/krs/utils/functional.py @@ -1,152 +1,75 @@ from difflib import SequenceMatcher -from math import e import re, json from datetime import datetime class CustomJSONEncoder(json.JSONEncoder): - """ - JSON Encoder for complex objects not serializable by default json code. - """ - def default(self, obj: object) -> object: - - """ - Serialize datetime objects to ISO 8601 format. - - Args: - obj (object): Object to serialize. - - Returns: - object: Serialized object. - """ - try: - if isinstance(obj, datetime): - # Format datetime object as a string in ISO 8601 format - return obj.isoformat() - # Let the base class default method raise the TypeError - return json.JSONEncoder.default(self, obj) - except TypeError as e: - return str(obj) - except Exception as e: - raise e - except: - print("An error occurred during serialization.") - raise + """JSON Encoder for complex objects not serializable by default json code.""" + def default(self, obj): + if isinstance(obj, datetime): + # Format datetime object as a string in ISO 8601 format + return obj.isoformat() + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) -def similarity(a : str, b: str) -> float: - """ - Calculate the similarity ratio between two strings using SequenceMatcher. - - Args: - a (str): First string. - b (str): Second string. - Returns: - float: Similarity ratio between the two strings. - - """ - - try: - return SequenceMatcher(None, a, b).ratio() - except Exception as e: - print(f"An error occurred during similarity calculation: {e}") - return 0.0 - except: - print("An error occurred during similarity calculation.") - return 0.0 +def similarity(a, b): + return SequenceMatcher(None, a, b).ratio() -def filter_similar_entries(log_entries : list) -> list: - - """ - Filter out highly similar log entries from a list of log entries. - - Args: - log_entries (list): List of log entries. - Returns: - list: Filtered list of log entries. - """ - - try: - unique_entries = list(log_entries) - to_remove = set() +def filter_similar_entries(log_entries): + unique_entries = list(log_entries) + to_remove = set() - # Compare each pair of log entries - for i in range(len(unique_entries)): - for j in range(i + 1, len(unique_entries)): - if similarity(unique_entries[i], unique_entries[j]) > 0.85: - # Choose the shorter entry to remove, or either if they are the same length - if len(unique_entries[i]) > len(unique_entries[j]): - to_remove.add(unique_entries[i]) - else: - to_remove.add(unique_entries[j]) + # Compare each pair of log entries + for i in range(len(unique_entries)): + for j in range(i + 1, len(unique_entries)): + if similarity(unique_entries[i], unique_entries[j]) > 0.85: + # Choose the shorter entry to remove, or either if they are the same length + if len(unique_entries[i]) > len(unique_entries[j]): + to_remove.add(unique_entries[i]) + else: + to_remove.add(unique_entries[j]) - # Filter out the highly similar entries - filtered_entries = {entry for entry in unique_entries if entry not in to_remove} - return filtered_entries - except Exception as e: - print(f"An error occurred during filtering of log entries: {e}") - return [] - except: - print("An error occurred during filtering of log entries.") - return [] + # Filter out the highly similar entries + filtered_entries = {entry for entry in unique_entries if entry not in to_remove} + return filtered_entries - -def extract_log_entries(log_contents : str) -> list: - - """ - Extract log entries from a string containing log data. - - Args: - log_contents (str): String containing log data. - Returns: - list: List of extracted log entries. - """ - - +def extract_log_entries(log_contents): # Patterns to match different log formats - - try: - patterns = [ - re.compile(r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{6}Z\s+(warn|error)\s+\S+\s+(.*)', re.IGNORECASE), - re.compile(r'[WE]\d{4} \d{2}:\d{2}:\d{2}.\d+\s+\d+\s+(.*)'), - re.compile(r'({.*})') - ] # JSON log entry pattern - - log_entries = set() - # Attempt to match each line with all patterns - for line in log_contents.split('\n'): - for pattern in patterns: - match = pattern.search(line) - if match: - if match.groups()[0].startswith('{'): - # Handle JSON formatted log entries - try: - log_json = json.loads(match.group(1)) # Extract JSON object - if 'severity' in log_json and log_json['severity'].lower() in ['error', 'warning']: # Check for severity - level = "Error" if log_json['severity'] == "ERROR" else "Warning" # Map severity to Error or Warning - message = log_json.get('error', '') if 'error' in log_json.keys() else line # Extract error message - log_entries.add(f"{level}: {message.strip()}") # Add formatted log entry - elif 'level' in log_json: # Check for level - level = "Error" if log_json['level'] == "error" else "Warning" # Map level to Error or Warning - message = log_json.get('msg', '') + log_json.get('error', '') # Extract message - log_entries.add(f"{level}: {message.strip()}") # Add formatted log entry - except json.JSONDecodeError: # Skip if JSON is not valid - continue # Skip if JSON is not valid - else: - if len(match.groups()) == 2: - level, message = match.groups() - elif len(match.groups()) == 1: - message = match.group(1) # Assuming error as default - level = "ERROR" # Default if not specified in the log + patterns = [ + re.compile(r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{6}Z\s+(warn|error)\s+\S+\s+(.*)', re.IGNORECASE), + re.compile(r'[WE]\d{4} \d{2}:\d{2}:\d{2}.\d+\s+\d+\s+(.*)'), + re.compile(r'({.*})') + ] - level = "Error" if "error" in level.lower() else "Warning" # Map level to Error or Warning - formatted_message = f"{level}: {message.strip()}" # Format log entry - log_entries.add(formatted_message) # Add formatted log entry - break # Stop after the first match + log_entries = set() + # Attempt to match each line with all patterns + for line in log_contents.split('\n'): + for pattern in patterns: + match = pattern.search(line) + if match: + if match.groups()[0].startswith('{'): + # Handle JSON formatted log entries + try: + log_json = json.loads(match.group(1)) + if 'severity' in log_json and log_json['severity'].lower() in ['error', 'warning']: + level = "Error" if log_json['severity'] == "ERROR" else "Warning" + message = log_json.get('error', '') if 'error' in log_json.keys() else line + log_entries.add(f"{level}: {message.strip()}") + elif 'level' in log_json: + level = "Error" if log_json['level'] == "error" else "Warning" + message = log_json.get('msg', '') + log_json.get('error', '') + log_entries.add(f"{level}: {message.strip()}") + except json.JSONDecodeError: + continue # Skip if JSON is not valid + else: + if len(match.groups()) == 2: + level, message = match.groups() + elif len(match.groups()) == 1: + message = match.group(1) # Assuming error as default + level = "ERROR" # Default if not specified in the log - return filter_similar_entries(log_entries) # Filter out highly similar log entries + level = "Error" if "error" in level.lower() else "Warning" + formatted_message = f"{level}: {message.strip()}" + log_entries.add(formatted_message) + break # Stop after the first match - except Exception as e: - print(f"An error occurred during pattern creation: {e}") - return [] - except: - print("An error occurred during pattern creation.") - return [] \ No newline at end of file + return filter_similar_entries(log_entries) \ No newline at end of file diff --git a/krs/utils/llm_client.py b/krs/utils/llm_client.py index dcb9de3..f46265a 100644 --- a/krs/utils/llm_client.py +++ b/krs/utils/llm_client.py @@ -3,135 +3,51 @@ import os, time from krs.utils.constants import (MAX_OUTPUT_TOKENS, LLMSTATE_PICKLE_FILEPATH) - -# This class is used to interact with the OpenAI API or Huggingface API to generate responses for the given prompts. class KrsGPTClient: - # The constructor initializes the client and pipeline objects, and loads the state from the pickle file. - - def __init__(self, reinitialize : bool = False, reset_history : bool = False) -> None: - - - """ - - Initializes the KrsGPTClient object. - - Args: - reinitialize (bool): Flag to indicate whether to reinitialize the client. - reset_history (bool): Flag to indicate whether to reset the chat history. - Returns: - None - - """ - - try: - self.reinitialize = reinitialize - self.client = None - self.pipeline = None - self.provider = None - self.model = None - self.openai_api_key = None - self.continue_chat = False - self.history = [] - self.max_tokens = MAX_OUTPUT_TOKENS - - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise - except: - print("An error occurred during initialization.") - raise - - - try: - if not self.reinitialize: - print("\nLoading LLM State..") - self.load_state() - print("\nModel: ", self.model) - if not self.model: - self.initialize_client() - - self.history = [] if reset_history == True else self.history - - if self.history: - continue_chat = input("\n\nDo you want to continue previous chat ? (y/n) >> ") - while continue_chat not in ['y', 'n']: - print("Please enter either y or n!") - continue_chat = input("\nDo you want to continue previous chat ? (y/n) >> ") - if continue_chat == 'No': - self.history = [] - else: - self.continue_chat = True - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise - except: - print("An error occurred during initialization.") - raise - - def save_state(self, filename: str = LLMSTATE_PICKLE_FILEPATH) -> None: - - """ - save_state saves the state of the client to a pickle file. - - Args: - filename (str): The path to the pickle file. - Returns: - None - """ - - try: - state = { - 'provider': self.provider, - 'model': self.model, - 'history': self.history, - 'openai_api_key': self.openai_api_key - } - with open(filename, 'wb') as output: - pickle.dump(state, output, pickle.HIGHEST_PROTOCOL) # Save the state to a pickle file - except FileNotFoundError as e: - print(f"Error: {e}") - return {} - except Exception as e: - print(f"Error: {e}") - return {} - - - def load_state(self) -> None: - - """ - load_state loads the state of the client from a pickle file. - - Args: - None - Returns: - None - """ - + def __init__(self, reinitialize=False, reset_history=False): + + self.reinitialize = reinitialize + self.client = None + self.pipeline = None + self.provider = None + self.model = None + self.openai_api_key = None + self.continue_chat = False + self.history = [] + self.max_tokens = MAX_OUTPUT_TOKENS + + + if not self.reinitialize: + print("\nLoading LLM State..") + self.load_state() + print("\nModel: ", self.model) + if not self.model: + self.initialize_client() + + self.history = [] if reset_history == True else self.history + + if self.history: + continue_chat = input("\n\nDo you want to continue previous chat ? (y/n) >> ") + while continue_chat not in ['y', 'n']: + print("Please enter either y or n!") + continue_chat = input("\nDo you want to continue previous chat ? (y/n) >> ") + if continue_chat == 'No': + self.history = [] + else: + self.continue_chat = True + + def save_state(self, filename=LLMSTATE_PICKLE_FILEPATH): + state = { + 'provider': self.provider, + 'model': self.model, + 'history': self.history, + 'openai_api_key': self.openai_api_key + } + with open(filename, 'wb') as output: + pickle.dump(state, output, pickle.HIGHEST_PROTOCOL) + + def load_state(self): try: with open(LLMSTATE_PICKLE_FILEPATH, 'rb') as f: state = pickle.load(f) @@ -144,414 +60,143 @@ def load_state(self) -> None: elif self.provider == 'huggingface': self.init_huggingface_client(reinitialize=True) except (FileNotFoundError, EOFError): - print("No previous state found.") - except Exception as e: - print(f"Error loading state: {e}") - return {} - except: - print("An error occurred while loading the state.") - return {} - - def install_package(self, package_name: str) -> None: - - """ - - install_package installs the required package using pip. - - Args: - package_name (str): The name of the package to install. - - Returns: - None - - """ - + pass + + def install_package(self, package_name): import importlib try: importlib.import_module(package_name) print(f"\n{package_name} is already installed.") except ImportError: print(f"\nInstalling {package_name}...", end='', flush=True) - result = subprocess.run(['pip', 'install', package_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Install the package using pip + result = subprocess.run(['pip', 'install', package_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE) if result.returncode == 0: print(f" \n{package_name} installed successfully.") else: print(f" \nFailed to install {package_name}.") - except Exception as e: - print(f"An error occurred while installing the package: {e}") - raise - except: - print("An error occurred while installing the package.") - raise - def initialize_client(self) -> None: - - """ - - initialize_client initializes the client based on the user's choice of model provider. - - Args: - None - Returns: - None - """ - - try: - if not self.client and not self.pipeline: - choice = input("\nChoose the model provider for healthcheck: \n\n[1] OpenAI \n[2] Huggingface\n\n>> ") - if choice == '1': - self.init_openai_client() - elif choice == '2': - self.init_huggingface_client() - else: - raise ValueError("Invalid option selected") - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise - except: - print("An error occurred during initialization.") - raise - - def init_openai_client(self, reinitialize: bool = False) -> None: - - - """ - - init_openai_client initializes the OpenAI client. - - Args: - reinitialize (bool): Flag to indicate whether to reinitialize the client. - - Returns: - None - - """ + def initialize_client(self): + if not self.client and not self.pipeline: + choice = input("\nChoose the model provider for healthcheck: \n\n[1] OpenAI \n[2] Huggingface\n\n>> ") + if choice == '1': + self.init_openai_client() + elif choice == '2': + self.init_huggingface_client() + else: + raise ValueError("Invalid option selected") - try: - if not reinitialize: - print("\nInstalling necessary libraries..........") - self.install_package('openai') - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise - except: - print("An error occurred during initialization.") - raise - - try: - - import openai - from openai import OpenAI - - self.provider = 'OpenAI' - self.openai_api_key = input("\nEnter your OpenAI API key: ") if not reinitialize else self.openai_api_key - self.model = input("\nEnter the OpenAI model name: ") if not reinitialize else self.model - - self.client = OpenAI(api_key=self.openai_api_key) - - if not reinitialize or self.reinitialize: - while True: - try: - self.validate_openai_key() - break - except openai.error.AuthenticationError: - self.openai_api_key = input("\nInvalid Key! Please enter the correct OpenAI API key: ") - except openai.error.InvalidRequestError as e: - print(e) - self.model = input("\nEnter an OpenAI model name from latest OpenAI docs: ") - except openai.APIConnectionError as e: - print(e) - self.init_openai_client(reinitialize=False) - - self.save_state() - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise - except: - print("An error occurred during initialization.") - raise - - def init_huggingface_client(self, reinitialize: bool = False) -> None: - - """ - - init_huggingface_client initializes the Huggingface client. - - Args: - reinitialize (bool): Flag to indicate whether to reinitialize the client. - - Returns: - None - - """ - - try: + def init_openai_client(self, reinitialize=False): - if not reinitialize: - print("\nInstalling necessary libraries..........") - self.install_package('transformers') - self.install_package('torch') - - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - - try: - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow warnings - - import warnings - from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer - - warnings.filterwarnings("ignore", category=FutureWarning) # Suppress FutureWarnings - - self.provider = 'huggingface' - self.model = input("\nEnter the Huggingface model name: ") if not reinitialize else self.model - - try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model) # Load the tokenizer - self.model_hf = AutoModelForCausalLM.from_pretrained(self.model) # Load the model - self.pipeline = pipeline('text-generation', model=self.model_hf, tokenizer=self.tokenizer) # Create a pipeline - - except OSError as e: - print("\nError loading model: ", e) - print("\nPlease enter a valid Huggingface model name.") - self.init_huggingface_client(reinitialize=True) - - self.save_state() - except ValueError as e: - print(f"An error occurred during initialization: {e}") - raise - except TypeError as e: - print(f"An error occurred during initialization: {e}") - raise - except AttributeError as e: - print(f"An error occurred during initialization: {e}") - raise - except KeyError as e: - print(f"An error occurred during initialization: {e}") - raise - except Exception as e: - print(f"An error occurred during initialization: {e}") - raise - - - - def validate_openai_key(self) -> None: - - """ - Validate the OpenAI API key by attempting a small request. - - Args: - None - Returns: - None - - """ - - try: - - # Test the API key by sending a small request - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": "Test prompt, do nothing"}], - max_tokens=5 - ) - print("API key and model are valid.") - except openai.error.InvalidRequestError as e: - print(f"Error: {e}") - raise - except openai.error.AuthenticationError as e: - print(f"Error: {e}") - raise - except openai.APIConnectionError as e: - print(f"Error: {e}") - raise - except Exception as e: - print(f"Error: {e}") - raise - - - def infer(self, prompt: str) -> None: + if not reinitialize: + print("\nInstalling necessary libraries..........") + self.install_package('openai') - """ - infer generates a response for the given prompt. - - Args: - prompt (str): The user prompt. - - Returns: - None - """ - - try: - self.history.append({"role": "user", "content": prompt}) - input_prompt = self.history_to_prompt() + import openai + from openai import OpenAI - if self.provider == 'OpenAI': - response = self.client.chat.completions.create( # Generate a response - model=self.model, - messages=input_prompt, - max_tokens = self.max_tokens - ) - output = response.choices[0].message.content.strip() - - elif self.provider == 'huggingface': - responses = self.pipeline(input_prompt, max_new_tokens=self.max_tokens) # Generate a response - output = responses[0]['generated_text'] - - self.history.append({"role": "assistant", "content": output}) - print(">> ", output) - except ValueError as e: - print(f"An error occurred during inference: {e}") - raise - except TypeError as e: - print(f"An error occurred during inference: {e}") - raise - except AttributeError as e: - print(f"An error occurred during inference: {e}") - raise - except KeyError as e: - print(f"An error occurred during inference: {e}") - raise - except Exception as e: - print(f"An error occurred during inference: {e}") - raise - except: - print("An error occurred during inference.") - raise - - def interactive_session(self, prompt_input: str) -> None: - - """ - interactive_session starts an interactive session with the user. - - Args: - prompt_input (str): The initial prompt to start the conversation. - Returns: - None - - """ - - try: - - print("\nInteractive session started. Type 'end chat' to exit from the session!\n") + self.provider = 'OpenAI' + self.openai_api_key = input("\nEnter your OpenAI API key: ") if not reinitialize else self.openai_api_key + self.model = input("\nEnter the OpenAI model name: ") if not reinitialize else self.model - if self.continue_chat: - print('>> ', self.history[-1]['content']) - else: - initial_prompt = prompt_input - self.infer(initial_prompt) # Generate a response for the initial prompt + self.client = OpenAI(api_key=self.openai_api_key) + if not reinitialize or self.reinitialize: while True: - prompt = input("\n>> ") - if prompt.lower() == 'end chat': + try: + self.validate_openai_key() break - self.infer(prompt) - self.save_state() - - except ValueError as e: - print(f"An error occurred during the interactive session: {e}") - raise - except TypeError as e: - print(f"An error occurred during the interactive session: {e}") - raise - except AttributeError as e: - print(f"An error occurred during the interactive session: {e}") - raise - except KeyError as e: - print(f"An error occurred during the interactive session: {e}") - raise - except Exception as e: - print(f"An error occurred during the interactive session: {e}") - raise - except: - print("An error occurred during the interactive session.") - raise - - - - def history_to_prompt(self) -> list: - - """ - history_to_prompt converts the chat history to a prompt format. - - Args: - None - Returns: - list: The chat history in prompt format. - """ + except openai.error.AuthenticationError: + self.openai_api_key = input("\nInvalid Key! Please enter the correct OpenAI API key: ") + except openai.error.InvalidRequestError as e: + print(e) + self.model = input("\nEnter an OpenAI model name from latest OpenAI docs: ") + except openai.APIConnectionError as e: + print(e) + self.init_openai_client(reinitialize=False) + + self.save_state() + + def init_huggingface_client(self, reinitialize=False): + + if not reinitialize: + print("\nInstalling necessary libraries..........") + self.install_package('transformers') + self.install_package('torch') + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + + import warnings + from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer + + warnings.filterwarnings("ignore", category=FutureWarning) + + self.provider = 'huggingface' + self.model = input("\nEnter the Huggingface model name: ") if not reinitialize else self.model + try: - if self.provider == 'OpenAI': - return self.history - elif self.provider == 'huggingface': - return " ".join([item["content"] for item in self.history]) - except ValueError as e: - print(f"An error occurred during history conversion: {e}") - raise - except TypeError as e: - print(f"An error occurred during history conversion: {e}") - raise - except AttributeError as e: - print(f"An error occurred during history conversion: {e}") - raise - except KeyError as e: - print(f"An error occurred during history conversion: {e}") - raise - except Exception as e: - print(f"An error occurred during history conversion: {e}") - raise - except: - print("An error occurred during history conversion.") - raise + self.tokenizer = AutoTokenizer.from_pretrained(self.model) + self.model_hf = AutoModelForCausalLM.from_pretrained(self.model) + self.pipeline = pipeline('text-generation', model=self.model_hf, tokenizer=self.tokenizer) + + except OSError as e: + print("\nError loading model: ", e) + print("\nPlease enter a valid Huggingface model name.") + self.init_huggingface_client(reinitialize=True) + + self.save_state() + + def validate_openai_key(self): + """Validate the OpenAI API key by attempting a small request.""" + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Test prompt, do nothing"}], + max_tokens=5 + ) + print("API key and model are valid.") + + def infer(self, prompt): + self.history.append({"role": "user", "content": prompt}) + input_prompt = self.history_to_prompt() + + if self.provider == 'OpenAI': + response = self.client.chat.completions.create( + model=self.model, + messages=input_prompt, + max_tokens = self.max_tokens + ) + output = response.choices[0].message.content.strip() + + elif self.provider == 'huggingface': + responses = self.pipeline(input_prompt, max_new_tokens=self.max_tokens) + output = responses[0]['generated_text'] + + self.history.append({"role": "assistant", "content": output}) + print(">> ", output) + + def interactive_session(self, prompt_input): + print("\nInteractive session started. Type 'end chat' to exit from the session!\n") + + if self.continue_chat: + print('>> ', self.history[-1]['content']) + else: + initial_prompt = prompt_input + self.infer(initial_prompt) + + while True: + prompt = input("\n>> ") + if prompt.lower() == 'end chat': + break + self.infer(prompt) + self.save_state() + + def history_to_prompt(self): + if self.provider == 'OpenAI': + return self.history + elif self.provider == 'huggingface': + return " ".join([item["content"] for item in self.history]) if __name__ == "__main__": - client = KrsGPTClient(reinitialize=False) # Initialize the client + client = KrsGPTClient(reinitialize=False) # client.interactive_session("You are an 8th grade math tutor. Ask questions to gauge my expertise so that you can generate a training plan for me.")