Skip to content

Commit

Permalink
Support login with matrixcore
Browse files Browse the repository at this point in the history
  • Loading branch information
nexy7574 committed Jan 11, 2025
1 parent 41be53b commit 46b4f73
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 56 deletions.
85 changes: 35 additions & 50 deletions src/niobot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
CommandDisabledError,
CommandError,
GenericMatrixError,
LoginException,
MessageException,
NioBotException,
)
Expand Down Expand Up @@ -275,10 +274,12 @@ def __init__(
if self.auto_join_rooms:
self.log.info("Auto-joining rooms enabled.")
self.add_event_callback(self._auto_join_room_backlog_callback, nio.InviteMemberEvent) # type: ignore
self.core.on("room_invite")(self._auto_join_room_callback)

if self.auto_read_messages:
self.log.info("Auto-updating read receipts enabled.")
self.add_event_callback(self.update_read_receipts, nio.RoomMessage)
self.core.on("m.room.message")(self.update_read_receipts)

if import_keys:
keys_path, keys_password = import_keys
Expand Down Expand Up @@ -317,29 +318,7 @@ def __init__(
# New MatrixCore things
self.sync_filter: str | None = None

@property
def access_token(self) -> str | None:
return self.core.http.access_token or None

@access_token.setter
def access_token(self, value: str):
self.core.http.access_token = value

@property
def user_id(self) -> str:
return self.core.http.user_id

@user_id.setter
def user_id(self, value: str):
self.core.http.user_id = value

@property
def device_id(self) -> str:
return self.core.http.device_id

@device_id.setter
def device_id(self, value: str):
self.core.http.device_id = value
self.core.on("m.room.message")(self.process_message)

@property
def supported_server_versions(self) -> typing.List[typing.Tuple[int, int, int]]:
Expand Down Expand Up @@ -446,7 +425,7 @@ def dispatch(self, event_name: typing.Union[str, nio.Event], *args, **kwargs):
else:
self.log.debug("%r is not in registered events: %s", event_name, self._events)

def is_old(self, event: nio.Event) -> bool:
def is_old(self, event: matrixcore.ClientEventWithoutRoomID) -> bool:
"""Checks if an event was sent before the bot started. Always returns False when ignore_old_events is False"""
if not self.start_time:
self.log.warning("have not started yet, using relative age comparison")
Expand All @@ -455,7 +434,7 @@ def is_old(self, event: nio.Event) -> bool:
start_time = self.start_time
if self.ignore_old_events is False:
return False
return start_time - event.server_timestamp / 1000 > 0
return start_time - event.origin_server_ts / 1000 > 0

async def update_read_receipts(self, room: U[str, nio.MatrixRoom], event: nio.Event):
"""Moves the read indicator to the given event in the room.
Expand All @@ -480,11 +459,10 @@ async def update_read_receipts(self, room: U[str, nio.MatrixRoom], event: nio.Ev
else:
self.log.debug("Updated read receipts for %s to %s.", room, event_id)

async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessage) -> None:
async def process_message(self, room: matrixcore.Room, event: matrixcore.ClientEventWithoutRoomID) -> None:
"""Processes a message and runs the command it is trying to invoke if any."""
if self.start_time is None:
raise RuntimeError("Bot has not started yet!")
self.message_cache.append((room, event))
self.dispatch("message", room, event)
if not isinstance(event, nio.RoomMessageText):
self.log.debug("Ignoring non-text message %r", event.event_id)
Expand All @@ -493,11 +471,11 @@ async def process_message(self, room: nio.MatrixRoom, event: nio.RoomMessage) ->
self.log.debug("Ignoring message sent by self.")
return
if self.is_old(event):
age = self.start_time - event.server_timestamp / 1000
age = self.start_time - event.origin_server_ts / 1000
self.log.debug(f"Ignoring message sent {age:.0f} seconds before startup.")
return

body = event.body
body = event.content.get("body", "")
if self.process_edits and event.flattened().get("content.m.new_content"):
body = event.flattened()["content.m.new_content"].get("body", event.body)

Expand All @@ -523,7 +501,7 @@ def get_prefix(c: str) -> typing.Union[str, None]:
except IndexError:
self.log.info(
"Failed to parse message %r - message terminated early (was the content *just* the prefix?)",
event.body,
body,
)
return
command: typing.Optional[Command] = self.get_command(command_name)
Expand Down Expand Up @@ -576,7 +554,7 @@ def _task_callback(t: asyncio.Task):
try:
task = asyncio.create_task(
await command.invoke(context),
name=f"COMMAND_{event.sender}_{room.room_id}_{command.name}_{time.monotonic_ns()}",
name=f"COMMAND_{event.sender}_{room.id}_{command.name}_{time.monotonic_ns()}",
)
context._task = task
context._perf_timer = time.perf_counter()
Expand Down Expand Up @@ -1387,30 +1365,38 @@ async def start(
access_token: typing.Optional[str] = None,
) -> None:
"""Starts the bot, running the sync loop."""
if not any((password, access_token)):
raise ValueError("You must specify either a password or an access token.")
self.loop = asyncio.get_event_loop()
self.dispatch("event_loop_ready")
if self.__key_import:
self.log.info("Starting automatic key import")
await self.import_keys(*map(str, self.__key_import))

async with self.sync_store:
if password:
self.log.info("Logging in with password.")
await self.core.password_login(self.user_id, password, self.device_id)
elif access_token:
self.log.info("Logging in with existing access token.")
if self.store_path:
try:
self.load_store()
except FileNotFoundError:
self.log.warning("Failed to load store.")
except nio.LocalProtocolError as e:
self.log.warning("No store? %r", e, exc_info=e)
self.access_token = access_token
self.start_time = time.time()
else:
raise LoginException("You must specify either a password or an access token.")
if password is not None and access_token is None:
self.log.info("Logging in with password.")
r = await self.core.password_login(self.core.http.user_id, password, self.device_id)
self.access_token = r.access_token
self.user_id = r.user_id
elif access_token:
self.log.info("Logging in with existing access token.")
self.access_token = self.core.http.access_token = access_token

resp = await self.core.http.whoami()
self.user_id = resp.user_id
self.device_id = resp.device_id
self.core.http.device_id = resp.device_id

if self.store_path:
try:
self.load_store()
except FileNotFoundError:
self.log.warning("Failed to load store.")
except nio.LocalProtocolError as e:
self.log.warning("No store? %r", e, exc_info=e)
self.start_time = time.time()

async with self.sync_store:
if self.should_upload_keys:
self.log.info("Uploading encryption keys...")
response = await self.keys_upload()
Expand Down Expand Up @@ -1628,5 +1614,4 @@ def __init__(
use_fallback_replies=False,
force_initial_sync=False,
process_message_edits=True,
onsite_state_resolution=False,
)
7 changes: 3 additions & 4 deletions src/niobot/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from collections.abc import Callable

import matrixcore
import nio

from .context import Context
Expand Down Expand Up @@ -444,8 +445,8 @@ async def invoke(self, ctx: Context) -> typing.Coroutine:
def construct_context(
self,
client: "NioBot",
room: nio.MatrixRoom,
src_event: nio.RoomMessageText,
room: matrixcore.Room,
src_event: matrixcore.CustomBaseModel,
invoking_prefix: str,
meta: str,
cls: type = Context,
Expand All @@ -462,8 +463,6 @@ def construct_context(
:param cls: The class to construct the context with. Defaults to `Context`.
:return: The constructed Context.
"""
if not isinstance(src_event, (nio.RoomMessageText, nio.RoomMessageNotice)):
raise TypeError("src_event must be a textual event (i.e. m.text or m.notice).")
return cls(client, room, src_event, self, invoking_prefix=invoking_prefix, invoking_string=meta)


Expand Down
33 changes: 31 additions & 2 deletions src/niobot/utils/sync_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,35 @@ async def wrapper(self: "MatrixCoreSyncStore", *args, **kwargs):

return wrapper

@require_db
async def get_next_batch(self, user_id: str = None) -> str:
"""Returns the next batch token for the given user ID (or the client's user ID)"""
await self._init_db()
user_id = user_id or self._client.http.user_id
async with self._db.execute("SELECT next_batch FROM meta WHERE user_id=?", (user_id,)) as cursor:
result = await cursor.fetchone()
if result:
self.log.debug("Next batch record for %r: %r", user_id, result["next_batch"])
return result["next_batch"]
self.log.debug("No next batch stored, returning empty token.")
return ""

@require_db
async def set_next_batch(self, user_id: str, next_batch: str) -> None:
"""Sets the next batch token for the given user ID"""
await self._init_db()
self.log.debug("Setting next batch to %r for %r.", next_batch, user_id)
await self._db.execute(
"""
INSERT INTO meta (user_id, next_batch) VALUES (?, ?)
ON CONFLICT(user_id)
DO
UPDATE SET next_batch=excluded.next_batch
WHERE user_id=excluded.user_id
""",
(user_id, next_batch),
)

@require_db
async def get_invited_room(self, room_id: str) -> matrixcore.InvitedRoom | None:
"""Fetches an invited room from the database."""
Expand Down Expand Up @@ -796,7 +825,7 @@ async def generate_sync(self) -> matrixcore.SyncResponse:

async with self._db.cursor() as cursor:
self.log.debug("Constructing invited rooms...")
async for row in cursor.execute("SELECT room_id, invite_state FROM invited_rooms"):
async for row in await cursor.execute("SELECT room_id, invite_state FROM invited_rooms"):
invite_state = matrixcore.InvitedRoom(
invite_state=matrixcore.InvitedRoom.InviteState.model_validate(
await self.aloads(row["invite_state"])
Expand All @@ -805,7 +834,7 @@ async def generate_sync(self) -> matrixcore.SyncResponse:
rooms.invite[row["room_id"]] = matrixcore.InvitedRoom(invite_state=invite_state)
self.log.debug("Added room %r to invited rooms.", row["room_id"])

async for row in cursor.execute("SELECT room_id, state, account_data FROM joined_rooms"):
async for row in await cursor.execute("SELECT room_id, room_state, account_data FROM joined_rooms"):
room_state = matrixcore.State(
events=[
matrixcore.ClientEventWithoutRoomID.model_validate(x)
Expand Down

0 comments on commit 46b4f73

Please sign in to comment.