Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): add oauth with microsoft #71

Merged
merged 19 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 80 additions & 57 deletions nesis/api/core/services/management.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import os
from typing import Optional

import bcrypt
import memcache
from sqlalchemy.orm import Session
from strgen import StringGenerator as SG

import logging
Expand Down Expand Up @@ -64,54 +66,96 @@ def create(self, **kwargs) -> UserSession:
self.__LOG.debug(f"Received session object {kwargs}")
email = user_session.get("email")
password = user_session.get("password")

if not all([email, password]):
# invalid auth case
# fail-safe. should never reach here.
raise UnauthorizedAccess("Missing email and password")
session_oauth_token_value = user_session.get(
os.environ.get("NESIS_OAUTH_TOKEN_KEY")
)
oauth_token_value = os.environ.get("NESIS_OAUTH_TOKEN_VALUE")

try:

users = session.query(User).filter_by(email=email).all()
if len(users) != 1:
raise UnauthorizedAccess("User not found")
user_dict = users[0].to_dict()
attributes = user_dict["attributes"]

# password based auth
db_pass = users[0].password
user_password = password.encode("utf-8")
if not bcrypt.checkpw(user_password, db_pass):
raise UnauthorizedAccess("Invalid email/password")

# update last login details.
db_user = users[0]
# session.add(db_user)
# session.commit()

token = SG("[\l\d]{128}").render()
session_token = self.__cache_key(token)
expiry = (
self.__config["memcache"].get("session", {"expiry": 0}).get("expiry", 0)
)
if self.__cache.get(session_token) is None:
self.__cache.set(session_token, user_dict, time=expiry)

while self.__cache.get(session_token)["id"] != user_dict["id"]:
token = SG("[\l\d]{128}").render()
session_token = self.__cache_key(token)
self.__cache.set(session_token, user_dict, time=expiry)
if all([email, password]):
users = session.query(User).filter_by(email=email).all()
if len(users) != 1:
raise UnauthorizedAccess("User not found")
user_dict = users[0].to_dict()
attributes = user_dict["attributes"]

# password based auth
db_pass = users[0].password
user_password = password.encode("utf-8")
if not bcrypt.checkpw(user_password, db_pass):
raise UnauthorizedAccess("Invalid email/password")
# update last login details.
db_user = users[0]

return self.__create_user_session(db_user)
elif all([email, session_oauth_token_value, oauth_token_value]):
if session_oauth_token_value != oauth_token_value:
raise UnauthorizedAccess("Invalid oauth token value")
secrets = SG(r"[\l\d]{30}").render_list(1, unique=True)

return UserSession(token=token, expiry=expiry, user=db_user)
try:
entity = _create_user(
root=False,
session=session,
user={**user_session, "password": secrets[0]},
)
# update last login details.
db_user = entity
except ConflictException:
db_user = session.query(User).filter_by(email=email).first()
return self.__create_user_session(db_user)
else:
raise UnauthorizedAccess("Missing email and password")

finally:
if session:
session.close()

def __create_user_session(self, db_user: User):
user_dict = db_user.to_dict()
token = SG("[\l\d]{128}").render()
session_token = self.__cache_key(token)
expiry = (
self.__config["memcache"].get("session", {"expiry": 0}).get("expiry", 0)
)
if self.__cache.get(session_token) is None:
self.__cache.set(session_token, user_dict, time=expiry)
while self.__cache.get(session_token)["id"] != user_dict["id"]:
token = SG("[\l\d]{128}").render()
session_token = self.__cache_key(token)
self.__cache.set(session_token, user_dict, time=expiry)
return UserSession(token=token, expiry=expiry, user=db_user)

def update(self, **kwargs):
raise NotImplementedError("Invalid operation on datasource")


def _create_user(session: Session, user: dict, root: bool):
name = user.get("name")
email = user.get("email")
password = user.get("password")
if not all([email, password, name]):
raise ServiceException("name, email and password must be supplied")
password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt())
entity = User(name=name, email=email, password=password, root=root)
session.add(entity)

try:
session.commit()
session.refresh(entity)
except Exception as exc:
session.rollback()
error_str = str(exc).lower()

if ("unique constraint" in error_str) and ("uq_user_email" in error_str):
# valid failure
raise ConflictException("User already exists")
else:
raise
return entity


class UserService(ServiceOperation):
"""
Manage system users
Expand Down Expand Up @@ -154,19 +198,7 @@ def create(self, **kwargs):
resource_type=self._resource_type,
)

name = user.get("name")
email = user.get("email")
password = user.get("password")

if not all([email, password, name]):
raise ServiceException("name, email and password must be supplied")

password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt())

entity = User(name=name, email=email, password=password, root=root)
session.add(entity)
session.commit()
session.refresh(entity)
entity = _create_user(root=root, session=session, user=user)

if not root:
for user_role in user.get("roles") or []:
Expand All @@ -175,15 +207,6 @@ def create(self, **kwargs):
)

return entity
except Exception as exc:
session.rollback()
error_str = str(exc).lower()

if ("unique constraint" in error_str) and ("uq_user_email" in error_str):
# valid failure
raise ConflictException("User already exists")
else:
raise
finally:
if session:
session.close()
Expand Down
73 changes: 71 additions & 2 deletions nesis/api/tests/core/controllers/test_management_users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import uuid

import pytest
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -28,10 +29,9 @@ def client():
def test_create_user(client):
# Get the prediction
data = {
"name": "s3 documents",
"name": "Full Name",
"password": tests.admin_password,
"email": tests.admin_email,
"root": True,
}

response = client.post(
Expand Down Expand Up @@ -75,6 +75,75 @@ def test_create_user(client):
return response.json


def test_create_user_oauth(client):
# Test that if authentication with oauth key/value pair

oauth_token_key = "____nesis_test_oath_key___"
oauth_token_value = str(uuid.uuid4())
os.environ["NESIS_OAUTH_TOKEN_KEY"] = oauth_token_key
os.environ["NESIS_OAUTH_TOKEN_VALUE"] = oauth_token_value

user = test_create_user(client=client)

data = {
"name": "Full Name",
"email": user["email"],
}

# No oauth tokens supplied and no password supplied, so we must fail
response = client.post(
"/v1/sessions", headers=tests.get_header(), data=json.dumps(data)
)
assert 401 == response.status_code
assert response.json.get("token") is None

# Invalid oauth tokens supplied and no password supplied so we must fail
response = client.post(
"/v1/sessions",
headers=tests.get_header(),
data=json.dumps({**data, oauth_token_key: str(uuid.uuid4())}),
)
assert 401 == response.status_code
assert response.json.get("token") is None

response = client.post(
"/v1/sessions",
headers=tests.get_header(),
data=json.dumps({**data, oauth_token_key: oauth_token_value}),
)
assert 200 == response.status_code
assert response.json.get("token") is not None

admin_session = response.json

# Now get the list of users, we should have just the one
response = client.get(
f"/v1/users", headers=tests.get_header(token=admin_session["token"])
)
assert 1 == len(response.json["items"])

# Authenticate as another user
response = client.post(
"/v1/sessions",
headers=tests.get_header(),
data=json.dumps(
{
**data,
oauth_token_key: oauth_token_value,
"email": "[email protected]",
}
),
)
assert 200 == response.status_code
assert response.json.get("token") is not None

# Now get the list of users, we should have two users
response = client.get(
f"/v1/users", headers=tests.get_header(token=admin_session["token"])
)
assert 2 == len(response.json["items"])


def test_create_users(client):
# Get the prediction
user = test_create_user(client=client)
Expand Down