Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"). You 

4# may not use this file except in compliance with the License. A copy of 

5# the License is located at 

6# 

7# http://aws.amazon.com/apache2.0/ 

8# 

9# or in the "license" file accompanying this file. This file is 

10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 

11# ANY KIND, either express or implied. See the License for the specific 

12# language governing permissions and limitations under the License. 

13 

14from __future__ import annotations 

15 

16import asyncio 

17import time 

18from functools import singledispatch 

19from logging import Logger, getLogger 

20from typing import Any, Dict, Optional, Union 

21 

22import boto3 

23from braket.annealing.problem import Problem 

24from braket.aws.aws_session import AwsSession 

25from braket.circuits.circuit import Circuit 

26from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask 

27 

28 

29class AwsQuantumTask(QuantumTask): 

30 """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing 

31 problem.""" 

32 

33 # TODO: Add API documentation that defines these states. Make it clear this is the contract. 

34 NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"} 

35 RESULTS_READY_STATES = {"COMPLETED"} 

36 

37 GATE_IR_TYPE = "jaqcd" 

38 ANNEALING_IR_TYPE = "annealing" 

39 

40 DEFAULT_RESULTS_POLL_TIMEOUT = 120 

41 DEFAULT_RESULTS_POLL_INTERVAL = 0.25 

42 

43 @staticmethod 

44 def create( 

45 aws_session: AwsSession, 

46 device_arn: str, 

47 task_specification: Union[Circuit, Problem], 

48 s3_destination_folder: AwsSession.S3DestinationFolder, 

49 shots: Optional[int] = None, 

50 backend_parameters: Dict[str, Any] = None, 

51 *args, 

52 **kwargs, 

53 ) -> AwsQuantumTask: 

54 """AwsQuantumTask factory method that serializes a quantum task specification 

55 (either a quantum circuit or annealing problem), submits it to Amazon Braket, 

56 and returns back an AwsQuantumTask tracking the execution. 

57 

58 Args: 

59 aws_session (AwsSession): AwsSession to connect to AWS with. 

60 

61 device_arn (str): The ARN of the quantum device. 

62 

63 task_specification (Union[Circuit, Problem]): The specification of the task 

64 to run on device. 

65 

66 s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket 

67 for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder 

68 to store task results in. 

69 

70 shots (int): The number of times to run the task on the device. If the device is a 

71 simulator, this implies the state is sampled N times, where N = `shots`. Default 

72 shots = 1_000. 

73 

74 backend_parameters (Dict[str, Any]): Additional parameters to send to the device. 

75 For example, for D-Wave: 

76 `{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}` 

77 

78 Returns: 

79 AwsQuantumTask: AwsQuantumTask tracking the task execution on the device. 

80 Note: 

81 The following arguments are typically defined via clients of Device. 

82 - `task_specification` 

83 - `s3_destination_folder` 

84 - `shots` 

85 """ 

86 if len(s3_destination_folder) != 2: 

87 raise ValueError( 

88 "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively." 

89 ) 

90 

91 create_task_kwargs = _create_common_params( 

92 device_arn, 

93 s3_destination_folder, 

94 shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS, 

95 ) 

96 return _create_internal( 

97 task_specification, 

98 aws_session, 

99 create_task_kwargs, 

100 backend_parameters or {}, 

101 *args, 

102 **kwargs, 

103 ) 

104 

105 def __init__( 

106 self, 

107 arn: str, 

108 aws_session: AwsSession = None, 

109 poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT, 

110 poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL, 

111 logger: Logger = getLogger(__name__), 

112 ): 

113 """ 

114 Args: 

115 arn (str): The ARN of the task. 

116 aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services. 

117 Default is `None`, in which case an `AwsSession` object will be created with the 

118 region of the task. 

119 poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds. 

120 poll_interval_seconds (int): The polling interval for result(), default is 0.25 

121 seconds. 

122 logger (Logger): Logger object with which to write logs, such as task statuses 

123 while waiting for task to be in a terminal state. Default is `getLogger(__name__)` 

124 

125 Examples: 

126 >>> task = AwsQuantumTask(arn='task_arn') 

127 >>> task.state() 

128 'COMPLETED' 

129 >>> result = task.result() 

130 AnnealingQuantumTaskResult(...) 

131 

132 >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300) 

133 >>> result = task.result() 

134 GateModelQuantumTaskResult(...) 

135 """ 

136 

137 self._arn: str = arn 

138 self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn( 

139 task_arn=arn 

140 ) 

141 self._poll_timeout_seconds = poll_timeout_seconds 

142 self._poll_interval_seconds = poll_interval_seconds 

143 self._logger = logger 

144 

145 self._metadata: Dict[str, Any] = {} 

146 self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None 

147 try: 

148 asyncio.get_event_loop() 

149 except Exception as e: 

150 self._logger.debug(e) 

151 self._logger.info("No event loop found; creating new event loop") 

152 asyncio.set_event_loop(asyncio.new_event_loop()) 

153 self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) 

154 

155 @staticmethod 

156 def _aws_session_for_task_arn(task_arn: str) -> AwsSession: 

157 """ 

158 Get an AwsSession for the Task ARN. The AWS session should be in the region of the task. 

159 

160 Returns: 

161 AwsSession: `AwsSession` object with default `boto_session` in task's region 

162 """ 

163 task_region = task_arn.split(":")[3] 

164 boto_session = boto3.Session(region_name=task_region) 

165 return AwsSession(boto_session=boto_session) 

166 

167 @property 

168 def id(self) -> str: 

169 """str: The ARN of the quantum task.""" 

170 return self._arn 

171 

172 def cancel(self) -> None: 

173 """Cancel the quantum task. This cancels the future and the task in Amazon Braket.""" 

174 self._future.cancel() 

175 self._aws_session.cancel_quantum_task(self._arn) 

176 

177 def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: 

178 """ 

179 Get task metadata defined in Amazon Braket. 

180 

181 Args: 

182 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved 

183 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the 

184 `GetQuantumTask` operation to retrieve metadata, which also updates the cached 

185 value. Default = `False`. 

186 Returns: 

187 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. 

188 If `use_cached_value` is `True`, Amazon Braket is not called and the most recently 

189 retrieved value is used. 

190 """ 

191 if not use_cached_value: 

192 self._metadata = self._aws_session.get_quantum_task(self._arn) 

193 return self._metadata 

194 

195 def state(self, use_cached_value: bool = False) -> str: 

196 """ 

197 The state of the quantum task. 

198 

199 Args: 

200 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved 

201 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the 

202 `GetQuantumTask` operation to retrieve metadata, which also updates the cached 

203 value. Default = `False`. 

204 Returns: 

205 str: The value of `status` in `metadata()`. This is the value of the `status` key 

206 in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, 

207 the value most recently returned from the `GetQuantumTask` operation is used. 

208 See Also: 

209 `metadata()` 

210 """ 

211 return self.metadata(use_cached_value).get("status") 

212 

213 def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: 

214 """ 

215 Get the quantum task result by polling Amazon Braket to see if the task is completed. 

216 Once the task is completed, the result is retrieved from S3 and returned as a 

217 `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult` 

218 

219 This method is a blocking thread call and synchronously returns a result. Call 

220 async_result() if you require an asynchronous invocation. 

221 Consecutive calls to this method return a cached result from the preceding request. 

222 """ 

223 try: 

224 return asyncio.get_event_loop().run_until_complete(self.async_result()) 

225 except asyncio.CancelledError: 

226 # Future was cancelled, return whatever is in self._result if anything 

227 self._logger.warning("Task future was cancelled") 

228 return self._result 

229 

230 def async_result(self) -> asyncio.Task: 

231 """ 

232 Get the quantum task result asynchronously. Consecutive calls to this method return 

233 the result cached from the most recent request. 

234 """ 

235 if self._future.done() and self._result is None: # timed out and no result 

236 task_status = self.metadata()["status"] 

237 if task_status in self.NO_RESULT_TERMINAL_STATES: 

238 self._logger.warning( 

239 f"Task is in terminal state {task_status} and no result is available" 

240 ) 

241 # Return done future. Don't restart polling. 

242 return self._future 

243 else: 

244 self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) 

245 return self._future 

246 

247 async def _create_future(self) -> asyncio.Task: 

248 """ 

249 Wrap the `_wait_for_completion` coroutine inside a future-like object. 

250 Invoking this method starts the coroutine and returns back the future-like object 

251 that contains it. Note that this does not block on the coroutine to finish. 

252 

253 Returns: 

254 asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine. 

255 """ 

256 return asyncio.create_task(self._wait_for_completion()) 

257 

258 def _get_results_formatter( 

259 self, 

260 ) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: 

261 """ 

262 Get results formatter based on irType of self.metadata() 

263 

264 Returns: 

265 Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: 

266 function that deserializes a string into a results structure 

267 """ 

268 current_metadata = self.metadata() 

269 ir_type = current_metadata["irType"] 

270 if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE: 

271 return AnnealingQuantumTaskResult.from_string 

272 elif ir_type == AwsQuantumTask.GATE_IR_TYPE: 

273 return GateModelQuantumTaskResult.from_string 

274 else: 

275 raise ValueError("Unknown IR type") 

276 

277 async def _wait_for_completion(self) -> GateModelQuantumTaskResult: 

278 """ 

279 Waits for the quantum task to be completed, then returns the result from the S3 bucket. 

280 Returns: 

281 GateModelQuantumTaskResult: If the task is in the `AwsQuantumTask.RESULTS_READY_STATES` 

282 state within the specified time limit, the result from the S3 bucket is loaded and 

283 returned. `None` is returned if a timeout occurs or task state is in 

284 `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`. 

285 Note: 

286 Timeout and sleep intervals are defined in the constructor fields 

287 `poll_timeout_seconds` and `poll_interval_seconds` respectively. 

288 """ 

289 self._logger.debug(f"Task {self._arn}: start polling for completion") 

290 start_time = time.time() 

291 

292 while (time.time() - start_time) < self._poll_timeout_seconds: 

293 current_metadata = self.metadata() 

294 task_status = current_metadata["status"] 

295 self._logger.debug(f"Task {self._arn}: task status {task_status}") 

296 if task_status in AwsQuantumTask.RESULTS_READY_STATES: 

297 result_string = self._aws_session.retrieve_s3_object_body( 

298 current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"] 

299 ) 

300 self._result = self._get_results_formatter()(result_string) 

301 return self._result 

302 elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES: 

303 self._logger.warning( 

304 f"Task is in terminal state {task_status}" + "and no result is available" 

305 ) 

306 self._result = None 

307 return None 

308 else: 

309 await asyncio.sleep(self._poll_interval_seconds) 

310 

311 # Timed out 

312 self._logger.warning( 

313 f"Task {self._arn}: polling for task completion timed out after " 

314 + f"{time.time()-start_time} secs" 

315 ) 

316 self._result = None 

317 return None 

318 

319 def __repr__(self) -> str: 

320 return f"AwsQuantumTask('id':{self.id})" 

321 

322 def __eq__(self, other) -> bool: 

323 if isinstance(other, AwsQuantumTask): 

324 return self.id == other.id 

325 return NotImplemented 

326 

327 def __hash__(self) -> int: 

328 return hash(self.id) 

329 

330 

331@singledispatch 

332def _create_internal( 

333 task_specification: Union[Circuit, Problem], 

334 aws_session: AwsSession, 

335 create_task_kwargs: Dict[str, Any], 

336 backend_parameters: Dict[str, Any], 

337 *args, 

338 **kwargs, 

339) -> AwsQuantumTask: 

340 raise TypeError("Invalid task specification type") 

341 

342 

343@_create_internal.register 

344def _( 

345 circuit: Circuit, 

346 aws_session: AwsSession, 

347 create_task_kwargs: Dict[str, Any], 

348 backend_parameters: Dict[str, Any], 

349 *args, 

350 **kwargs, 

351) -> AwsQuantumTask: 

352 create_task_kwargs.update( 

353 { 

354 "ir": circuit.to_ir().json(), 

355 "irType": AwsQuantumTask.GATE_IR_TYPE, 

356 "backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}}, 

357 } 

358 ) 

359 task_arn = aws_session.create_quantum_task(**create_task_kwargs) 

360 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) 

361 

362 

363@_create_internal.register 

364def _( 

365 problem: Problem, 

366 aws_session: AwsSession, 

367 create_task_kwargs: Dict[str, Any], 

368 backend_parameters: Dict[str, Any], 

369 *args, 

370 **kwargs, 

371) -> AwsQuantumTask: 

372 create_task_kwargs.update( 

373 { 

374 "ir": problem.to_ir().json(), 

375 "irType": AwsQuantumTask.ANNEALING_IR_TYPE, 

376 "backendParameters": {"annealingModelParameters": backend_parameters}, 

377 } 

378 ) 

379 

380 task_arn = aws_session.create_quantum_task(**create_task_kwargs) 

381 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) 

382 

383 

384def _create_common_params( 

385 device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int 

386) -> Dict[str, Any]: 

387 return { 

388 "backendArn": device_arn, 

389 "resultsS3Bucket": s3_destination_folder[0], 

390 "resultsS3Prefix": s3_destination_folder[1], 

391 "shots": shots, 

392 }