Skip to content

Commit

Permalink
Reverting strange changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bachya committed Apr 13, 2019
1 parent dc71fbf commit d6ad354
Showing 1 changed file with 1 addition and 52 deletions.
53 changes: 1 addition & 52 deletions homeassistant/auth/permissions/util.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,16 @@
"""Helpers to deal with permissions."""
from functools import wraps

from typing import (
TYPE_CHECKING, Callable, Dict, List, Optional, Union, cast) # noqa: F401
from typing import Callable, Dict, List, Optional, Union, cast # noqa: F401

from homeassistant.exceptions import Unauthorized, UnknownUser

from .const import POLICY_CONTROL
from .models import PermissionLookup
from .types import CategoryType, SubCategoryDict, ValueType

if TYPE_CHECKING:
from homeassistant.core import HomeAssistant, Service, ServiceCall # noqa

LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str],
Optional[ValueType]]
SubCatLookupType = Dict[str, LookupFunc]


def authorized_service_call(hass: 'HomeAssistant', domain: str) -> Callable:
"""Ensure user of a config entry-enabled service call has permission."""
def decorator(service: 'Service') -> Callable:
"""Decorate."""
@wraps(service)
async def check_permissions(call: 'ServiceCall') -> None:
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return

user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
permission=POLICY_CONTROL
)

# If the user passes one or more entity IDs, check permissions
# there; otherwise, check permissions against entities registered
# to the domain:
if call.data.get('entity_id'):
if isinstance(call.data['entity'], str):
entities = [call.data['entity_id']]
else:
entities = call.data['entity_id']
else:
reg = await hass.helpers.entity_registry.async_get_registry()
entities = [
entity.entity_id for entity in reg.entities.values()
if entity.platform == domain
]

for entity_id in entities:
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
return await service(call)

raise Unauthorized(
context=call.context,
permission=POLICY_CONTROL,
)
return check_permissions
return decorator


def lookup_all(perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict,
object_id: str) -> ValueType:
"""Look up permission for all."""
Expand Down

0 comments on commit d6ad354

Please sign in to comment.