diff --git a/nesis/api/core/services/management.py b/nesis/api/core/services/management.py index 50e6741..ba18dec 100644 --- a/nesis/api/core/services/management.py +++ b/nesis/api/core/services/management.py @@ -1,8 +1,10 @@ import json +import os from typing import Optional import bcrypt import memcache +from sqlalchemy.orm import Session from strgen import StringGenerator as SG import logging @@ -64,54 +66,96 @@ def create(self, **kwargs) -> UserSession: self.__LOG.debug(f"Received session object {kwargs}") email = user_session.get("email") password = user_session.get("password") - - if not all([email, password]): - # invalid auth case - # fail-safe. should never reach here. - raise UnauthorizedAccess("Missing email and password") + session_oauth_token_value = user_session.get( + os.environ.get("NESIS_OAUTH_TOKEN_KEY") + ) + oauth_token_value = os.environ.get("NESIS_OAUTH_TOKEN_VALUE") try: - users = session.query(User).filter_by(email=email).all() - if len(users) != 1: - raise UnauthorizedAccess("User not found") - user_dict = users[0].to_dict() - attributes = user_dict["attributes"] - - # password based auth - db_pass = users[0].password - user_password = password.encode("utf-8") - if not bcrypt.checkpw(user_password, db_pass): - raise UnauthorizedAccess("Invalid email/password") - - # update last login details. - db_user = users[0] - # session.add(db_user) - # session.commit() - - token = SG("[\l\d]{128}").render() - session_token = self.__cache_key(token) - expiry = ( - self.__config["memcache"].get("session", {"expiry": 0}).get("expiry", 0) - ) - if self.__cache.get(session_token) is None: - self.__cache.set(session_token, user_dict, time=expiry) - - while self.__cache.get(session_token)["id"] != user_dict["id"]: - token = SG("[\l\d]{128}").render() - session_token = self.__cache_key(token) - self.__cache.set(session_token, user_dict, time=expiry) + if all([email, password]): + users = session.query(User).filter_by(email=email).all() + if len(users) != 1: + raise UnauthorizedAccess("User not found") + user_dict = users[0].to_dict() + attributes = user_dict["attributes"] + + # password based auth + db_pass = users[0].password + user_password = password.encode("utf-8") + if not bcrypt.checkpw(user_password, db_pass): + raise UnauthorizedAccess("Invalid email/password") + # update last login details. + db_user = users[0] + + return self.__create_user_session(db_user) + elif all([email, session_oauth_token_value, oauth_token_value]): + if session_oauth_token_value != oauth_token_value: + raise UnauthorizedAccess("Invalid oauth token value") + secrets = SG(r"[\l\d]{30}").render_list(1, unique=True) - return UserSession(token=token, expiry=expiry, user=db_user) + try: + entity = _create_user( + root=False, + session=session, + user={**user_session, "password": secrets[0]}, + ) + # update last login details. + db_user = entity + except ConflictException: + db_user = session.query(User).filter_by(email=email).first() + return self.__create_user_session(db_user) + else: + raise UnauthorizedAccess("Missing email and password") finally: if session: session.close() + def __create_user_session(self, db_user: User): + user_dict = db_user.to_dict() + token = SG("[\l\d]{128}").render() + session_token = self.__cache_key(token) + expiry = ( + self.__config["memcache"].get("session", {"expiry": 0}).get("expiry", 0) + ) + if self.__cache.get(session_token) is None: + self.__cache.set(session_token, user_dict, time=expiry) + while self.__cache.get(session_token)["id"] != user_dict["id"]: + token = SG("[\l\d]{128}").render() + session_token = self.__cache_key(token) + self.__cache.set(session_token, user_dict, time=expiry) + return UserSession(token=token, expiry=expiry, user=db_user) + def update(self, **kwargs): raise NotImplementedError("Invalid operation on datasource") +def _create_user(session: Session, user: dict, root: bool): + name = user.get("name") + email = user.get("email") + password = user.get("password") + if not all([email, password, name]): + raise ServiceException("name, email and password must be supplied") + password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()) + entity = User(name=name, email=email, password=password, root=root) + session.add(entity) + + try: + session.commit() + session.refresh(entity) + except Exception as exc: + session.rollback() + error_str = str(exc).lower() + + if ("unique constraint" in error_str) and ("uq_user_email" in error_str): + # valid failure + raise ConflictException("User already exists") + else: + raise + return entity + + class UserService(ServiceOperation): """ Manage system users @@ -154,19 +198,7 @@ def create(self, **kwargs): resource_type=self._resource_type, ) - name = user.get("name") - email = user.get("email") - password = user.get("password") - - if not all([email, password, name]): - raise ServiceException("name, email and password must be supplied") - - password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()) - - entity = User(name=name, email=email, password=password, root=root) - session.add(entity) - session.commit() - session.refresh(entity) + entity = _create_user(root=root, session=session, user=user) if not root: for user_role in user.get("roles") or []: @@ -175,15 +207,6 @@ def create(self, **kwargs): ) return entity - except Exception as exc: - session.rollback() - error_str = str(exc).lower() - - if ("unique constraint" in error_str) and ("uq_user_email" in error_str): - # valid failure - raise ConflictException("User already exists") - else: - raise finally: if session: session.close() diff --git a/nesis/api/tests/core/controllers/test_management_users.py b/nesis/api/tests/core/controllers/test_management_users.py index 145e738..571ea9c 100644 --- a/nesis/api/tests/core/controllers/test_management_users.py +++ b/nesis/api/tests/core/controllers/test_management_users.py @@ -1,5 +1,6 @@ import json import os +import uuid import pytest from sqlalchemy.orm.session import Session @@ -28,10 +29,9 @@ def client(): def test_create_user(client): # Get the prediction data = { - "name": "s3 documents", + "name": "Full Name", "password": tests.admin_password, "email": tests.admin_email, - "root": True, } response = client.post( @@ -75,6 +75,75 @@ def test_create_user(client): return response.json +def test_create_user_oauth(client): + # Test that if authentication with oauth key/value pair + + oauth_token_key = "____nesis_test_oath_key___" + oauth_token_value = str(uuid.uuid4()) + os.environ["NESIS_OAUTH_TOKEN_KEY"] = oauth_token_key + os.environ["NESIS_OAUTH_TOKEN_VALUE"] = oauth_token_value + + user = test_create_user(client=client) + + data = { + "name": "Full Name", + "email": user["email"], + } + + # No oauth tokens supplied and no password supplied, so we must fail + response = client.post( + "/v1/sessions", headers=tests.get_header(), data=json.dumps(data) + ) + assert 401 == response.status_code + assert response.json.get("token") is None + + # Invalid oauth tokens supplied and no password supplied so we must fail + response = client.post( + "/v1/sessions", + headers=tests.get_header(), + data=json.dumps({**data, oauth_token_key: str(uuid.uuid4())}), + ) + assert 401 == response.status_code + assert response.json.get("token") is None + + response = client.post( + "/v1/sessions", + headers=tests.get_header(), + data=json.dumps({**data, oauth_token_key: oauth_token_value}), + ) + assert 200 == response.status_code + assert response.json.get("token") is not None + + admin_session = response.json + + # Now get the list of users, we should have just the one + response = client.get( + f"/v1/users", headers=tests.get_header(token=admin_session["token"]) + ) + assert 1 == len(response.json["items"]) + + # Authenticate as another user + response = client.post( + "/v1/sessions", + headers=tests.get_header(), + data=json.dumps( + { + **data, + oauth_token_key: oauth_token_value, + "email": "another.user.email@domain.com", + } + ), + ) + assert 200 == response.status_code + assert response.json.get("token") is not None + + # Now get the list of users, we should have two users + response = client.get( + f"/v1/users", headers=tests.get_header(token=admin_session["token"]) + ) + assert 2 == len(response.json["items"]) + + def test_create_users(client): # Get the prediction user = test_create_user(client=client)