diff --git a/ChangeLog.md b/ChangeLog.md index b26fc6b2..781a3d88 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -3,8 +3,10 @@ Starting with v1.31.6, this file will contain a record of major features and updates made in each release of graph-notebook. ## Upcoming + - Added `%reset_graph` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/610)) - Added `%get_graph` line magic and enabled `%status` for Neptune Analytics ([Link to PR](https://github.com/aws/graph-notebook/pull/611)) +- Added `%%oc --plan-cache` support for Neptune DB ([Link to PR](https://github.com/aws/graph-notebook/pull/613)) - Upgraded to Gremlin-Python 3.7 ([Link to PR](https://github.com/aws/graph-notebook/pull/597)) ## Release 4.3.1 (June 3, 2024) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index 30c38e75..036a79ed 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -46,7 +46,7 @@ from graph_notebook.magics.streams import StreamViewer from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \ LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \ - DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \ + DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \ FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \ NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \ STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \ @@ -180,7 +180,7 @@ MEDIA_TYPE_NTRIPLES_TEXT, MEDIA_TYPE_TURTLE, MEDIA_TYPE_N3, MEDIA_TYPE_TRIX, MEDIA_TYPE_TRIG, MEDIA_TYPE_RDF4J_BINARY] -byte_units = {'B': 1, 'KB': 1024, 'MB': 1024**2, 'GB': 1024**3, 'TB': 1024**4} +byte_units = {'B': 1, 'KB': 1024, 'MB': 1024 ** 2, 'GB': 1024 ** 3, 'TB': 1024 ** 4} class QueryMode(Enum): @@ -521,11 +521,11 @@ def neptune_config_allowlist(self, line='', cell=''): @line_magic @neptune_db_only - def stream_viewer(self,line): + def stream_viewer(self, line): parser = argparse.ArgumentParser() parser.add_argument('language', nargs='?', default=STREAM_PG, help=f'language (default={STREAM_PG}) [{STREAM_PG}|{STREAM_RDF}]', - choices = [STREAM_PG, STREAM_RDF]) + choices=[STREAM_PG, STREAM_RDF]) parser.add_argument('--limit', type=int, default=10, help='Maximum number of rows to display at a time') @@ -534,7 +534,7 @@ def stream_viewer(self,line): language = args.language limit = args.limit uri = self.client.get_uri_with_port() - viewer = StreamViewer(self.client,uri,language,limit=limit) + viewer = StreamViewer(self.client, uri, language, limit=limit) viewer.show() @line_magic @@ -877,7 +877,8 @@ def sparql(self, line='', cell='', local_ns: dict = None): if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE': lines = [] for b in results['results']['bindings']: - lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}') + lines.append( + f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}') raw_output = widgets.Output(layout=sparql_layout) with raw_output: html = sparql_construct_template.render(lines=lines) @@ -1168,7 +1169,8 @@ def gremlin(self, line, cell, local_ns: dict = None): query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms if self.graph_notebook_config.proxy_host != '' and self.client.is_neptune_domain(): using_http = True - query_res_http = self.client.gremlin_http_query(cell, headers={'Accept': 'application/vnd.gremlin-v1.0+json;types=false'}) + query_res_http = self.client.gremlin_http_query(cell, headers={ + 'Accept': 'application/vnd.gremlin-v1.0+json;types=false'}) query_res_http.raise_for_status() query_res_http_json = query_res_http.json() query_res = query_res_http_json['result']['data'] @@ -1603,7 +1605,7 @@ def on_button_delete_clicked(b): with output: job_status_output.clear_output() interval_output.close() - total_status_wait = max_status_retries*poll_interval + total_status_wait = max_status_retries * poll_interval print(result) if interval_check_response.get("status") != 'healthy': print(f"Could not retrieve the status of the reset operation within the allotted time of " @@ -1849,7 +1851,7 @@ def load(self, line='', local_ns: dict = None): value=str(args.concurrency), placeholder=1, min=1, - max=2**16, + max=2 ** 16, disabled=False, layout=widgets.Layout(display=concurrency_hbox_visibility, width=widget_width) @@ -2057,8 +2059,8 @@ def on_button_clicked(b): named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,) base_uri_hbox.children = (base_uri_hbox_label, base_uri,) dep_hbox.children = (dep_hbox_label, dependencies,) - concurrency_hbox.children = (concurrency_hbox_label, concurrency, ) - periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit, ) + concurrency_hbox.children = (concurrency_hbox_label, concurrency,) + periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit,) validated = True validation_label_style = DescriptionStyle(color='red') @@ -2210,8 +2212,9 @@ def on_button_clicked(b): if poll_status.value == 'FALSE': start_msg_label = widgets.Label(f'Load started successfully!') - polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" ' - f'in another cell to check the current status of your bulk load.') + polling_msg_label = widgets.Label( + f'You can run "%load_status {load_result["payload"]["loadId"]}" ' + f'in another cell to check the current status of your bulk load.') start_msg_hbox = widgets.HBox([start_msg_label]) polling_msg_hbox = widgets.HBox([polling_msg_label]) vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox]) @@ -2254,11 +2257,13 @@ def on_button_clicked(b): with job_status_output: # parse status & execution_time differently for Analytics and NeptuneDB overall_status = \ - interval_check_response["payload"]["status"] if self.client.is_analytics_domain() \ - else interval_check_response["payload"]["overallStatus"]["status"] + interval_check_response["payload"][ + "status"] if self.client.is_analytics_domain() \ + else interval_check_response["payload"]["overallStatus"]["status"] total_time_spent = \ - interval_check_response["payload"]["timeElapsedSeconds"] if self.client.is_analytics_domain() \ - else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"] + interval_check_response["payload"][ + "timeElapsedSeconds"] if self.client.is_analytics_domain() \ + else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"] print(f'Overall Status: {overall_status}') if overall_status in FINAL_LOAD_STATUSES: execution_time = total_time_spent @@ -3179,7 +3184,7 @@ def handle_opencypher_query(self, line, cell, local_ns): """ parser = argparse.ArgumentParser() parser.add_argument('-pc', '--plan-cache', type=str.lower, default='auto', - help=f'Neptune Analytics only. Specifies the plan cache mode to use. ' + help=f'Specifies the plan cache mode to use. ' f'Accepted values: {OPENCYPHER_PLAN_CACHE_MODES}') parser.add_argument('-qt', '--query-timeout', type=int, default=None, help=f'Neptune Analytics only. Specifies the maximum query timeout in milliseconds.') @@ -3286,17 +3291,23 @@ def handle_opencypher_query(self, line, cell, local_ns): first_tab_html = opencypher_explain_template.render(table=explain, link=f"data:text/html;base64,{base64_str}") elif args.mode == 'query': - if not self.client.is_analytics_domain(): - if args.plan_cache != 'auto': - print("planCache is not supported for Neptune DB, ignoring.") - if args.query_timeout is not None: - print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.") + if not self.client.is_analytics_domain() and args.query_timeout is not None: + print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.") query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms oc_http = self.client.opencypher_http(cell, query_params=query_params, plan_cache=args.plan_cache, query_timeout=args.query_timeout) query_time = time.time() * 1000 - query_start + if oc_http.status_code == 400 and not self.client.is_analytics_domain() and args.plan_cache != "auto": + try: + oc_http_ex = json.loads(oc_http.content.decode('utf-8')) + if (oc_http_ex["code"] == "MalformedQueryException" + and oc_http_ex["detailedMessage"].startswith("Invalid input")): + print("Please ensure that you are on NeptuneDB 1.3.2.0 or later when attempting to use " + "--plan-cache.") + except: + pass oc_http.raise_for_status() try: diff --git a/src/graph_notebook/neptune/client.py b/src/graph_notebook/neptune/client.py index f8c3c2dd..07219b68 100644 --- a/src/graph_notebook/neptune/client.py +++ b/src/graph_notebook/neptune/client.py @@ -163,6 +163,16 @@ def normalize_service_name(neptune_service: str): return NEPTUNE_DB_SERVICE_NAME +def set_plan_cache_hint(query: str, plan_cache_value: str): + plan_cache_op_re = r"(?i)USING\s+QUERY:\s*PLANCACHE" + if re.search(plan_cache_op_re, query) is not None: + print("planCache hint is already present in query. Ignoring parameter value.") + return query + plan_cache_hint = f'USING QUERY: PLANCACHE "{plan_cache_value}"\n' + query_with_hint = plan_cache_hint + query + return query_with_hint + + class Client(object): def __init__(self, host: str, port: int = DEFAULT_PORT, neptune_service: str = NEPTUNE_DB_SERVICE_NAME, @@ -407,19 +417,23 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None, if 'content-type' not in headers: headers['content-type'] = 'application/x-www-form-urlencoded' url += 'openCypher' - data = { - 'query': query - } + data = {} + if plan_cache: + if plan_cache not in OPENCYPHER_PLAN_CACHE_MODES: + print('Invalid --plan-cache mode specified, defaulting to auto.') + else: + if self.is_analytics_domain(): + data['planCache'] = plan_cache + elif plan_cache != 'auto': + query = set_plan_cache_hint(query, plan_cache) + data['query'] = query if explain: data['explain'] = explain headers['Accept'] = "text/html" if query_params: data['parameters'] = str(query_params).replace("'", '"') # '{"AUS_code":"AUS","WLG_code":"WLG"}' - if self.is_analytics_domain(): - if plan_cache: - data['planCache'] = plan_cache - if query_timeout: - data['queryTimeoutMilliseconds'] = str(query_timeout) + if query_timeout and self.is_analytics_domain(): + data['queryTimeoutMilliseconds'] = str(query_timeout) else: url += 'db/neo4j/tx/commit' headers['content-type'] = 'application/json'