Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atproto proxy and atproto labelers support #345

Merged
merged 3 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
atproto\_client.client.methods\_mixin.headers
=============================================

.. automodule:: atproto_client.client.methods_mixin.headers
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ Submodules
.. toctree::
:maxdepth: 4

atproto_client.client.methods_mixin.headers
atproto_client.client.methods_mixin.session
atproto_client.client.methods_mixin.time
27 changes: 16 additions & 11 deletions docs/source/dm.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,40 @@ PASSWORD = 'hunter2' # never hardcode your password in a real application


def main() -> None:
# create resolver instance with in-memory cache
id_resolver = IdResolver()

# create client instance and login
client = Client()
client.login(USERNAME, PASSWORD)
client.login(USERNAME, PASSWORD) # use App Password with access to Direct Messages!

# create client proxied to Bluesky Chat service
dm_client = client.with_bsky_chat_proxy()
# create shortcut to convo methods
dm = dm_client.chat.bsky.convo

convo_list = client.chat.bsky.convo.list_convos() # use limit and cursor to paginate
convo_list = dm.list_convos() # use limit and cursor to paginate
print(f'Your conversations ({len(convo_list.convos)}):')
for convo in convo_list.convos:
members = ', '.join(member.display_name for member in convo.members)
print(f'- ID: {convo.id} ({members})')

# create resolver instance with in-memory cache
id_resolver = IdResolver()
# resolve DID
chat_to = id_resolver.handle.resolve('test.marshal.dev')

# create or get conversation with chat_to
convo = client.chat.bsky.convo.get_convo_for_members(
convo = dm.get_convo_for_members(
models.ChatBskyConvoGetConvoForMembers.Params(members=[chat_to]),
)
print(f'\nConvo ID: {convo.convo.id}')
).convo

print(f'\nConvo ID: {convo.id}')
print('Convo members:')
for member in convo.convo.members:
for member in convo.members:
print(f'- {member.display_name} ({member.did})')

# send a message to the conversation
client.chat.bsky.convo.send_message(
dm.send_message(
models.ChatBskyConvoSendMessage.Data(
convo_id=convo.convo.id,
convo_id=convo.id,
message=models.ChatBskyConvoDefs.MessageInput(
text='Hello from Python SDK!',
),
Expand Down
27 changes: 16 additions & 11 deletions examples/advanced_usage/direct_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,40 @@


def main() -> None:
# create resolver instance with in-memory cache
id_resolver = IdResolver()

# create client instance and login
client = Client()
client.login(USERNAME, PASSWORD)
client.login(USERNAME, PASSWORD) # use App Password with access to Direct Messages!

# create client proxied to Bluesky Chat service
dm_client = client.with_bsky_chat_proxy()
# create shortcut to convo methods
dm = dm_client.chat.bsky.convo

convo_list = client.chat.bsky.convo.list_convos() # use limit and cursor to paginate
convo_list = dm.list_convos() # use limit and cursor to paginate
print(f'Your conversations ({len(convo_list.convos)}):')
for convo in convo_list.convos:
members = ', '.join(member.display_name for member in convo.members)
print(f'- ID: {convo.id} ({members})')

# create resolver instance with in-memory cache
id_resolver = IdResolver()
# resolve DID
chat_to = id_resolver.handle.resolve('test.marshal.dev')

# create or get conversation with chat_to
convo = client.chat.bsky.convo.get_convo_for_members(
convo = dm.get_convo_for_members(
models.ChatBskyConvoGetConvoForMembers.Params(members=[chat_to]),
)
print(f'\nConvo ID: {convo.convo.id}')
).convo

print(f'\nConvo ID: {convo.id}')
print('Convo members:')
for member in convo.convo.members:
for member in convo.members:
print(f'- {member.display_name} ({member.did})')

# send a message to the conversation
client.chat.bsky.convo.send_message(
dm.send_message(
models.ChatBskyConvoSendMessage.Data(
convo_id=convo.convo.id,
convo_id=convo.id,
message=models.ChatBskyConvoDefs.MessageInput(
text='Hello from Python SDK!',
),
Expand Down
18 changes: 17 additions & 1 deletion packages/atproto_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import typing as t
from asyncio import Lock

import typing_extensions as te
from atproto_core.uri import AtUri

from atproto_client import models
from atproto_client.client.async_raw import AsyncClientRaw
from atproto_client.client.methods_mixin import SessionMethodsMixin, TimeMethodsMixin
from atproto_client.client.methods_mixin.headers import HeadersConfigurationMethodsMixin
from atproto_client.client.methods_mixin.session import AsyncSessionDispatchMixin
from atproto_client.client.session import Session, SessionEvent, SessionResponse
from atproto_client.exceptions import LoginRequiredError
Expand All @@ -24,7 +26,9 @@
from atproto_client.request import Response


class AsyncClient(AsyncSessionDispatchMixin, SessionMethodsMixin, TimeMethodsMixin, AsyncClientRaw):
class AsyncClient(
AsyncSessionDispatchMixin, SessionMethodsMixin, TimeMethodsMixin, HeadersConfigurationMethodsMixin, AsyncClientRaw
):
"""High-level client for XRPC of ATProto."""

def __init__(self, base_url: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any) -> None:
Expand Down Expand Up @@ -73,6 +77,18 @@ async def _import_session_string(self, session_string: str) -> Session:

return import_session

async def clone(self) -> te.Self:
"""Clone the client instance.

Used to customize atproto proxy and set of labeler services.

Returns:
Cloned client instance.
"""
cloned_client = super().clone()
cloned_client.me = self.me
return cloned_client

async def login(
self, login: t.Optional[str] = None, password: t.Optional[str] = None, session_string: t.Optional[str] = None
) -> 'models.AppBskyActorDefs.ProfileViewDetailed':
Expand Down
41 changes: 26 additions & 15 deletions packages/atproto_client/client/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing as t
from enum import Enum

import typing_extensions as te

from atproto_client.models.utils import get_model_as_dict, get_model_as_json
from atproto_client.request import AsyncRequest, Request, Response

Expand Down Expand Up @@ -54,19 +56,16 @@ def _handle_base_url(base_url: t.Optional[str] = None) -> str:
return base_url


class ClientBase:
"""Low-level methods are here."""

def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[Request] = None) -> None:
if request is None:
request = Request()
class _ClientCommonMethodsMixin:
def clone(self) -> te.Self:
"""Clone the client instance.

self._request = request
self._base_url = _handle_base_url(base_url)
Used to customize atproto proxy and set of labeler services.

@property
def request(self) -> Request:
return self._request
Returns:
Cloned client instance.
"""
return type(self)(base_url=self._base_url, request=self.request.clone())

def update_base_url(self, base_url: t.Optional[str] = None) -> None:
"""Update XRPC base URL.
Expand All @@ -82,6 +81,21 @@ def update_base_url(self, base_url: t.Optional[str] = None) -> None:
def _build_url(self, nsid: str) -> str:
return f'{self._base_url}/{nsid}'


class ClientBase(_ClientCommonMethodsMixin):
"""Low-level methods are here."""

def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[Request] = None) -> None:
if request is None:
request = Request()

self._request = request
self._base_url = _handle_base_url(base_url)

@property
def request(self) -> Request:
return self._request

def invoke_query(
self,
nsid: str,
Expand All @@ -108,7 +122,7 @@ def _invoke(self, invoke_type: InvokeType, **kwargs: t.Any) -> Response:
return self.request.post(**kwargs)


class AsyncClientBase:
class AsyncClientBase(_ClientCommonMethodsMixin):
"""Low-level methods are here."""

def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[AsyncRequest] = None) -> None:
Expand All @@ -122,9 +136,6 @@ def __init__(self, base_url: t.Optional[str] = None, request: t.Optional[AsyncRe
def request(self) -> AsyncRequest:
return self._request

def _build_url(self, nsid: str) -> str:
return f'{self._base_url}/{nsid}'

async def invoke_query(
self,
nsid: str,
Expand Down
16 changes: 15 additions & 1 deletion packages/atproto_client/client/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import typing as t
from threading import Lock

import typing_extensions as te
from atproto_core.uri import AtUri

from atproto_client import models
from atproto_client.client.methods_mixin import SessionMethodsMixin, TimeMethodsMixin
from atproto_client.client.methods_mixin.headers import HeadersConfigurationMethodsMixin
from atproto_client.client.methods_mixin.session import SessionDispatchMixin
from atproto_client.client.raw import ClientRaw
from atproto_client.client.session import Session, SessionEvent, SessionResponse
Expand All @@ -17,7 +19,7 @@
from atproto_client.request import Response


class Client(SessionDispatchMixin, SessionMethodsMixin, TimeMethodsMixin, ClientRaw):
class Client(SessionDispatchMixin, SessionMethodsMixin, TimeMethodsMixin, HeadersConfigurationMethodsMixin, ClientRaw):
"""High-level client for XRPC of ATProto."""

def __init__(self, base_url: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any) -> None:
Expand Down Expand Up @@ -66,6 +68,18 @@ def _import_session_string(self, session_string: str) -> Session:

return import_session

def clone(self) -> te.Self:
"""Clone the client instance.

Used to customize atproto proxy and set of labeler services.

Returns:
Cloned client instance.
"""
cloned_client = super().clone()
cloned_client.me = self.me
return cloned_client

def login(
self, login: t.Optional[str] = None, password: t.Optional[str] = None, session_string: t.Optional[str] = None
) -> 'models.AppBskyActorDefs.ProfileViewDetailed':
Expand Down
92 changes: 92 additions & 0 deletions packages/atproto_client/client/methods_mixin/headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import typing as t
from enum import Enum

import typing_extensions as te
from atproto_core.exceptions import AtProtocolError

_MAX_LABELERS_COUNT = 10
_ATPROTO_PROXY_HEADER = 'atproto-proxy'
_ATPROTO_ACCEPT_LABELERS_HEADER = 'atproto-accept-labelers'


class HeadersConfigurationMethodsMixin:
BSKY_CHAT_DID: t.ClassVar[t.Literal['did:web:api.bsky.chat']] = 'did:web:api.bsky.chat'

# Bluesky hardcoded labeler https://docs.bsky.app/blog/blueskys-moderation-architecture
BSKY_LABELER_DID: t.ClassVar[t.Literal['did:plc:ar7c4by46qjdydhdevvrndac']] = 'did:plc:ar7c4by46qjdydhdevvrndac'

class AtprotoServiceType(Enum):
"""The type of atproto service."""

ATPROTO_LABELER = 'atproto_labeler'
BSKY_CHAT = 'bsky_chat'

def with_proxy(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> te.Self:
"""Get a new client instance with the atproto-proxy header configured.

Args:
service_type: The type of service.
did: The DID of the proxy.

Returns:
:obj:`self`: Configured client instance.
"""
cloned_client = self.clone()
cloned_client.configure_proxy_header(service_type, did)
return cloned_client

def with_labelers(self, labeler_dids: t.List[str]) -> te.Self:
"""Get a new client instance with the atproto-accept-labelers header configured.

Args:
labeler_dids: The DIDs of the labelers.

Returns:
:obj:`self`: Configured client instance.
"""
cloned_client = self.clone()
cloned_client.configure_labelers_header(labeler_dids)
return cloned_client

def configure_proxy_header(self, service_type: t.Union[AtprotoServiceType, str], did: str) -> None:
"""Configure the atproto-proxy header to be applied on requests.

Args:
service_type: The type of service.
did: The DID of the proxy.
"""
if not did.startswith('did:'):
raise AtProtocolError('Invalid DID format')

if isinstance(service_type, self.AtprotoServiceType):
service_type = service_type.value

proxy_header = f'{did}#{service_type}'
self.request.add_additional_header(_ATPROTO_PROXY_HEADER, proxy_header)

def configure_labelers_header(self, labeler_dids: t.List[str]) -> None:
"""Configure the atproto-labelers header to be applied on requests.

Args:
labeler_dids: The DIDs of the labelers.
"""
labelers_prepared = [f'{labeler_did};redact' for labeler_did in labeler_dids if labeler_did.startswith('did:')]
labelers_header_value = ','.join(labelers_prepared[:_MAX_LABELERS_COUNT])

self.request.add_additional_header(_ATPROTO_ACCEPT_LABELERS_HEADER, labelers_header_value)

def with_bsky_chat_proxy(self) -> te.Self:
"""Get a new client instance with the atproto-proxy header configured for bsky.chat.

Returns:
:obj:`self`: Configured client instance.
"""
return self.with_proxy(self.AtprotoServiceType.BSKY_CHAT, self.BSKY_CHAT_DID)

def with_bsky_labeler(self) -> te.Self:
"""Get a new client instance with the atproto-accept-labelers header configured for Bluesky Labeler.

Returns:
:obj:`self`: Configured client instance.
"""
return self.with_labelers([self.BSKY_LABELER_DID])
3 changes: 2 additions & 1 deletion packages/atproto_client/client/methods_mixin/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def _get_auth_headers(token: str) -> t.Dict[str, str]:
return {'Authorization': f'Bearer {token}'}

def _set_auth_headers(self, token: str) -> None:
self.request.set_additional_headers(self._get_auth_headers(token))
for header_name, header_value in self._get_auth_headers(token).items():
self.request.add_additional_header(header_name, header_value)

def _update_pds_endpoint(self, pds_endpoint: str) -> None:
self.update_base_url(pds_endpoint)
Expand Down
Loading
Loading