-
-
Notifications
You must be signed in to change notification settings - Fork 32.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
254 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
"""Trusted Networks auth provider. | ||
It shows list of users if access from trusted network. | ||
Abort login flow if not access from trusted network. | ||
""" | ||
import voluptuous as vol | ||
|
||
from homeassistant import data_entry_flow | ||
from homeassistant.core import callback | ||
from homeassistant.exceptions import HomeAssistantError | ||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS | ||
|
||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ | ||
}, extra=vol.PREVENT_EXTRA) | ||
|
||
|
||
class InvalidAuthError(HomeAssistantError): | ||
"""Raised when try to access from untrusted networks.""" | ||
|
||
|
||
class InvalidUserError(HomeAssistantError): | ||
"""Raised when try to login as invalid user.""" | ||
|
||
|
||
@AUTH_PROVIDERS.register('trusted_networks') | ||
class TrustedNetworksAuthProvider(AuthProvider): | ||
"""Trusted Networks auth provider. | ||
Allow passwordless access from trusted network. | ||
""" | ||
|
||
DEFAULT_TITLE = 'Trusted Networks' | ||
|
||
async def async_credential_flow(self): | ||
"""Return a flow to login.""" | ||
users = await self.store.async_get_users() | ||
available_users = {user.id: user.name | ||
for user in users | ||
if not user.system_generated and user.is_active} | ||
|
||
return LoginFlow(self, available_users) | ||
|
||
async def async_get_or_create_credentials(self, flow_result): | ||
"""Get credentials based on the flow result.""" | ||
user_id = flow_result['user'] | ||
|
||
users = await self.store.async_get_users() | ||
for user in users: | ||
if (not user.system_generated and | ||
user.is_active and | ||
user.id == user_id): | ||
for credential in await self.async_credentials(): | ||
if credential.data['user_id'] == user_id: | ||
return credential | ||
cred = self.async_create_credentials({'user_id': user_id}) | ||
await self.store.async_link_user(user, cred) | ||
return cred | ||
|
||
# We only allow login as exist user | ||
raise InvalidUserError | ||
|
||
async def async_user_meta_for_credentials(self, credentials): | ||
"""Return extra user metadata for credentials. | ||
Trusted network auth provider should never create new user. | ||
""" | ||
raise NotImplementedError | ||
|
||
@callback | ||
def async_validate_access(self, ip_address): | ||
"""Make sure the access from trusted networks. | ||
Raise InvalidAuthError if not. | ||
Raise InvalidAuthError if trusted_networks is not config | ||
""" | ||
if (not hasattr(self.hass, 'http') or | ||
not self.hass.http or not self.hass.http.trusted_networks): | ||
raise InvalidAuthError('trusted_networks is not configured') | ||
|
||
if not any(ip_address in trusted_network for trusted_network | ||
in self.hass.http.trusted_networks): | ||
raise InvalidAuthError('Not in trusted_networks') | ||
|
||
|
||
class LoginFlow(data_entry_flow.FlowHandler): | ||
"""Handler for the login flow.""" | ||
|
||
def __init__(self, auth_provider, available_users): | ||
"""Initialize the login flow.""" | ||
self._auth_provider = auth_provider | ||
self._available_users = available_users | ||
|
||
async def async_step_init(self, user_input=None): | ||
"""Handle the step of the form.""" | ||
errors = {} | ||
try: | ||
self._auth_provider.async_validate_access(self.source) | ||
|
||
except InvalidAuthError: | ||
errors['base'] = 'invalid_auth' | ||
return self.async_show_form( | ||
step_id='init', | ||
data_schema=None, | ||
errors=errors, | ||
) | ||
|
||
if user_input is not None: | ||
user_id = user_input['user'] | ||
if user_id not in self._available_users: | ||
errors['base'] = 'invalid_auth' | ||
|
||
if not errors: | ||
return self.async_create_entry( | ||
title=self._auth_provider.name, | ||
data=user_input | ||
) | ||
|
||
schema = {'user': vol.In(self._available_users)} | ||
|
||
return self.async_show_form( | ||
step_id='init', | ||
data_schema=vol.Schema(schema), | ||
errors=errors, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
"""Test the Trusted Networks auth provider.""" | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
import voluptuous as vol | ||
|
||
from homeassistant import auth | ||
from homeassistant.auth import auth_store | ||
from homeassistant.auth.providers import trusted_networks as tn_auth | ||
|
||
|
||
@pytest.fixture | ||
def store(hass): | ||
"""Mock store.""" | ||
return auth_store.AuthStore(hass) | ||
|
||
|
||
@pytest.fixture | ||
def provider(hass, store): | ||
"""Mock provider.""" | ||
return tn_auth.TrustedNetworksAuthProvider(hass, store, { | ||
'type': 'trusted_networks' | ||
}) | ||
|
||
|
||
@pytest.fixture | ||
def manager(hass, store, provider): | ||
"""Mock manager.""" | ||
return auth.AuthManager(hass, store, { | ||
(provider.type, provider.id): provider | ||
}) | ||
|
||
|
||
async def test_trusted_networks_credentials(manager, provider): | ||
"""Test trusted_networks credentials related functions.""" | ||
owner = await manager.async_create_user("test-owner") | ||
tn_owner_cred = await provider.async_get_or_create_credentials({ | ||
'user_id': owner.id | ||
}) | ||
assert tn_owner_cred.is_new is False | ||
assert any(cred.id == tn_owner_cred.id for cred in owner.credentials) | ||
|
||
user = await manager.async_create_user("test-user") | ||
tn_user_cred = await provider.async_get_or_create_credentials({ | ||
'user_id': user.id | ||
}) | ||
assert tn_user_cred.id != tn_owner_cred.id | ||
assert tn_user_cred.is_new is False | ||
assert any(cred.id == tn_user_cred.id for cred in user.credentials) | ||
|
||
with pytest.raises(tn_auth.InvalidUserError): | ||
await provider.async_get_or_create_credentials({ | ||
'user_id': 'invalid-user' | ||
}) | ||
|
||
|
||
async def test_validate_access(provider): | ||
"""Test validate access from trusted networks.""" | ||
with pytest.raises(tn_auth.InvalidAuthError): | ||
provider.async_validate_access('192.168.0.1') | ||
|
||
provider.hass.http = Mock(trusted_networks=['192.168.0.1']) | ||
provider.async_validate_access('192.168.0.1') | ||
|
||
with pytest.raises(tn_auth.InvalidAuthError): | ||
provider.async_validate_access('127.0.0.1') | ||
|
||
|
||
async def test_login_flow(manager, provider): | ||
"""Test login flow.""" | ||
owner = await manager.async_create_user("test-owner") | ||
user = await manager.async_create_user("test-user") | ||
|
||
flow = await provider.async_credential_flow() | ||
|
||
# trusted network didn't loaded | ||
step = await flow.async_step_init() | ||
assert step['step_id'] == 'init' | ||
assert step['errors']['base'] == 'invalid_auth' | ||
|
||
provider.hass.http = Mock(trusted_networks=['192.168.0.1']) | ||
|
||
# not from trusted network | ||
flow.source = '127.0.0.1' | ||
step = await flow.async_step_init() | ||
assert step['step_id'] == 'init' | ||
assert step['errors']['base'] == 'invalid_auth' | ||
|
||
# from trusted network, list users | ||
flow.source = '192.168.0.1' | ||
step = await flow.async_step_init() | ||
assert step['step_id'] == 'init' | ||
|
||
schema = step['data_schema'] | ||
assert schema({'user': owner.id}) | ||
with pytest.raises(vol.Invalid): | ||
assert schema({'user': 'invalid-user'}) | ||
|
||
# login with invalid user | ||
step = await flow.async_step_init({'user': 'invalid-user'}) | ||
assert step['step_id'] == 'init' | ||
assert step['errors']['base'] == 'invalid_auth' | ||
|
||
# login with valid user | ||
step = await flow.async_step_init({'user': user.id}) | ||
assert step['type'] == 'create_entry' | ||
assert step['data']['user'] == user.id |