forked from daveshap/Quickly_Extract_Science_Papers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat.py
94 lines (72 loc) · 2.7 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import openai
from time import time, sleep
from halo import Halo
import textwrap
import yaml
### file operations
def save_file(filepath, content):
with open(filepath, "w", encoding="utf-8") as outfile:
outfile.write(content)
def open_file(filepath):
with open(filepath, "r", encoding="utf-8", errors="ignore") as infile:
return infile.read()
def save_yaml(filepath, data):
with open(filepath, "w", encoding="utf-8") as file:
yaml.dump(data, file, allow_unicode=True)
def open_yaml(filepath):
with open(filepath, "r", encoding="utf-8") as file:
data = yaml.load(file, Loader=yaml.FullLoader)
return data
functions = [{}]
### API functions
def chatbot(conversation, model="gpt-4-0613", temperature=0):
max_retry = 7
retry = 0
while True:
try:
spinner = Halo(text="Thinking...", spinner="dots")
spinner.start()
response = openai.ChatCompletion.create(
model=model, messages=conversation, temperature=temperature
)
text = response["choices"][0]["message"]["content"]
spinner.stop()
return text, response["usage"]["total_tokens"]
except Exception as oops:
print(f'\n\nError communicating with OpenAI: "{oops}"')
if "maximum context length" in str(oops):
a = conversation.pop(0)
print("\n\n DEBUG: Trimming oldest message")
continue
retry += 1
if retry >= max_retry:
print(f"\n\nExiting due to excessive errors in API: {oops}")
exit(1)
print(f"\n\nRetrying in {2 ** (retry - 1) * 5} seconds...")
sleep(2 ** (retry - 1) * 5)
def chat_print(text):
formatted_lines = [
textwrap.fill(line, width=120, initial_indent=" ", subsequent_indent=" ")
for line in text.split("\n")
]
formatted_text = "\n".join(formatted_lines)
print("\n\n\nCHATBOT:\n\n%s" % formatted_text)
if __name__ == "__main__":
openai.api_key = open_file("key_openai.txt").strip()
paper = open_file("input.txt")
if len(paper) > 22000:
paper = paper[0:22000]
ALL_MESSAGES = [{"role": "system", "content": paper}]
while True:
# get user input
text = input("\n\n\nUSER:\n\n")
if text == "":
# empty submission, probably on accident
continue
ALL_MESSAGES.append({"role": "user", "content": text})
# get response
response, tokens = chatbot(ALL_MESSAGES)
if tokens >= 7800:
a = ALL_MESSAGES.pop(1)
chat_print(response)
ALL_MESSAGES.append({"role": "assistant", "content": response})