Skip to content

Commit

Permalink
Use flow_result in LoginFlow, not username
Browse files Browse the repository at this point in the history
Split async_finish to async_start_mfa and async_finish
Address other review comment
  • Loading branch information
awarecan committed Jul 16, 2018
1 parent eccf75d commit 6c255b2
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 122 deletions.
17 changes: 14 additions & 3 deletions homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from collections import OrderedDict

import voluptuous as vol

from homeassistant import data_entry_flow
from homeassistant.core import callback

Expand Down Expand Up @@ -117,7 +119,7 @@ async def async_get_user(self, user_id):
return await self._store.async_get_user(user_id)

async def async_get_user_by_credentials(self, credentials):
"""Get a user by credential, raise ValueError if not found."""
"""Get a user by credential, return None if not found."""
for user in await self.async_get_users():
for creds in user.credentials:
if creds.id == credentials.id:
Expand Down Expand Up @@ -148,7 +150,9 @@ async def async_create_user(self, name):
async def async_get_or_create_user(self, credentials):
"""Get or create a user."""
if not credentials.is_new:
return await self.async_get_user_by_credentials(credentials)
user = await self.async_get_user_by_credentials(credentials)
if user is None:
raise ValueError('Unable to find the user.')

auth_provider = self._async_get_auth_provider(credentials)

Expand Down Expand Up @@ -199,7 +203,7 @@ async def async_remove_credentials(self, credentials):

await self._store.async_remove_credentials(credentials)

async def async_enable_user_mfa(self, user, mfa_module_id, data=None):
async def async_enable_user_mfa(self, user, mfa_module_id, data):
"""Enable a multi-factor auth module for user."""
if mfa_module_id not in self._mfa_modules:
raise ValueError('Unable find multi-factor auth module: {}'
Expand All @@ -209,6 +213,13 @@ async def async_enable_user_mfa(self, user, mfa_module_id, data=None):
'multi-factor auth module.')

module = self.get_auth_mfa_module(mfa_module_id)
if module.setup_schema is not None:
try:
# pylint: disable=not-callable
data = module.setup_schema(data)
except vol.Invalid as err:
raise ValueError('Data does not match schema: {}'.format(err))

result = await module.async_setup_user(user.id, data)
await self._store.async_enable_user_mfa(user, mfa_module_id)
return result
Expand Down
6 changes: 3 additions & 3 deletions homeassistant/auth/mfa_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def name(self):
"""Return the name of the auth module."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)

# Implement by extending class

@property
def input_schema(self):
"""Return a voluptuous schema to define mfa auth module's input."""
Expand All @@ -118,9 +120,7 @@ def setup_schema(self):
"""
return None

# Implement by extending class

async def async_setup_user(self, user_id, data=None):
async def async_setup_user(self, user_id, data):
"""Setup mfa auth module for user."""
raise NotImplementedError

Expand Down
20 changes: 4 additions & 16 deletions homeassistant/auth/mfa_modules/insecure_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Example auth module."""
import logging
from collections import OrderedDict

import voluptuous as vol

Expand All @@ -14,9 +13,6 @@
})]
}, extra=vol.PREVENT_EXTRA)

STORAGE_VERSION = 1
STORAGE_KEY = 'mfa_modules.insecure_example'

_LOGGER = logging.getLogger(__name__)


Expand All @@ -35,24 +31,16 @@ def __init__(self, hass, config):
@property
def input_schema(self):
"""Validate login flow input data."""
schema = OrderedDict()
schema['pin'] = str
return vol.Schema(schema)
return vol.Schema({'pin': str})

@property
def setup_schema(self):
"""Validate async_setup_user input data."""
schema = OrderedDict()
schema['pin'] = str
return vol.Schema(schema)
return vol.Schema({'pin': str})

async def async_setup_user(self, user_id, data=None):
async def async_setup_user(self, user_id, data):
"""Setup mfa module for user."""
try:
data = self.setup_schema(data) # pylint: disable=not-callable
except vol.Invalid as err:
raise ValueError('Data does not match schema: {}'.format(err))

# data shall has been validate in caller
pin = data['pin']

for user in self._users:
Expand Down
77 changes: 40 additions & 37 deletions homeassistant/auth/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ def __init__(self, auth_provider: AuthProvider):
self._auth_modules = []
self._auth_manager = auth_provider.hass.auth
self._user = None
self._username = None
self._credential_flow_result = None
self.created_at = dt_util.utcnow()

async def async_step_init(self, user_input=None):
"""Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input == None.
Return await self.async_finish(username) if login init step pass.
Return await self.async_start_mfa(flow_result) if login init step pass.
"""
raise NotImplementedError

Expand Down Expand Up @@ -194,7 +194,7 @@ async def async_step_mfa(self, user_input=None):
auth_module = self._auth_manager.get_auth_mfa_module(
self._auth_module_id)
if auth_module is None:
# Given an invalid input to async_ste[_select_mfa_module
# Given an invalid input to async_step_select_mfa_module
# will show invalid_auth_module error
return await self.async_step_select_mfa_module(
user_input={'multi_factor_auth_module': None})
Expand All @@ -210,49 +210,52 @@ async def async_step_mfa(self, user_input=None):
errors['base'] = 'invalid_auth'

if not errors:
return await self.async_finish(self._username, mfa_valid=True)
return await self.async_finish(self._credential_flow_result)

return self.async_show_form(
step_id='mfa',
data_schema=auth_module.input_schema,
errors=errors,
)

async def async_finish(self, username, mfa_valid=False):
"""Handle the pass of login flow."""
if not mfa_valid:
self._username = username

# We need get user from username, so we need get credential first
# async_get_or_create_credentials would not save data, it is safe
# to be called here. It will be called later in end of workflow.
credentials = await self._auth_provider.\
async_get_or_create_credentials({'username': username})

# multi-factor module cannot enabled for new credential
# which has not linked to a user yet
if not credentials.is_new:
self._user = await self._auth_manager.\
async_get_user_by_credentials(credentials)

# module in user.mfa_modules may not loaded
# the config may have changed after the user enabled module
# we need double check available mfa_modules for this user
if self._user and self._user.mfa_modules:
modules = [m_id for m_id in self._user.mfa_modules
if self._auth_manager.
get_auth_mfa_module(m_id)]

if modules:
self._auth_modules = modules
if len(self._auth_modules) == 1:
self._auth_module_id = self._auth_modules[0]
return await self.async_step_mfa()
# need select mfa module first
return await self.async_step_select_mfa_module()
async def async_start_mfa(self, flow_result):
"""Start mfa validation flow if need."""
self._credential_flow_result = flow_result

# We need get user from flow_result, so we need get credential first
# async_get_or_create_credentials would not save data, it is safe
# to be called here. It will be called later in end of workflow.
credentials = await self._auth_provider.\
async_get_or_create_credentials(flow_result)

# multi-factor module cannot enabled for new credential
# which has not linked to a user yet
if not credentials.is_new:
self._user = await self._auth_manager.\
async_get_user_by_credentials(credentials)

# module in user.mfa_modules may not loaded
# the config may have changed after the user enabled module
# we need double check available mfa_modules for this user
if self._user and self._user.mfa_modules:
modules = [m_id for m_id in self._user.mfa_modules
if self._auth_manager.
get_auth_mfa_module(m_id)]

if modules:
self._auth_modules = modules
if len(self._auth_modules) == 1:
self._auth_module_id = self._auth_modules[0]
return await self.async_step_mfa()
# need select mfa module first
return await self.async_step_select_mfa_module()

# new credential or no mfa_module enabled or passed mfa validate
return await self.async_finish(flow_result)

async def async_finish(self, flow_result):
"""Handle the pass of login flow."""
return self.async_create_entry(
title=self._auth_provider.name,
data={'username': username}
data=flow_result
)
7 changes: 3 additions & 4 deletions homeassistant/auth/providers/homeassistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,17 @@ class HassLoginFlow(LoginFlow):
async def async_step_init(self, user_input=None):
"""Handle the step of username/password validation."""
errors = {}
result = None

if user_input is not None:
try:
await self._auth_provider.async_validate_login(
user_input['username'], user_input['password'])
result = user_input['username']
except InvalidAuth:
errors['base'] = 'invalid_auth'

if not errors and result:
return await self.async_finish(result)
if not errors:
user_input.pop('password')
return await self.async_start_mfa(user_input)

schema = OrderedDict()
schema['username'] = str
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/auth/providers/insecure_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ async def async_step_init(self, user_input=None):
errors['base'] = 'invalid_auth'

if not errors:
return await self.async_finish(user_input['username'])
user_input.pop('password')
return await self.async_start_mfa(user_input)

schema = OrderedDict()
schema['username'] = str
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/auth/providers/legacy_api_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def async_step_init(self, user_input=None):
errors['base'] = 'invalid_auth'

if not errors:
return await self.async_finish(LEGACY_USER)
return await self.async_start_mfa({'username': LEGACY_USER})

schema = OrderedDict()
schema['password'] = str
Expand Down
21 changes: 3 additions & 18 deletions homeassistant/scripts/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def run(args):
subparsers.required = True
parser_list = subparsers.add_parser('list')
parser_list.set_defaults(func=list_users)
parser_list.add_argument('-a', '--all', default=False,
help="Show all users included system user")

parser_add = subparsers.add_parser('add')
parser_add.add_argument('username', type=str)
Expand Down Expand Up @@ -67,21 +65,9 @@ async def run_command(hass, args):
async def list_users(hass, provider, args):
"""List the users."""
count = 0
if args.all:
for user in await hass.auth.async_get_users():
print("{}{}{}".format(
str(user.name).ljust(20),
str(user.id).ljust(34),
str(user.mfa_modules)
))
count += 1
for cred in user.credentials:
print(" - {}".format(cred.data.get('username')))

else:
for user in provider.data.users:
count += 1
print(user['username'])
for user in provider.data.users:
count += 1
print(user['username'])

print()
print("Total users:", count)
Expand Down Expand Up @@ -114,7 +100,6 @@ async def validate_login(hass, provider, args):
print("Auth valid")
except hass_auth.InvalidAuth:
print("Auth invalid")
return


async def change_password(hass, provider, args):
Expand Down
42 changes: 3 additions & 39 deletions tests/scripts/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

import pytest

from homeassistant.auth.models import Credentials
from homeassistant.scripts import auth as script_auth
from homeassistant.auth.providers import homeassistant as hass_auth

from tests.common import register_auth_provider, MockUser
from homeassistant.scripts import auth as script_auth
from tests.common import register_auth_provider


@pytest.fixture
Expand All @@ -26,7 +24,7 @@ async def test_list_user(hass, provider, capsys):
data.add_auth('test-user', 'test-pass')
data.add_auth('second-user', 'second-pass')

await script_auth.list_users(hass, provider, Mock(all=False))
await script_auth.list_users(hass, provider, None)

captured = capsys.readouterr()

Expand All @@ -39,40 +37,6 @@ async def test_list_user(hass, provider, capsys):
])


async def test_list_all_user(hass, provider, capsys):
"""Test we can list all users."""
# Add fake user with credentials for example auth provider.
user = MockUser(
id='mock-user',
is_owner=False,
is_active=True,
name='Paulus',
mfa_modules=['homeassistant']
).add_to_auth_manager(hass.auth)
user.credentials.append(Credentials(
id='mock-id',
auth_provider_type='homeassistant',
auth_provider_id=None,
data={'username': 'test-user'},
is_new=False,
))

await script_auth.list_users(hass, provider, Mock(all=True))

captured = capsys.readouterr()

assert captured.out == '\n'.join([
'{}{}{}'.format(
'Paulus ',
'mock-user ',
"['homeassistant']"),
' - test-user',
'',
'Total users: 1',
''
])


async def test_add_user(hass, provider, capsys, hass_storage):
"""Test we can add a user."""
data = provider.data
Expand Down

0 comments on commit 6c255b2

Please sign in to comment.