Skip to content

Commit

Permalink
Add ability to specify ssl_context for the realtime connection
Browse files Browse the repository at this point in the history
  • Loading branch information
functionpointer committed Aug 8, 2024
1 parent c6a5af4 commit abb4150
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tibber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime as dt
import logging
import zoneinfo
from ssl import SSLContext
from typing import Any

import aiohttp
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
websession: aiohttp.ClientSession | None = None,
time_zone: dt.tzinfo | None = None,
user_agent: str | None = None,
ssl: SSLContext | bool = True,
):
"""Initialize the Tibber connection.
Expand All @@ -43,10 +45,11 @@ def __init__(
:param websession: The websession to use when communicating with the Tibber API.
:param time_zone: The time zone to display times in and to use.
:param user_agent: User agent identifier for the platform running this. Required if websession is None.
:param ssl: SSLContext to use.
"""

if websession is None:
websession = aiohttp.ClientSession()
websession = aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl=ssl))
elif user_agent is None:
user_agent = websession.headers.get(aiohttp.hdrs.USER_AGENT)
if user_agent is None:
Expand All @@ -60,6 +63,7 @@ def __init__(
self._access_token,
self.timeout,
self._user_agent,
ssl=ssl,
)

self.time_zone: dt.tzinfo = time_zone or zoneinfo.ZoneInfo("UTC")
Expand Down
4 changes: 4 additions & 0 deletions tibber/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import random
from typing import Any
from ssl import SSLContext

from gql import Client

Expand All @@ -26,6 +27,7 @@ def __init__(
access_token: str,
timeout: int,
user_agent: str,
ssl: SSLContext | bool
):
"""Initialize the Tibber connection.
Expand All @@ -36,6 +38,7 @@ def __init__(
self._access_token: str = access_token
self._timeout: int = timeout
self._user_agent: str = user_agent
self._ssl_context = ssl

self._sub_endpoint: str | None = None
self._homes: list[TibberHome] = []
Expand Down Expand Up @@ -90,6 +93,7 @@ def _create_sub_manager(self) -> None:
self.sub_endpoint,
self._access_token,
self._user_agent,
ssl=self._ssl_context,
),
)

Expand Down
4 changes: 3 additions & 1 deletion tibber/websocket_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import datetime as dt
import logging
from ssl import SSLContext

from gql.transport.exceptions import TransportClosed
from gql.transport.websockets import WebsocketsTransport
Expand All @@ -13,13 +14,14 @@
class TibberWebsocketsTransport(WebsocketsTransport):
"""Tibber websockets transport."""

def __init__(self, url: str, access_token: str, user_agent: str) -> None:
def __init__(self, url: str, access_token: str, user_agent: str, ssl: SSLContext | bool = True) -> None:
"""Initialize TibberWebsocketsTransport."""
super().__init__(
url=url,
init_payload={"token": access_token},
headers={"User-Agent": user_agent},
ping_interval=30,
ssl=ssl,
)
self._user_agent: str = user_agent
self._timeout: int = 90
Expand Down

0 comments on commit abb4150

Please sign in to comment.