From 304b9c03fcafaa093e6fd83eb61d753e3cd0a562 Mon Sep 17 00:00:00 2001 From: Lars Holm Nielsen Date: Thu, 14 Dec 2023 02:16:10 +0100 Subject: [PATCH] global: new domain list feature * New domain list feature that can be used to block email domains from registering, and automatically verifying users from other domains. --- .../alembic/6ec5ce377ca3_create_domains.py | 95 +++++++++++ invenio_accounts/datastore.py | 12 +- invenio_accounts/domains.py | 32 ++++ invenio_accounts/ext.py | 5 +- invenio_accounts/forms.py | 13 +- invenio_accounts/models.py | 157 +++++++++++++++++- invenio_accounts/utils.py | 53 ++++++ invenio_accounts/views/rest.py | 12 +- tests/test_models.py | 74 ++++++++- tests/test_views.py | 24 ++- tests/test_views_rest.py | 20 ++- 11 files changed, 489 insertions(+), 8 deletions(-) create mode 100644 invenio_accounts/alembic/6ec5ce377ca3_create_domains.py create mode 100644 invenio_accounts/domains.py diff --git a/invenio_accounts/alembic/6ec5ce377ca3_create_domains.py b/invenio_accounts/alembic/6ec5ce377ca3_create_domains.py new file mode 100644 index 00000000..d0a74fe2 --- /dev/null +++ b/invenio_accounts/alembic/6ec5ce377ca3_create_domains.py @@ -0,0 +1,95 @@ +# +# This file is part of Invenio. +# Copyright (C) 2016-2018 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Create tables for domain list feature.""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql +from sqlalchemy_utils import JSONType + +# revision identifiers, used by Alembic. +revision = "6ec5ce377ca3" +down_revision = "037afe10e9ff" +branch_labels = () +depends_on = None + + +def upgrade(): + """Upgrade database.""" + op.create_table( + "accounts_domain_category", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("label", sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_accounts_domain_category")), + ) + op.create_table( + "accounts_domain_org", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("pid", sa.String(length=255), nullable=True), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column( + "json", + sa.JSON() + .with_variant(JSONType(), "mysql") + .with_variant( + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), "postgresql" + ) + .with_variant(JSONType(), "sqlite"), + nullable=False, + ), + sa.Column("parent_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["parent_id"], + ["accounts_domain_org.id"], + name=op.f("fk_accounts_domain_org_parent_id_accounts_domain_org"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_accounts_domain_org")), + sa.UniqueConstraint("pid", name=op.f("uq_accounts_domain_org_pid")), + ) + op.create_table( + "accounts_domains", + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("domain", sa.String(length=255), nullable=False), + sa.Column("tld", sa.String(length=255), nullable=False), + sa.Column("status", sa.Integer(), nullable=False), + sa.Column("flagged", sa.Boolean(), nullable=False), + sa.Column("flagged_source", sa.String(length=255), nullable=False), + sa.Column("org_id", sa.Integer(), nullable=True), + sa.Column("category", sa.Integer(), nullable=True), + sa.Column("num_users", sa.Integer(), nullable=False), + sa.Column("num_active", sa.Integer(), nullable=False), + sa.Column("num_inactive", sa.Integer(), nullable=False), + sa.Column("num_confirmed", sa.Integer(), nullable=False), + sa.Column("num_verified", sa.Integer(), nullable=False), + sa.Column("num_blocked", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["category"], + ["accounts_domain_category.id"], + name=op.f("fk_accounts_domains_category_accounts_domain_category"), + ), + sa.ForeignKeyConstraint( + ["org_id"], + ["accounts_domain_org.id"], + name=op.f("fk_accounts_domains_org_id_accounts_domain_org"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_accounts_domains")), + sa.UniqueConstraint("domain", name=op.f("uq_accounts_domains_domain")), + ) + op.add_column( + "accounts_user", sa.Column("domain", sa.String(length=255), nullable=True) + ) + + +def downgrade(): + """Downgrade database.""" + op.drop_column("accounts_user", "domain") + op.drop_table("accounts_domains") + op.drop_table("accounts_domain_org") + op.drop_table("accounts_domain_category") diff --git a/invenio_accounts/datastore.py b/invenio_accounts/datastore.py index d736a818..4e274780 100644 --- a/invenio_accounts/datastore.py +++ b/invenio_accounts/datastore.py @@ -10,7 +10,7 @@ from flask_security import SQLAlchemyUserDatastore -from .models import Role +from .models import Domain, Role from .proxies import current_db_change_history from .sessions import delete_user_sessions from .signals import datastore_post_commit, datastore_pre_commit @@ -65,3 +65,13 @@ def create_role(self, **kwargs): def find_role_by_id(self, role_id): """Fetches roles searching by id.""" return self.role_model.query.filter_by(id=role_id).one_or_none() + + def find_domain(self, domain): + """Find a domain.""" + return Domain.query.filter_by(domain=domain).one_or_none() + + def create_domain(self, domain, **kwargs): + """Create a new domain.""" + domain = Domain.create(domain, **kwargs) + self.put(domain) + return domain diff --git a/invenio_accounts/domains.py b/invenio_accounts/domains.py new file mode 100644 index 00000000..8508316c --- /dev/null +++ b/invenio_accounts/domains.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# +# This file is part of Invenio. +# Copyright (C) 2023 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Domain blocking listener.""" + +from .models import DomainStatus + + +def on_user_confirmed(app, user): + """Listener for when a user is confirmed.""" + security = app.extensions["security"] + datastore = security.datastore + + # Domain is inserted on domain list when a user confirms their email. + domain = datastore.find_domain(user.domain) + if domain is None: + domain = datastore.create_domain(user.domain) + + # Verify user if domain is verified. + if domain.status == DomainStatus.verified: + user.verified_at = security.datetime_factory() + # Happens if e.g. user register an account, domain is later blocked, + # and user requests to resend email confirmation or link is still valid. + elif domain.status == DomainStatus.blocked: + user.blocked_at = security.datetime_factory() + user.active = False + user.verified_at = None diff --git a/invenio_accounts/ext.py b/invenio_accounts/ext.py index dc52298f..265de4d4 100644 --- a/invenio_accounts/ext.py +++ b/invenio_accounts/ext.py @@ -16,7 +16,7 @@ from flask_kvsession import KVSessionExtension from flask_login import LoginManager, user_logged_in, user_logged_out from flask_principal import AnonymousIdentity -from flask_security import Security +from flask_security import Security, user_confirmed from invenio_db import db from passlib.registry import register_crypt_handler from werkzeug.utils import cached_property @@ -30,6 +30,7 @@ from . import config from .datastore import SessionAwareSQLAlchemyUserDatastore +from .domains import on_user_confirmed from .hash import InvenioAesEncryptedEmail from .models import Role, User from .sessions import csrf_token_reset, login_listener, logout_listener @@ -198,6 +199,8 @@ def delay_security_email(msg): if app.config.get("ACCOUNTS_USERINFO_HEADERS"): request_finished.connect(set_session_info, app) + user_confirmed.connect(on_user_confirmed, app) + # Set Session KV store session_kvstore_factory = obj_or_import_string( app.config["ACCOUNTS_SESSION_STORE_FACTORY"] diff --git a/invenio_accounts/forms.py b/invenio_accounts/forms.py index 05859b4e..febe2405 100644 --- a/invenio_accounts/forms.py +++ b/invenio_accounts/forms.py @@ -17,7 +17,8 @@ from invenio_i18n import gettext as _ from wtforms import FormField, HiddenField -from invenio_accounts.proxies import current_datastore +from .proxies import current_datastore +from .utils import validate_domain class RegistrationFormRecaptcha(FlaskForm): @@ -44,6 +45,16 @@ def __init__(self, *args, **kwargs): if not self.next.data: self.next.data = request.args.get("next", "") + def validate(self, extra_validators=None): + """Validate domain on email list.""" + if not super().validate(extra_validators=extra_validators): + return False + + if not validate_domain(self.email.data): + self.email.errors.append(_("The email domain is blocked.")) + return False + return True + if app.config.get("RECAPTCHA_PUBLIC_KEY") and app.config.get( "RECAPTCHA_PRIVATE_KEY" ): diff --git a/invenio_accounts/models.py b/invenio_accounts/models.py index 58a405aa..006983b1 100644 --- a/invenio_accounts/models.py +++ b/invenio_accounts/models.py @@ -25,7 +25,7 @@ from .errors import AlreadyLinkedError from .profiles import UserPreferenceDict, UserProfileDict -from .utils import validate_username +from .utils import DomainStatus, split_emailaddr, validate_username json_field = ( db.JSON() @@ -103,6 +103,9 @@ class User(db.Model, Timestamp, UserMixin): _email = db.Column("email", db.String(255), unique=True) """User email.""" + domain = db.Column(db.String(255), nullable=True) + """Domain of email.""" + password = db.Column(db.String(255)) """User password.""" @@ -212,6 +215,8 @@ def email(self): def email(self, email): """Set lowercase email.""" self._email = email.lower() + prefix, domain = split_emailaddr(email) + self.domain = domain @hybrid_property def user_profile(self): @@ -463,3 +468,153 @@ def delete_by_user(cls, method, user): """Unlink a user from an external id.""" with db.session.begin_nested(): cls.query.filter_by(id_user=user.id, method=method).delete() + + +class DomainOrg(db.Model): + """Domain organisation.""" + + __tablename__ = "accounts_domain_org" + + id = db.Column(db.Integer(), primary_key=True, autoincrement=True) + + pid = db.Column(db.String(255), unique=True, nullable=True) + """Persistent identifier for organisation.""" + + name = db.Column(db.String(255), nullable=False) + """Name of organisation.""" + + json = db.Column( + json_field, + default=lambda: dict(), + nullable=False, + ) + """Store additional metadata about the organisation.""" + + parent_id = db.Column( + db.Integer(), db.ForeignKey("accounts_domain_org.id"), nullable=True + ) + """Link to parent organisation.""" + + parent = db.relationship("DomainOrg", remote_side=[id]) + """Relationship to parent.""" + + domains = db.relationship("Domain", back_populates="org") + """Relationship to domains for this organisation.""" + + @classmethod + def create(cls, pid, name, json=None, parent=None): + """Create a domain organisation.""" + obj = cls(pid=pid, name=name, json=json or {}, parent=parent) + db.session.add(obj) + return obj + + +class DomainCategory(db.Model): + """Model for storing different domain categories.""" + + __tablename__ = "accounts_domain_category" + + id = db.Column(db.Integer(), primary_key=True, autoincrement=True) + + label = db.Column(db.String(255)) + + @classmethod + def create(cls, label): + """Create a new domain category.""" + obj = cls(label=label) + db.session.add(obj) + return obj + + @classmethod + def get(cls, label): + """Get a domain category.""" + return cls.query.filter_by(label=label).one_or_none() + + +class Domain(db.Model, Timestamp): + """User domains model.""" + + __tablename__ = "accounts_domains" + + id = db.Column(db.Integer(), primary_key=True, autoincrement=True) + """Domain ID""" + + _domain = db.Column("domain", db.String(255), unique=True, nullable=False) + """Domain name.""" + + tld = db.Column(db.String(255), nullable=False) + """Top-level domain.""" + + status = db.Column(db.Enum(DomainStatus), default=DomainStatus.new, nullable=False) + """Status of domain. + + Use to control possibility and capability of users registering with this domain. + """ + + flagged = db.Column(db.Boolean(), default=False, nullable=False) + """Flag domain - used by automatic processes to flag domain.""" + + flagged_source = db.Column(db.String(255), default="", nullable=False) + """Source of flag.""" + + org_id = db.Column(db.Integer(), db.ForeignKey(DomainOrg.id), nullable=True) + """Organisation associated with domain.""" + + org = db.relationship("DomainOrg", back_populates="domains") + + # spammer, mail-provider, organisation, company + category = db.Column(db.Integer(), db.ForeignKey(DomainCategory.id), nullable=True) + """Category of domain.""" + + num_users = db.Column(db.Integer(), default=0, nullable=False) + """Computed property to store number of users in domain.""" + + num_active = db.Column(db.Integer(), default=0, nullable=False) + """Computed property to store number of active users in domain.""" + + num_inactive = db.Column(db.Integer(), default=0, nullable=False) + """Computed property to store number of inactive users in domain.""" + + num_confirmed = db.Column(db.Integer(), default=0, nullable=False) + """Computed property to store number of confirmed users in domain.""" + + num_verified = db.Column(db.Integer(), default=0, nullable=False) + """Computed property to store number of verified users in domain.""" + + num_blocked = db.Column(db.Integer(), default=0, nullable=False) + """Computed property to store number of blocked users in domain.""" + + @classmethod + def create( + cls, + domain, + status=DomainStatus.new, + flagged=False, + flagged_source="", + org=None, + category=None, + ): + """Create a new domain.""" + obj = cls( + domain=domain, + status=status, + flagged=flagged, + flagged_source=flagged_source, + org=org, + category=category, + ) + db.session.add(obj) + return obj + + @hybrid_property + def domain(self): + """Get domain name.""" + return self._domain + + @domain.setter + def domain(self, value): + """Set domain name.""" + if value[-1] == ".": + value = value[:-1] + self._domain = value.lower() + self.tld = self._domain.split(".")[-1] diff --git a/invenio_accounts/utils.py b/invenio_accounts/utils.py index 53b33c97..6701c7ec 100644 --- a/invenio_accounts/utils.py +++ b/invenio_accounts/utils.py @@ -8,6 +8,7 @@ """Utility function for ACCOUNTS.""" +import enum import re import uuid from datetime import datetime @@ -20,14 +21,36 @@ from flask_security.signals import password_changed, user_registered from flask_security.utils import config_value as security_config_value from flask_security.utils import get_security_endpoint_name, hash_password, send_mail +from invenio_db import db +from invenio_i18n import gettext as _ from jwt import DecodeError, ExpiredSignatureError, decode, encode from werkzeug.routing import BuildError from werkzeug.utils import import_string +from wtforms import ValidationError from .errors import JWTDecodeError, JWTExpiredToken from .proxies import current_datastore, current_security +class DomainStatus(enum.Enum): + """Domain status. + + The domain status controls if new users can register and their verification status. + """ + + new = 1 + """User registration is allowed - new domain requiring review.""" + + moderated = 2 + """User registration is allowed and users are automatically verified.""" + + verified = 3 + """User registration is allowed and users are automatically verified.""" + + blocked = 4 + """User registration from domain is blocked.""" + + def jwt_create_token(user_id=None, additional_data=None): """Encode the JWT token. @@ -216,3 +239,33 @@ def validate_username(username): # text explaining the validation rules. message = current_app.config["ACCOUNTS_USERNAME_RULES_TEXT"] raise ValueError(message) + + +def validate_domain_form(form, field): + """Validator for use with WTForm.""" + if not validate_domain(field.data): + raise ValidationError(_("The email domain is blocked.")) + + +def validate_domain(email): + """Validate the domain of email address.""" + email = email.lower() + try: + prefix, domain = split_emailaddr(email) + except ValueError: + return False + with db.session.no_autoflush: + domain = current_datastore.find_domain(domain) + if domain is not None and domain.status == DomainStatus.blocked: + return False + return True + + +def split_emailaddr(email): + """Split email address in prefix and domain.""" + prefix, domain = email.split("@", 1) + prefix = prefix.lower().strip() + domain = domain.lower().strip() + if domain[-1] == ".": + domain = domain[:-1] + return prefix, domain diff --git a/invenio_accounts/views/rest.py b/invenio_accounts/views/rest.py index d103aa31..7057d6dc 100644 --- a/invenio_accounts/views/rest.py +++ b/invenio_accounts/views/rest.py @@ -32,6 +32,7 @@ ) from flask_security.views import logout from invenio_db import db +from invenio_i18n import gettext as _ from invenio_rest.errors import FieldError, RESTValidationError from webargs import ValidationError, fields, validate from webargs.flaskparser import FlaskParser as FlaskParserBase @@ -46,6 +47,7 @@ default_reset_password_link_func, obj_or_import_string, register_user, + validate_domain, ) @@ -210,6 +212,12 @@ def default_user_payload(user): } +def validate_domain_rest(email): + """Validator for use with WTForm.""" + if not validate_domain(email): + raise ValidationError(_("The email domain is blocked.")) + + def _abort(message, field=None, status=None): if field: raise RESTValidationError([FieldError(field, message)]) @@ -313,7 +321,9 @@ class RegisterView(MethodView): decorators = [user_already_authenticated] post_args = { - "email": fields.Email(required=True, validate=[unique_user_email]), + "email": fields.Email( + required=True, validate=[unique_user_email, validate_domain_rest] + ), "password": fields.String( required=True, validate=[validate.Length(min=6, max=128)] ), diff --git a/tests/test_models.py b/tests/test_models.py index 67cee9a1..1c2ec41c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,14 @@ from sqlalchemy import inspect from invenio_accounts import testutils -from invenio_accounts.models import SessionActivity, User +from invenio_accounts.models import ( + Domain, + DomainCategory, + DomainOrg, + DomainStatus, + SessionActivity, + User, +) class CustomProfile(Schema): @@ -137,3 +144,68 @@ def test_custom_profiles(app): user.user_profile["file_descriptor"] = "1" assert dict(user.user_profile) == {"file_descriptor": 1} + + +def test_user_domain_attr(app): + u = User(email="admin@CERN.CH") + db.session.commit() + assert u.domain == "cern.ch" + + +def test_domain_model(app): + d = Domain.create("CERN.CH") + assert d.domain == "cern.ch" + assert d.tld == "ch" + assert d.status == DomainStatus.new + assert d.flagged == False + assert d.flagged_source == "" + assert d.org is None + assert d.category is None + db.session.commit() + + # Support top level domains like + d = Domain.create("cern") + assert d.domain == "cern" + assert d.tld == "cern" + db.session.commit() + + # Normalise domain names + d = Domain.create("zenodo.org.") + assert d.domain == "zenodo.org" + assert d.tld == "org" + db.session.commit() + + with pytest.raises(Exception): + Domain.create("cern.ch.") + db.session.commit() + + +def test_domain_org(app): + parent = DomainOrg.create( + "https://ror.org/01cwqze88", + "National Institutes of Health", + json={"country": "us"}, + ) + + child = DomainOrg.create( + "https://ror.org/040gcmg81", + "National Cancer Institute", + json={"country": "us"}, + parent=parent, + ) + db.session.commit() + + d = Domain.create("cancer.gov", status=DomainStatus.verified, org=child) + db.session.commit() + + assert d.org == child + assert child.parent == parent + + +def test_domain_category(app): + c1 = DomainCategory.create("spammer") + c2 = DomainCategory.create("organisation") + db.session.commit() + + c = DomainCategory.get("spammer") + assert c.label == "spammer" diff --git a/tests/test_views.py b/tests/test_views.py index 6568300d..0b623650 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -15,9 +15,10 @@ from flask_security import url_for_security from flask_security.forms import LoginForm from flask_security.views import _security +from invenio_db import db from invenio_i18n import gettext as _ -from invenio_accounts.models import SessionActivity +from invenio_accounts.models import Domain, DomainStatus, SessionActivity from invenio_accounts.testutils import create_test_user @@ -70,6 +71,27 @@ def test_no_log_in_message_for_logged_in_users(app): assert resp.data == client.get(app.config["SECURITY_POST_LOGIN_VIEW"]).data +def test_registration_blocked(app): + """Test blocking of domain.""" + with app.app_context(): + forgot_password_url = url_for_security("forgot_password") + Domain.create("inveniosoftware.org", status=DomainStatus.blocked) + db.session.commit() + + with app.test_client() as client: + test_email = "info@inveniosoftware.org" + test_password = "test1234" + resp = client.post( + url_for_security("register"), + data=dict( + email=test_email, + password=test_password, + ), + environ_base={"REMOTE_ADDR": "127.0.0.1"}, + ) + assert "The email domain is blocked." in resp.text + + def test_view_list_sessions(app): """Test view list sessions.""" with app.test_request_context(): diff --git a/tests/test_views_rest.py b/tests/test_views_rest.py index abb5d9fe..3bfae899 100644 --- a/tests/test_views_rest.py +++ b/tests/test_views_rest.py @@ -16,6 +16,7 @@ from flask_security import current_user from invenio_db import db +from invenio_accounts.models import Domain, DomainStatus from invenio_accounts.testutils import create_test_user @@ -211,6 +212,21 @@ def test_custom_registration_view(app_with_flexible_registration): assert res.status_code == 200 +def test_registration_view_blocked_by_domain(api): + app = api + with app.app_context(): + # Block the domain + Domain.create("test.com", status=DomainStatus.blocked) + db.session.commit() + with app.test_client() as client: + url = url_for("invenio_accounts_rest_auth.register") + + res = client.post( + url, data=dict(email="new@test.com", password="123456", active=True) + ) + assert_error_resp(res, [("email", "blocked")]) + + def test_logout_view(api): app = api with app.app_context(): @@ -424,7 +440,7 @@ def test_confirm_email_view(api): confirmed_user = create_test_user( email="confirmed@test.com", confirmed_at=datetime.datetime.now() ) - + Domain.create("test.com", status=DomainStatus.verified) db.session.commit() # Generate token token = generate_confirmation_token(normal_user) @@ -448,6 +464,8 @@ def test_confirm_email_view(api): payload = get_json(res) assert "your email has been confirmed" in payload["message"].lower() assert normal_user.confirmed_at + # User is verified because domain is verified. + assert normal_user.verified_at def test_sessions_list_view(api):