Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow system users to refresh tokens #15574

Merged
merged 1 commit into from
Jul 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions homeassistant/components/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,28 +252,27 @@ async def post(self, request):
hass = request.app['hass']
data = await request.post()

client_id = data.get('client_id')
if client_id is None or not indieauth.verify_client_id(client_id):
return self.json({
'error': 'invalid_request',
'error_description': 'Invalid client id',
}, status_code=400)

grant_type = data.get('grant_type')

if grant_type == 'authorization_code':
return await self._async_handle_auth_code(hass, client_id, data)
return await self._async_handle_auth_code(hass, data)

if grant_type == 'refresh_token':
return await self._async_handle_refresh_token(
hass, client_id, data)
return await self._async_handle_refresh_token(hass, data)

return self.json({
'error': 'unsupported_grant_type',
}, status_code=400)

async def _async_handle_auth_code(self, hass, client_id, data):
async def _async_handle_auth_code(self, hass, data):
"""Handle authorization code request."""
client_id = data.get('client_id')
if client_id is None or not indieauth.verify_client_id(client_id):
return self.json({
'error': 'invalid_request',
'error_description': 'Invalid client id',
}, status_code=400)

code = data.get('code')

if code is None:
Expand Down Expand Up @@ -309,8 +308,15 @@ async def _async_handle_auth_code(self, hass, client_id, data):
int(refresh_token.access_token_expiration.total_seconds()),
})

async def _async_handle_refresh_token(self, hass, client_id, data):
async def _async_handle_refresh_token(self, hass, data):
"""Handle authorization code request."""
client_id = data.get('client_id')
if client_id is not None and not indieauth.verify_client_id(client_id):
return self.json({
'error': 'invalid_request',
'error_description': 'Invalid client id',
}, status_code=400)

token = data.get('refresh_token')

if token is None:
Expand All @@ -320,11 +326,16 @@ async def _async_handle_refresh_token(self, hass, client_id, data):

refresh_token = await hass.auth.async_get_refresh_token(token)

if refresh_token is None or refresh_token.client_id != client_id:
if refresh_token is None:
return self.json({
'error': 'invalid_grant',
}, status_code=400)

if refresh_token.client_id != client_id:
return self.json({
'error': 'invalid_request',
}, status_code=400)

access_token = hass.auth.async_create_access_token(refresh_token)

return self.json({
Expand Down
65 changes: 65 additions & 0 deletions tests/components/auth/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,68 @@ async def test_cors_on_token(hass, aiohttp_client):
'origin': 'http://example.com'
})
assert resp.headers['Access-Control-Allow-Origin'] == 'http://example.com'


async def test_refresh_token_system_generated(hass, aiohttp_client):
"""Test that we can get access tokens for system generated user."""
client = await async_setup_auth(hass, aiohttp_client)
user = await hass.auth.async_create_system_user('Test System')
refresh_token = await hass.auth.async_create_refresh_token(user, None)

resp = await client.post('/auth/token', data={
'client_id': 'https://this-is-not-allowed-for-system-users.com/',
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
})

assert resp.status == 400
result = await resp.json()
assert result['error'] == 'invalid_request'

resp = await client.post('/auth/token', data={
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
})

assert resp.status == 200
tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None


async def test_refresh_token_different_client_id(hass, aiohttp_client):
"""Test that we verify client ID."""
client = await async_setup_auth(hass, aiohttp_client)
user = await hass.auth.async_create_user('Test User')
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)

# No client ID
resp = await client.post('/auth/token', data={
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
})

assert resp.status == 400
result = await resp.json()
assert result['error'] == 'invalid_request'

# Different client ID
resp = await client.post('/auth/token', data={
'client_id': 'http://example-different.com',
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
})

assert resp.status == 400
result = await resp.json()
assert result['error'] == 'invalid_request'

# Correct
resp = await client.post('/auth/token', data={
'client_id': CLIENT_ID,
'grant_type': 'refresh_token',
'refresh_token': refresh_token.token,
})

assert resp.status == 200
tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None