-
Notifications
You must be signed in to change notification settings - Fork 1
/
client.py
319 lines (246 loc) · 10.7 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""Class structures for communicating with the API server over HTTP
"""
import os
from collections import OrderedDict
import json
import numpy as np
import time
import requests
from exceptions import TCCError, HTTPCommunicationError, ServerError
class Client(object):
"""Main class for communication with the TeraChem Cloud API server
"""
def __init__(self,
user=None,
api_key=None,
url="http://localhost:80",
engine='TeraChem',
verbose=False):
"""Initialize a Client object
Args:
user (str): TeraChem Cloud user
api_key (str): TeraChem Cloud API key
engine (str): Code to be used for ab initio calculation
host (str): URL for the TeraChem api server (e.g. http://<hostname>:<port>)
verbose (bool): print extra info about API interactions
"""
# Try to get authentication from environment
if user is not None:
self.user = str(user)
else:
self.user = os.environ['TCCLOUD_USER']
if self.user is None:
raise ValueError('"user" not specified and environment variable "TCCLOUD_USER" not set')
if api_key is not None:
self.api_key = str(api_key)
else:
self.api_key = os.environ['TCCLOUD_API_KEY']
if self.api_key is None:
raise ValueError('"api_key" not specified and environment variable "TCCLOUD_API_KEY" not set')
# TCC server options
self.engine = engine.lower()
self.url = url
self.submit_endpoint = "/v1/{}/".format(self.engine)
self.results_endpoint = "/v1/job/"
self.help_endpoint = "/v1/docs/"
self.verbose = verbose
# try to connect to the server
payload = {
'api_key': self.api_key,
'user_id': self.user
}
try:
r = requests.post(self.url + '/login', json=payload)
except requests.exceptions.RequestException as e:
raise HTTPCommunicationError('Error while POSTing login', e)
if r.status_code != requests.codes.ok:
raise ServerError(r)
if self.verbose:
print('LOGIN> http code: {} response: {}'.format(r.status_code, r.text))
def help(self):
"""Request allowed keywords from API server
"""
# Package data according to API server specifications
payload = {
'engine': self.engine,
'api_key': self.api_key,
'user_id': self.user
}
# Send HTTP request
try:
r = requests.get(self.url + self.help_endpoint, json=payload)
except requests.exceptions.RequestException as e:
raise HTTPCommunicationError('Error while POSTing for docs', e)
if r.status_code != requests.codes.ok:
raise ServerError(r)
response = json.loads(r.text)
print('API parameters for {} backend (with allowed types and values):'.format(self.engine))
print(response['docs'])
def submit(self, geom, options):
"""Pack and send the current tc_config dict as a POST request to the Tornado API server
This function returns a job_id and a message
Args:
geom (np.ndarray or list): Cartesian geometry at which to perform the calculation
options (dict): Job options to pass to TeraChem Cloud server
Returns:
str: Job id
dict: Results
"""
# Flatten any arrays for JSON serialization
if isinstance(geom, np.ndarray):
geom = list(geom.flatten())
job_options = options.copy()
for key, value in job_options.items():
if isinstance(value, np.ndarray):
job_options[key] = list(value.flatten())
# Package data according to API server specifications
payload = {
'api_key': self.api_key,
'user_id': self.user,
'geom': geom,
'config': job_options,
}
# Send HTTP request
try:
r = requests.post(self.url + self.submit_endpoint, json=payload)
except requests.exceptions.RequestError as e:
raise HTTPCommunicationError('Error while POSTing for job submission', e)
if r.status_code != requests.codes.ok:
raise ServerError(r)
response = json.loads(r.text)
if self.verbose:
print("SUBMIT> http code: {} response: {}".format(r.status_code, response))
try:
job_id = response['job_id']
except KeyError:
raise TCCError("Unexpectedly did not receive job ID: {}".format(response))
return job_id
def is_finished(self, results):
"""Helper function to test whether a job is finished.
Args:
results (dict): Job results from self.get_results()
Returns:
bool: True if job succeeded/failed, False if job is running/submitted/pending
"""
job_status = results['job_status']
return (job_status == 'SUCCESS' or job_status == 'FAILURE')
def get_results(self, job_id):
"""Query API for results of calculations.
Recommended way to check for job completion:
::
results = client.get_results(job_id)
finished = client.is_finished(results)
Args:
job_id (str): Job id to check status of
Returns:
dict: Result dictionary from TCC server with job_id added for posterity
"""
payload = {
'api_key': self.api_key,
'user_id': self.user,
'job_id': job_id
}
try:
r = requests.get(self.url + self.results_endpoint, json=payload)
except requests.exceptions.RequestError as e:
raise HTTPCommunicationError('Error while GETing for job results', e)
if r.status_code != requests.codes.ok:
raise ServerError(r)
results = json.loads(r.text)
results['job_id'] = job_id
if self.verbose:
print("GET_RESULTS> job_id: {} current status: {}".format(
job_id, results['job_status']))
if self.is_finished(results):
print(results)
return results
def poll_for_results(self, job_id, sleep_seconds=1, max_poll=200):
"""Send http request every sleep_seconds seconds until a finished job is
returned or max_poll requests have been sent.
Recommended way to check for job completion:
::
results = client.poll_for_results(job_id)
finished = client.is_finished(results)
Args:
job_id (str): Job id to poll for
sleep_seconds (int): Number of seconds to wait between poll loops
max_poll (int): Number of poll loops
Returns:
dict: Results dict as given by self.get_results()
"""
results = {}
for i in range(max_poll):
if self.verbose:
print('POLL_FOR_RESULTS> poll loop: {}'.format(i))
results = self.get_results(job_id)
if self.is_finished(results):
break
time.sleep(sleep_seconds)
if self.verbose and not self.is_finished(results):
print("!!!WARNING!!! {} did not finish during poll loop".format(job_id))
return results
def poll_for_bulk_results(self, job_ids, sleep_seconds=1, max_poll=200):
"""Send http request every sleep_seconds seconds until a finished job is
returned or max_poll requests have been sent.
Recommended way to check for job completion:
::
results_list = client.poll_for_bulk_results(job_ids)
finished = [client.is_finished(r) for r in results_list]
Args:
job_ids (list): Job ids to poll for
sleep_seconds (int): Number of seconds to wait between poll loops
max_poll (int): Number of poll loops
Returns:
list: List of results dicts as given by self.get_results()
"""
# Initialize result storage
results_dict = OrderedDict()
for j in job_ids:
results_dict[j] = {}
running_jobs = list(results_dict.keys())
for i in range(max_poll):
if self.verbose:
print('POLL_FOR_BULK_RESULTS> poll loop: {}'.format(i))
for job_id in running_jobs:
results_dict[job_id] = self.get_results(job_id)
# Update running jobs
running_jobs = [k for k,v in list(results_dict.items()) if not self.is_finished(v)]
if len(running_jobs) == 0:
break
time.sleep(sleep_seconds)
if self.verbose:
for job_id in running_jobs:
print("!!!WARNING!!! {} did not finish during poll loop".format(job_id))
# Pull results out into list
results_list = [v for v in results_dict.values()]
return results_list
def compute(self, geom, options, sleep_seconds=1, max_poll=200):
"""Convenience routine for synchronous use.
Check self.poll_for_results() for recommended way to check for job completion.
Args:
geom ((num_atom, 3) ndarray): Geometry to consider
options (dict): Job options to pass to TeraChem Cloud server
sleep_seconds (int): Number of seconds to wait between poll loops for self.poll_for_results()
max_poll (int): Number of poll loops for self.poll_for_results()
**kwargs: TCC configuration passed to self.submit()
Returns:
dict: Job results from TCC server
"""
job_id = self.submit(geom, options)
results = self.poll_for_results(job_id, sleep_seconds, max_poll)
return results
def compute_bulk(self, geoms, options, sleep_seconds=1, max_poll=200):
"""Convenience routine for multiple geometries.
Check self.poll_for_bulk_results() for recommended way to check for job completion.
Args:
geoms (list of (num_atom, 3) ndarray): Geometries to consider
options (dict): Job options to pass to TeraChem Cloud server
sleep_seconds (int): Number of seconds to wait between poll loops for self.poll_for_bulk_results()
max_poll (int): Number of poll loops for self.poll_for_bulk_results()
**kwargs: TCC configuration passed to self.submit()
Returns:
list: List of Job results from TCC server
"""
job_ids = [self.submit(g, options) for g in geoms]
results_list = self.poll_for_bulk_results(job_ids, sleep_seconds, max_poll)
return results_list