Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright 2019-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"). You 

4# may not use this file except in compliance with the License. A copy of 

5# the License is located at 

6# 

7# http://aws.amazon.com/apache2.0/ 

8# 

9# or in the "license" file accompanying this file. This file is 

10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 

11# ANY KIND, either express or implied. See the License for the specific 

12# language governing permissions and limitations under the License. 

13 

14from __future__ import annotations 

15 

16import asyncio 

17import time 

18from functools import singledispatch 

19from logging import Logger, getLogger 

20from typing import Any, Dict, Union 

21 

22import boto3 

23from braket.annealing.problem import Problem 

24from braket.aws.aws_session import AwsSession 

25from braket.circuits.circuit import Circuit 

26from braket.circuits.circuit_helpers import validate_circuit_and_shots 

27from braket.tasks import AnnealingQuantumTaskResult, GateModelQuantumTaskResult, QuantumTask 

28 

29 

30class AwsQuantumTask(QuantumTask): 

31 """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing 

32 problem.""" 

33 

34 # TODO: Add API documentation that defines these states. Make it clear this is the contract. 

35 NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"} 

36 RESULTS_READY_STATES = {"COMPLETED"} 

37 

38 GATE_IR_TYPE = "jaqcd" 

39 ANNEALING_IR_TYPE = "annealing" 

40 

41 DEFAULT_RESULTS_POLL_TIMEOUT = 120 

42 DEFAULT_RESULTS_POLL_INTERVAL = 0.25 

43 

44 @staticmethod 

45 def create( 

46 aws_session: AwsSession, 

47 device_arn: str, 

48 task_specification: Union[Circuit, Problem], 

49 s3_destination_folder: AwsSession.S3DestinationFolder, 

50 shots: int, 

51 backend_parameters: Dict[str, Any] = None, 

52 *args, 

53 **kwargs, 

54 ) -> AwsQuantumTask: 

55 """AwsQuantumTask factory method that serializes a quantum task specification 

56 (either a quantum circuit or annealing problem), submits it to Amazon Braket, 

57 and returns back an AwsQuantumTask tracking the execution. 

58 

59 Args: 

60 aws_session (AwsSession): AwsSession to connect to AWS with. 

61 

62 device_arn (str): The ARN of the quantum device. 

63 

64 task_specification (Union[Circuit, Problem]): The specification of the task 

65 to run on device. 

66 

67 s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket 

68 for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder 

69 to store task results in. 

70 

71 shots (int): The number of times to run the task on the device. If the device is a 

72 simulator, this implies the state is sampled N times, where N = `shots`. 

73 `shots=0` is only available on simulators and means that the simulator 

74 will compute the exact results based on the task specification. 

75 

76 backend_parameters (Dict[str, Any]): Additional parameters to send to the device. 

77 For example, for D-Wave: 

78 `{"dWaveParameters": {"postprocessingType": "OPTIMIZATION"}}` 

79 

80 Returns: 

81 AwsQuantumTask: AwsQuantumTask tracking the task execution on the device. 

82 

83 Note: 

84 The following arguments are typically defined via clients of Device. 

85 - `task_specification` 

86 - `s3_destination_folder` 

87 - `shots` 

88 

89 See Also: 

90 `braket.aws.aws_quantum_simulator.AwsQuantumSimulator.run()` 

91 `braket.aws.aws_qpu.AwsQpu.run()` 

92 """ 

93 if len(s3_destination_folder) != 2: 

94 raise ValueError( 

95 "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively." 

96 ) 

97 

98 create_task_kwargs = _create_common_params( 

99 device_arn, 

100 s3_destination_folder, 

101 shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS, 

102 ) 

103 return _create_internal( 

104 task_specification, 

105 aws_session, 

106 create_task_kwargs, 

107 backend_parameters or {}, 

108 *args, 

109 **kwargs, 

110 ) 

111 

112 def __init__( 

113 self, 

114 arn: str, 

115 aws_session: AwsSession = None, 

116 poll_timeout_seconds: int = DEFAULT_RESULTS_POLL_TIMEOUT, 

117 poll_interval_seconds: int = DEFAULT_RESULTS_POLL_INTERVAL, 

118 logger: Logger = getLogger(__name__), 

119 ): 

120 """ 

121 Args: 

122 arn (str): The ARN of the task. 

123 aws_session (AwsSession, optional): The `AwsSession` for connecting to AWS services. 

124 Default is `None`, in which case an `AwsSession` object will be created with the 

125 region of the task. 

126 poll_timeout_seconds (int): The polling timeout for result(), default is 120 seconds. 

127 poll_interval_seconds (int): The polling interval for result(), default is 0.25 

128 seconds. 

129 logger (Logger): Logger object with which to write logs, such as task statuses 

130 while waiting for task to be in a terminal state. Default is `getLogger(__name__)` 

131 

132 Examples: 

133 >>> task = AwsQuantumTask(arn='task_arn') 

134 >>> task.state() 

135 'COMPLETED' 

136 >>> result = task.result() 

137 AnnealingQuantumTaskResult(...) 

138 

139 >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300) 

140 >>> result = task.result() 

141 GateModelQuantumTaskResult(...) 

142 """ 

143 

144 self._arn: str = arn 

145 self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn( 

146 task_arn=arn 

147 ) 

148 self._poll_timeout_seconds = poll_timeout_seconds 

149 self._poll_interval_seconds = poll_interval_seconds 

150 self._logger = logger 

151 

152 self._metadata: Dict[str, Any] = {} 

153 self._result: Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult] = None 

154 try: 

155 asyncio.get_event_loop() 

156 except Exception as e: 

157 self._logger.debug(e) 

158 self._logger.info("No event loop found; creating new event loop") 

159 asyncio.set_event_loop(asyncio.new_event_loop()) 

160 self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) 

161 

162 @staticmethod 

163 def _aws_session_for_task_arn(task_arn: str) -> AwsSession: 

164 """ 

165 Get an AwsSession for the Task ARN. The AWS session should be in the region of the task. 

166 

167 Returns: 

168 AwsSession: `AwsSession` object with default `boto_session` in task's region 

169 """ 

170 task_region = task_arn.split(":")[3] 

171 boto_session = boto3.Session(region_name=task_region) 

172 return AwsSession(boto_session=boto_session) 

173 

174 @property 

175 def id(self) -> str: 

176 """str: The ARN of the quantum task.""" 

177 return self._arn 

178 

179 def cancel(self) -> None: 

180 """Cancel the quantum task. This cancels the future and the task in Amazon Braket.""" 

181 self._future.cancel() 

182 self._aws_session.cancel_quantum_task(self._arn) 

183 

184 def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: 

185 """ 

186 Get task metadata defined in Amazon Braket. 

187 

188 Args: 

189 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved 

190 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the 

191 `GetQuantumTask` operation to retrieve metadata, which also updates the cached 

192 value. Default = `False`. 

193 Returns: 

194 Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation. 

195 If `use_cached_value` is `True`, Amazon Braket is not called and the most recently 

196 retrieved value is used. 

197 """ 

198 if not use_cached_value: 

199 self._metadata = self._aws_session.get_quantum_task(self._arn) 

200 return self._metadata 

201 

202 def state(self, use_cached_value: bool = False) -> str: 

203 """ 

204 The state of the quantum task. 

205 

206 Args: 

207 use_cached_value (bool, optional): If `True`, uses the value most recently retrieved 

208 from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the 

209 `GetQuantumTask` operation to retrieve metadata, which also updates the cached 

210 value. Default = `False`. 

211 Returns: 

212 str: The value of `status` in `metadata()`. This is the value of the `status` key 

213 in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`, 

214 the value most recently returned from the `GetQuantumTask` operation is used. 

215 See Also: 

216 `metadata()` 

217 """ 

218 return self.metadata(use_cached_value).get("status") 

219 

220 def result(self) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: 

221 """ 

222 Get the quantum task result by polling Amazon Braket to see if the task is completed. 

223 Once the task is completed, the result is retrieved from S3 and returned as a 

224 `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult` 

225 

226 This method is a blocking thread call and synchronously returns a result. Call 

227 async_result() if you require an asynchronous invocation. 

228 Consecutive calls to this method return a cached result from the preceding request. 

229 """ 

230 try: 

231 return asyncio.get_event_loop().run_until_complete(self.async_result()) 

232 except asyncio.CancelledError: 

233 # Future was cancelled, return whatever is in self._result if anything 

234 self._logger.warning("Task future was cancelled") 

235 return self._result 

236 

237 def async_result(self) -> asyncio.Task: 

238 """ 

239 Get the quantum task result asynchronously. Consecutive calls to this method return 

240 the result cached from the most recent request. 

241 """ 

242 if self._future.done() and self._result is None: # timed out and no result 

243 task_status = self.metadata()["status"] 

244 if task_status in self.NO_RESULT_TERMINAL_STATES: 

245 self._logger.warning( 

246 f"Task is in terminal state {task_status} and no result is available" 

247 ) 

248 else: 

249 self._future = asyncio.get_event_loop().run_until_complete(self._create_future()) 

250 return self._future 

251 

252 async def _create_future(self) -> asyncio.Task: 

253 """ 

254 Wrap the `_wait_for_completion` coroutine inside a future-like object. 

255 Invoking this method starts the coroutine and returns back the future-like object 

256 that contains it. Note that this does not block on the coroutine to finish. 

257 

258 Returns: 

259 asyncio.Task: An asyncio Task that contains the _wait_for_completion() coroutine. 

260 """ 

261 return asyncio.create_task(self._wait_for_completion()) 

262 

263 def _get_results_formatter( 

264 self, 

265 ) -> Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: 

266 """ 

267 Get results formatter based on irType of self.metadata() 

268 

269 Returns: 

270 Union[GateModelQuantumTaskResult.from_string, AnnealingQuantumTaskResult.from_string]: 

271 function that deserializes a string into a results structure 

272 """ 

273 current_metadata = self.metadata() 

274 ir_type = current_metadata["irType"] 

275 if ir_type == AwsQuantumTask.ANNEALING_IR_TYPE: 

276 return AnnealingQuantumTaskResult.from_string 

277 elif ir_type == AwsQuantumTask.GATE_IR_TYPE: 

278 return GateModelQuantumTaskResult.from_string 

279 else: 

280 raise ValueError("Unknown IR type") 

281 

282 async def _wait_for_completion( 

283 self, 

284 ) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: 

285 """ 

286 Waits for the quantum task to be completed, then returns the result from the S3 bucket. 

287 

288 Returns: 

289 Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in the 

290 `AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit, 

291 the result from the S3 bucket is loaded and returned. 

292 `None` is returned if a timeout occurs or task state is in 

293 `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`. 

294 Note: 

295 Timeout and sleep intervals are defined in the constructor fields 

296 `poll_timeout_seconds` and `poll_interval_seconds` respectively. 

297 """ 

298 self._logger.debug(f"Task {self._arn}: start polling for completion") 

299 start_time = time.time() 

300 

301 while (time.time() - start_time) < self._poll_timeout_seconds: 

302 current_metadata = self.metadata() 

303 task_status = current_metadata["status"] 

304 self._logger.debug(f"Task {self._arn}: task status {task_status}") 

305 if task_status in AwsQuantumTask.RESULTS_READY_STATES: 

306 result_string = self._aws_session.retrieve_s3_object_body( 

307 current_metadata["resultsS3Bucket"], current_metadata["resultsS3ObjectKey"] 

308 ) 

309 self._result = self._get_results_formatter()(result_string) 

310 return self._result 

311 elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES: 

312 self._logger.warning( 

313 f"Task is in terminal state {task_status} and no result is available" 

314 ) 

315 self._result = None 

316 return None 

317 else: 

318 await asyncio.sleep(self._poll_interval_seconds) 

319 

320 # Timed out 

321 self._logger.warning( 

322 f"Task {self._arn}: polling for task completion timed out after " 

323 + f"{time.time()-start_time} seconds. Please increase the timeout; " 

324 + "this can be done by creating a new AwsQuantumTask with this task's ARN " 

325 + "and a higher value for the `poll_timeout_seconds` parameter." 

326 ) 

327 self._result = None 

328 return None 

329 

330 def __repr__(self) -> str: 

331 return f"AwsQuantumTask('id':{self.id})" 

332 

333 def __eq__(self, other) -> bool: 

334 if isinstance(other, AwsQuantumTask): 

335 return self.id == other.id 

336 return NotImplemented 

337 

338 def __hash__(self) -> int: 

339 return hash(self.id) 

340 

341 

342@singledispatch 

343def _create_internal( 

344 task_specification: Union[Circuit, Problem], 

345 aws_session: AwsSession, 

346 create_task_kwargs: Dict[str, Any], 

347 backend_parameters: Dict[str, Any], 

348 *args, 

349 **kwargs, 

350) -> AwsQuantumTask: 

351 raise TypeError("Invalid task specification type") 

352 

353 

354@_create_internal.register 

355def _( 

356 circuit: Circuit, 

357 aws_session: AwsSession, 

358 create_task_kwargs: Dict[str, Any], 

359 backend_parameters: Dict[str, Any], 

360 *args, 

361 **kwargs, 

362) -> AwsQuantumTask: 

363 validate_circuit_and_shots(circuit, create_task_kwargs["shots"]) 

364 create_task_kwargs.update( 

365 { 

366 "ir": circuit.to_ir().json(), 

367 "irType": AwsQuantumTask.GATE_IR_TYPE, 

368 "backendParameters": {"gateModelParameters": {"qubitCount": circuit.qubit_count}}, 

369 } 

370 ) 

371 task_arn = aws_session.create_quantum_task(**create_task_kwargs) 

372 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) 

373 

374 

375@_create_internal.register 

376def _( 

377 problem: Problem, 

378 aws_session: AwsSession, 

379 create_task_kwargs: Dict[str, Any], 

380 backend_parameters: Dict[str, Any], 

381 *args, 

382 **kwargs, 

383) -> AwsQuantumTask: 

384 create_task_kwargs.update( 

385 { 

386 "ir": problem.to_ir().json(), 

387 "irType": AwsQuantumTask.ANNEALING_IR_TYPE, 

388 "backendParameters": {"annealingModelParameters": backend_parameters}, 

389 } 

390 ) 

391 

392 task_arn = aws_session.create_quantum_task(**create_task_kwargs) 

393 return AwsQuantumTask(task_arn, aws_session, *args, **kwargs) 

394 

395 

396def _create_common_params( 

397 device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int 

398) -> Dict[str, Any]: 

399 return { 

400 "backendArn": device_arn, 

401 "resultsS3Bucket": s3_destination_folder[0], 

402 "resultsS3Prefix": s3_destination_folder[1], 

403 "shots": shots, 

404 }