Skip to content

Commit

Permalink
Use aiofiles for file I/O in commands if possible
Browse files Browse the repository at this point in the history
aiohttp's ClientSession is now at the top level main.py which uses discord.Bot's running loop attribute
(still causes ssl.c errors, time to switch prod distros?)

Shutdown command: also close aiohttp Clientsession but its still buggy: aio-libs/aiohttp#1925

Other fixes
  • Loading branch information
zavocc committed Sep 25, 2024
1 parent 3f4eb03 commit 8e92bf4
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 74 deletions.
54 changes: 10 additions & 44 deletions cogs/admin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import aiofiles
import aiofiles.os
import discord
import random
import subprocess
from discord.ext import commands
from os import environ, mkdir, remove
from os import environ

class Admin(commands.Cog):
def __init__(self, bot):
Expand All @@ -17,6 +19,10 @@ async def admin_shutdown(self, ctx):
return

await ctx.send("Shutting down...")
# Shutdown aiohttp client and the bot
if hasattr(self.bot, "_aiohttp_session"):
await self.bot._aiohttp_session.close()

await self.bot.close()

# Execute command
Expand Down Expand Up @@ -48,59 +54,19 @@ async def admin_execute(self, ctx, *shell_command):
if output.stdout:
# If the output exceeds 2000 characters, send it as a file
if len(output.stdout.decode('utf-8')) > 2000:
with open(_xfilepath, "w+") as f:
f.write(output.stdout.decode('utf-8'))
async with aiofiles.open(_xfilepath, "w+") as f:
await f.write(output.stdout.decode('utf-8'))

await ctx.respond(f"I executed `{pretty_shell_command}` and got:", file=discord.File(_xfilepath, "output.txt"))

# Delete the file
remove(_xfilepath)
await aiofiles.os.remove(_xfilepath)
else:
await ctx.respond(f"I executed `{pretty_shell_command}` and got:")
await ctx.send(f"```{output.stdout.decode('utf-8')}```")
else:
await ctx.respond(f"I executed `{pretty_shell_command}` and got no output")


# TODO: To write a better implementation of "eval" command, this code is very buggy
# Evaluate command
#@commands.command()
#async def admin_evaluate(self, ctx, *python_expression):
# """Evaluates a python code (owner only)"""
# if ctx.author.id != 1039885147761283113:
# await ctx.respond("Only my master can do that >:(")
# return
#
# # Check for arguments
# if not python_expression or len(python_expression) == 0:
# await ctx.respond("You need to provide a inline python expression to evaluate")
# return
#
# # Tuple to string as inline f-string doesn't work with these expressions
# pretty_py_exec = " ".join(python_expression)
#
# try:
# output = eval(f"{pretty_py_exec}")
# except Exception as e:
# await ctx.respond(f"I executed `{pretty_py_exec}` and got an error:\n{e}")
# return
#
# # Print the output
# if output is not None:
# # Send the output to file if it exceeds 2000 characters
# if len(str(output)) > 2000:
# # Check if temp folder exists
# if not Path("temp").exists(): mkdir("temp")
#
# with open("temp/py_output.txt", "w+") as f:
# f.write(output)
#
# await ctx.respond(f"I executed `{pretty_py_exec}` and got:", file=File(_xfilepath, "output.txt"))
#
# # Delete the file
# remove("temp/py_output.txt")
# else:
# await ctx.respond(f"I executed `{pretty_py_exec}` and got:\n```{str(output)}```")

def setup(bot):
bot.add_cog(Admin(bot))
21 changes: 7 additions & 14 deletions cogs/gemini/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import google.api_core.exceptions
import aiohttp
import aiofiles
import aiofiles.os
import asyncio
import discord
import importlib
Expand All @@ -23,9 +24,6 @@ def __init__(self, bot):
self.bot: discord.Bot = bot
self.author = environ.get("BOT_NAME", "Jakey Bot")

self.bot.loop.create_task(self._initialize())

async def _initialize(self):
# Load the database and initialize the HistoryManagement class
# MongoDB database connection for chat history and possibly for other things
try:
Expand All @@ -45,11 +43,7 @@ async def _initialize(self):
self._assistants_system_prompt = Assistants()

# Media download shared session
self._download_session = aiohttp.ClientSession()

def cog_unload(self):
# Close media download session
self.bot.loop.create_task(self._download_session.close())
self._download_session: aiohttp.ClientSession = self.bot._aiohttp_session

###############################################
# Ask command
Expand Down Expand Up @@ -150,7 +144,7 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str
except aiohttp.ClientError as httperror:
# Remove the file if it exists ensuring no data persists even on failure
if Path(_xfilename).exists():
remove(_xfilename)
await aiofiles.os.remove(_xfilename)
# Raise exception
raise httperror

Expand All @@ -174,7 +168,7 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str
await ctx.respond(f"❌ An error has occured when uploading the file or the file format is not supported\nLog:\n```{e}```")
return
finally:
remove(_xfilename)
await aiofiles.os.remove(_xfilename)

# Immediately use the "used" status message to indicate that the file API is used
if verbose_logs:
Expand All @@ -185,7 +179,6 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str

# Add caution that the attachment data would be lost in 48 hours
await ctx.send("> 📝 **Note:** The submitted file attachment will be deleted from the context after 48 hours.")
await _x_msgstatus.delete()

if not attachment and hasattr(_Tool, "file_uri"):
_Tool.tool_config = "NONE"
Expand Down Expand Up @@ -262,8 +255,8 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str
if len(answer.text) > 4096:
# Send the response as file
response_file = f"{environ.get('TEMP_DIR')}/response{random.randint(6000,7000)}.md"
with open(response_file, "w+") as f:
f.write(answer.text)
async with aiofiles.open(response_file, "w+") as f:
await f.write(answer.text)
await ctx.respond("⚠️ Response is too long. But, I saved your response into a markdown file", file=discord.File(response_file, "response.md"))
elif len(answer.text) > 2000:
embed = discord.Embed(
Expand All @@ -281,7 +274,7 @@ async def ask(self, ctx, prompt: str, attachment: discord.Attachment, model: str
# Increment the prompt count
_prompt_count += 1
# Also save the ChatSession.history attribute to the context history chat history key so it will be saved through pickle
_chat_thread = asyncio.to_thread(jsonpickle.encode, chat_session.history, indent=4, keys=True)
_chat_thread = await asyncio.to_thread(jsonpickle.encode, chat_session.history, indent=4, keys=True)

# Print context size and model info
if append_history:
Expand Down
11 changes: 1 addition & 10 deletions cogs/gemini/message_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ def __init__(self, bot):
self.bot: discord.Bot = bot
self.author = environ.get("BOT_NAME", "Jakey Bot")

# Run _initialize function using the discord.Bot.loop.create_task which seems to be the proper way to run async functions
# In constructor
self.bot.loop.create_task(self._initialize())

async def _initialize(self):
# Check for gemini API keys
if environ.get("GOOGLE_AI_TOKEN") is None or environ.get("GOOGLE_AI_TOKEN") == "INSERT_API_KEY":
raise Exception("GOOGLE_AI_TOKEN is not configured in the dev.env file. Please configure it and try again.")
Expand All @@ -33,11 +28,7 @@ async def _initialize(self):
self._system_prompt = Assistants()

# Media download shared session
self._media_download_session = aiohttp.ClientSession()

def cog_unload(self):
# Close media download session
self.bot.loop.create_task(self._media_download_session.close())
self._media_download_session: aiohttp.ClientSession = self.bot._aiohttp_session

async def _media_download(self, url, save_path):
# Check if the file size is too large (max 3MB)
Expand Down
7 changes: 4 additions & 3 deletions cogs/gemini/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from discord.ext import commands
from os import environ
import google.generativeai as genai
import aiofiles
import datetime
import discord
import inspect
Expand Down Expand Up @@ -148,9 +149,9 @@ async def summarize(self, ctx, before_date: str, after_date: str, around_date: s
if len(_summary.text) > 4096:
# Send the response as file
response_file = f"{environ.get('TEMP_DIR')}/response{random.randint(8000,9000)}.md"
with open(response_file, "a+") as f:
f.write(_app_title + "\n----------\n")
f.write(_summary.text)
async with aiofiles.open(response_file, "a+") as f:
await f.write(_app_title + "\n----------\n")
await f.write(_summary.text)
await ctx.respond(f"Here is the summary generated for this channel\n>✨ Model used: {model}", file=discord.File(response_file, "response.md"))
else:
_embed = discord.Embed(
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from inspect import cleandoc
from os import chdir, environ, mkdir
from pathlib import Path
import aiohttp
import discord
import importlib
import logging
Expand Down Expand Up @@ -40,6 +41,8 @@

# Bot
bot = bridge.Bot(command_prefix=commands.when_mentioned_or("$"), intents = intents)
# aiohttp session
bot._aiohttp_session = aiohttp.ClientSession(loop=bot.loop)

###############################################
# ON READY
Expand Down
3 changes: 2 additions & 1 deletion tools/audio_editor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Huggingface spaces endpoints
import google.generativeai as genai
import aiofiles.os
import asyncio
import discord
import importlib
Expand Down Expand Up @@ -74,5 +75,5 @@ async def _tool_function(self, prompt: str, edit_start_in_seconds: int = 3, edit
await self.ctx.send(file=discord.File(fp=result))

# Cleanup
os.remove(result)
await aiofiles.os.remove(result)
return "Audio editing success"
3 changes: 2 additions & 1 deletion tools/image_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Huggingface spaces endpoints
import google.generativeai as genai
import aiofiles.os
import asyncio
import discord
import importlib
Expand Down Expand Up @@ -70,5 +71,5 @@ async def _tool_function(self, image_description: str, width: int, height: int):
await self.ctx.send(file=discord.File(fp=result[0]))

# Cleanup
os.remove(result[0])
await aiofiles.os.remove(result[0])
return "Image generation success and the file should be sent automatically"
3 changes: 2 additions & 1 deletion tools/web_browsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from core.ai.embeddings import GeminiDocumentRetrieval
from google_labs_html_chunker.html_chunker import HtmlChunker
import google.generativeai as genai
import aiofiles
import asyncio
import discord
import importlib
Expand Down Expand Up @@ -54,7 +55,7 @@ async def _tool_function(self, query: str, max_results: int):
links = []

# Load excluded urls list
with open("data/excluded_urls.yaml") as x:
async with aiofiles.open("data/excluded_urls.yaml") as x:
excluded_url_list = yaml.safe_load(x)

# Iterate
Expand Down

0 comments on commit 8e92bf4

Please sign in to comment.