Skip to content

Commit

Permalink
global: new domain list feature
Browse files Browse the repository at this point in the history
* New domain list feature that can be used to block email domains from
  registering, and automatically verifying users from other domains.
  • Loading branch information
lnielsen committed Dec 14, 2023
1 parent 04d5171 commit 304b9c0
Show file tree
Hide file tree
Showing 11 changed files with 489 additions and 8 deletions.
95 changes: 95 additions & 0 deletions invenio_accounts/alembic/6ec5ce377ca3_create_domains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Create tables for domain list feature."""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import JSONType

# revision identifiers, used by Alembic.
revision = "6ec5ce377ca3"
down_revision = "037afe10e9ff"
branch_labels = ()
depends_on = None


def upgrade():
"""Upgrade database."""
op.create_table(
"accounts_domain_category",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("label", sa.String(length=255), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_accounts_domain_category")),
)
op.create_table(
"accounts_domain_org",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("pid", sa.String(length=255), nullable=True),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"json",
sa.JSON()
.with_variant(JSONType(), "mysql")
.with_variant(
postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), "postgresql"
)
.with_variant(JSONType(), "sqlite"),
nullable=False,
),
sa.Column("parent_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["parent_id"],
["accounts_domain_org.id"],
name=op.f("fk_accounts_domain_org_parent_id_accounts_domain_org"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_accounts_domain_org")),
sa.UniqueConstraint("pid", name=op.f("uq_accounts_domain_org_pid")),
)
op.create_table(
"accounts_domains",
sa.Column("created", sa.DateTime(), nullable=False),
sa.Column("updated", sa.DateTime(), nullable=False),
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("domain", sa.String(length=255), nullable=False),
sa.Column("tld", sa.String(length=255), nullable=False),
sa.Column("status", sa.Integer(), nullable=False),
sa.Column("flagged", sa.Boolean(), nullable=False),
sa.Column("flagged_source", sa.String(length=255), nullable=False),
sa.Column("org_id", sa.Integer(), nullable=True),
sa.Column("category", sa.Integer(), nullable=True),
sa.Column("num_users", sa.Integer(), nullable=False),
sa.Column("num_active", sa.Integer(), nullable=False),
sa.Column("num_inactive", sa.Integer(), nullable=False),
sa.Column("num_confirmed", sa.Integer(), nullable=False),
sa.Column("num_verified", sa.Integer(), nullable=False),
sa.Column("num_blocked", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["category"],
["accounts_domain_category.id"],
name=op.f("fk_accounts_domains_category_accounts_domain_category"),
),
sa.ForeignKeyConstraint(
["org_id"],
["accounts_domain_org.id"],
name=op.f("fk_accounts_domains_org_id_accounts_domain_org"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_accounts_domains")),
sa.UniqueConstraint("domain", name=op.f("uq_accounts_domains_domain")),
)
op.add_column(
"accounts_user", sa.Column("domain", sa.String(length=255), nullable=True)
)


def downgrade():
"""Downgrade database."""
op.drop_column("accounts_user", "domain")
op.drop_table("accounts_domains")
op.drop_table("accounts_domain_org")
op.drop_table("accounts_domain_category")
12 changes: 11 additions & 1 deletion invenio_accounts/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from flask_security import SQLAlchemyUserDatastore

from .models import Role
from .models import Domain, Role
from .proxies import current_db_change_history
from .sessions import delete_user_sessions
from .signals import datastore_post_commit, datastore_pre_commit
Expand Down Expand Up @@ -65,3 +65,13 @@ def create_role(self, **kwargs):
def find_role_by_id(self, role_id):
"""Fetches roles searching by id."""
return self.role_model.query.filter_by(id=role_id).one_or_none()

def find_domain(self, domain):
"""Find a domain."""
return Domain.query.filter_by(domain=domain).one_or_none()

def create_domain(self, domain, **kwargs):
"""Create a new domain."""
domain = Domain.create(domain, **kwargs)
self.put(domain)
return domain
32 changes: 32 additions & 0 deletions invenio_accounts/domains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2023 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Domain blocking listener."""

from .models import DomainStatus


def on_user_confirmed(app, user):
"""Listener for when a user is confirmed."""
security = app.extensions["security"]
datastore = security.datastore

# Domain is inserted on domain list when a user confirms their email.
domain = datastore.find_domain(user.domain)
if domain is None:
domain = datastore.create_domain(user.domain)

# Verify user if domain is verified.
if domain.status == DomainStatus.verified:
user.verified_at = security.datetime_factory()
# Happens if e.g. user register an account, domain is later blocked,
# and user requests to resend email confirmation or link is still valid.
elif domain.status == DomainStatus.blocked:
user.blocked_at = security.datetime_factory()
user.active = False
user.verified_at = None
5 changes: 4 additions & 1 deletion invenio_accounts/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flask_kvsession import KVSessionExtension
from flask_login import LoginManager, user_logged_in, user_logged_out
from flask_principal import AnonymousIdentity
from flask_security import Security
from flask_security import Security, user_confirmed
from invenio_db import db
from passlib.registry import register_crypt_handler
from werkzeug.utils import cached_property
Expand All @@ -30,6 +30,7 @@

from . import config
from .datastore import SessionAwareSQLAlchemyUserDatastore
from .domains import on_user_confirmed
from .hash import InvenioAesEncryptedEmail
from .models import Role, User
from .sessions import csrf_token_reset, login_listener, logout_listener
Expand Down Expand Up @@ -198,6 +199,8 @@ def delay_security_email(msg):
if app.config.get("ACCOUNTS_USERINFO_HEADERS"):
request_finished.connect(set_session_info, app)

user_confirmed.connect(on_user_confirmed, app)

# Set Session KV store
session_kvstore_factory = obj_or_import_string(
app.config["ACCOUNTS_SESSION_STORE_FACTORY"]
Expand Down
13 changes: 12 additions & 1 deletion invenio_accounts/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from invenio_i18n import gettext as _
from wtforms import FormField, HiddenField

from invenio_accounts.proxies import current_datastore
from .proxies import current_datastore
from .utils import validate_domain


class RegistrationFormRecaptcha(FlaskForm):
Expand All @@ -44,6 +45,16 @@ def __init__(self, *args, **kwargs):
if not self.next.data:
self.next.data = request.args.get("next", "")

def validate(self, extra_validators=None):
"""Validate domain on email list."""
if not super().validate(extra_validators=extra_validators):
return False

if not validate_domain(self.email.data):
self.email.errors.append(_("The email domain is blocked."))
return False
return True

if app.config.get("RECAPTCHA_PUBLIC_KEY") and app.config.get(
"RECAPTCHA_PRIVATE_KEY"
):
Expand Down
157 changes: 156 additions & 1 deletion invenio_accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .errors import AlreadyLinkedError
from .profiles import UserPreferenceDict, UserProfileDict
from .utils import validate_username
from .utils import DomainStatus, split_emailaddr, validate_username

json_field = (
db.JSON()
Expand Down Expand Up @@ -103,6 +103,9 @@ class User(db.Model, Timestamp, UserMixin):
_email = db.Column("email", db.String(255), unique=True)
"""User email."""

domain = db.Column(db.String(255), nullable=True)
"""Domain of email."""

password = db.Column(db.String(255))
"""User password."""

Expand Down Expand Up @@ -212,6 +215,8 @@ def email(self):
def email(self, email):
"""Set lowercase email."""
self._email = email.lower()
prefix, domain = split_emailaddr(email)
self.domain = domain

@hybrid_property
def user_profile(self):
Expand Down Expand Up @@ -463,3 +468,153 @@ def delete_by_user(cls, method, user):
"""Unlink a user from an external id."""
with db.session.begin_nested():
cls.query.filter_by(id_user=user.id, method=method).delete()


class DomainOrg(db.Model):
"""Domain organisation."""

__tablename__ = "accounts_domain_org"

id = db.Column(db.Integer(), primary_key=True, autoincrement=True)

pid = db.Column(db.String(255), unique=True, nullable=True)
"""Persistent identifier for organisation."""

name = db.Column(db.String(255), nullable=False)
"""Name of organisation."""

json = db.Column(
json_field,
default=lambda: dict(),
nullable=False,
)
"""Store additional metadata about the organisation."""

parent_id = db.Column(
db.Integer(), db.ForeignKey("accounts_domain_org.id"), nullable=True
)
"""Link to parent organisation."""

parent = db.relationship("DomainOrg", remote_side=[id])
"""Relationship to parent."""

domains = db.relationship("Domain", back_populates="org")
"""Relationship to domains for this organisation."""

@classmethod
def create(cls, pid, name, json=None, parent=None):
"""Create a domain organisation."""
obj = cls(pid=pid, name=name, json=json or {}, parent=parent)
db.session.add(obj)
return obj


class DomainCategory(db.Model):
"""Model for storing different domain categories."""

__tablename__ = "accounts_domain_category"

id = db.Column(db.Integer(), primary_key=True, autoincrement=True)

label = db.Column(db.String(255))

@classmethod
def create(cls, label):
"""Create a new domain category."""
obj = cls(label=label)
db.session.add(obj)
return obj

@classmethod
def get(cls, label):
"""Get a domain category."""
return cls.query.filter_by(label=label).one_or_none()


class Domain(db.Model, Timestamp):
"""User domains model."""

__tablename__ = "accounts_domains"

id = db.Column(db.Integer(), primary_key=True, autoincrement=True)
"""Domain ID"""

_domain = db.Column("domain", db.String(255), unique=True, nullable=False)
"""Domain name."""

tld = db.Column(db.String(255), nullable=False)
"""Top-level domain."""

status = db.Column(db.Enum(DomainStatus), default=DomainStatus.new, nullable=False)
"""Status of domain.
Use to control possibility and capability of users registering with this domain.
"""

flagged = db.Column(db.Boolean(), default=False, nullable=False)
"""Flag domain - used by automatic processes to flag domain."""

flagged_source = db.Column(db.String(255), default="", nullable=False)
"""Source of flag."""

org_id = db.Column(db.Integer(), db.ForeignKey(DomainOrg.id), nullable=True)
"""Organisation associated with domain."""

org = db.relationship("DomainOrg", back_populates="domains")

# spammer, mail-provider, organisation, company
category = db.Column(db.Integer(), db.ForeignKey(DomainCategory.id), nullable=True)
"""Category of domain."""

num_users = db.Column(db.Integer(), default=0, nullable=False)
"""Computed property to store number of users in domain."""

num_active = db.Column(db.Integer(), default=0, nullable=False)
"""Computed property to store number of active users in domain."""

num_inactive = db.Column(db.Integer(), default=0, nullable=False)
"""Computed property to store number of inactive users in domain."""

num_confirmed = db.Column(db.Integer(), default=0, nullable=False)
"""Computed property to store number of confirmed users in domain."""

num_verified = db.Column(db.Integer(), default=0, nullable=False)
"""Computed property to store number of verified users in domain."""

num_blocked = db.Column(db.Integer(), default=0, nullable=False)
"""Computed property to store number of blocked users in domain."""

@classmethod
def create(
cls,
domain,
status=DomainStatus.new,
flagged=False,
flagged_source="",
org=None,
category=None,
):
"""Create a new domain."""
obj = cls(
domain=domain,
status=status,
flagged=flagged,
flagged_source=flagged_source,
org=org,
category=category,
)
db.session.add(obj)
return obj

@hybrid_property
def domain(self):
"""Get domain name."""
return self._domain

@domain.setter
def domain(self, value):
"""Set domain name."""
if value[-1] == ".":
value = value[:-1]
self._domain = value.lower()
self.tld = self._domain.split(".")[-1]
Loading

0 comments on commit 304b9c0

Please sign in to comment.