Skip to content

Commit

Permalink
Add case insensitive constraint to username (#27266)
Browse files Browse the repository at this point in the history
This helps us to recognize usernames properly

(cherry picked from commit 1d25105)
  • Loading branch information
ephraimbuddy committed Nov 9, 2022
1 parent 9f6c9e4 commit 51194c7
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Add case-insensitive unique constraint for username
Revision ID: e07f49787c9d
Revises: b0d31815b5a6
Create Date: 2022-10-25 17:29:46.432326
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = 'e07f49787c9d'
down_revision = 'b0d31815b5a6'
branch_labels = None
depends_on = None
airflow_version = '2.4.3'


def upgrade():
"""Apply Add case-insensitive unique constraint"""
conn = op.get_bind()
if conn.dialect.name == 'postgresql':
op.create_index('idx_ab_user_username', 'ab_user', [sa.text('LOWER(username)')], unique=True)
op.create_index(
"idx_ab_register_user_username", 'ab_register_user', [sa.text('LOWER(username)')], unique=True
)
elif conn.dialect.name == 'sqlite':
with op.batch_alter_table('ab_user') as batch_op:
batch_op.alter_column(
'username',
existing_type=sa.String(64),
_type=sa.String(64, collation='NOCASE'),
unique=True,
nullable=False,
)
with op.batch_alter_table('ab_register_user') as batch_op:
batch_op.alter_column(
'username',
existing_type=sa.String(64),
_type=sa.String(64, collation='NOCASE'),
unique=True,
nullable=False,
)


def downgrade():
"""Unapply Add case-insensitive unique constraint"""
conn = op.get_bind()
if conn.dialect.name == 'postgresql':
op.drop_index('idx_ab_user_username', table_name='ab_user')
op.drop_index('idx_ab_register_user_username', table_name='ab_register_user')
elif conn.dialect.name == 'sqlite':
with op.batch_alter_table('ab_user') as batch_op:
batch_op.alter_column(
'username',
existing_type=sa.String(64, collation='NOCASE'),
_type=sa.String(64),
unique=True,
nullable=False,
)
with op.batch_alter_table('ab_register_user') as batch_op:
batch_op.alter_column(
'username',
existing_type=sa.String(64, collation='NOCASE'),
_type=sa.String(64),
unique=True,
nullable=False,
)
33 changes: 32 additions & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,36 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]:
)


def check_username_duplicates(session: Session) -> Iterable[str]:
"""
Check unique username in User & RegisterUser table
:param session: session of the sqlalchemy
:rtype: str
"""
from airflow.www.fab_security.sqla.models import RegisterUser, User

for model in [User, RegisterUser]:
dups = []
try:
dups = (
session.query(model.username) # type: ignore[attr-defined]
.group_by(model.username) # type: ignore[attr-defined]
.having(func.count() > 1)
.all()
)
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
if dups:
yield (
f'Seems you have mixed case usernames in {model.__table__.name} table.\n' # type: ignore
'You have to rename or delete those mixed case usernames '
'before upgrading the database.\n'
f'usernames with mixed cases: {[dup.username for dup in dups]}'
)


def reflect_tables(tables: list[Base | str] | None, session):
"""
When running checks prior to upgrades, we use reflection to determine current state of the
Expand Down Expand Up @@ -1393,6 +1423,7 @@ def _check_migration_errors(session: Session = NEW_SESSION) -> Iterable[str]:
check_conn_type_null,
check_run_id_null,
check_bad_references,
check_username_duplicates,
)
for check_fn in check_functions:
log.debug("running check function %s", check_fn.__name__)
Expand Down Expand Up @@ -1679,7 +1710,7 @@ def drop_flask_models(connection):
:param connection: SQLAlchemy Connection
:return: None
"""
from flask_appbuilder.models.sqla import Base
from airflow.www.fab_security.sqla.models import Base

Base.metadata.drop_all(connection)

Expand Down
2 changes: 1 addition & 1 deletion airflow/www/fab_security/sqla/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def find_user(self, username=None, email=None):
else:
return (
self.get_session.query(self.user_model)
.filter(self.user_model.username == username)
.filter(func.lower(self.user_model.username) == func.lower(username))
.one_or_none()
)
except MultipleResultsFound:
Expand Down
36 changes: 33 additions & 3 deletions airflow/www/fab_security/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,19 @@

from flask import current_app, g
from flask_appbuilder.models.sqla import Model
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Table, UniqueConstraint
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Table,
UniqueConstraint,
event,
func,
)
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref, relationship

Expand Down Expand Up @@ -135,7 +147,9 @@ class User(Model):
id = Column(Integer, primary_key=True)
first_name = Column(String(64), nullable=False)
last_name = Column(String(64), nullable=False)
username = Column(String(256), unique=True, nullable=False)
username = Column(
String(256).with_variant(String(256, collation='NOCASE'), "sqlite"), unique=True, nullable=False
)
password = Column(String(256))
active = Column(Boolean)
email = Column(String(256), unique=True, nullable=False)
Expand Down Expand Up @@ -228,8 +242,24 @@ class RegisterUser(Model):
id = Column(Integer, primary_key=True)
first_name = Column(String(64), nullable=False)
last_name = Column(String(64), nullable=False)
username = Column(String(256), unique=True, nullable=False)
username = Column(
String(256).with_variant(String(256, collation='NOCASE'), "sqlite"), unique=True, nullable=False
)
password = Column(String(256))
email = Column(String(256), nullable=False)
registration_date = Column(DateTime, default=datetime.datetime.now, nullable=True)
registration_hash = Column(String(256))


@event.listens_for(User.__table__, "before_create")
def add_index_on_ab_user_username_postgres(table, conn, **kw):
if conn.dialect.name != "postgresql":
return
table.indexes.add(Index("idx_ab_user_username", func.lower(table.c.username), unique=True))


@event.listens_for(RegisterUser.__table__, "before_create")
def add_index_on_ab_register_user_username_postgres(table, conn, **kw):
if conn.dialect.name != "postgresql":
return
table.indexes.add(Index("idx_ab_register_user_username", func.lower(table.c.username), unique=True))
2 changes: 1 addition & 1 deletion docs/apache-airflow/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a522ff773bc6403318e4a15be19f25d48323208de3acc8a970f7f40ad4782563
88fe8ef077e673c69080ac260baa0bab33d9189035c0dea5cc652c990ef44ee8
4 changes: 3 additions & 1 deletion docs/apache-airflow/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=================================+===================+===================+==============================================================+
| ``b0d31815b5a6`` (head) | ``ecb43d2a1842`` | ``2.4.2`` | Add missing auto-increment to columns on FAB tables |
| ``e07f49787c9d`` (head) | ``b0d31815b5a6`` | ``2.4.3`` | Add case-insensitive unique constraint for username |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``b0d31815b5a6`` | ``ecb43d2a1842`` | ``2.4.2`` | Add missing auto-increment to columns on FAB tables |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``ecb43d2a1842`` | ``1486deb605b4`` | ``2.4.0`` | Add processor_subdir column to DagModel, SerializedDagModel |
| | | | and CallbackRequest tables. |
Expand Down
21 changes: 20 additions & 1 deletion tests/www/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
from airflow.www.fab_security.manager import AnonymousUser
from airflow.www.fab_security.sqla.models import User, assoc_permission_role
from airflow.www.utils import CustomSQLAInterface
from tests.test_utils.api_connexion_utils import create_user_scope, delete_role, set_user_single_role
from tests.test_utils.api_connexion_utils import (
create_user,
create_user_scope,
delete_role,
delete_user,
set_user_single_role,
)
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.db import clear_db_dags, clear_db_runs
from tests.test_utils.mock_security_manager import MockSecurityManager
Expand Down Expand Up @@ -935,3 +941,16 @@ def test_update_user_auth_stat_subsequent_unsuccessful_auth(mock_security_manage
assert old_user.fail_login_count == 10
assert old_user.last_login == datetime.datetime(1984, 12, 1, 0, 0, 0)
assert mock_security_manager.update_user.called_once


def test_users_can_be_found(app, security_manager, session, caplog):
"""Test that usernames are case insensitive"""
create_user(app, "Test")
create_user(app, "test")
create_user(app, "TEST")
create_user(app, "TeSt")
assert security_manager.find_user("Test")
users = security_manager.get_all_users()
assert len(users) == 1
delete_user(app, "Test")
assert "Error adding new user to database" in caplog.text

0 comments on commit 51194c7

Please sign in to comment.