Skip to content

Commit

Permalink
Address code review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
awarecan authored and balloob committed Sep 12, 2018
1 parent f180604 commit f097fab
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
4 changes: 2 additions & 2 deletions homeassistant/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ async def async_remove_refresh_token(self,
@callback
def async_create_access_token(self,
refresh_token: models.RefreshToken,
used_by: Optional[str] = None) -> str:
remote_ip: Optional[str] = None) -> str:
"""Create a new access token."""
self._store.async_log_refresh_token_usage(refresh_token, used_by)
self._store.async_log_refresh_token_usage(refresh_token, remote_ip)

# pylint: disable=no-self-use
now = dt_util.utcnow()
Expand Down
17 changes: 10 additions & 7 deletions homeassistant/auth/auth_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ async def async_get_refresh_token_by_token(
@callback
def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken,
used_by: Optional[str] = None) -> models.RefreshToken:
remote_ip: Optional[str] = None) -> models.RefreshToken:
"""Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_by = used_by
refresh_token.last_used_ip = remote_ip
self._async_schedule_save()
return refresh_token

Expand Down Expand Up @@ -252,9 +252,11 @@ async def _async_load(self) -> None:
token_type = models.TOKEN_TYPE_NORMAL

# old refresh_token don't have last_used_at (pre-0.78)
last_used_at = rt_dict.get('last_used_at')
if last_used_at is not None:
last_used_at_str = rt_dict.get('last_used_at')
if last_used_at_str:
last_used_at = dt_util.parse_datetime(last_used_at_str)
else:
last_used_at = None

token = models.RefreshToken(
id=rt_dict['id'],
Expand All @@ -270,7 +272,7 @@ async def _async_load(self) -> None:
token=rt_dict['token'],
jwt_key=rt_dict['jwt_key'],
last_used_at=last_used_at,
last_used_by=rt_dict.get('last_used_by'),
last_used_ip=rt_dict.get('last_used_ip'),
)
users[rt_dict['user_id']].refresh_tokens[token.id] = token

Expand Down Expand Up @@ -325,9 +327,10 @@ def _data_to_save(self) -> Dict:
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
'jwt_key': refresh_token.jwt_key,
'last_used_at': refresh_token.last_used_at.isoformat()
'last_used_at':
refresh_token.last_used_at.isoformat()
if refresh_token.last_used_at else None,
'last_used_by': refresh_token.last_used_by,
'last_used_ip': refresh_token.last_used_ip,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class RefreshToken:
default=attr.Factory(lambda: generate_secret(64)))

last_used_at = attr.ib(type=Optional[datetime], default=None)
last_used_by = attr.ib(type=Optional[str], default=None)
last_used_ip = attr.ib(type=Optional[str], default=None)


@attr.s(slots=True)
Expand Down
23 changes: 12 additions & 11 deletions tests/auth/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def test_saving_loading(hass, hass_storage):
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
manager.async_create_access_token(refresh_token, '192.168.0.1')
# the second refresh token will not be used
await manager.async_create_refresh_token(user, CLIENT_ID)
await manager.async_create_refresh_token(user, 'dummy-client')

await flush_store(manager._store._store)

Expand All @@ -291,16 +291,17 @@ async def test_saving_loading(hass, hass_storage):
assert len(users) == 1
assert users[0] == user
assert len(users[0].refresh_tokens) == 2
# verify the first refresh token
r_token = list(users[0].refresh_tokens.values())[0]
assert r_token.client_id == CLIENT_ID
assert r_token.last_used_at is not None
assert r_token.last_used_by == '192.168.0.1'
# verify the second refresh token
r_token = list(users[0].refresh_tokens.values())[1]
assert r_token.client_id == CLIENT_ID
assert r_token.last_used_at is None
assert r_token.last_used_by is None
for r_token in users[0].refresh_tokens.values():
if r_token.client_id == CLIENT_ID:
# verify the first refresh token
assert r_token.last_used_at is not None
assert r_token.last_used_ip == '192.168.0.1'
elif r_token.client_id == 'dummy-client':
# verify the second refresh token
assert r_token.last_used_at is None
assert r_token.last_used_ip is None
else:
assert False, 'Unknown client_id: %s' % r_token.client_id


async def test_cannot_retrieve_expired_access_token(hass):
Expand Down

0 comments on commit f097fab

Please sign in to comment.