diff --git a/cogs/admin.py b/cogs/admin.py index f9b88c4e..19c566de 100644 --- a/cogs/admin.py +++ b/cogs/admin.py @@ -1,8 +1,10 @@ +import aiofiles +import aiofiles.os import discord import random import subprocess from discord.ext import commands -from os import environ, mkdir, remove +from os import environ class Admin(commands.Cog): def __init__(self, bot): @@ -17,6 +19,10 @@ async def admin_shutdown(self, ctx): return await ctx.send("Shutting down...") + # Shutdown aiohttp client and the bot + if hasattr(self.bot, "_aiohttp_session"): + await self.bot._aiohttp_session.close() + await self.bot.close() # Execute command @@ -48,13 +54,13 @@ async def admin_execute(self, ctx, *shell_command): if output.stdout: # If the output exceeds 2000 characters, send it as a file if len(output.stdout.decode('utf-8')) > 2000: - with open(_xfilepath, "w+") as f: - f.write(output.stdout.decode('utf-8')) + async with aiofiles.open(_xfilepath, "w+") as f: + await f.write(output.stdout.decode('utf-8')) await ctx.respond(f"I executed `{pretty_shell_command}` and got:", file=discord.File(_xfilepath, "output.txt")) # Delete the file - remove(_xfilepath) + await aiofiles.os.remove(_xfilepath) else: await ctx.respond(f"I executed `{pretty_shell_command}` and got:") await ctx.send(f"```{output.stdout.decode('utf-8')}```") @@ -62,45 +68,5 @@ async def admin_execute(self, ctx, *shell_command): await ctx.respond(f"I executed `{pretty_shell_command}` and got no output") - # TODO: To write a better implementation of "eval" command, this code is very buggy - # Evaluate command - #@commands.command() - #async def admin_evaluate(self, ctx, *python_expression): - # """Evaluates a python code (owner only)""" - # if ctx.author.id != 1039885147761283113: - # await ctx.respond("Only my master can do that >:(") - # return - # - # # Check for arguments - # if not python_expression or len(python_expression) == 0: - # await ctx.respond("You need to provide a inline python expression to evaluate") - # return - # - # # Tuple to string as inline f-string doesn't work with these expressions - # pretty_py_exec = " ".join(python_expression) - # - # try: - # output = eval(f"{pretty_py_exec}") - # except Exception as e: - # await ctx.respond(f"I executed `{pretty_py_exec}` and got an error:\n{e}") - # return - # - # # Print the output - # if output is not None: - # # Send the output to file if it exceeds 2000 characters - # if len(str(output)) > 2000: - # # Check if temp folder exists - # if not Path("temp").exists(): mkdir("temp") - # - # with open("temp/py_output.txt", "w+") as f: - # f.write(output) - # - # await ctx.respond(f"I executed `{pretty_py_exec}` and got:", file=File(_xfilepath, "output.txt")) - # - # # Delete the file - # remove("temp/py_output.txt") - # else: - # await ctx.respond(f"I executed `{pretty_py_exec}` and got:\n```{str(output)}```") - def setup(bot): bot.add_cog(Admin(bot)) diff --git a/cogs/gemini/generative.py b/cogs/gemini/generative.py index 2bc1bcc3..fd160cd9 100644 --- a/cogs/gemini/generative.py +++ b/cogs/gemini/generative.py @@ -9,6 +9,7 @@ import google.api_core.exceptions import aiohttp import aiofiles +import aiofiles.os import asyncio import discord import importlib @@ -23,9 +24,6 @@ def __init__(self, bot): self.bot: discord.Bot = bot self.author = environ.get("BOT_NAME", "Jakey Bot") - self.bot.loop.create_task(self._initialize()) - - async def _initialize(self): # Load the database and initialize the HistoryManagement class # MongoDB database connection for chat history and possibly for other things try: @@ -45,11 +43,7 @@ async def _initialize(self): self._assistants_system_prompt = Assistants() # Media download shared session - self._download_session = aiohttp.ClientSession() - - def cog_unload(self): - # Close media download session - self.bot.loop.create_task(self._download_session.close()) + self._download_session: aiohttp.ClientSession = self.bot._aiohttp_session ############################################### # Ask command @@ -150,7 +144,7 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str except aiohttp.ClientError as httperror: # Remove the file if it exists ensuring no data persists even on failure if Path(_xfilename).exists(): - remove(_xfilename) + await aiofiles.os.remove(_xfilename) # Raise exception raise httperror @@ -174,7 +168,7 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str await ctx.respond(f"❌ An error has occured when uploading the file or the file format is not supported\nLog:\n```{e}```") return finally: - remove(_xfilename) + await aiofiles.os.remove(_xfilename) # Immediately use the "used" status message to indicate that the file API is used if verbose_logs: @@ -185,7 +179,6 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str # Add caution that the attachment data would be lost in 48 hours await ctx.send("> 📝 **Note:** The submitted file attachment will be deleted from the context after 48 hours.") - await _x_msgstatus.delete() if not attachment and hasattr(_Tool, "file_uri"): _Tool.tool_config = "NONE" @@ -262,8 +255,8 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str if len(answer.text) > 4096: # Send the response as file response_file = f"{environ.get('TEMP_DIR')}/response{random.randint(6000,7000)}.md" - with open(response_file, "w+") as f: - f.write(answer.text) + async with aiofiles.open(response_file, "w+") as f: + await f.write(answer.text) await ctx.respond("⚠️ Response is too long. But, I saved your response into a markdown file", file=discord.File(response_file, "response.md")) elif len(answer.text) > 2000: embed = discord.Embed( @@ -281,7 +274,7 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str # Increment the prompt count _prompt_count += 1 # Also save the ChatSession.history attribute to the context history chat history key so it will be saved through pickle - _chat_thread = asyncio.to_thread(jsonpickle.encode, chat_session.history, indent=4, keys=True) + _chat_thread = await asyncio.to_thread(jsonpickle.encode, chat_session.history, indent=4, keys=True) # Print context size and model info if append_history: diff --git a/cogs/gemini/message_actions.py b/cogs/gemini/message_actions.py index d9c7c620..fdf3d3bd 100644 --- a/cogs/gemini/message_actions.py +++ b/cogs/gemini/message_actions.py @@ -15,11 +15,6 @@ def __init__(self, bot): self.bot: discord.Bot = bot self.author = environ.get("BOT_NAME", "Jakey Bot") - # Run _initialize function using the discord.Bot.loop.create_task which seems to be the proper way to run async functions - # In constructor - self.bot.loop.create_task(self._initialize()) - - async def _initialize(self): # Check for gemini API keys if environ.get("GOOGLE_AI_TOKEN") is None or environ.get("GOOGLE_AI_TOKEN") == "INSERT_API_KEY": raise Exception("GOOGLE_AI_TOKEN is not configured in the dev.env file. Please configure it and try again.") @@ -33,11 +28,7 @@ async def _initialize(self): self._system_prompt = Assistants() # Media download shared session - self._media_download_session = aiohttp.ClientSession() - - def cog_unload(self): - # Close media download session - self.bot.loop.create_task(self._media_download_session.close()) + self._media_download_session: aiohttp.ClientSession = self.bot._aiohttp_session async def _media_download(self, url, save_path): # Check if the file size is too large (max 3MB) diff --git a/cogs/gemini/summarize.py b/cogs/gemini/summarize.py index b65b5f91..20b06aca 100644 --- a/cogs/gemini/summarize.py +++ b/cogs/gemini/summarize.py @@ -3,6 +3,7 @@ from discord.ext import commands from os import environ import google.generativeai as genai +import aiofiles import datetime import discord import inspect @@ -148,9 +149,9 @@ async def summarize(self, ctx, before_date: str, after_date: str, around_date: s if len(_summary.text) > 4096: # Send the response as file response_file = f"{environ.get('TEMP_DIR')}/response{random.randint(8000,9000)}.md" - with open(response_file, "a+") as f: - f.write(_app_title + "\n----------\n") - f.write(_summary.text) + async with aiofiles.open(response_file, "a+") as f: + await f.write(_app_title + "\n----------\n") + await f.write(_summary.text) await ctx.respond(f"Here is the summary generated for this channel\n>✨ Model used: {model}", file=discord.File(response_file, "response.md")) else: _embed = discord.Embed( diff --git a/main.py b/main.py index 3138eb03..da99ffc8 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from inspect import cleandoc from os import chdir, environ, mkdir from pathlib import Path +import aiohttp import discord import importlib import logging @@ -40,6 +41,8 @@ # Bot bot = bridge.Bot(command_prefix=commands.when_mentioned_or("$"), intents = intents) +# aiohttp session +bot._aiohttp_session = aiohttp.ClientSession(loop=bot.loop) ############################################### # ON READY diff --git a/tools/audio_editor.py b/tools/audio_editor.py index ce41292e..822c611a 100644 --- a/tools/audio_editor.py +++ b/tools/audio_editor.py @@ -1,5 +1,6 @@ # Huggingface spaces endpoints import google.generativeai as genai +import aiofiles.os import asyncio import discord import importlib @@ -74,5 +75,5 @@ async def _tool_function(self, prompt: str, edit_start_in_seconds: int = 3, edit await self.ctx.send(file=discord.File(fp=result)) # Cleanup - os.remove(result) + await aiofiles.os.remove(result) return "Audio editing success" diff --git a/tools/image_generator.py b/tools/image_generator.py index d672658c..d3dcd175 100644 --- a/tools/image_generator.py +++ b/tools/image_generator.py @@ -1,5 +1,6 @@ # Huggingface spaces endpoints import google.generativeai as genai +import aiofiles.os import asyncio import discord import importlib @@ -70,5 +71,5 @@ async def _tool_function(self, image_description: str, width: int, height: int): await self.ctx.send(file=discord.File(fp=result[0])) # Cleanup - os.remove(result[0]) + await aiofiles.os.remove(result[0]) return "Image generation success and the file should be sent automatically" diff --git a/tools/web_browsing.py b/tools/web_browsing.py index ec989525..56b96558 100644 --- a/tools/web_browsing.py +++ b/tools/web_browsing.py @@ -1,6 +1,7 @@ from core.ai.embeddings import GeminiDocumentRetrieval from google_labs_html_chunker.html_chunker import HtmlChunker import google.generativeai as genai +import aiofiles import asyncio import discord import importlib @@ -54,7 +55,7 @@ async def _tool_function(self, query: str, max_results: int): links = [] # Load excluded urls list - with open("data/excluded_urls.yaml") as x: + async with aiofiles.open("data/excluded_urls.yaml") as x: excluded_url_list = yaml.safe_load(x) # Iterate