Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"). You 

4# may not use this file except in compliance with the License. A copy of 

5# the License is located at 

6# 

7# http://aws.amazon.com/apache2.0/ 

8# 

9# or in the "license" file accompanying this file. This file is 

10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 

11# ANY KIND, either express or implied. See the License for the specific 

12# language governing permissions and limitations under the License. 

13 

14from __future__ import annotations 

15 

16import asyncio 

17import time 

18from functools import singledispatch 

19from logging import Logger, getLogger 

20from typing import Any, Dict, Optional, Union 

21 

22import boto3 

23from braket.annealing.problem import Problem 

24from braket.aws.aws_session import AwsSession 

25from braket.circuits.circuit import Circuit 

26from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask 

27 

28 

29class AwsQuantumTask(QuantumTask): 

30 """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing 

31 problem.""" 

32 

33 # TODO: Add API documentation that defines these states. Make it clear this is the contract. 

34 NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"} 

35 RESULTS_READY_STATES = {"COMPLETED"} 

36 

37 GATE_IR_TYPE = "jaqcd" 

38 ANNEALING_IR_TYPE = "annealing" 

39 DEFAULT_SHOTS = 1_000 

40 

41 DEFAULT_RESULTS_POLL_TIMEOUT = 120 

42 DEFAULT_RESULTS_POLL_INTERVAL = 0.25 

43 

44 @staticmethod 

45 def create( 

46 aws_session: AwsSession, 

47 device_arn: str, 

48 task_specification: Union[Circuit, Problem], 

49 s3_destination_folder: AwsSession.S3DestinationFolder, 

50 shots: Optional[int] = None, 

51 backend_parameters: Dict[str, Any] = None, 

52 *args, 

53 **kwargs, 

54 ) -> AwsQuantumTask: 

55 """AwsQuantumTask factory method that serializes a quantum task specification 

56 (either a quantum circuit or annealing problem), submits it to Amazon Braket, 

57 and returns back an AwsQuantumTask tracking the execution. 

58 

59 Args: 

60 aws_session (AwsSession): AwsSession to connect to AWS with. 

61 

62 device_arn (str): The ARN of the quantum device. 

63 

64 task_specification (Union[Circuit, Problem]): The specification of the task 

65 to run on device. 

66 

67 s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket 

68 for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder 

69 to store task results in. 

70 

71 shots (int): The number of times to run the task on the device. If the device is a 

72 simulator, this implies the state is sampled N times, where N = `shots`. Default 

73 shots = 1_000. 

74 

75 backend_parameters (Dict[str, Any]): Additional parameters to send to the device. 

76 For example, for D-Wave: 

77 `{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}` 

78 

79 Returns: 

80 AwsQuantumTask: AwsQuantumTask tracking the task execution on the device. 

81 Note: 

82 The following arguments are typically defined via clients of Device. 

83 - `task_specification` 

84 - `s3_destination_folder` 

85 - `shots` 

86 """ 

87 if len(s3_destination_folder) != 2: 

88 raise ValueError( 

89 "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively." 

90 ) 

91 

92 create_task_kwargs = _create_common_params( 

93 device_arn, 

94 s3_destination_folder, 

95 shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS, 

96 ) 

97 return _create_internal( 

98 task_specification, 

99 aws_session, 

100 create_task_kwargs, 

101 backend_parameters or {}, 

102 *args, 

103 **kwargs, 

104 ) 

105 

106 def __init__( 

107 self, 

108 arn: str, 

109 aws_session: AwsSession = None, 

110 poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT, 

111 poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL, 

112 logger: Logger = getLogger(__name__), 

113 ): 

114 """ 

115 Args: 

116 arn (str): The ARN of the task. 

117 aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services. 

118 Default is `None`, in which case an `AwsSession` object will be created with the 

119 region of the task. 

120 poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds. 

121 poll_interval_seconds (int): The polling interval for result(), default is 0.25 

122 seconds. 

123 logger (Logger): Logger object with which to write logs, such as task statuses 

124 while waiting for task to be in a terminal state. Default is `getLogger(__name__)` 

125 

126 Examples: 

127 >>> task = AwsQuantumTask(arn='task_arn') 

128 >>> task.state() 

129 'COMPLETED' 

130 >>> result = task.result() 

131 AnnealingQuantumTaskResult(...) 

132 

133 >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300) 

134 >>> result = task.result() 

135 GateModelQuantumTaskResult(...) 

136 """ 

137 

138 self._arn: str = arn 

139 self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn( 

140 task_arn=arn 

141 ) 

142 self._poll_timeout_seconds = poll_timeout_seconds 

143 self._poll_interval_seconds = poll_interval_seconds 

144 self._logger = logger 

145 

146 self._metadata: Dict[str, Any] = {} 

147 self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None 

148 try: 

149 asyncio.get_event_loop() 

150 except Exception as e: 

151 self._logger.debug(e) 

152 self._logger.info("No event loop found; creating new event loop") 

153 asyncio.set_event_loop(asyncio.new_event_loop()) 

154 self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) 

155 

156 @staticmethod 

157 def _aws_session_for_task_arn(task_arn: str) -> AwsSession: 

158 """ 

159 Get an AwsSession for the Task ARN. The AWS session should be in the region of the task. 

160 

161 Returns: 

162 AwsSession: `AwsSession` object with default `boto_session` in task's region 

163 """ 

164 task_region = task_arn.split(":")[3] 

165 boto_session = boto3.Session(region_name=task_region) 

166 return AwsSession(boto_session=boto_session) 

167 

168 @property 

169 def id(self) -> str: 

170 """str: The ARN of the quantum task.""" 

171 return self._arn 

172 

173 def cancel(self) -> None: 

174 """Cancel the quantum task. This cancels the future and the task in Amazon Braket.""" 

175 self._future.cancel() 

176 self._aws_session.cancel_quantum_task(self._arn) 

177 

178 def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: 

179 """ 

180 Get task metadata defined in Amazon Braket. 

181 

182 Args: 

183 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved 

184 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the 

185 `GetQuantumTask` operation to retrieve metadata, which also updates the cached 

186 value. Default = `False`. 

187 Returns: 

188 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. 

189 If `use_cached_value` is `True`, Amazon Braket is not called and the most recently 

190 retrieved value is used. 

191 """ 

192 if not use_cached_value: 

193 self._metadata = self._aws_session.get_quantum_task(self._arn) 

194 return self._metadata 

195 

196 def state(self, use_cached_value: bool = False) -> str: 

197 """ 

198 The state of the quantum task. 

199 

200 Args: 

201 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved 

202 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the 

203 `GetQuantumTask` operation to retrieve metadata, which also updates the cached 

204 value. Default = `False`. 

205 Returns: 

206 str: The value of `status` in `metadata()`. This is the value of the `status` key 

207 in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, 

208 the value most recently returned from the `GetQuantumTask` operation is used. 

209 See Also: 

210 `metadata()` 

211 """ 

212 return self.metadata(use_cached_value).get("status") 

213 

214 def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: 

215 """ 

216 Get the quantum task result by polling Amazon Braket to see if the task is completed. 

217 Once the task is completed, the result is retrieved from S3 and returned as a 

218 `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult` 

219 

220 This method is a blocking thread call and synchronously returns a result. Call 

221 async_result() if you require an asynchronous invocation. 

222 Consecutive calls to this method return a cached result from the preceding request. 

223 """ 

224 try: 

225 return asyncio.get_event_loop().run_until_complete(self.async_result()) 

226 except asyncio.CancelledError: 

227 # Future was cancelled, return whatever is in self._result if anything 

228 self._logger.warning("Task future was cancelled") 

229 return self._result 

230 

231 def async_result(self) -> asyncio.Task: 

232 """ 

233 Get the quantum task result asynchronously. Consecutive calls to this method return 

234 the result cached from the most recent request. 

235 """ 

236 if self._future.done() and self._result is None: # timed out and no result 

237 task_status = self.metadata()["status"] 

238 if task_status in self.NO_RESULT_TERMINAL_STATES: 

239 self._logger.warning( 

240 f"Task is in terminal state {task_status} and no result is available" 

241 ) 

242 # Return done future. Don't restart polling. 

243 return self._future 

244 else: 

245 self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) 

246 return self._future 

247 

248 async def _create_future(self) -> asyncio.Task: 

249 """ 

250 Wrap the `_wait_for_completion` coroutine inside a future-like object. 

251 Invoking this method starts the coroutine and returns back the future-like object 

252 that contains it. Note that this does not block on the coroutine to finish. 

253 

254 Returns: 

255 asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine. 

256 """ 

257 return asyncio.create_task(self._wait_for_completion()) 

258 

259 def _get_results_formatter( 

260 self, 

261 ) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: 

262 """ 

263 Get results formatter based on irType of self.metadata() 

264 

265 Returns: 

266 Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: 

267 function that deserializes a string into a results structure 

268 """ 

269 current_metadata = self.metadata() 

270 ir_type = current_metadata["irType"] 

271 if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE: 

272 return AnnealingQuantumTaskResult.from_string 

273 elif ir_type == AwsQuantumTask.GATE_IR_TYPE: 

274 return GateModelQuantumTaskResult.from_string 

275 else: 

276 raise ValueError("Unknown IR type") 

277 

278 async def _wait_for_completion(self) -> GateModelQuantumTaskResult: 

279 """ 

280 Waits for the quantum task to be completed, then returns the result from the S3 bucket. 

281 Returns: 

282 GateModelQuantumTaskResult: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES` 

283 state within the specified time limit, the result from the S3 bucket is loaded and 

284 returned. `None` is returned if a timeout occurs or task state is in 

285 `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`. 

286 Note: 

287 Timeout and sleep intervals are defined in the constructor fields 

288 `poll_timeout_seconds` and `poll_interval_seconds` respectively. 

289 """ 

290 self._logger.debug(f"Task {self._arn}: start polling for completion") 

291 start_time = time.time() 

292 

293 while (time.time() - start_time) < self._poll_timeout_seconds: 

294 current_metadata = self.metadata() 

295 task_status = current_metadata["status"] 

296 self._logger.debug(f"Task {self._arn}: task status {task_status}") 

297 if task_status in AwsQuantumTask.RESULTS_READY_STATES: 

298 result_string = self._aws_session.retrieve_s3_object_body( 

299 current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"] 

300 ) 

301 self._result = self._get_results_formatter()(result_string) 

302 return self._result 

303 elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES: 

304 self._logger.warning( 

305 f"Task is in terminal state {task_status}" + "and no result is available" 

306 ) 

307 self._result = None 

308 return None 

309 else: 

310 await asyncio.sleep(self._poll_interval_seconds) 

311 

312 # Timed out 

313 self._logger.warning( 

314 f"Task {self._arn}: polling for task completion timed out after " 

315 + f"{time.time()-start_time} secs" 

316 ) 

317 self._result = None 

318 return None 

319 

320 def __repr__(self) -> str: 

321 return f"AwsQuantumTask('id':{self.id})" 

322 

323 def __eq__(self, other) -> bool: 

324 if isinstance(other, AwsQuantumTask): 

325 return self.id == other.id 

326 return NotImplemented 

327 

328 def __hash__(self) -> int: 

329 return hash(self.id) 

330 

331 

332@singledispatch 

333def _create_internal( 

334 task_specification: Union[Circuit, Problem], 

335 aws_session: AwsSession, 

336 create_task_kwargs: Dict[str, Any], 

337 backend_parameters: Dict[str, Any], 

338 *args, 

339 **kwargs, 

340) -> AwsQuantumTask: 

341 raise TypeError("Invalid task specification type") 

342 

343 

344@_create_internal.register 

345def _( 

346 circuit: Circuit, 

347 aws_session: AwsSession, 

348 create_task_kwargs: Dict[str, Any], 

349 backend_parameters: Dict[str, Any], 

350 *args, 

351 **kwargs, 

352) -> AwsQuantumTask: 

353 create_task_kwargs.update( 

354 { 

355 "ir": circuit.to_ir().json(), 

356 "irType": AwsQuantumTask.GATE_IR_TYPE, 

357 "backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}}, 

358 } 

359 ) 

360 task_arn = aws_session.create_quantum_task(**create_task_kwargs) 

361 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) 

362 

363 

364@_create_internal.register 

365def _( 

366 problem: Problem, 

367 aws_session: AwsSession, 

368 create_task_kwargs: Dict[str, Any], 

369 backend_parameters: Dict[str, Any], 

370 *args, 

371 **kwargs, 

372) -> AwsQuantumTask: 

373 create_task_kwargs.update( 

374 { 

375 "ir": problem.to_ir().json(), 

376 "irType": AwsQuantumTask.ANNEALING_IR_TYPE, 

377 "backendParameters": {"annealingModelParameters": backend_parameters}, 

378 } 

379 ) 

380 

381 task_arn = aws_session.create_quantum_task(**create_task_kwargs) 

382 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) 

383 

384 

385def _create_common_params( 

386 device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int 

387) -> Dict[str, Any]: 

388 return { 

389 "backendArn": device_arn, 

390 "resultsS3Bucket": s3_destination_folder[0], 

391 "resultsS3Prefix": s3_destination_folder[1], 

392 "shots": shots, 

393 }