diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index ed287e042c3c..bef75c980458 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -1,13 +1,10 @@ import os import json -import yaml import logging from DB import ( get_session, Provider, ProviderSetting, - Conversation, - Message, Prompt, PromptCategory, Argument, @@ -26,13 +23,34 @@ ) +def ensure_default_user(): + """Ensure default admin user exists""" + session = get_session() + user = session.query(User).filter_by(email=DEFAULT_USER).first() + if not user: + logging.info("Creating default admin user...") + user = User(email=DEFAULT_USER, admin=True) + session.add(user) + session.commit() + logging.info("Default user created.") + session.close() + return user + + def import_agents(user=DEFAULT_USER): agents = [ f.name for f in os.scandir("agents") if f.is_dir() and not f.name.startswith("__") ] + session = get_session() for agent_name in agents: + # Check if agent already exists + agent_exists = session.query(Agent).filter_by(name=agent_name).first() + if agent_exists: + logging.info(f"Agent {agent_name} already exists, skipping...") + continue + config_path = f"agents/{agent_name}/config.json" with open(config_path) as f: config = json.load(f) @@ -43,6 +61,7 @@ def import_agents(user=DEFAULT_USER): user=user, ) logging.info(f"Imported agent: {agent_name}") + session.close() def import_extensions(): @@ -55,82 +74,69 @@ def import_extensions(): del extensions_data["AGiXT Chains"] extension_settings_data = Extensions().get_extension_settings() session = get_session() - # Get the existing extensions and commands from the database - existing_extensions = session.query(Extension).all() - existing_commands = session.query(Command).all() - # Add new extensions and commands, and update existing commands + + # Process each extension for extension_data in extensions_data: extension_name = extension_data["extension_name"] - description = extension_data.get( - "description", "" - ) # Assign an empty string if description is missing - - # Find the existing extension or create a new one - extension = next( - (ext for ext in existing_extensions if ext.name == extension_name), - None, - ) - if extension is None: + description = extension_data.get("description", "") + + # Find or create extension + extension = session.query(Extension).filter_by(name=extension_name).first() + if not extension: extension = Extension(name=extension_name, description=description) session.add(extension) session.flush() - existing_extensions.append(extension) + logging.info(f"Imported extension: {extension_name}") + + # Process commands commands = extension_data["commands"] for command_data in commands: if "friendly_name" not in command_data: continue + command_name = command_data["friendly_name"] - # Find the existing command or create a new one - command = next( - ( - cmd - for cmd in existing_commands - if cmd.extension_id == extension.id and cmd.name == command_name - ), - None, + # Find or create command + command = ( + session.query(Command) + .filter_by(extension_id=extension.id, name=command_name) + .first() ) - if command is None: + + if not command: command = Command( extension_id=extension.id, name=command_name, ) session.add(command) session.flush() - existing_commands.append(command) logging.info(f"Imported command: {command_name}") - # Add command arguments + + # Process command arguments if "command_args" in command_data: - command_args = command_data["command_args"] - for arg, arg_type in command_args.items(): - if ( + for arg, arg_type in command_data["command_args"].items(): + existing_arg = ( session.query(Argument) .filter_by(command_id=command.id, name=arg) .first() - ): - continue - command_arg = Argument( - command_id=command.id, - name=arg, ) - session.add(command_arg) - logging.info(f"Imported argument: {arg} to command: {command_name}") - session.commit() - # Add extensions to the database if they don't exist - for extension_name in extension_settings_data.keys(): + if not existing_arg: + command_arg = Argument( + command_id=command.id, + name=arg, + ) + session.add(command_arg) + logging.info( + f"Imported argument: {arg} to command: {command_name}" + ) + + # Process extension settings + for extension_name, settings in extension_settings_data.items(): extension = session.query(Extension).filter_by(name=extension_name).first() if not extension: extension = Extension(name=extension_name) session.add(extension) session.flush() - existing_extensions.append(extension) logging.info(f"Imported extension: {extension_name}") - session.commit() - # Migrate extension settings - for extension_name, settings in extension_settings_data.items(): - extension = session.query(Extension).filter_by(name=extension_name).first() - if not extension: - logging.info(f"Extension '{extension_name}' not found.") - continue for setting_name, setting_value in settings.items(): setting = ( @@ -138,12 +144,7 @@ def import_extensions(): .filter_by(extension_id=extension.id, name=setting_name) .first() ) - if setting: - setting.value = setting_value - logging.info( - f"Updating setting: {setting_name} for extension: {extension_name}" - ) - else: + if not setting: setting = Setting( extension_id=extension.id, name=setting_name, @@ -153,6 +154,12 @@ def import_extensions(): logging.info( f"Imported setting: {setting_name} for extension: {extension_name}" ) + else: + setting.value = setting_value + logging.info( + f"Updated setting: {setting_name} for extension: {extension_name}" + ) + session.commit() session.close() @@ -167,12 +174,22 @@ def import_chains(user=DEFAULT_USER): if not chain_files: logging.info(f"No JSON files found in chains directory.") return + from Chain import Chain chain_importer = Chain(user=user) + session = get_session() + failures = [] for file in chain_files: chain_name = os.path.splitext(file)[0] + + # Check if chain already exists + existing_chain = session.query(Chain).filter_by(name=chain_name).first() + if existing_chain: + logging.info(f"Chain {chain_name} already exists, skipping...") + continue + file_path = os.path.join(chain_dir, file) with open(file_path, "r") as f: try: @@ -183,9 +200,14 @@ def import_chains(user=DEFAULT_USER): except Exception as e: logging.info(f"(1/3) Error importing chain from '{file}': {str(e)}") failures.append(file) - if failures: - # Try each that failed again just in case it had a dependency on another chain - for file in failures: + + # Retry failed imports twice more + for retry in range(2): + if not failures: + break + retry_failures = failures.copy() + failures = [] + for file in retry_failures: chain_name = os.path.splitext(file)[0] file_path = os.path.join(chain_dir, file) with open(file_path, "r") as f: @@ -193,39 +215,29 @@ def import_chains(user=DEFAULT_USER): chain_data = json.load(f) result = chain_importer.import_chain(chain_name, chain_data) logging.info(result) - failures.remove(file) except Exception as e: - logging.info(f"(2/3) Error importing chain from '{file}': {str(e)}") - if failures: - # Try one more time. - for file in failures: - chain_name = os.path.splitext(file)[0] - file_path = os.path.join(chain_dir, file) - with open(file_path, "r") as f: - try: - chain_data = json.load(f) - result = chain_importer.import_chain(chain_name, chain_data) - logging.info(result) - failures.remove(file) - except Exception as e: - logging.info( - f"(3/3) Error importing chain from '{file}': {str(e)}" - ) + logging.info( + f"({retry + 2}/3) Error importing chain from '{file}': {str(e)}" + ) + failures.append(file) + if failures: logging.info( f"Failed to import the following chains: {', '.join([os.path.splitext(file)[0] for file in failures])}" ) + session.close() + def import_prompts(user=DEFAULT_USER): session = get_session() - # Add default category if it doesn't exist user_data = session.query(User).filter(User.email == user).first() user_id = user_data.id + + # Ensure default category exists default_category = ( session.query(PromptCategory).filter_by(name="Default", user_id=user_id).first() ) - if not default_category: default_category = PromptCategory( name="Default", description="Default category", user_id=user_id @@ -234,12 +246,10 @@ def import_prompts(user=DEFAULT_USER): session.commit() logging.info("Imported Default prompt category") - # Get all prompt files in the specified folder for root, dirs, files in os.walk("prompts"): for file in files: prompt_category = None if root != "prompts": - # Use subfolder name as the prompt category category_name = os.path.basename(root) prompt_category = ( session.query(PromptCategory) @@ -255,143 +265,91 @@ def import_prompts(user=DEFAULT_USER): session.add(prompt_category) session.commit() else: - # Assign to "Uncategorized" category if prompt is in the root folder prompt_category = default_category - # Read the prompt content from the file - with open(os.path.join(root, file), "r") as f: - prompt_content = f.read() - - # Check if prompt with the same name and category already exists prompt_name = os.path.splitext(file)[0] - prompt = ( + + # Check if prompt already exists + existing_prompt = ( session.query(Prompt) .filter_by( - name=prompt_name, prompt_category=prompt_category, user_id=user_id + name=prompt_name, + prompt_category_id=prompt_category.id, + user_id=user_id, ) .first() ) - prompt_args = [] - for word in prompt_content.split(): - if word.startswith("{") and word.endswith("}"): - prompt_args.append(word[1:-1]) - if not prompt: - # Create the prompt entry in the database - prompt = Prompt( - name=prompt_name, - description="", - content=prompt_content, - prompt_category=prompt_category, - user_id=user_id, + + if existing_prompt: + logging.info( + f"Prompt {prompt_name} already exists in category {prompt_category.name}, skipping..." ) - session.add(prompt) - session.commit() - logging.info(f"Imported prompt: {prompt_name}") + continue + + with open(os.path.join(root, file), "r") as f: + prompt_content = f.read() + + # Create new prompt + prompt = Prompt( + name=prompt_name, + description="", + content=prompt_content, + prompt_category=prompt_category, + user_id=user_id, + ) + session.add(prompt) + session.commit() + logging.info(f"Imported prompt: {prompt_name}") + + # Add prompt arguments + prompt_args = [ + word[1:-1] + for word in prompt_content.split() + if word.startswith("{") and word.endswith("}") + ] - # Populate prompt arguments for arg in prompt_args: if ( - session.query(Argument) + not session.query(Argument) .filter_by(prompt_id=prompt.id, name=arg) .first() ): - continue - argument = Argument( - prompt_id=prompt.id, - name=arg, - ) - session.add(argument) - session.commit() - logging.info(f"Imported prompt argument: {arg} for {prompt_name}") - session.close() - - -def get_conversations(): - conversation_dir = os.path.join("conversations") - if os.path.exists(conversation_dir): - conversations = os.listdir(conversation_dir) - return [conversation.split(".")[0] for conversation in conversations] - return [] - - -def get_conversation(conversation_name): - history = {"interactions": []} - try: - history_file = os.path.join("conversations", f"{conversation_name}.yaml") - if os.path.exists(history_file): - with open(history_file, "r") as file: - history = yaml.safe_load(file) - except: - history = {"interactions": []} - return history + argument = Argument( + prompt_id=prompt.id, + name=arg, + ) + session.add(argument) + logging.info(f"Imported prompt argument: {arg} for {prompt_name}") + session.commit() -def import_conversations(user=DEFAULT_USER): - session = get_session() - user_data = session.query(User).filter(User.email == user).first() - user_id = user_data.id - conversations = get_conversations() - for conversation_name in conversations: - conversation = get_conversation(conversation_name=conversation_name) - if not conversation: - logging.info(f"Conversation '{conversation_name}' is empty, skipping.") - continue - if "interactions" in conversation: - for interaction in conversation["interactions"]: - agent_name = interaction["role"] - message = interaction["message"] - timestamp = interaction["timestamp"] - conversation = ( - session.query(Conversation) - .filter( - Conversation.name == conversation_name, - Conversation.user_id == user_id, - ) - .first() - ) - if not conversation: - # Create the conversation - conversation = Conversation(name=conversation_name, user_id=user_id) - session.add(conversation) - session.commit() - message = Message( - role=agent_name, - content=message, - timestamp=timestamp, - conversation_id=conversation.id, - ) - session.add(message) - session.commit() - logging.info(f"Imported conversation: {conversation_name}") session.close() def import_providers(): session = get_session() providers = get_providers() + for provider_name in providers: provider_options = get_provider_options(provider_name) - provider = session.query(Provider).filter_by(name=provider_name).one_or_none() - if provider: - logging.info(f"Updating provider: {provider_name}") - else: + + # Find or create provider + provider = session.query(Provider).filter_by(name=provider_name).first() + if not provider: provider = Provider(name=provider_name) session.add(provider) - logging.info(f"Imported provider: {provider_name}") session.commit() + logging.info(f"Imported provider: {provider_name}") + # Update provider settings for option_name, option_value in provider_options.items(): provider_setting = ( session.query(ProviderSetting) .filter_by(provider_id=provider.id, name=option_name) - .one_or_none() + .first() ) - if provider_setting: - provider_setting.value = str(option_value) - logging.info( - f"Updating provider setting: {option_name} for provider: {provider_name}" - ) - else: + + if not provider_setting: provider_setting = ProviderSetting( provider_id=provider.id, name=option_name, @@ -401,29 +359,29 @@ def import_providers(): logging.info( f"Imported provider setting: {option_name} for provider: {provider_name}" ) + else: + provider_setting.value = str(option_value) + logging.info( + f"Updated provider setting: {option_name} for provider: {provider_name}" + ) + session.commit() session.close() def import_all_data(): - session = get_session() - user_count = session.query(User).count() - if user_count == 0: - # Create the default user - logging.info("Creating default admin user...") - user = User(email=DEFAULT_USER, admin=True) - session.add(user) - session.commit() - logging.info("Default user created.") - logging.info("Importing providers...") - import_providers() - logging.info("Importing extensions...") - import_extensions() - logging.info("Importing prompts...") - import_prompts() - logging.info("Importing agents...") - import_agents() - logging.info("Importing chains...") - import_chains() - logging.info("Imports complete.") - session.close() + # Ensure default user exists + ensure_default_user() + + # Import all data types + logging.info("Importing providers...") + import_providers() + logging.info("Importing extensions...") + import_extensions() + logging.info("Importing prompts...") + import_prompts() + logging.info("Importing agents...") + import_agents() + logging.info("Importing chains...") + import_chains() + logging.info("Imports complete.") diff --git a/agixt/readers/file.py b/agixt/readers/file.py index e0f6833ef07c..027b8ed94ce4 100644 --- a/agixt/readers/file.py +++ b/agixt/readers/file.py @@ -7,7 +7,8 @@ import shutil import logging from datetime import datetime -import nbformat # Import nbformat for reading .ipynb files +import nbformat + class FileReader(Memories): def __init__( @@ -83,14 +84,14 @@ async def write_file_to_memory(self, file_path: str): command_name="Transcribe Audio from File", command_args={"filename": file_path}, ) - # If file extension is ipynb, extract code and markdown cells + # If file extension is ipynb, extract code and markdown cells elif file_path.endswith(".ipynb"): - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: nb = nbformat.read(f, as_version=4) for cell in nb.cells: - if cell.cell_type == 'markdown': + if cell.cell_type == "markdown": content += cell.source + "\n" - elif cell.cell_type == 'code': + elif cell.cell_type == "code": content += cell.source + "\n" # Otherwise just read the file else: @@ -111,4 +112,4 @@ async def write_file_to_memory(self, file_path: str): return True except Exception as e: logging.error(f"Error reading file {file_path}: {e}") - return False \ No newline at end of file + return False