-
Notifications
You must be signed in to change notification settings - Fork 2
/
client.py
484 lines (399 loc) · 16.8 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
import warnings
import json
import getpass
from typing import Sequence, List, Dict, Any
import time
import requests
try:
import bs4
except ImportError:
bs4 = None
try:
import stem
from aids.app.obfuscate import get_tor_session, renew_connection
except ImportError:
stem = None
from aids.app.models import Story, Scenario, ValidationError
from aids.app.writelogs import logged
from aids.app import settings, schemes
def check_errors(request):
def inner_func(cls, method, url, **kwargs):
request_success = False
while not request_success:
try:
response = request(cls, method, url, **kwargs)
response.raise_for_status()
except (
requests.exceptions.ConnectionError,
requests.exceptions.SSLError,
) as exc:
cls.logger_err.exception(exc)
cls.logger.info("Network unstable. Retrying...")
cls.logger_err.error(
"Server URL: %s, failed while trying to connect.", url
)
except requests.exceptions.HTTPError as exc:
try:
errors = response.json()["errors"]
except json.decoder.JSONDecodeError:
errors = "No errors"
except KeyError:
#
errors = response.json()
raw_response = (
response.content[:50]
if len(response.content) > 50
else response.content
)
try:
payload = kwargs["data"]
except KeyError:
payload = "none"
error_message = f"""
Server URL: {response.url},
failed with status code ({response.status_code}).
Errors: {errors}.
Raw response: {raw_response}
Request payload: {payload}
"""
cls.logger.error(error_message)
raise requests.exceptions.HTTPError from exc
else:
request_success = True
# time.sleep(3)
return response
return inner_func
@logged
class Session(requests.Session):
"""
Overrriden version of requests.Session that checks for errors
after completing the request.
"""
@check_errors
def request(self, method, url, **kwargs):
return super().request(method, url, **kwargs)
@logged
class BaseClient:
"""
Base client from where all other clients must inherit.
"""
def __init__(self):
self.url = None
self.session = Session()
self.session.headers.update(settings.get_request_headers())
self.logger.info("%s successfully initialized.", self.__class__.__name__)
def __del__(self):
self.session.close()
def quit(self):
"""
Kill the client.
"""
self.session.close()
def renew(self):
"""
Use Tor to fake our IP address. Note that couldfare is going to be a
PITA so this method is pretty useless as it is.
"""
if not stem:
raise ImportError("You need the stem library to use this method")
try:
renew_connection()
except stem.SocketError:
self.logger_err.error(
"Socket Error: " "The Tor service is not up. Unable to continue."
)
# (XXX) Shouldn't I re-raise the exception?
self.session = get_tor_session(self.session)
def login(self, credentials: dict):
"""
Login into the site using a dict credentials.
"""
raise NotImplementedError
def logout(self):
"""
Clean up the session to \"log-out\".
"""
self.session.headers = settings.get_request_headers()
self.session.cookies.clear()
self.logger.info("Logged out.")
class AIDScrapper(BaseClient):
"""
AID Client to make API calls via requests.
"""
adventures = Story()
prompts = Scenario()
def __init__(self):
super().__init__()
self.url = "https://api.aidungeon.io/graphql"
# Get all settings
self.stories_query = schemes.stories_query
self.story_query = schemes.story_query
self.create_scen_payload = schemes.create_scen_payload
self.update_scen_payload = schemes.update_scen_payload
self.make_WI_payload = schemes.make_WI_payload
self.update_WI_payload = schemes.update_WI_payload
self.scenarios_query = schemes.scenarios_query
self.scenario_query = schemes.scenario_query
self.wi_query = schemes.wi_query
self.aid_loginpayload = schemes.aid_loginpayload
self.offset = 0
def login(self, credentials=None):
if not credentials:
try:
self.logger.info("Trying to log-in via file...")
username = settings.get_secret("AID_USERNAME")
password = settings.get_secret("AID_PASSWORD")
except settings.ImproperlyConfigured:
self.logger.info("File is not configured... logging in via console.")
warnings.warn("Registering a user is the preferred way of logging in.")
username = input("Your username or e-mail: ").strip()
password = getpass.getpass("Your password: ")
credentials = {"username": username, "password": password}
else:
self.logger.info("Credentials were passed to the function directly...")
key = self.get_login_token(credentials)
self.session.headers.update({"x-access-token": key})
self.logger.info(
'User "%s" sucessfully logged into AID', credentials["username"]
)
def _get_story_content(self, story_id: str) -> Dict[str, Any]:
self.story_query.update({"variables": {"publicId": story_id}})
adventure = self.session.post(self.url, json=self.story_query).json()["data"][
"adventure"
]
return adventure
def _get_scenario_content(self, scenario_id: str) -> Dict[str, Any]:
self.scenario_query.update({"variables": {"publicId": scenario_id}})
wi = self._get_wi(scenario_id)
scenario = self.session.post(self.url, json=self.scenario_query).json()["data"][
"scenario"
]
scenario.update({"worldInfo": wi})
return scenario
def _get_wi(self, scenario_id: str) -> Dict[str, Any]:
self.wi_query["variables"].update({"contentPublicId": scenario_id})
return self.session.post(self.url, json=self.wi_query).json()["data"][
"worldInfoType"
]
def _query_objects(self, query: dict, term: str = "") -> Dict[str, Any]:
query["variables"]["input"]["searchTerm"] = (
term or self.adventures.title or self.prompts.title
)
return self.session.post(self.url, data=json.dumps(query)).json()["data"][
"user"
]["search"]
def get_stories(self):
while True:
result: List[Dict[str, Any]] = self._query_objects(self.stories_query)
if any(result):
assert result, "No result?"
for story in result:
s = self._get_story_content(story["publicId"])
self.offset += 1
if not self.adventures.title:
# To optimize queries -- stop when we are under self.adventures.min_act actions
try:
self.adventures._add(s)
except ValidationError as exc:
self.logger.debug(exc)
# actions are under the limit. Abort.
return
else:
self.adventures.add(s)
self.logger.info('Loaded story: "%s"', story["title"])
self.logger.debug("Got %d stories so far", len(self.adventures))
self.stories_query["variables"]["input"]["offset"] = self.offset
else:
self.logger.info("All stories downloaded")
return
def get_scenarios(self):
while True:
result: List[Dict[str, Any]] = self._query_objects(self.scenarios_query)
if any(result):
assert result, "No result?"
for scenario in result:
self.add_all_scenarios(scenario["publicId"])
self.logger.debug("Got %d scenarios so far", len(self.prompts))
self.scenarios_query["variables"]["input"]["offset"] = self.offset
else:
self.logger.info("All scenarios downloaded")
self.offset = 0
break
def add_all_scenarios(self, pubid, isOption=False) -> List[Dict[str, Any]]:
"""Adds all scenarios and their children to memory"""
scenario: Dict[str, Any] = self._get_scenario_content(pubid)
scenario["isOption"] = isOption
if "options" in scenario and isinstance(scenario["options"], Sequence):
for option in scenario["options"]:
self.add_all_scenarios(option["publicId"], True)
self.prompts.add(scenario)
self.offset += 1 if not isOption else 0
self.logger.info("Added %s to memory", scenario["title"])
def get_login_token(self, credentials: Dict[str, Any]):
self.aid_loginpayload["variables"]["identifier"] = self.aid_loginpayload[
"variables"
]["email"] = credentials["username"]
self.aid_loginpayload["variables"]["password"] = credentials["password"]
res = self.session.post(self.url, data=json.dumps(self.aid_loginpayload)).json()
if "data" in res:
try:
token = res["data"]["login"]["accessToken"]
except KeyError as exc:
raise (KeyError("There was no token")) from exc
assert token
return token
self.logger_err.error("There was no data")
return None
def upload_in_bulk(self, scenarios: Dict[str, Any]):
for key in scenarios:
scenario = scenarios[key]
assert isinstance(scenario, dict)
res = self.session.post(
self.url, data=json.dumps(self.create_scen_payload)
).json()["data"]["createScenario"]
scenario.update({"publicId": res["publicId"]})
new_scenario = self.update_scen_payload.copy()
# (XXX) This process have been delegated to the
# data models. Maybe wait for me to make a proper "Scenario" object
# to refactor it?
clean_scenario = {
k: v
for k, v in scenario.items()
if k in new_scenario["variables"]["input"]
}
new_scenario.update({"variables": {"input": clean_scenario}})
self.session.post(self.url, data=json.dumps(new_scenario))
self.logger.info("%s successfully uploaded...", scenario["title"])
class ClubClient(BaseClient):
def __init__(self):
super().__init__()
warnings.warn(
"The Club Client is should be out of service due to "
"changes in aidg.club.com back-end."
)
self.url = "https://prompts.aidg.club/"
if not bs4:
raise ImportError(
"You must install the BeautifulSoup library to use the Club client."
)
def _post(self, obj_url, params):
url = self.url + obj_url
self.session.headers.update(dict(Referer=url))
params["__RequestVerificationToken"] = self.get_secret_token(url)
self.session.post(url, data=params)
@staticmethod
def reformat_tags(tags):
nsfw = "false"
tags_str = ", ".join(tag for tag in tags)
for tag in tags:
if tag == "nsfw":
nsfw = "true"
return {"nsfw": nsfw, "tags": tags_str}
def get_secret_token(self, url):
res = self.session.get(url)
body = bs4.BeautifulSoup(res.text)
hidden_token = body.find("input", {"name": "__RequestVerificationToken"})
return hidden_token.attrs["value"]
def register(self, credentials=None):
credentials = credentials or {}
params = {
"ReturnUrl": "/",
"Honey": "",
"Username": "",
"Password": "",
"PasswordConfirm": "",
}
params.update(credentials)
if not credentials:
params.update(
{
"Username": input("Username: "),
"Password": getpass.getpass("Password: "),
"PasswordConfirm": getpass.getpass("Password(Again): "),
}
)
self._post("user/register/", params)
def login(self, credentials=None):
credentials = credentials or {}
params = {"ReturnUrl": "", "Honey": None, "Username": "", "Password": ""}
params.update(credentials)
if not credentials:
params.update(
{
"Username": input("Username: "),
"Password": getpass.getpass("Password: "),
}
)
self._post("user/login/", params)
def publish_scenario(self, title: str = ""):
"""
Publish a scenario with a given name to the club.
"""
# variables
variables = ("?savedraft=true", "?confirm=false#")
with open("scenario.json") as file:
infile = json.load(file)
for scenario in infile["scenarios"]:
if title in (scenario["title"], "*"):
# prepare the request
# prepare tags
tags = self.reformat_tags(scenario["tags"])
try:
quests = "\n".join(scenario["quests"]["quest"])
except KeyError:
quests = []
params = {
"Honey": "",
"Command.ParentId": "",
"Command.Title": scenario["title"],
"Command.Description": scenario["description"],
"Command.promptsContent": scenario["prompts"],
"Command.promptsTags": tags["tags"],
"Command.Memory": scenario["memory"],
"Command.Quests": quests,
"Command.AuthorsNote": scenario["authorsNote"],
"Command.Nsfw": tags["nsfw"],
"ScriptZip": "", # file
"WorldInfoFile": "", # file
}
# prepare WI
counter = 0
try:
for wi_entry in scenario["worldInfo"]:
params[f"Command.WorldInfos[{counter}].Keys"] = wi_entry["keys"]
params[f"Command.WorldInfos[{counter}].Entry"] = wi_entry[
"entry"
]
counter += 1
except KeyError:
pass
res = self.session.post(variables[1], params)
print(f'Your prompts number is {res.url.split("/")[-1]}')
# I don't want to overload his servers...
time.sleep(1)
class HoloClient(BaseClient):
def __init__(self):
super().__init__()
self.base_url = "https://writeholo.com/"
self.url = self.base_url + "api/"
# Get all settings
self.generate_holo = schemes.generate_holo
self.curr_story_id = ""
def login(self, credentials=None):
# we need to get the cookies to interact with the API
self.session.get(self.base_url)
if credentials:
# TODO
raise NotImplementedError
assert self.session.cookies
def create_scenario(self):
res = self.session.post(self.url + "create_story")
return res.json()["story_id"]
def generate_output(self, context: Dict[str, Any] = None):
if not self.curr_story_id:
self.curr_story_id = self.create_scenario()
self.generate_holo["story_id"] = self.curr_story_id
self.generate_holo.update(context)
payload = json.dumps(self.generate_holo)
res = self.session.post(self.url + "draw_completions", data=payload)
return res.json()["outputs"]