Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom authentication credentials and serializers for Gremlin #356

Merged
merged 5 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Alternatively, the magic extensions can be manually reloaded for a single notebo

### Gremlin Server

In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. Optionally, modify `traversal_source` if your graph traversal source name differs from the default value. For a local Gremlin server (HTTP or WebSockets), you can use the following command:
In a new cell in the Jupyter notebook, change the configuration using `%%graph_notebook_config` and modify the fields for `host`, `port`, and `ssl`. Optionally, modify `traversal_source` if your graph traversal source name differs from the default value, `username` and `password` if required by the graph store, or `message_serializer` for a specific data transfer format. For a local Gremlin server (HTTP or WebSockets), you can use the following command:

```
%%graph_notebook_config
Expand All @@ -172,7 +172,10 @@ In a new cell in the Jupyter notebook, change the configuration using `%%graph_n
"port": 8182,
"ssl": false,
"gremlin": {
"traversal_source": "g"
"traversal_source": "g",
"username": "",
"password": "",
"message_serializer": "graphsonv3"
}
}
```
Expand Down
8 changes: 7 additions & 1 deletion additional-databases/gremlin-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ Several of the steps below are optional but please read each step carefully and
{
"host": "localhost",
"port": 8182,
"ssl": false
"ssl": false,
"gremlin": {
"traversal_source": "g",
"username": "",
"password": "",
"message_serializer": "graphsonv3"
}
}
```
If the Gremlin Server you wish to connect to is remote, replacing `localhost` with the IP address or DNS of the remote server should work. This assumes you have access to that server from your local machine.
Expand Down
48 changes: 42 additions & 6 deletions src/graph_notebook/configuration/generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import os
from enum import Enum

from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, \
NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host
from graph_notebook.neptune.client import SPARQL_ACTION, DEFAULT_PORT, DEFAULT_REGION, DEFAULT_GREMLIN_SERIALIZER, \
DEFAULT_GREMLIN_TRAVERSAL_SOURCE, NEPTUNE_CONFIG_HOST_IDENTIFIERS, is_allowed_neptune_host, GRAPHSONV3_VARIANTS, \
GRAPHSONV2_VARIANTS, GRAPHBINARYV1_VARIANTS

DEFAULT_CONFIG_LOCATION = os.path.expanduser('~/graph_notebook_config.json')

Expand Down Expand Up @@ -49,16 +50,38 @@ class GremlinSection(object):
Used for gremlin-specific settings in a notebook's configuration
"""

def __init__(self, traversal_source: str = ''):
def __init__(self, traversal_source: str = '', username: str = '', password: str = '',
message_serializer: str = ''):
"""
:param traversal_source: used to specify the traversal source for a Gremlin traversal, in the case that we are
connected to an endpoint that can access multiple graphs.
:param username: used to specify a username for authenticating to Gremlin Server, if the endpoint supports it.
:param password: used to specify a password for authenticating to Gremlin Server, if the endpoint supports it.
:param message_serializer: used to specify a serializer for encoding the data to and from Gremlin Server.
"""

if traversal_source == '':
traversal_source = 'g'
traversal_source = DEFAULT_GREMLIN_TRAVERSAL_SOURCE

serializer_lower = message_serializer.lower()
if serializer_lower == '':
message_serializer = DEFAULT_GREMLIN_SERIALIZER
elif serializer_lower in GRAPHSONV3_VARIANTS:
message_serializer = 'graphsonv3'
elif serializer_lower in GRAPHSONV2_VARIANTS:
message_serializer = 'graphsonv2'
elif serializer_lower in GRAPHBINARYV1_VARIANTS:
message_serializer = 'graphbinaryv1'
else:
print(f'Invalid Gremlin serializer specified, defaulting to graphsonv3. '
f'Valid serializers: [graphsonv3, graphsonv2, graphbinaryv1].')
message_serializer = DEFAULT_GREMLIN_SERIALIZER


self.traversal_source = traversal_source
self.username = username
self.password = password
self.message_serializer = message_serializer

def to_dict(self):
return self.__dict__
Expand Down Expand Up @@ -141,10 +164,11 @@ def write_to_file(self, file_path=DEFAULT_CONFIG_LOCATION):

def generate_config(host, port, auth_mode: AuthModeEnum = AuthModeEnum.DEFAULT, ssl: bool = True, load_from_s3_arn='',
aws_region: str = DEFAULT_REGION, proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
sparql_section: SparqlSection = SparqlSection(), gremlin_section: GremlinSection = GremlinSection(),
neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS):
use_ssl = False if ssl in [False, 'False', 'false', 'FALSE'] else True
c = Configuration(host, port, auth_mode, load_from_s3_arn, use_ssl, aws_region, proxy_host, proxy_port,
neptune_hosts=neptune_hosts)
sparql_section, gremlin_section, neptune_hosts)
return c


Expand All @@ -171,14 +195,26 @@ def generate_default_config():
parser.add_argument("--aws_region", help="aws region your ml cluster is in.", default=DEFAULT_REGION)
parser.add_argument("--proxy_host", help="the proxy host url to route a connection through", default='')
parser.add_argument("--proxy_port", help="the proxy port to use when creating proxy connection", default=8182)
parser.add_argument("--sparql_path", help="the namespace path to append to the SPARQL endpoint",
default=SPARQL_ACTION)
parser.add_argument("--gremlin_traversal_source", help="the traversal source to use for Gremlin queries",
default=DEFAULT_GREMLIN_TRAVERSAL_SOURCE)
parser.add_argument("--gremlin_username", help="the username to use when creating Gremlin connections", default='')
parser.add_argument("--gremlin_password", help="the password to use when creating Gremlin connections", default='')
parser.add_argument("--gremlin_serializer",
help="the serializer to use as the encoding format when creating Gremlin connections",
default=DEFAULT_GREMLIN_SERIALIZER)
parser.add_argument("--neptune_hosts", help="list of host snippets to use for identifying neptune endpoints",
default=DEFAULT_CONFIG_LOCATION)
args = parser.parse_args()

auth_mode_arg = args.auth_mode if args.auth_mode != '' else AuthModeEnum.DEFAULT.value
config = generate_config(args.host, int(args.port), AuthModeEnum(auth_mode_arg), args.ssl,
args.load_from_s3_arn, args.aws_region, args.proxy_host, int(args.proxy_port),
neptune_hosts=args.neptune_hosts)
SparqlSection(args.sparql_path, ''),
GremlinSection(args.gremlin_traversal_source, args.gremlin_username,
args.gremlin_serializer),
args.neptune_hosts)
config.write_to_file(args.config_destination)

exit(0)
2 changes: 1 addition & 1 deletion src/graph_notebook/configuration/get_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def get_config_from_dict(data: dict, neptune_hosts: list = NEPTUNE_CONFIG_HOST_IDENTIFIERS) -> Configuration:

sparql_section = SparqlSection(**data['sparql']) if 'sparql' in data else SparqlSection('')
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection('')
gremlin_section = GremlinSection(**data['gremlin']) if 'gremlin' in data else GremlinSection()
proxy_host = str(data['proxy_host']) if 'proxy_host' in data else ''
proxy_port = int(data['proxy_port']) if 'proxy_port' in data else 8182

Expand Down
7 changes: 5 additions & 2 deletions src/graph_notebook/magics/graph_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def _generate_client_from_config(self, config: Configuration):
.with_tls(config.ssl) \
.with_proxy_host(config.proxy_host) \
.with_proxy_port(config.proxy_port) \
.with_sparql_path(config.sparql.path)
.with_sparql_path(config.sparql.path) \
.with_gremlin_serializer(config.gremlin.message_serializer)
if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
if self.neptune_cfg_allowlist != NEPTUNE_CONFIG_HOST_IDENTIFIERS:
Expand All @@ -251,7 +252,9 @@ def _generate_client_from_config(self, config: Configuration):
.with_port(config.port) \
.with_tls(config.ssl) \
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)
.with_gremlin_traversal_source(config.gremlin.traversal_source) \
.with_gremlin_login(config.gremlin.username, config.gremlin.password) \
.with_gremlin_serializer(config.gremlin.message_serializer)

self.client = builder.build()

Expand Down
39 changes: 35 additions & 4 deletions src/graph_notebook/neptune/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
from botocore.session import Session as botocoreSession
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from gremlin_python.driver import client
from gremlin_python.driver import client, serializer
from gremlin_python.driver.protocol import GremlinServerError
from neo4j import GraphDatabase
import nest_asyncio
# from graph_notebook.magics.graph_magic import NEPTUNE_CONFIG_HOST_IDENTIFIERS

# This patch is no longer needed when graph_notebook is using the a Gremlin Python
# client >= 3.5.0 as the HashableDict is now part of that client driver.
# import graph_notebook.neptune.gremlin.graphsonV3d0_MapType_objectify_patch # noqa F401

DEFAULT_GREMLIN_SERIALIZER = 'graphsonv3'
DEFAULT_GREMLIN_TRAVERSAL_SOURCE = 'g'
DEFAULT_SPARQL_CONTENT_TYPE = 'application/x-www-form-urlencoded'
DEFAULT_PORT = 8182
DEFAULT_REGION = 'us-east-1'
Expand Down Expand Up @@ -94,6 +95,9 @@

NEPTUNE_CONFIG_HOST_IDENTIFIERS = ["amazonaws.com"]

GRAPHSONV3_VARIANTS = ['graphsonv3', 'graphsonv3d0', 'graphsonserializersv3d0']
GRAPHSONV2_VARIANTS = ['graphsonv2', 'graphsonv2d0', 'graphsonserializersv2d0']
GRAPHBINARYV1_VARIANTS = ['graphbinaryv1', 'graphbinary', 'graphbinaryserializersv1']

def is_allowed_neptune_host(hostname: str, host_allowlist: list):
for host_snippet in host_allowlist:
Expand All @@ -102,16 +106,32 @@ def is_allowed_neptune_host(hostname: str, host_allowlist: list):
return False


def get_gremlin_serializer(serializer_str: str):
serializer_lower = serializer_str.lower()
if serializer_lower == 'graphbinaryv1':
return serializer.GraphBinarySerializersV1()
elif serializer_lower == 'graphsonv2':
return serializer.GraphSONSerializersV2d0()
else:
return serializer.GraphSONSerializersV3d0()


class Client(object):
def __init__(self, host: str, port: int = DEFAULT_PORT, ssl: bool = True, region: str = DEFAULT_REGION,
sparql_path: str = '/sparql', gremlin_traversal_source: str = 'g', auth=None, session: Session = None,
sparql_path: str = '/sparql', gremlin_traversal_source: str = DEFAULT_GREMLIN_TRAVERSAL_SOURCE,
gremlin_username: str = '', gremlin_password: str = '',
gremlin_serializer: str = DEFAULT_GREMLIN_SERIALIZER,
auth=None, session: Session = None,
proxy_host: str = '', proxy_port: int = DEFAULT_PORT,
neptune_hosts: list = None):
self.target_host = host
self.target_port = port
self.ssl = ssl
self.sparql_path = sparql_path
self.gremlin_traversal_source = gremlin_traversal_source
self.gremlin_username = gremlin_username
self.gremlin_password = gremlin_password
self.gremlin_serializer = get_gremlin_serializer(gremlin_serializer)
self.region = region
self._auth = auth
self._session = session
Expand Down Expand Up @@ -223,7 +243,9 @@ def get_gremlin_connection(self, transport_kwargs) -> client.Client:
ws_url = f'{self.get_uri_with_port(use_websocket=True)}/gremlin'
request = self._prepare_request('GET', ws_url)
traversal_source = 'g' if self.is_neptune_domain() else self.gremlin_traversal_source
return client.Client(ws_url, traversal_source, headers=dict(request.headers), **transport_kwargs)
return client.Client(ws_url, traversal_source, username=self.gremlin_username,
password=self.gremlin_password, message_serializer=self.gremlin_serializer,
headers=dict(request.headers), **transport_kwargs)

def gremlin_query(self, query, transport_args=None, bindings=None):
if transport_args is None:
Expand Down Expand Up @@ -735,6 +757,15 @@ def with_gremlin_traversal_source(self, traversal_source: str):
self.args['gremlin_traversal_source'] = traversal_source
return ClientBuilder(self.args)

def with_gremlin_login(self, username: str, password: str):
self.args['gremlin_username'] = username
self.args['gremlin_password'] = password
return ClientBuilder(self.args)

def with_gremlin_serializer(self, message_serializer: str):
self.args['gremlin_serializer'] = message_serializer
return ClientBuilder(self.args)

def with_tls(self, tls: bool):
self.args['ssl'] = tls
return ClientBuilder(self.args)
Expand Down
7 changes: 5 additions & 2 deletions test/integration/IntegrationTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def setup_client_builder(config: Configuration) -> ClientBuilder:
.with_proxy_host(config.proxy_host) \
.with_proxy_port(config.proxy_port) \
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)
.with_gremlin_traversal_source(config.gremlin.traversal_source) \
.with_gremlin_serializer(config.gremlin.message_serializer)
if config.auth_mode == AuthModeEnum.IAM:
builder = builder.with_iam(get_session())
else:
Expand All @@ -34,7 +35,9 @@ def setup_client_builder(config: Configuration) -> ClientBuilder:
.with_proxy_host(config.proxy_host) \
.with_proxy_port(config.proxy_port) \
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source)
.with_gremlin_traversal_source(config.gremlin.traversal_source) \
.with_gremlin_login(config.gremlin.username, config.gremlin.password) \
.with_gremlin_serializer(config.gremlin.message_serializer)

return builder

Expand Down
5 changes: 5 additions & 0 deletions test/integration/iam/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def setup_iam_client(config: Configuration) -> Client:
.with_proxy_port(config.proxy_port) \
.with_sparql_path(config.sparql.path) \
.with_gremlin_traversal_source(config.gremlin.traversal_source) \
.with_gremlin_login(config.gremlin.username, config.gremlin.password) \
.with_gremlin_serializer(config.gremlin.message_serializer) \
.with_iam(get_session()) \
.build()

Expand All @@ -29,5 +31,8 @@ def setup_iam_client(config: Configuration) -> Client:
assert client.proxy_port == config.proxy_port
assert client.sparql_path == config.sparql.path
assert client.gremlin_traversal_source == config.gremlin.traversal_source
assert client.gremlin_username == config.gremlin.username
assert client.gremlin_password == config.gremlin.password
assert client.gremlin_serializer == config.gremlin.message_serializer
assert client.ssl is config.ssl
return client