Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use boto3 to get credentials. #39

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 27 additions & 42 deletions jupyter_drives/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import entrypoints
from traitlets import Enum, Unicode, default
from traitlets.config import Configurable
import boto3

# Supported third-party services
MANAGERS = {}
Expand Down Expand Up @@ -42,8 +43,9 @@ class DrivesConfig(Configurable):
)

region_name = Unicode(
"eu-north-1",
config = True,
None,
config = True,
allow_none=True,
help = "Region name.",
)

Expand All @@ -52,13 +54,6 @@ class DrivesConfig(Configurable):
help="Base URL of the provider service REST API.",
)

custom_credentials_path = Unicode(
None,
config = True,
allow_none = True,
help="Custom path of file where credentials are located. Extension automatically checks jupyter_notebook_config.py or directly in ~/.aws/credentials for AWS CLI users."
)

@default("api_base_url")
def set_default_api_base_url(self):
# for AWS S3 drives
Expand All @@ -80,25 +75,27 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self._load_credentials()

def _load_credentials(self):
def _load_credentials(self):
# check if credentials were already set in jupyter_notebook_config.py
if self.access_key_id is not None and self.secret_access_key is not None:
return

# check if user provided custom path for credentials extraction
if self.custom_credentials_path is None and "JP_DRIVES_CUSTOM_CREDENTIALS_PATH" in os.environ:
self.custom_credentials_path = os.environ["JP_DRIVES_CUSTOM_CREDENTIALS_PATH"]
if self.custom_credentials_path is not None:
self.provider, self.access_key_id, self.secret_access_key, self.session_token = self._extract_credentials_from_file(self.custom_credentials_path)
return

# if not, try to load credentials from AWS CLI
aws_credentials_path = "~/.aws/credentials" #add read me about credentials path in windows: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html
if os.path.exists(aws_credentials_path):
self.access_key_id, self.secret_access_key, self.session_token = self._extract_credentials_from_file(aws_credentials_path)
# automatically extract credentials for S3 drives
try:
s = boto3.Session()
c = s.get_credentials()
if c is not None:
self.access_key_id = c.access_key
self.secret_access_key = c.secret_key
self.region_name = s.region_name
self.session_token = c.token
self.provider = 's3'
return

# as a last resort, use environment variables
except:
# S3 credentials couldn't automatically be extracted through boto
pass

# use environment variables
if "JP_DRIVES_ACCESS_KEY_ID" in os.environ and "JP_DRIVES_SECRET_ACCESS_KEY" in os.environ:
self.access_key_id = os.environ["JP_DRIVES_ACCESS_KEY_ID"]
self.secret_access_key = os.environ["JP_DRIVES_SECRET_ACCESS_KEY"]
Expand All @@ -107,22 +104,10 @@ def _load_credentials(self):
if "JP_DRIVES_PROVIDER" in os.environ:
self.provider = os.environ["JP_DRIVES_PROVIDER"]
return

def _extract_credentials_from_file(self, file_path):
try:
with open(file_path, 'r') as file:
provider, access_key_id, secret_access_key, session_token = None, None, None, None
lines = file.readlines()
for line in lines:
if line.startswith("drives_provider ="):
provider = line.split("=")[1].strip()
elif line.startswith("drives_access_key_id ="):
access_key_id = line.split("=")[1].strip()
elif line.startswith("drives_secret_access_key ="):
secret_access_key = line.split("=")[1].strip()
elif line.startswith("drives_session_token ="):
session_token = line.split("=")[1].strip()
return provider, access_key_id, secret_access_key, session_token
except Exception as e:
print(f"Failed loading credentials from {file_path}: {e}")
return

s = boto3.Session()
c = s.get_credentials()
self.access_key_id = c.access_key
self.secret_access_key = c.secret_key
self.region_name = s.region_name
self.session_token = c.token
Loading