diff --git a/tibber/__init__.py b/tibber/__init__.py index 1d5e800..21ed9b3 100644 --- a/tibber/__init__.py +++ b/tibber/__init__.py @@ -4,6 +4,7 @@ import datetime as dt import logging import zoneinfo +from ssl import SSLContext from typing import Any import aiohttp @@ -35,6 +36,7 @@ def __init__( websession: aiohttp.ClientSession | None = None, time_zone: dt.tzinfo | None = None, user_agent: str | None = None, + ssl: SSLContext | bool = True, ): """Initialize the Tibber connection. @@ -43,10 +45,11 @@ def __init__( :param websession: The websession to use when communicating with the Tibber API. :param time_zone: The time zone to display times in and to use. :param user_agent: User agent identifier for the platform running this. Required if websession is None. + :param ssl: SSLContext to use. """ if websession is None: - websession = aiohttp.ClientSession() + websession = aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=ssl)) elif user_agent is None: user_agent = websession.headers.get(aiohttp.hdrs.USER_AGENT) if user_agent is None: @@ -60,6 +63,7 @@ def __init__( self._access_token, self.timeout, self._user_agent, + ssl=ssl, ) self.time_zone: dt.tzinfo = time_zone or zoneinfo.ZoneInfo("UTC") diff --git a/tibber/realtime.py b/tibber/realtime.py index 7c2297b..0b78df5 100644 --- a/tibber/realtime.py +++ b/tibber/realtime.py @@ -4,6 +4,7 @@ import datetime as dt import logging import random +from ssl import SSLContext from typing import Any from gql import Client @@ -21,12 +22,7 @@ class TibberRT: """Class to handle real time connection with the Tibber api.""" # pylint: disable=too-many-instance-attributes - def __init__( - self, - access_token: str, - timeout: int, - user_agent: str, - ): + def __init__(self, access_token: str, timeout: int, user_agent: str, ssl: SSLContext | bool): """Initialize the Tibber connection. :param access_token: The access token to access the Tibber API with. @@ -36,6 +32,7 @@ def __init__( self._access_token: str = access_token self._timeout: int = timeout self._user_agent: str = user_agent + self._ssl_context = ssl self._sub_endpoint: str | None = None self._homes: list[TibberHome] = [] @@ -90,6 +87,7 @@ def _create_sub_manager(self) -> None: self.sub_endpoint, self._access_token, self._user_agent, + ssl=self._ssl_context, ), ) diff --git a/tibber/websocket_transport.py b/tibber/websocket_transport.py index d4c04f7..4734b76 100644 --- a/tibber/websocket_transport.py +++ b/tibber/websocket_transport.py @@ -3,6 +3,7 @@ import asyncio import datetime as dt import logging +from ssl import SSLContext from gql.transport.exceptions import TransportClosed from gql.transport.websockets import WebsocketsTransport @@ -13,13 +14,14 @@ class TibberWebsocketsTransport(WebsocketsTransport): """Tibber websockets transport.""" - def __init__(self, url: str, access_token: str, user_agent: str) -> None: + def __init__(self, url: str, access_token: str, user_agent: str, ssl: SSLContext | bool = True) -> None: """Initialize TibberWebsocketsTransport.""" super().__init__( url=url, init_payload={"token": access_token}, headers={"User-Agent": user_agent}, ping_interval=30, + ssl=ssl, ) self._user_agent: str = user_agent self._timeout: int = 90