forked from ray-project/docu-mentor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
263 lines (209 loc) · 8.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import httpx
from dotenv import load_dotenv
import os
import openai
import logging
import string
import sys
import ray
from ray import serve
from utils import (
generate_jwt,
get_installation_access_token,
get_diff_url,
get_branch_files,
get_pr_head_branch,
parse_diff_to_line_numbers,
get_context_from_files,
)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger("Docu Mentor")
GREETING = """
👋 Hi, I'm @docu-mentor, an LLM-powered GitHub app
powered by [Anyscale Endpoints](https://app.endpoints.anyscale.com/)
that gives you actionable feedback on your writing.
Simply create a new comment in this PR that says:
@docu-mentor run
and I will start my analysis. I only look at what you changed
in this PR. If you only want me to look at specific files or folders,
you can specify them like this:
@docu-mentor run doc/ README.md
In this example, I'll have a look at all files contained in the "doc/"
folder and the file "README.md". All good? Let's get started!
"""
load_dotenv()
# If the app was installed, retrieve the installation access token through the App's
# private key and app ID, by generating an intermediary JWT token.
APP_ID = os.environ.get("APP_ID")
PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "")
ANYSCALE_API_ENDPOINT = "https://api.endpoints.anyscale.com/v1"
openai.api_base = ANYSCALE_API_ENDPOINT
openai.api_key = os.environ.get("ANYSCALE_API_KEY")
SYSTEM_CONTENT = """You are a helpful assistant.
Improve the following <content>. Criticise syntax, grammar, punctuation, style, etc.
Recommend common technical writing knowledge, such as used in Vale
and the Google developer documentation style guide.
For Python docstrings, make sure input arguments and return values are documented.
Also, docstrings should have good descriptions and come with examples.
If the content is good, don't comment on it.
You can use GitHub-flavored markdown syntax in your answer.
If you encounter several files, give very concise feedback per file.
"""
PROMPT = """Improve this content.
Don't comment on file names or other meta data, just the actual text.
The <content> will be in JSON format and contains file name keys and text values.
Make sure to give very concise feedback per file.
"""
def mentor(
content,
model="codellama/CodeLlama-34b-Instruct-hf",
system_content=SYSTEM_CONTENT,
prompt=PROMPT
):
result = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": system_content},
{"role": "user", "content": f"This is the content: {content}. {prompt}"},
],
temperature=0,
)
usage = result.get("usage")
prompt_tokens = usage.get("prompt_tokens")
completion_tokens = usage.get("completion_tokens")
content = result["choices"][0]["message"]["content"]
return content, model, prompt_tokens, completion_tokens
try:
ray.init()
except:
logger.info("Ray init failed.")
@ray.remote
def mentor_task(content, model, system_content, prompt):
return mentor(content, model, system_content, prompt)
def ray_mentor(
content: dict,
model="codellama/CodeLlama-34b-Instruct-hf",
system_content=SYSTEM_CONTENT,
prompt="Improve this content."
):
futures = [
mentor_task.remote(v, model, system_content, prompt)
for v in content.values()
]
suggestions = ray.get(futures)
content = {k: v[0] for k, v in zip(content.keys(), suggestions)}
prompt_tokens = sum(v[2] for v in suggestions)
completion_tokens = sum(v[3] for v in suggestions)
print_content = ""
for k, v in content.items():
print_content += f"{k}:\n\t\{v}\n\n"
logger.info(print_content)
return print_content, model, prompt_tokens, completion_tokens
app = FastAPI()
async def handle_webhook(request: Request):
data = await request.json()
installation = data.get("installation")
if installation and installation.get("id"):
installation_id = installation.get("id")
logger.info(f"Installation ID: {installation_id}")
JWT_TOKEN = generate_jwt()
installation_access_token = await get_installation_access_token(
JWT_TOKEN, installation_id
)
headers = {
"Authorization": f"token {installation_access_token}",
"User-Agent": "docu-mentor-bot",
"Accept": "application/vnd.github.VERSION.diff",
}
else:
raise ValueError("No app installation found.")
# If PR exists and is opened
if "pull_request" in data.keys() and (
data["action"] in ["opened", "reopened"]
): # use "synchronize" for tracking new commits
pr = data.get("pull_request")
# Greet the user and show instructions.
async with httpx.AsyncClient() as client:
await client.post(
f"{pr['issue_url']}/comments",
json={"body": GREETING},
headers=headers,
)
return JSONResponse(content={}, status_code=200)
# Check if the event is a new or modified issue comment
if "issue" in data.keys() and data.get("action") in ["created", "edited"]:
issue = data["issue"]
# Check if the issue is a pull request
if "/pull/" in issue["html_url"]:
pr = issue.get("pull_request")
# Get the comment body
comment = data.get("comment")
comment_body = comment.get("body")
# Remove all whitespace characters except for regular spaces
comment_body = comment_body.translate(
str.maketrans("", "", string.whitespace.replace(" ", ""))
)
# Skip if the bot talks about itself
author_handle = comment["user"]["login"]
# Check if the bot is mentioned in the comment
if (
author_handle != "docu-mentor[bot]"
and "@docu-mentor run" in comment_body
):
async with httpx.AsyncClient() as client:
# Fetch diff from GitHub
files_to_keep = comment_body.replace(
"@docu-mentor run", ""
).split(" ")
files_to_keep = [item for item in files_to_keep if item]
logger.info(files_to_keep)
url = get_diff_url(pr)
diff_response = await client.get(url, headers=headers)
diff = diff_response.text
files_with_lines = parse_diff_to_line_numbers(diff)
# Get head branch of the PR
headers["Accept"] = "application/vnd.github.full+json"
head_branch = await get_pr_head_branch(pr, headers)
# Get files from head branch
head_branch_files = await get_branch_files(pr, head_branch, headers)
print("HEAD FILES", head_branch_files)
# Enrich diff data with context from the head branch.
context_files = get_context_from_files(head_branch_files, files_with_lines)
# Filter the dictionary
if files_to_keep:
context_files = {
k: context_files[k]
for k in context_files
if any(sub in k for sub in files_to_keep)
}
# Get suggestions from Docu Mentor
content, model, prompt_tokens, completion_tokens = \
ray_mentor(context_files) if ray.is_initialized() else mentor(context_files)
# Let's comment on the PR
await client.post(
f"{comment['issue_url']}/comments",
json={
"body": f":rocket: Docu Mentor finished "
+ "analysing your PR! :rocket:\n\n"
+ "Take a look at your results:\n"
+ f"{content}\n\n"
+ "This bot is proudly powered by "
+ "[Anyscale Endpoints](https://app.endpoints.anyscale.com/).\n"
+ f"It used the model {model}, used {prompt_tokens} prompt tokens, "
+ f"and {completion_tokens} completion tokens in total."
},
headers=headers,
)
@serve.deployment(route_prefix="/")
@serve.ingress(app)
class ServeBot:
@app.get("/")
async def root(self):
return {"message": "Docu Mentor reporting for duty!"}
@app.post("/webhook/")
async def handle_webhook_route(self, request: Request):
return await handle_webhook(request)
# Run with: serve run main:bot
bot = ServeBot.bind()