-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(github): add github integration
- Loading branch information
1 parent
362aef2
commit 0480459
Showing
21 changed files
with
465 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ praw==7.8.1 | |
py-cord==2.6.1 | ||
python-dotenv==1.0.1 | ||
requests==2.32.3 | ||
requests-oauthlib==2.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# standard imports | ||
import os | ||
|
||
# lib imports | ||
from cryptography import x509 | ||
from cryptography.hazmat.backends import default_backend | ||
from cryptography.hazmat.primitives import hashes | ||
from cryptography.hazmat.primitives.asymmetric import rsa | ||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption | ||
from datetime import datetime, timedelta, UTC | ||
|
||
# local imports | ||
from src.common import common | ||
|
||
CERT_FILE = os.path.join(common.data_dir, "cert.pem") | ||
KEY_FILE = os.path.join(common.data_dir, "key.pem") | ||
|
||
|
||
def check_expiration(cert_path: str) -> int: | ||
with open(cert_path, "rb") as cert_file: | ||
cert_data = cert_file.read() | ||
cert = x509.load_pem_x509_certificate(cert_data, default_backend()) | ||
expiry_date = cert.not_valid_after_utc | ||
return (expiry_date - datetime.now(UTC)).days | ||
|
||
|
||
def generate_certificate(): | ||
private_key = rsa.generate_private_key( | ||
public_exponent=65537, | ||
key_size=4096, | ||
) | ||
subject = issuer = x509.Name([ | ||
x509.NameAttribute(x509.NameOID.COMMON_NAME, u"localhost"), | ||
]) | ||
cert = x509.CertificateBuilder().subject_name( | ||
subject | ||
).issuer_name( | ||
issuer | ||
).public_key( | ||
private_key.public_key() | ||
).serial_number( | ||
x509.random_serial_number() | ||
).not_valid_before( | ||
datetime.now(UTC) | ||
).not_valid_after( | ||
datetime.now(UTC) + timedelta(days=365) | ||
).sign(private_key, hashes.SHA256()) | ||
|
||
with open(KEY_FILE, "wb") as f: | ||
f.write(private_key.private_bytes( | ||
encoding=Encoding.PEM, | ||
format=PrivateFormat.TraditionalOpenSSL, | ||
encryption_algorithm=NoEncryption(), | ||
)) | ||
|
||
with open(CERT_FILE, "wb") as f: | ||
f.write(cert.public_bytes(Encoding.PEM)) | ||
|
||
|
||
def initialize_certificate() -> tuple[str, str]: | ||
print("Initializing SSL certificate") | ||
if os.path.exists(CERT_FILE) and os.path.exists(KEY_FILE): | ||
cert_expires_in = check_expiration(CERT_FILE) | ||
print(f"Certificate expires in {cert_expires_in} days.") | ||
if cert_expires_in >= 90: | ||
return CERT_FILE, KEY_FILE | ||
print("Generating new certificate") | ||
generate_certificate() | ||
return CERT_FILE, KEY_FILE | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# standard imports | ||
import shelve | ||
import threading | ||
|
||
|
||
class Database: | ||
def __init__(self, db_path): | ||
self.db_path = db_path | ||
self.lock = threading.Lock() | ||
|
||
def __enter__(self): | ||
self.lock.acquire() | ||
self.db = shelve.open(self.db_path, writeback=True) | ||
return self.db | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.sync() | ||
self.db.close() | ||
self.lock.release() | ||
|
||
def sync(self): | ||
self.db.sync() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
DISCORD_BOT = None | ||
REDDIT_BOT = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# standard imports | ||
import asyncio | ||
import os | ||
from threading import Thread | ||
from typing import Tuple | ||
|
||
# lib imports | ||
import discord | ||
from flask import Flask, jsonify, redirect, request, Response | ||
from requests_oauthlib import OAuth2Session | ||
|
||
# local imports | ||
from src.common import crypto | ||
from src.common import globals | ||
|
||
|
||
DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") | ||
DISCORD_CLIENT_SECRET = os.getenv("DISCORD_CLIENT_SECRET") | ||
DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI", "https://localhost:8080/discord/callback") | ||
|
||
app = Flask('LizardByte-bot') | ||
|
||
|
||
@app.route('/') | ||
def main(): | ||
return "LizardByte-bot is live!" | ||
|
||
|
||
@app.route("/discord/callback") | ||
def discord_callback(): | ||
# get all active states from the global state manager | ||
with globals.DISCORD_BOT.db as db: | ||
active_states = db['oauth_states'] | ||
|
||
discord_oauth = OAuth2Session(DISCORD_CLIENT_ID, redirect_uri=DISCORD_REDIRECT_URI) | ||
token = discord_oauth.fetch_token("https://discord.com/api/oauth2/token", | ||
client_secret=DISCORD_CLIENT_SECRET, | ||
authorization_response=request.url) | ||
|
||
# Fetch the user's Discord profile | ||
response = discord_oauth.get("https://discord.com/api/users/@me") | ||
discord_user = response.json() | ||
|
||
# if the user is not in the active states, return an error | ||
if discord_user['id'] not in active_states: | ||
return "Invalid state" | ||
|
||
# remove the user from the active states | ||
del active_states[discord_user['id']] | ||
|
||
# Fetch the user's connected accounts | ||
connections_response = discord_oauth.get("https://discord.com/api/users/@me/connections") | ||
connections = connections_response.json() | ||
|
||
with globals.DISCORD_BOT.db as db: | ||
db['discord_users'] = db.get('discord_users', {}) | ||
db['discord_users'][discord_user['id']] = { | ||
'discord_username': discord_user['username'], | ||
'discord_global_name': discord_user['global_name'], | ||
'github_id': None, | ||
'github_username': None, | ||
'token': token, # TODO: should we store the token at all? | ||
} | ||
|
||
for connection in connections: | ||
if connection['type'] == 'github': | ||
db['discord_users'][discord_user['id']]['github_id'] = connection['id'] | ||
db['discord_users'][discord_user['id']]['github_username'] = connection['name'] | ||
|
||
# Redirect to our main website | ||
return redirect("https://app.lizardbyte.dev") | ||
|
||
|
||
@app.route("/webhook/<source>", methods=["POST"]) | ||
def webhook(source: str) -> Tuple[Response, int]: | ||
""" | ||
Process webhooks from various sources. | ||
* GitHub sponsors: https://github.com/sponsors/LizardByte/dashboard/webhooks | ||
* GitHub status: https://www.githubstatus.com | ||
Parameters | ||
---------- | ||
source : str | ||
The source of the webhook (e.g., 'github_sponsors', 'github_status'). | ||
Returns | ||
------- | ||
flask.Response | ||
Response to the webhook request | ||
""" | ||
valid_sources = ["github_sponsors", "github_status"] | ||
|
||
if source not in valid_sources: | ||
return jsonify({"status": "error", "message": "Invalid source"}), 400 | ||
|
||
print(f"received webhook from {source}") | ||
data = request.json | ||
print(f"received webhook data: \n{data}") | ||
|
||
if source == "github_sponsors": | ||
# ensure the secret matches | ||
# if data['secret'] != os.getenv("GITHUB_SPONSORS_WEBHOOK_SECRET_KEY"): | ||
# return jsonify({"status": "error", "message": "Invalid secret"}), 400 | ||
|
||
# process the webhook data | ||
if data['action'] == "created": | ||
message = f'New GitHub sponsor: {data["sponsorship"]["sponsor"]["login"]}' | ||
|
||
# create a discord embed | ||
embed = discord.Embed( | ||
author=discord.EmbedAuthor( | ||
name=data["sponsorship"]["sponsor"]["login"], | ||
url=data["sponsorship"]["sponsor"]["url"], | ||
icon_url=data["sponsorship"]["sponsor"]["avatar_url"], | ||
), | ||
color=0x00ff00, | ||
description=message, | ||
footer=discord.EmbedFooter( | ||
text=f"Sponsored at {data['sponsorship']['created_at']}", | ||
), | ||
title="New GitHub Sponsor", | ||
) | ||
message = asyncio.run_coroutine_threadsafe( | ||
globals.DISCORD_BOT.send_message_to_channel( | ||
channel_id=os.getenv("DISCORD_SPONSORS_CHANNEL_ID"), | ||
embeds=[embed], | ||
), globals.DISCORD_BOT.loop) | ||
message.result() # wait for the message to be sent | ||
|
||
return jsonify({"status": "success"}), 200 | ||
|
||
|
||
def run(): | ||
cert_file, key_file = crypto.initialize_certificate() | ||
|
||
app.run( | ||
host="0.0.0.0", | ||
port=8080, | ||
ssl_context=(cert_file, key_file) | ||
) | ||
|
||
|
||
def start(): | ||
server = Thread( | ||
name="Flask", | ||
daemon=True, | ||
target=run, | ||
) | ||
server.start() | ||
Oops, something went wrong.