diff --git a/.env.example b/.env.example index 45a531d..afee151 100644 --- a/.env.example +++ b/.env.example @@ -9,4 +9,6 @@ SUPABASE_KEY=your_key_here CATBOX_USERHASH=your_hash_here # Provide a secret to trusted minecraft servers to autheticate users -SYNERGY_SECRET=your_secret_here \ No newline at end of file +SYNERGY_SECRET=your_secret_here + +OPENAI_API_KEY=your_key_here \ No newline at end of file diff --git a/api.py b/api.py index d4fe3bc..52742f2 100644 --- a/api.py +++ b/api.py @@ -1,4 +1,5 @@ """Simple FastAPI server to generate verification codes for users.""" + import os import random from typing import Annotated @@ -31,10 +32,20 @@ async def get_verification_code(user: User, authorization: Annotated[str, Header db = DatabaseManager() # Invalidate existing codes for this user - await db.table("verification_codes").update({"valid": False}).eq("minecraft_uuid", str(user.uuid)).gt("expires", utcnow()).execute() + await ( + db.table("verification_codes") + .update({"valid": False}) + .eq("minecraft_uuid", str(user.uuid)) + .gt("expires", utcnow()) + .execute() + ) code = random.randint(100000, 999999) - await db.table("verification_codes").insert({"minecraft_uuid": str(user.uuid), "username": username, "code": code}).execute() + await ( + db.table("verification_codes") + .insert({"minecraft_uuid": str(user.uuid), "username": username, "code": code}) + .execute() + ) return code diff --git a/bot/submission/submit.py b/bot/submission/submit.py index ec9a6ea..49ea5a9 100644 --- a/bot/submission/submit.py +++ b/bot/submission/submit.py @@ -2,10 +2,11 @@ # from __future__ import annotations # dpy cannot resolve FlagsConverter with forward references :( from collections.abc import Sequence +from textwrap import dedent from typing import Literal, cast, TYPE_CHECKING, Any import discord -from discord import InteractionResponse, Guild +from discord import InteractionResponse, Guild, Message from discord.ext import commands from discord.ext.commands import ( Context, @@ -16,6 +17,7 @@ flag, ) from postgrest import APIResponse +from pydantic import ValidationError from bot import utils, config from bot.submission.ui import BuildSubmissionForm, ConfirmationView @@ -24,7 +26,7 @@ from database.database import DatabaseManager from database.enums import Status, Category from bot._types import SubmissionCommandResponse, GuildMessageable -from bot.utils import RunningMessage, parse_dimensions +from bot.utils import RunningMessage, parse_dimensions, parse_build_title, remove_markdown from database.message import get_build_id_by_message from database.schema import TypeRecord from database.server_settings import get_server_setting @@ -373,7 +375,9 @@ async def list_patterns(self, ctx: Context): async with RunningMessage(ctx) as sent_message: patterns: APIResponse[TypeRecord] = await DatabaseManager().table("types").select("*").execute() names = [pattern["name"] for pattern in patterns.data] - await sent_message.edit(content="Here are the available patterns:", embed=utils.info_embed("Patterns", ", ".join(names))) + await sent_message.edit( + content="Here are the available patterns:", embed=utils.info_embed("Patterns", ", ".join(names)) + ) @Cog.listener(name="on_raw_reaction_add") async def confirm_record(self, payload: discord.RawReactionActionEvent): @@ -424,6 +428,34 @@ async def confirm_record(self, payload: discord.RawReactionActionEvent): # TODO: Add a check when adding vote channels to the database raise ValueError(f"Invalid channel type for a vote channel: {type(vote_channel)}") + @Cog.listener(name="on_message") + async def infer_build_from_title(self, message: Message): + """Infer a build from a message.""" + if message.author.bot: + return + + if message.channel.id not in [726156829629087814, 667401499554611210, 536004554743873556]: + return + + title_str = remove_markdown(message.content).splitlines()[0] + try: + title = await parse_build_title(title_str, mode="ai" if len(title_str) <= 300 else "manual") + except ValidationError: + return + + build = Build() + build.record_category = title.record_category + build.category = "Door" + build.component_restrictions = title.component_restrictions + build.door_width = title.door_width + build.door_height = title.door_height + build.door_depth = title.door_depth + build.wiring_placement_restrictions = title.wiring_placement_restrictions + build.door_types = title.door_types + build.door_orientation_type = title.orientation + # print(title) + await self.bot.get_channel(536004554743873556).send(embed=build.generate_embed()) + def format_submission_input(ctx: Context, data: SubmissionCommandResponse) -> dict[str, Any]: """Formats the submission data from what is passed in commands to something recognizable by Build.""" diff --git a/bot/utils.py b/bot/utils.py index 3478a4d..541b689 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -1,15 +1,28 @@ +"""Utility functions for the bot.""" + +from __future__ import annotations + import re +from io import StringIO +from textwrap import dedent from traceback import format_tb from types import TracebackType -from typing import overload, Literal, Any +from typing import overload, Literal, TYPE_CHECKING import discord +from async_lru import alru_cache from discord import Message, Webhook from discord.abc import Messageable +from markdown import Markdown +from openai import AsyncOpenAI +from pydantic import BaseModel, Field from bot.config import OWNER_ID, PRINT_TRACEBACKS from database.database import DatabaseManager -from database.schema import RECORD_CATEGORIES, DOOR_ORIENTATION_NAMES +from database.schema import DoorOrientationName, RecordCategory, DOOR_ORIENTATION_NAMES + +if TYPE_CHECKING: + from xml.etree.ElementTree import Element discord_red = 0xF04747 discord_yellow = 0xFAA61A @@ -45,9 +58,7 @@ def parse_dimensions(dim_str: str) -> tuple[int, int, int | None]: ... @overload -def parse_dimensions( - dim_str: str, *, min_dim: int, max_dim: Literal[3] -) -> tuple[int, int | None, int | None]: ... +def parse_dimensions(dim_str: str, *, min_dim: int, max_dim: Literal[3]) -> tuple[int, int | None, int | None]: ... def parse_dimensions(dim_str: str, *, min_dim: int = 2, max_dim: int = 3) -> tuple[int | None, ...]: @@ -161,31 +172,244 @@ async def __aexit__( return False -async def parse_build_title(title: str) -> dict[str, Any]: - """Parses a title into a category and a name. +# See https://stackoverflow.com/questions/761824/python-how-to-convert-markdown-formatted-text-to-text +def _unmark_element(element: Element, stream=None): + if stream is None: + stream = StringIO() + if element.text: + stream.write(element.text) + for sub in element: + _unmark_element(sub, stream) + if element.tail: + stream.write(element.tail) + return stream.getvalue() + + +# patching Markdown +Markdown.output_formats["plain"] = _unmark_element # type: ignore +__md = Markdown(output_format="plain") # type: ignore +__md.stripTopLevelTags = False + + +def remove_markdown(text: str) -> str: + """Removes markdown formatting from a string.""" + return __md.convert(text) + + +async def parse_build_title(title: str, mode: Literal["ai", "manual"] = "manual") -> DoorTitle: + """Parses a title into its components. A build title should be in the format of: ``` - [Record Category] [component restrictions]+ [wiring placement restrictions]+ + [Record Category] [component restrictions]+ [wiring placement restrictions]+ + ``` Args: title: The title to parse + mode: The mode to parse the title in. Either "ai" or "manual". + + Returns: + A tuple of the parsed door title and the unparsed part + """ + if "\n" in title: + raise ValueError("Title cannot contain newlines") + + if mode == "ai": + return await ai_parse_piston_door_title(title) + elif mode == "manual": + title, _ = await manual_parse_piston_door_title(title) + return title + + +class DoorTitle(BaseModel): + record_category: RecordCategory | None = Field(..., description="The record category of the door") + component_restrictions: list[str] = Field(..., description="The restrictions on the components of the door") + door_width: int | None = Field(..., description="the width of the door") + door_height: int | None = Field(..., description="the height of the door") + door_depth: int | None = Field(..., description="the depth of the door") + wiring_placement_restrictions: list[str] = Field( + ..., description="The restrictions on the wiring placement of the door" + ) + door_types: list[str] = Field(..., description="The patterns of the door") + orientation: DoorOrientationName = Field(..., description="The orientation of the door") + + +def replace_insensitive(string: str, old: str, new: str) -> str: + """Replaces a substring in a string case-insensitively. + + Args: + string: The string to search and replace in. + old: The substring to search for. + new: The substring to replace with. Returns: - A dictionary containing the parsed information. + The modified string. """ - data = {} + pattern = re.compile(re.escape(old), re.IGNORECASE) + return pattern.sub(new, string) + + +async def manual_parse_piston_door_title(title: str) -> tuple[DoorTitle, str]: + """Parses a piston door title into its components.""" + title = title.lower() + + # Define record categories + record_categories = ["smallest", "fastest", "first"] + + # Check for record category + record_category = None + for category in record_categories: + if title.startswith(category): + record_category = category.capitalize() + title = title[len(category) :].strip() + break + + # Extract door size + door_size_match = re.search(r"\d+x\d+(x\d+)?", title) + door_size = (None, None, None) + if door_size_match: + door_size_str = door_size_match.group() + door_size = tuple(map(int, door_size_str.split("x"))) + if len(door_size) == 2: + door_size = (*door_size, None) + title = replace_insensitive(title, door_size_str, "").strip() + + # Split the remaining title by known door types + door_types = [] + for door_type in await get_valid_door_types(): + if door_type.lower() in title.lower(): + door_types.append(door_type) + title = replace_insensitive(title, door_type, "").strip() + + # Split remaining by orientation + orientation: DoorOrientationName | None = None + for orient in DOOR_ORIENTATION_NAMES: + if orient.lower() in title: + orientation = orient + title = replace_insensitive(title, orient, "").strip() + break + if orientation is None: + orientation = "Door" + + # Remaining words are restrictions words = title.split() - if words[0].title() in RECORD_CATEGORIES: - data["record_category"] = words.pop(0) - if words[-1].title() not in DOOR_ORIENTATION_NAMES: - raise ValueError(f"Invalid orientation. Expected one of {DOOR_ORIENTATION_NAMES}, found {words[-1]}") - else: - data["category"] = "Door" + component_restrictions = [] + wiring_placement_restrictions = [] + unparsed = [] + for word in words: + if word.title() in await get_valid_restrictions("component"): + component_restrictions.append(word.title()) + elif word.title() in await get_valid_restrictions("wiring-placement"): + wiring_placement_restrictions.append(word.title()) + else: + unparsed.append(word) + + assert orientation is not None + return DoorTitle( + record_category=record_category, + component_restrictions=component_restrictions, + door_width=door_size[0], + door_height=door_size[1], + door_depth=door_size[2], + wiring_placement_restrictions=wiring_placement_restrictions, + door_types=door_types, + orientation=orientation, + ), ", ".join(unparsed) + + +async def ai_parse_piston_door_title(title: str) -> DoorTitle: + """Parses a piston door title into its components using AI.""" + client = AsyncOpenAI() + system_prompt = dedent(""" + You are an expert at structured data extraction. You will be given unstructured text from a minecraft piston door name and should convert it into the given structure. + A build title is in the format of: + ``` + [Record Category] [component restrictions]+ [wiring placement restrictions]+ + + ``` + + Examples: + Title: "Smallest 5 high triangle piston door" + Parsed: {"record_category": "Smallest", "component_restrictions": [], "door_width": null, "door_height": 5, "door_depth": null, "wiring_placement_restrictions": [], "door_types": ["Triangle"], "orientation": "Door"} + """) + + completion = await client.beta.chat.completions.parse( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"Parse the following door title: {title}"}, + ], + response_format=DoorTitle, + ) + return completion.choices[0].message.parsed + + +@alru_cache() +async def get_valid_restrictions(type: Literal["component", "wiring-placement"]) -> list[str]: + """Gets a list of valid restrictions for a given type. + + Args: + type: The type of restriction. Either "component" or "wiring-placement" + Returns: + A list of valid restrictions for the given type. + """ db = DatabaseManager() - # Parse component restrictions - component_restrictions = await db.table("restrictions").select("name").eq("build_category", data["category"]).eq("type", "component").execute() + valid_restrictions_response = await db.table("restrictions").select("name").eq("type", type).execute() + return [restriction["name"] for restriction in valid_restrictions_response.data] + +@alru_cache() +async def get_valid_door_types() -> list[str]: + """Gets a list of valid door types. + + Returns: + A list of valid door types. + """ + db = DatabaseManager() + valid_door_types_response = await db.table("types").select("name").eq("build_category", "Door").execute() + return [door_type["name"] for door_type in valid_door_types_response.data] + + +# --- Unused --- +async def validate_restrictions(restrictions: list[str], type: Literal["component", "wiring-placement"]) -> list[str]: + """Validates a list of restrictions to ensure all of them are valid. + + Args: + restrictions: The list of restrictions to validate + type: The type of restriction. Either "component" or "wiring_placement" + + Returns: + The original list of restrictions if all of them are valid. + + Raises: + ValueError: If any of the restrictions are invalid. + """ + valid_restrictions = await get_valid_restrictions(type) + + invalid_restrictions = [r for r in restrictions if r not in valid_restrictions] + if invalid_restrictions: + raise ValueError( + f"Invalid {type} restrictions. Found {invalid_restrictions} which are not one of the restrictions in the database." + ) + return restrictions + + +async def validate_door_types(door_types: list[str]) -> list[str]: + """Validates a list of door types to ensure all of them are valid. + + Args: + door_types: The list of door types to validate + + Returns: + The original list of door types if all of them are valid. + + Raises: + ValueError: If any of the door types are invalid. + """ + invalid_door_types = [dt for dt in door_types if dt not in await get_valid_door_types()] + if invalid_door_types: + raise ValueError( + f"Invalid door types. Found {invalid_door_types} which are not one of the door types in the database." + ) + return door_types diff --git a/bot/verify.py b/bot/verify.py index 300f2a5..049ea8d 100644 --- a/bot/verify.py +++ b/bot/verify.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from bot.main import RedstoneSquid + class VerifyCog(Cog, name="verify"): def __init__(self, bot: RedstoneSquid): self.bot = bot diff --git a/database/user.py b/database/user.py index 14218fb..15649fb 100644 --- a/database/user.py +++ b/database/user.py @@ -1,4 +1,5 @@ """Handles user data and operations.""" + from uuid import UUID import requests @@ -37,16 +38,32 @@ async def link_minecraft_account(user_id: int, code: str) -> bool: """ db = DatabaseManager() - response = await db.table("verification_codes").select("minecraft_uuid", "minecraft_username").eq("code", code).gt("expires", utcnow()).maybe_single().execute() + response = ( + await db.table("verification_codes") + .select("minecraft_uuid", "minecraft_username") + .eq("code", code) + .gt("expires", utcnow()) + .maybe_single() + .execute() + ) if response is None: return False minecraft_uuid = response.data["minecraft_uuid"] minecraft_username = response.data["minecraft_username"] # TODO: This currently does not check if the ign is already in use without a UUID or discord ID given. - response = await db.table("users").update({"minecraft_uuid": minecraft_uuid, "ign": minecraft_username}).eq("discord_id", user_id).execute() + response = ( + await db.table("users") + .update({"minecraft_uuid": minecraft_uuid, "ign": minecraft_username}) + .eq("discord_id", user_id) + .execute() + ) if not response.data: - await db.table("users").insert({"discord_id": user_id, "minecraft_uuid": minecraft_uuid, "ign": minecraft_username}).execute() + await ( + db.table("users") + .insert({"discord_id": user_id, "minecraft_uuid": minecraft_uuid, "ign": minecraft_username}) + .execute() + ) return True @@ -80,4 +97,6 @@ def get_minecraft_username(user_uuid: str | UUID) -> str | None: elif response.status_code == 204: # No content return None else: - raise ValueError(f"Failed to get username for UUID {user_uuid}. The Mojang API returned status code {response.status_code}.") + raise ValueError( + f"Failed to get username for UUID {user_uuid}. The Mojang API returned status code {response.status_code}." + ) diff --git a/database/utils.py b/database/utils.py index 66364e9..3eef8fb 100644 --- a/database/utils.py +++ b/database/utils.py @@ -7,6 +7,7 @@ import requests from requests_toolbelt import MultipartEncoder + def utcnow() -> str: """Returns the current time in UTC in the format of a string.""" current_utc = datetime.now(tz=timezone.utc) @@ -28,15 +29,11 @@ def upload_to_catbox(filename: str, file: bytes, mimetype: str) -> str: """ catbox_url = "https://catbox.moe/user/api.php" data = { - 'reqtype': 'fileupload', - 'userhash': os.getenv('CATBOX_USERHASH'), - 'fileToUpload': (filename, file, mimetype) + "reqtype": "fileupload", + "userhash": os.getenv("CATBOX_USERHASH"), + "fileToUpload": (filename, file, mimetype), } encoder = MultipartEncoder(fields=data) - response = requests.post( - catbox_url, - data=encoder, - headers={'Content-Type': encoder.content_type} - ) + response = requests.post(catbox_url, data=encoder, headers={"Content-Type": encoder.content_type}) return response.text diff --git a/requirements.in b/requirements.in index af30cc9..a5fcffe 100644 --- a/requirements.in +++ b/requirements.in @@ -7,6 +7,6 @@ jishaku requests-toolbelt fastapi uvicorn -langchain -langchain-openai +openai async_lru +markdown diff --git a/requirements.txt b/requirements.txt index 6d30faf..802d290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,12 +4,11 @@ # # pip-compile # -aiohappyeyeballs==2.3.4 +aiohappyeyeballs==2.3.5 # via aiohttp -aiohttp==3.10.1 +aiohttp==3.10.3 # via # discord-py - # langchain # supabase-py-async aiosignal==1.3.1 # via aiohttp @@ -20,13 +19,13 @@ anyio==4.4.0 # httpx # openai # starlette -argcomplete==3.4.0 +argcomplete==3.5.0 # via commitizen astunparse==1.6.3 # via import-expression async-lru==2.0.4 # via -r requirements.in -attrs==24.1.0 +attrs==24.2.0 # via aiohttp braceexpand==0.1.7 # via jishaku @@ -50,7 +49,7 @@ colorama==0.4.6 # click # commitizen # tqdm -commitizen==3.28.0 +commitizen==3.29.0 # via supabase-py-async decli==0.6.2 # via commitizen @@ -70,16 +69,14 @@ frozenlist==1.4.1 # via # aiohttp # aiosignal -google-auth==2.32.0 +google-auth==2.33.0 # via # google-auth-oauthlib # gspread google-auth-oauthlib==1.2.1 # via gspread -gotrue==2.6.1 +gotrue==2.6.2 # via supabase-py-async -greenlet==3.0.3 - # via sqlalchemy gspread==6.1.2 # via -r requirements.in h11==0.14.0 @@ -115,46 +112,26 @@ jinja2==3.1.4 # via commitizen jishaku==2.5.2 # via -r requirements.in -jsonpatch==1.33 - # via langchain-core -jsonpointer==3.0.0 - # via jsonpatch -langchain==0.2.12 - # via -r requirements.in -langchain-core==0.2.28 - # via - # langchain - # langchain-openai - # langchain-text-splitters -langchain-openai==0.1.20 +jiter==0.5.0 + # via openai +markdown==3.6 # via -r requirements.in -langchain-text-splitters==0.2.2 - # via langchain -langsmith==0.1.96 - # via - # langchain - # langchain-core markupsafe==2.1.5 # via jinja2 multidict==6.0.5 # via # aiohttp # yarl -numpy==1.26.4 - # via langchain oauth2client==4.1.3 # via -r requirements.in oauthlib==3.2.2 # via requests-oauthlib -openai==1.38.0 - # via langchain-openai -orjson==3.10.6 - # via langsmith +openai==1.40.3 + # via -r requirements.in packaging==24.1 # via # commitizen # deprecation - # langchain-core postgrest==0.16.9 # via supabase-py-async prompt-toolkit==3.0.36 @@ -172,9 +149,6 @@ pydantic==2.8.2 # via # fastapi # gotrue - # langchain - # langchain-core - # langsmith # openai # postgrest pydantic-core==2.20.1 @@ -187,24 +161,16 @@ python-dateutil==2.9.0.post0 # storage3 python-dotenv==1.0.1 # via -r requirements.in -pyyaml==6.0.1 - # via - # commitizen - # langchain - # langchain-core +pyyaml==6.0.2 + # via commitizen questionary==2.0.1 # via commitizen realtime==1.0.6 # via supabase-py-async -regex==2024.7.24 - # via tiktoken requests==2.32.3 # via - # langchain - # langsmith # requests-oauthlib # requests-toolbelt - # tiktoken requests-oauthlib==2.0.0 # via google-auth-oauthlib requests-toolbelt==1.0.0 @@ -223,8 +189,6 @@ sniffio==1.3.1 # anyio # httpx # openai -sqlalchemy==2.0.31 - # via langchain starlette==0.37.2 # via fastapi storage3==0.7.7 @@ -235,14 +199,8 @@ supabase-py-async==2.5.6 # via -r requirements.in supafunc==0.4.0 # via supabase-py-async -tenacity==8.5.0 - # via - # langchain - # langchain-core termcolor==2.4.0 # via commitizen -tiktoken==0.7.0 - # via langchain-openai tomlkit==0.13.0 # via commitizen tqdm==4.66.5 @@ -250,12 +208,10 @@ tqdm==4.66.5 typing-extensions==4.12.2 # via # fastapi - # langchain-core # openai # pydantic # pydantic-core # realtime - # sqlalchemy # storage3 urllib3==2.2.2 # via requests