From 39ad02b906b784bfb162d6bb7af3d1f7fd4e8146 Mon Sep 17 00:00:00 2001 From: William Kray Date: Thu, 10 Apr 2025 13:52:52 -0700 Subject: [PATCH] move verification state management to database table to survive restarts --- community/bot.py | 85 +++++++++++++++++++++++++++++++++++++++++++++--- community/db.py | 14 ++++++++ 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/community/bot.py b/community/bot.py index 097330b..cb5659c 100644 --- a/community/bot.py +++ b/community/bot.py @@ -8,6 +8,7 @@ import fnmatch import asyncio import random import asyncpg.exceptions +from datetime import datetime from mautrix.client import ( Client, @@ -90,6 +91,8 @@ class CommunityBot(Plugin): self.client.add_dispatcher(MembershipEventDispatcher) # Start background redaction task self._redaction_tasks = asyncio.create_task(self._redaction_loop()) + # Clean up stale verification states + await self.cleanup_stale_verification_states() async def stop(self) -> None: if self._redaction_tasks: @@ -925,13 +928,14 @@ class CommunityBot(Plugin): verification_phrase = random.choice(self.config["verification_phrases"]) # Store verification state - self._verification_states[dm_room] = { + verification_state = { "user": evt.sender, "target_room": evt.room_id, "phrase": verification_phrase, "attempts": self.config["verification_attempts"], "required_level": required_level } + await self.store_verification_state(dm_room, verification_state) # Send greeting greeting = self.config["verification_message"].format( @@ -950,10 +954,10 @@ class CommunityBot(Plugin): if evt.sender == self.client.mxid: return - if evt.room_id not in self._verification_states: + state = await self.get_verification_state(evt.room_id) + if not state: return - state = self._verification_states[evt.room_id] user_phrase = evt.content.body.strip().lower() expected_phrase = state["phrase"].lower() @@ -984,7 +988,7 @@ class CommunityBot(Plugin): ) finally: await self.client.leave_room(evt.room_id) - del self._verification_states[evt.room_id] + await self.delete_verification_state(evt.room_id) else: state["attempts"] -= 1 if state["attempts"] <= 0: @@ -998,8 +1002,9 @@ class CommunityBot(Plugin): f"User verification failed for {evt.sender} in room {evt.room_id}, you may need to manually verify them." ) await self.client.leave_room(evt.room_id) - del self._verification_states[evt.room_id] + await self.delete_verification_state(evt.room_id) else: + await self.store_verification_state(evt.room_id, state) await self.client.send_notice( evt.room_id, f"Phrase does not match, you have {state['attempts']} tries remaining." @@ -1946,6 +1951,76 @@ class CommunityBot(Plugin): self.log.error(error_msg) await evt.respond(error_msg, edits=msg) + async def store_verification_state(self, dm_room_id: str, state: dict) -> None: + """Store verification state in the database.""" + await self.database.execute( + """INSERT OR REPLACE INTO verification_states + (dm_room_id, user_id, target_room_id, verification_phrase, attempts_remaining, required_power_level) + VALUES ($1, $2, $3, $4, $5, $6)""", + dm_room_id, + state["user"], + state["target_room"], + state["phrase"], + state["attempts"], + state["required_level"] + ) + + async def get_verification_state(self, dm_room_id: str) -> Optional[dict]: + """Retrieve verification state from the database.""" + row = await self.database.fetchrow( + "SELECT * FROM verification_states WHERE dm_room_id = $1", + dm_room_id + ) + if not row: + return None + return { + "user": row["user_id"], + "target_room": row["target_room_id"], + "phrase": row["verification_phrase"], + "attempts": row["attempts_remaining"], + "required_level": row["required_power_level"] + } + + async def delete_verification_state(self, dm_room_id: str) -> None: + """Delete verification state from the database.""" + await self.database.execute( + "DELETE FROM verification_states WHERE dm_room_id = $1", + dm_room_id + ) + + async def cleanup_stale_verification_states(self) -> None: + """Clean up verification states that are no longer valid.""" + # Get all verification states + states = await self.database.fetch("SELECT * FROM verification_states") + + for state in states: + try: + # Check if DM room still exists and bot is still in it + try: + await self.client.get_state_event(state["dm_room_id"], EventType.ROOM_MEMBER, self.client.mxid) + except Exception: + # Bot is not in the DM room anymore, state is stale + await self.delete_verification_state(state["dm_room_id"]) + continue + + # Check if user is still in the target room + try: + await self.client.get_state_event(state["target_room_id"], EventType.ROOM_MEMBER, state["user_id"]) + except Exception: + # User is not in the target room anymore, state is stale + await self.delete_verification_state(state["dm_room_id"]) + continue + + # Check if verification is too old (older than 24 hours) + if (datetime.now() - state["created_at"]).total_seconds() > 86400: + await self.delete_verification_state(state["dm_room_id"]) + continue + + except Exception as e: + self.log.error(f"Error checking verification state {state['dm_room_id']}: {e}") + # If we can't check the state, assume it's stale + await self.delete_verification_state(state["dm_room_id"]) + @classmethod def get_db_upgrade_table(cls) -> None: return upgrade_table diff --git a/community/db.py b/community/db.py index 99f97df..7f059d1 100644 --- a/community/db.py +++ b/community/db.py @@ -22,3 +22,17 @@ async def upgrade_v2(conn: Connection) -> None: room_id TEXT NOT NULL )""" ) + +@upgrade_table.register(description="Add verification states table") +async def upgrade_v3(conn: Connection) -> None: + await conn.execute( + """CREATE TABLE verification_states ( + dm_room_id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + target_room_id TEXT NOT NULL, + verification_phrase TEXT NOT NULL, + attempts_remaining INTEGER NOT NULL, + required_power_level INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )""" + )