Skip to content

Commit

Permalink
Add %%oc --plan-cache support for Neptune DB (#613)
Browse files Browse the repository at this point in the history
* Add --plan-cache support for NeptuneDB

* update changelog
  • Loading branch information
michaelnchin authored Jun 7, 2024
1 parent f07e887 commit a047733
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 31 deletions.
2 changes: 2 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
Starting with v1.31.6, this file will contain a record of major features and updates made in each release of graph-notebook.

## Upcoming

- Added `%reset_graph` line magic ([Link to PR](https://github.com/aws/graph-notebook/pull/610))
- Added `%get_graph` line magic and enabled `%status` for Neptune Analytics ([Link to PR](https://github.com/aws/graph-notebook/pull/611))
- Added `%%oc --plan-cache` support for Neptune DB ([Link to PR](https://github.com/aws/graph-notebook/pull/613))
- Upgraded to Gremlin-Python 3.7 ([Link to PR](https://github.com/aws/graph-notebook/pull/597))

## Release 4.3.1 (June 3, 2024)
Expand Down
57 changes: 34 additions & 23 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from graph_notebook.magics.streams import StreamViewer
from graph_notebook.neptune.client import ClientBuilder, Client, PARALLELISM_OPTIONS, PARALLELISM_HIGH, \
LOAD_JOB_MODES, MODE_AUTO, FINAL_LOAD_STATUSES, SPARQL_ACTION, FORMAT_CSV, FORMAT_OPENCYPHER, FORMAT_NTRIPLE, \
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
DB_LOAD_TYPES, ANALYTICS_LOAD_TYPES, VALID_BULK_FORMATS, VALID_INCREMENTAL_FORMATS, \
FORMAT_NQUADS, FORMAT_RDFXML, FORMAT_TURTLE, STREAM_RDF, STREAM_PG, STREAM_ENDPOINTS, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, \
STATISTICS_LANGUAGE_INPUTS, STATISTICS_LANGUAGE_INPUTS_SPARQL, STATISTICS_MODES, SUMMARY_MODES, \
Expand Down Expand Up @@ -180,7 +180,7 @@
MEDIA_TYPE_NTRIPLES_TEXT, MEDIA_TYPE_TURTLE, MEDIA_TYPE_N3, MEDIA_TYPE_TRIX,
MEDIA_TYPE_TRIG, MEDIA_TYPE_RDF4J_BINARY]

byte_units = {'B': 1, 'KB': 1024, 'MB': 1024**2, 'GB': 1024**3, 'TB': 1024**4}
byte_units = {'B': 1, 'KB': 1024, 'MB': 1024 ** 2, 'GB': 1024 ** 3, 'TB': 1024 ** 4}


class QueryMode(Enum):
Expand Down Expand Up @@ -521,11 +521,11 @@ def neptune_config_allowlist(self, line='', cell=''):

@line_magic
@neptune_db_only
def stream_viewer(self,line):
def stream_viewer(self, line):
parser = argparse.ArgumentParser()
parser.add_argument('language', nargs='?', default=STREAM_PG,
help=f'language (default={STREAM_PG}) [{STREAM_PG}|{STREAM_RDF}]',
choices = [STREAM_PG, STREAM_RDF])
choices=[STREAM_PG, STREAM_RDF])

parser.add_argument('--limit', type=int, default=10, help='Maximum number of rows to display at a time')

Expand All @@ -534,7 +534,7 @@ def stream_viewer(self,line):
language = args.language
limit = args.limit
uri = self.client.get_uri_with_port()
viewer = StreamViewer(self.client,uri,language,limit=limit)
viewer = StreamViewer(self.client, uri, language, limit=limit)
viewer.show()

@line_magic
Expand Down Expand Up @@ -877,7 +877,8 @@ def sparql(self, line='', cell='', local_ns: dict = None):
if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE':
lines = []
for b in results['results']['bindings']:
lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
lines.append(
f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
raw_output = widgets.Output(layout=sparql_layout)
with raw_output:
html = sparql_construct_template.render(lines=lines)
Expand Down Expand Up @@ -1168,7 +1169,8 @@ def gremlin(self, line, cell, local_ns: dict = None):
query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
if self.graph_notebook_config.proxy_host != '' and self.client.is_neptune_domain():
using_http = True
query_res_http = self.client.gremlin_http_query(cell, headers={'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
query_res_http = self.client.gremlin_http_query(cell, headers={
'Accept': 'application/vnd.gremlin-v1.0+json;types=false'})
query_res_http.raise_for_status()
query_res_http_json = query_res_http.json()
query_res = query_res_http_json['result']['data']
Expand Down Expand Up @@ -1603,7 +1605,7 @@ def on_button_delete_clicked(b):
with output:
job_status_output.clear_output()
interval_output.close()
total_status_wait = max_status_retries*poll_interval
total_status_wait = max_status_retries * poll_interval
print(result)
if interval_check_response.get("status") != 'healthy':
print(f"Could not retrieve the status of the reset operation within the allotted time of "
Expand Down Expand Up @@ -1849,7 +1851,7 @@ def load(self, line='', local_ns: dict = None):
value=str(args.concurrency),
placeholder=1,
min=1,
max=2**16,
max=2 ** 16,
disabled=False,
layout=widgets.Layout(display=concurrency_hbox_visibility,
width=widget_width)
Expand Down Expand Up @@ -2057,8 +2059,8 @@ def on_button_clicked(b):
named_graph_uri_hbox.children = (named_graph_uri_hbox_label, named_graph_uri,)
base_uri_hbox.children = (base_uri_hbox_label, base_uri,)
dep_hbox.children = (dep_hbox_label, dependencies,)
concurrency_hbox.children = (concurrency_hbox_label, concurrency, )
periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit, )
concurrency_hbox.children = (concurrency_hbox_label, concurrency,)
periodic_commit_hbox.children = (periodic_commit_hbox_label, periodic_commit,)

validated = True
validation_label_style = DescriptionStyle(color='red')
Expand Down Expand Up @@ -2210,8 +2212,9 @@ def on_button_clicked(b):

if poll_status.value == 'FALSE':
start_msg_label = widgets.Label(f'Load started successfully!')
polling_msg_label = widgets.Label(f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
polling_msg_label = widgets.Label(
f'You can run "%load_status {load_result["payload"]["loadId"]}" '
f'in another cell to check the current status of your bulk load.')
start_msg_hbox = widgets.HBox([start_msg_label])
polling_msg_hbox = widgets.HBox([polling_msg_label])
vbox = widgets.VBox([start_msg_hbox, polling_msg_hbox])
Expand Down Expand Up @@ -2254,11 +2257,13 @@ def on_button_clicked(b):
with job_status_output:
# parse status & execution_time differently for Analytics and NeptuneDB
overall_status = \
interval_check_response["payload"]["status"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["status"]
interval_check_response["payload"][
"status"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["status"]
total_time_spent = \
interval_check_response["payload"]["timeElapsedSeconds"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
interval_check_response["payload"][
"timeElapsedSeconds"] if self.client.is_analytics_domain() \
else interval_check_response["payload"]["overallStatus"]["totalTimeSpent"]
print(f'Overall Status: {overall_status}')
if overall_status in FINAL_LOAD_STATUSES:
execution_time = total_time_spent
Expand Down Expand Up @@ -3179,7 +3184,7 @@ def handle_opencypher_query(self, line, cell, local_ns):
"""
parser = argparse.ArgumentParser()
parser.add_argument('-pc', '--plan-cache', type=str.lower, default='auto',
help=f'Neptune Analytics only. Specifies the plan cache mode to use. '
help=f'Specifies the plan cache mode to use. '
f'Accepted values: {OPENCYPHER_PLAN_CACHE_MODES}')
parser.add_argument('-qt', '--query-timeout', type=int, default=None,
help=f'Neptune Analytics only. Specifies the maximum query timeout in milliseconds.')
Expand Down Expand Up @@ -3286,17 +3291,23 @@ def handle_opencypher_query(self, line, cell, local_ns):
first_tab_html = opencypher_explain_template.render(table=explain,
link=f"data:text/html;base64,{base64_str}")
elif args.mode == 'query':
if not self.client.is_analytics_domain():
if args.plan_cache != 'auto':
print("planCache is not supported for Neptune DB, ignoring.")
if args.query_timeout is not None:
print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.")
if not self.client.is_analytics_domain() and args.query_timeout is not None:
print("queryTimeoutMilliseconds is not supported for Neptune DB, ignoring.")

query_start = time.time() * 1000 # time.time() returns time in seconds w/high precision; x1000 to get in ms
oc_http = self.client.opencypher_http(cell, query_params=query_params,
plan_cache=args.plan_cache,
query_timeout=args.query_timeout)
query_time = time.time() * 1000 - query_start
if oc_http.status_code == 400 and not self.client.is_analytics_domain() and args.plan_cache != "auto":
try:
oc_http_ex = json.loads(oc_http.content.decode('utf-8'))
if (oc_http_ex["code"] == "MalformedQueryException"
and oc_http_ex["detailedMessage"].startswith("Invalid input")):
print("Please ensure that you are on NeptuneDB 1.3.2.0 or later when attempting to use "
"--plan-cache.")
except:
pass
oc_http.raise_for_status()

try:
Expand Down
30 changes: 22 additions & 8 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def normalize_service_name(neptune_service: str):
return NEPTUNE_DB_SERVICE_NAME


def set_plan_cache_hint(query: str, plan_cache_value: str):
plan_cache_op_re = r"(?i)USING\s+QUERY:\s*PLANCACHE"
if re.search(plan_cache_op_re, query) is not None:
print("planCache hint is already present in query. Ignoring parameter value.")
return query
plan_cache_hint = f'USING QUERY: PLANCACHE "{plan_cache_value}"\n'
query_with_hint = plan_cache_hint + query
return query_with_hint


class Client(object):
def __init__(self, host: str, port: int = DEFAULT_PORT,
neptune_service: str = NEPTUNE_DB_SERVICE_NAME,
Expand Down Expand Up @@ -407,19 +417,23 @@ def opencypher_http(self, query: str, headers: dict = None, explain: str = None,
if 'content-type' not in headers:
headers['content-type'] = 'application/x-www-form-urlencoded'
url += 'openCypher'
data = {
'query': query
}
data = {}
if plan_cache:
if plan_cache not in OPENCYPHER_PLAN_CACHE_MODES:
print('Invalid --plan-cache mode specified, defaulting to auto.')
else:
if self.is_analytics_domain():
data['planCache'] = plan_cache
elif plan_cache != 'auto':
query = set_plan_cache_hint(query, plan_cache)
data['query'] = query
if explain:
data['explain'] = explain
headers['Accept'] = "text/html"
if query_params:
data['parameters'] = str(query_params).replace("'", '"') # '{"AUS_code":"AUS","WLG_code":"WLG"}'
if self.is_analytics_domain():
if plan_cache:
data['planCache'] = plan_cache
if query_timeout:
data['queryTimeoutMilliseconds'] = str(query_timeout)
if query_timeout and self.is_analytics_domain():
data['queryTimeoutMilliseconds'] = str(query_timeout)
else:
url += 'db/neo4j/tx/commit'
headers['content-type'] = 'application/json'
Expand Down

0 comments on commit a047733

Please sign in to comment.