Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context to login flow #15914

Merged
merged 4 commits into from
Aug 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ async def _async_create_login_flow(self, handler, *, context, data):
"""Create a login flow."""
auth_provider = self._providers[handler]

return await auth_provider.async_credential_flow()
return await auth_provider.async_credential_flow(context)

async def _async_finish_login_flow(self, result):
async def _async_finish_login_flow(self, context, result):
"""Result of a credential login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/auth/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def async_create_credentials(self, data):

# Implement by extending class

async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/auth/providers/homeassistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def async_initialize(self):
self.data = Data(self.hass)
await self.data.async_load()

async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return a flow to login."""
return LoginFlow(self)

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/auth/providers/insecure_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""

async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return a flow to login."""
return LoginFlow(self)

Expand Down
2 changes: 1 addition & 1 deletion homeassistant/auth/providers/legacy_api_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):

DEFAULT_TITLE = 'Legacy API Password'

async def async_credential_flow(self):
async def async_credential_flow(self, context):
"""Return a flow to login."""
return LoginFlow(self)

Expand Down
3 changes: 1 addition & 2 deletions homeassistant/components/auth/login_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
"handler": ["insecure_example", null],
"result": "411ee2f916e648d691e937ae9344681e",
"source": "user",
"title": "Example",
"type": "create_entry",
"version": 1
Expand Down Expand Up @@ -152,7 +151,7 @@ async def post(self, request, data):
handler = data['handler']

try:
result = await self._flow_mgr.async_init(handler)
result = await self._flow_mgr.async_init(handler, context={})
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except data_entry_flow.UnknownStep:
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/config/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get(self, request):

return self.json([
flw for flw in hass.config_entries.flow.async_progress()
if flw['source'] != config_entries.SOURCE_USER])
if flw['context']['source'] != config_entries.SOURCE_USER])


class ConfigManagerFlowResourceView(FlowManagerResourceView):
Expand Down
17 changes: 5 additions & 12 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,26 +372,23 @@ async def async_forward_entry_unload(self, entry, component):
return await entry.async_unload(
self.hass, component=getattr(self.hass.components, component))

async def _async_finish_flow(self, result):
async def _async_finish_flow(self, context, result):
"""Finish a config flow and add an entry."""
# If no discovery config entries in progress, remove notification.
if not any(ent['source'] in DISCOVERY_SOURCES for ent
if not any(ent['context']['source'] in DISCOVERY_SOURCES for ent
in self.hass.config_entries.flow.async_progress()):
self.hass.components.persistent_notification.async_dismiss(
DISCOVERY_NOTIFICATION_ID)

if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None

source = result['source']
if source is None:
source = SOURCE_USER
entry = ConfigEntry(
version=result['version'],
domain=result['handler'],
title=result['title'],
data=result['data'],
source=source,
source=context['source'],
)
self._entries.append(entry)
await self._async_schedule_save()
Expand All @@ -406,7 +403,7 @@ async def _async_finish_flow(self, result):
self.hass, entry.domain, self._hass_config)

# Return Entry if they not from a discovery request
if result['source'] not in DISCOVERY_SOURCES:
if context['source'] not in DISCOVERY_SOURCES:
return entry

return entry
Expand All @@ -422,10 +419,7 @@ async def _async_create_flow(self, handler_key, *, context, data):
if handler is None:
raise data_entry_flow.UnknownHandler

if context is not None:
source = context.get('source', SOURCE_USER)
else:
source = SOURCE_USER
source = context['source']

# Make sure requirements and dependencies of component are resolved
await async_process_deps_reqs(
Expand All @@ -442,7 +436,6 @@ async def _async_create_flow(self, handler_key, *, context, data):
)

flow = handler()
flow.source = source
flow.init_step = source
return flow

Expand Down
8 changes: 4 additions & 4 deletions homeassistant/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def async_progress(self) -> List[Dict]:
return [{
'flow_id': flow.flow_id,
'handler': flow.handler,
'source': flow.source,
'context': flow.context,
} for flow in self._progress.values()]

async def async_init(self, handler: Callable, *, context: Dict = None,
Expand All @@ -57,6 +57,7 @@ async def async_init(self, handler: Callable, *, context: Dict = None,
flow.hass = self.hass
flow.handler = handler
flow.flow_id = uuid.uuid4().hex
flow.context = context
self._progress[flow.flow_id] = flow

return await self._async_handle_step(flow, flow.init_step, data)
Expand Down Expand Up @@ -108,7 +109,7 @@ async def _async_handle_step(self, flow: Any, step_id: str,
self._progress.pop(flow.flow_id)

# We pass a copy of the result because we're mutating our version
entry = await self._async_finish_flow(dict(result))
entry = await self._async_finish_flow(flow.context, dict(result))

if result['type'] == RESULT_TYPE_CREATE_ENTRY:
result['result'] = entry
Expand All @@ -122,8 +123,8 @@ class FlowHandler:
flow_id = None
hass = None
handler = None
source = None
cur_step = None
context = None

# Set by _async_create_flow callback
init_step = 'init'
Expand Down Expand Up @@ -156,7 +157,6 @@ def async_create_entry(self, *, title: str, data: Dict) -> Dict:
'handler': self.handler,
'title': title,
'data': data,
'source': self.source,
}

@callback
Expand Down
5 changes: 3 additions & 2 deletions tests/components/cast/test_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for the Cast config flow."""
from unittest.mock import patch

from homeassistant import data_entry_flow
from homeassistant import config_entries, data_entry_flow
from homeassistant.setup import async_setup_component
from homeassistant.components import cast

Expand All @@ -15,7 +15,8 @@ async def test_creating_entry_sets_up_media_player(hass):
MockDependency('pychromecast', 'discovery'), \
patch('pychromecast.discovery.discover_chromecasts',
return_value=True):
result = await hass.config_entries.flow.async_init(cast.DOMAIN)
result = await hass.config_entries.flow.async_init(
cast.DOMAIN, context={'source': config_entries.SOURCE_USER})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY

await hass.async_block_till_done()
Expand Down
4 changes: 1 addition & 3 deletions tests/components/config/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def async_step_user(self, user_input=None):
'handler': 'test',
'title': 'Test Entry',
'type': 'create_entry',
'source': 'user',
'version': 1,
}

Expand Down Expand Up @@ -264,7 +263,6 @@ def async_step_account(self, user_input=None):
'type': 'create_entry',
'title': 'user-title',
'version': 1,
'source': 'user',
}


Expand Down Expand Up @@ -295,7 +293,7 @@ def async_step_account(self, user_input=None):
{
'flow_id': form['flow_id'],
'handler': 'test',
'source': 'hassio'
'context': {'source': 'hassio'}
}
]

Expand Down
5 changes: 3 additions & 2 deletions tests/components/sonos/test_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for the Sonos config flow."""
from unittest.mock import patch

from homeassistant import data_entry_flow
from homeassistant import config_entries, data_entry_flow
from homeassistant.setup import async_setup_component
from homeassistant.components import sonos

Expand All @@ -13,7 +13,8 @@ async def test_creating_entry_sets_up_media_player(hass):
with patch('homeassistant.components.media_player.sonos.async_setup_entry',
return_value=mock_coro(True)) as mock_setup, \
patch('soco.discover', return_value=True):
result = await hass.config_entries.flow.async_init(sonos.DOMAIN)
result = await hass.config_entries.flow.async_init(
sonos.DOMAIN, context={'source': config_entries.SOURCE_USER})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY

await hass.async_block_till_done()
Expand Down
3 changes: 2 additions & 1 deletion tests/helpers/test_config_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ async def test_user_init_trumps_discovery(hass, flow_conf):
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM

# User starts flow
result = await hass.config_entries.flow.async_init('test', data={})
result = await hass.config_entries.flow.async_init(
'test', context={'source': config_entries.SOURCE_USER}, data={})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY

# Discovery flow has been aborted
Expand Down
9 changes: 6 additions & 3 deletions tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def async_step_user(self, user_input=None):
})

with patch.dict(config_entries.HANDLERS, {'comp': TestFlow, 'beer': 5}):
yield from manager.flow.async_init('comp')
yield from manager.flow.async_init(
'comp', context={'source': config_entries.SOURCE_USER})
yield from hass.async_block_till_done()

assert len(mock_setup_entry.mock_calls) == 1
Expand Down Expand Up @@ -171,7 +172,8 @@ def async_step_user(self, user_input=None):
)

with patch.dict(config_entries.HANDLERS, {'test': TestFlow}):
await hass.config_entries.flow.async_init('test')
await hass.config_entries.flow.async_init(
'test', context={'source': config_entries.SOURCE_USER})

class Test2Flow(data_entry_flow.FlowHandler):
VERSION = 3
Expand All @@ -187,7 +189,8 @@ def async_step_user(self, user_input=None):

with patch('homeassistant.config_entries.HANDLERS.get',
return_value=Test2Flow):
await hass.config_entries.flow.async_init('test')
await hass.config_entries.flow.async_init(
'test', context={'source': config_entries.SOURCE_USER})

# To trigger the call_later
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
Expand Down
6 changes: 4 additions & 2 deletions tests/test_data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ async def async_create_flow(handler_name, *, context, data):
if context is not None else 'user_input'
return flow

async def async_add_entry(result):
async def async_add_entry(context, result):
if (result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY):
result['source'] = context.get('source') \
if context is not None else 'user'
entries.append(result)

manager = data_entry_flow.FlowManager(
Expand Down Expand Up @@ -168,7 +170,7 @@ async def async_step_init(self, user_input=None):
assert entry['handler'] == 'test'
assert entry['title'] == 'Test Title'
assert entry['data'] == 'Test Data'
assert entry['source'] == 'user_input'
assert entry['source'] == 'user'


async def test_discovery_init_flow(manager):
Expand Down