Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FCM v1: use async version of google-auth and add HTTP proxy support #372

Merged
merged 9 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/372.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FCM v1: use async version of google-auth and add HTTP proxy support.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ dependencies = [
"attrs>=19.2.0",
"cryptography>=2.6.1",
"idna>=2.8",
"google-auth>=2.27.0",
"google-auth[aiohttp]>=2.27.0",
"jaeger-client>=4.0.0",
"matrix-common==1.3.0",
"opentracing>=2.2.0",
Expand All @@ -104,6 +104,7 @@ dev = [
"mypy-zope==1.0.1",
"towncrier",
"tox",
"google-auth-stubs==0.2.0",
"types-opentracing>=2.4.2",
"types-pyOpenSSL",
"types-PyYAML",
Expand Down
84 changes: 52 additions & 32 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
import time
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple

import google.auth.transport.requests
from google.oauth2 import service_account
import aiohttp
import google.auth.transport._aiohttp_requests
from google.auth._default_async import load_credentials_from_file
from google.oauth2._credentials_async import Credentials
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
from opentracing import Span, logs, tags
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.defer import DeferredSemaphore
from twisted.internet.defer import Deferred, DeferredSemaphore
from twisted.web.client import FileBodyProducer, HTTPConnectionPool, readBody
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
Expand Down Expand Up @@ -153,6 +157,15 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None:
proxy_url_str=proxy_url,
)

# Use the fcm_options config dictionary as a foundation for the body;
# this lets the Sygnal admin choose custom FCM options
# (e.g. content_available).
self.base_request_body = self.get_config("fcm_options", dict, {})
if not isinstance(self.base_request_body, dict):
raise PushkinSetupException(
"Config field fcm_options, if set, must be a dictionary of options"
)

self.api_version = APIVersion.Legacy
version_str = self.get_config("api_version", str)
if not version_str:
Expand Down Expand Up @@ -180,19 +193,31 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None:
"Must configure `project_id` when using FCM api v1",
)

self.service_account_file = self.get_config("service_account_file", str)
if self.api_version is APIVersion.V1 and not self.service_account_file:
raise PushkinSetupException(
"Must configure `service_account_file` when using FCM api v1",
)
self.credentials: Credentials = None # type: ignore
MatMaul marked this conversation as resolved.
Show resolved Hide resolved

# Use the fcm_options config dictionary as a foundation for the body;
# this lets the Sygnal admin choose custom FCM options
# (e.g. content_available).
self.base_request_body = self.get_config("fcm_options", dict, {})
if not isinstance(self.base_request_body, dict):
raise PushkinSetupException(
"Config field fcm_options, if set, must be a dictionary of options"
if self.api_version is APIVersion.V1:
self.service_account_file = self.get_config("service_account_file", str)
if self.service_account_file:
try:
self.credentials, _ = load_credentials_from_file(
str(self.service_account_file),
scopes=AUTH_SCOPES,
)
except google.auth.exceptions.DefaultCredentialsError:
pass

if not self.credentials:
raise PushkinSetupException(
"Must configure valid `service_account_file` when using FCM api v1",
)

session = None
if proxy_url:
os.environ["HTTPS_PROXY"] = proxy_url
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
session = aiohttp.ClientSession(trust_env=True, auto_decompress=False)

self.request = google.auth.transport._aiohttp_requests.Request(
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
session=session
)

@classmethod
Expand Down Expand Up @@ -464,21 +489,19 @@ def _handle_v1_response(
f"Unknown GCM response code {response.code}"
)

def _get_access_token(self) -> str:
"""Retrieve a valid access token that can be used to authorize requests.
async def _get_auth_header(self) -> str:
"""Retrieve the auth header that can be used to authorize requests.

:return: Access token.
:return: Needed content of the `Authorization` header
"""
# TODO: Should we use the environment variable approach instead?
MatMaul marked this conversation as resolved.
Show resolved Hide resolved
# export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
# credentials, project = google.auth.default(scopes=AUTH_SCOPES)
credentials = service_account.Credentials.from_service_account_file(
str(self.service_account_file),
scopes=AUTH_SCOPES,
)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
return credentials.token
if self.api_version is APIVersion.Legacy:
return "key=%s" % (self.api_key,)
else:
if not self.credentials.valid:
await Deferred.fromFuture(
asyncio.ensure_future(self.credentials.refresh(self.request))
)
return "Bearer %s" % self.credentials.token

async def _dispatch_notification_unlimited(
self, n: Notification, device: Device, context: NotificationContext
Expand Down Expand Up @@ -532,10 +555,7 @@ async def _dispatch_notification_unlimited(
"Content-Type": ["application/json"],
}

if self.api_version == APIVersion.Legacy:
headers["Authorization"] = ["key=%s" % (self.api_key,)]
elif self.api_version is APIVersion.V1:
headers["Authorization"] = ["Bearer %s" % (self._get_access_token(),)]
headers["Authorization"] = [await self._get_auth_header()]

body = self.base_request_body.copy()
body["data"] = data
Expand Down
12 changes: 9 additions & 3 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple
from unittest.mock import MagicMock

from sygnal.gcmpushkin import GcmPushkin
from sygnal.gcmpushkin import GcmPushkin, PushkinSetupException

from tests import testutils
from tests.testutils import DummyResponse
Expand Down Expand Up @@ -86,12 +86,18 @@ class TestGcmPushkin(GcmPushkin):
"""

def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]):
super().__init__(name, sygnal, config)
self.preloaded_response = DummyResponse(0)
self.preloaded_response_payload: Dict[str, Any] = {}
self.last_request_body: Dict[str, Any] = {}
self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type]
self.num_requests = 0
try:
super().__init__(name, sygnal, config)
except PushkinSetupException as e:
# for FCM v1 API we get an exception because the service account file
# does not exist, let's ignore it and move forward
if "service_account_file" not in str(e):
raise e

def preload_with_response(
self, code: int, response_payload: Dict[str, Any]
Expand All @@ -110,7 +116,7 @@ async def _perform_http_request( # type: ignore[override]
self.num_requests += 1
return self.preloaded_response, json.dumps(self.preloaded_response_payload)

def _get_access_token(self) -> str:
async def _get_auth_header(self) -> str:
return "token"


Expand Down
Loading