Skip to content

Commit

Permalink
Merge pull request #629 from freelawproject/threads-refresh-access-token
Browse files Browse the repository at this point in the history
feat(Threads): Refresh Threads access tokens automatically.
  • Loading branch information
mlissner authored Nov 14, 2024
2 parents ef2c97c + 523b160 commit e206fd4
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 8 deletions.
66 changes: 64 additions & 2 deletions bc/channel/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
import logging

from django.db import models
from django.urls import reverse
from redis.exceptions import LockError

from bc.core.models import AbstractDateTimeModel
from bc.core.utils.color import format_color_str
from bc.sponsorship.models import Sponsorship
from bc.users.models import User

from .utils.connectors.base import BaseAPIConnector
from ..core.utils.redis import make_redis_interface
from .utils.connectors.base import (
BaseAPIConnector,
RefreshableBaseAPIConnector,
)
from .utils.connectors.bluesky import BlueskyConnector
from .utils.connectors.masto import MastodonConnector, get_handle_parts
from .utils.connectors.threads import ThreadsConnector
from .utils.connectors.twitter import TwitterConnector

logger = logging.getLogger(__name__)

r = make_redis_interface("CACHE")


class Group(AbstractDateTimeModel):
name = models.CharField(
Expand Down Expand Up @@ -70,6 +81,7 @@ class Channel(AbstractDateTimeModel):
(BLUESKY, "Bluesky"),
(THREADS, "Threads"),
)
CHANNELS_TO_REFRESH = [THREADS]
service = models.PositiveSmallIntegerField(
help_text="Type of the service",
choices=CHANNELS,
Expand Down Expand Up @@ -107,7 +119,9 @@ class Channel(AbstractDateTimeModel):
blank=True,
)

def get_api_wrapper(self) -> BaseAPIConnector:
def get_api_wrapper(
self,
) -> BaseAPIConnector | RefreshableBaseAPIConnector:
match self.service:
case self.TWITTER:
return TwitterConnector(
Expand Down Expand Up @@ -145,6 +159,54 @@ def self_url(self):
f"Channel.self_url() not yet implemented for service {self.service}"
)

def validate_access_token(self):
"""
Validates and refreshes the access token for the channel if necessary.
This method implements a locking mechanism to avoid multiple tasks
from concurrently trying to validate the same token.
"""
if self.service not in self.CHANNELS_TO_REFRESH:
return
lock_key = self._get_refresh_lock_key()
lock = r.lock(lock_key, sleep=1, timeout=60)
blocking_timeout = 60
try:
# Use a blocking lock to wait until locking task is finished
lock.acquire(blocking=True, blocking_timeout=blocking_timeout)
# Then perform action to validate
self._refresh_access_token()
except LockError as e:
logger.error(
f"LockError while acquiring lock for channel {self}: {e}"
)
raise e
finally:
if not lock.owned():
return
try:
lock.release()
except Exception as e:
logger.error(f"Error releasing lock for channel {self}:\n{e}")

def _refresh_access_token(self):
api = self.get_api_wrapper()
try:
refreshed, access_token = api.validate_access_token()
if refreshed:
self.access_token = access_token
self.save()
except Exception as e:
logger.error(
f"Error when trying to refresh token for channel {self.pk}:\n{e}"
)

def _get_refresh_lock_key(self):
"""
Constructs the Redis key used for locking during access token refresh.
"""
return f"token_refresh_lock_{self.account_id}@{self.get_service_display()}"

def __str__(self) -> str:
if self.account:
return f"{self.pk}: {self.account}"
Expand Down
1 change: 1 addition & 0 deletions bc/channel/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def enqueue_text_status_for_channel(channel: Channel, text: str) -> None:
channel (Channel): The channel object.
text (str): Message for the new status.
"""
channel.validate_access_token()
api = channel.get_api_wrapper()
queue.enqueue(
api.add_status,
Expand Down
17 changes: 17 additions & 0 deletions bc/channel/utils/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,20 @@ def add_status(
int: The unique identifier for the new status.
"""
...


class RefreshableBaseAPIConnector(BaseAPIConnector, Protocol):
"""
Extends BaseAPIConnector to add logic to validate access tokens.
"""

def validate_access_token(self) -> tuple[bool, str]:
"""
Validates the access token and refreshes it if necessary.
Returns:
tuple[bool, str]: A tuple where the first element is a boolean
indicating if the token was refreshed, and the second element
is the current access token.
"""
...
16 changes: 15 additions & 1 deletion bc/channel/utils/connectors/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ThreadsConnector:
"""
A connector for interfacing with the Threads API, which complies with
the BaseAPIConnector protocol.
the RefreshableBaseAPIConnector protocol.
"""

def __init__(
Expand All @@ -32,6 +32,20 @@ def get_api_object(self, _version=None) -> ThreadsAPI:
)
return api

def validate_access_token(self) -> tuple[bool, str]:
"""
Ensures that the access token used by the connector is up-to-date.
This method delegates the validation of the access token to the underlying
`ThreadsAPI` instance by checking the access token's expiration date and
refreshing it if necessary.
Returns:
tuple[bool, str]: A tuple where the first element is a boolean
indicating whether the token was refreshed, and the second element is the current access token.
"""
return self.api.validate_access_token()

def upload_media(self, media: bytes, _alt_text=None) -> str:
"""
Uploads media to public storage for Threads API compatibility.
Expand Down
112 changes: 111 additions & 1 deletion bc/channel/utils/connectors/threads_api/client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import logging
import time
import uuid
from datetime import datetime, timedelta, timezone

import requests
from django.conf import settings

from bc.core.utils.images import convert_to_jpeg, resize_image
from bc.core.utils.redis import make_redis_interface
from bc.core.utils.s3 import put_object_in_bucket

logger = logging.getLogger(__name__)


_BASE_API_URL = "https://graph.threads.net/v1.0"

r = make_redis_interface("CACHE")


class ThreadsAPI:
"""
Expand Down Expand Up @@ -165,6 +168,113 @@ def attempt_post(
return None
return response

def validate_access_token(self) -> tuple[bool, str]:
"""
Validates the current access token and refreshes it if necessary.
This method checks the expiration date of the access token stored in the Redis cache.
If the expiration date is missing, expired, or will expire within two days,
it attempts to refresh the token by calling `refresh_access_token`.
Returns:
tuple[bool, str]: A tuple where the first element is a boolean
indicating whether the token was refreshed, and the second element is the current access token.
"""
refreshed = False

try:
cached_expiration_date = r.get(self._get_expiration_key())
except Exception as e:
logger.error(
f"Could not retrieve cached token, will attempt to refresh.\n"
f"Redis error: {e}"
)
return self.refresh_access_token(), self._access_token

if cached_expiration_date is None:
return self.refresh_access_token(), self._access_token

expiration_date = datetime.fromisoformat(str(cached_expiration_date))
delta = expiration_date - datetime.now(timezone.utc)
will_expire_soon = delta <= timedelta(days=2)

if will_expire_soon:
refreshed = self.refresh_access_token()

return refreshed, self._access_token

def refresh_access_token(self) -> bool:
"""
Refreshes the access token by making a request to the Threads API.
If the refresh is successful, it updates the access token and its expiration date in the cache.
Returns:
bool: `True` if the access token was successfully refreshed and updated; `False` otherwise.
"""
refresh_access_token_url = (
"https://graph.threads.net/refresh_access_token"
)
params = {
"grant_type": "th_refresh_token",
"access_token": self._access_token,
}
try:
response = requests.get(
refresh_access_token_url,
params=params,
timeout=10,
)
response.raise_for_status()
except requests.exceptions.RequestException as err:
logger.error(
f"Failed to refresh access token for Threads account {self._account_id}:\n"
f"{err}"
)
return False

data = response.json()
new_access_token = data.get("access_token")
expires_in = data.get("expires_in") # In seconds

if new_access_token is None or expires_in is None:
logger.error(
f"Missing 'access_token' or 'expires_in' in refresh access token response for Threads account {self._account_id}. "
f"If the issue persists, a new access token can be retrieved manually with the script again.\n"
f"Response data: {data}"
)
return False

self._access_token = new_access_token
self._set_token_expiration_in_cache(expires_in)

return True

def _set_token_expiration_in_cache(self, expires_in: int):
"""
Stores the access token's expiration date in the Redis cache.
Args:
expires_in (int): The number of seconds until the access token expires.
"""
delay = timedelta(seconds=expires_in)
expiration_date = (datetime.now(timezone.utc) + delay).isoformat()
key = self._get_expiration_key()
try:
r.set(
key,
expiration_date.encode("utf-8"),
ex=expires_in, # ensure the cache entry expires when the token does
)
except Exception as e:
logger.error(f"Could not set {key} in cache:\n{e}")

def _get_expiration_key(self) -> str:
"""
Returns the Redis key used for storing the access token's expiration date.
"""
return f"threads_token_expiration_{self._account_id}"

@staticmethod
def resize_and_upload_to_public_storage(media: bytes) -> str:
"""
Expand Down
3 changes: 3 additions & 0 deletions bc/subscription/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def enqueue_posts_for_new_case(
initial_complaint_link=initial_complaint_link,
)

channel.validate_access_token()
api = channel.get_api_wrapper()

sponsor_message = None
Expand Down Expand Up @@ -423,6 +424,8 @@ def make_post_for_webhook_event(
if sponsor_text and files:
files = add_sponsored_text_to_thumbnails(files, sponsor_text)

channel.validate_access_token()

api = channel.get_api_wrapper()
api_post_id = api.add_status(message, image, files)

Expand Down
Loading

0 comments on commit e206fd4

Please sign in to comment.