import json import time import os import logging import threading from collections import defaultdict from datetime import datetime, timezone import requests from flask import Flask, request, Response app = Flask(__name__) # ============================================================ # CONFIG # ============================================================ class Config: REQUIRED_KEYS = ["TUWUNEL_URL", "LOCAL_DOMAIN"] def __init__(self): self.tuwunel_url = os.getenv("TUWUNEL_URL") self.local_domain = os.getenv("LOCAL_DOMAIN") self.admin_users = self._parse_list_env("ADMIN_USERS", []) self.admin_token = os.getenv("ADMIN_TOKEN", "") self.domain_whitelist = set( self._parse_list_env("DOMAIN_WHITELIST", []) ) self.block_external_dms = self._parse_bool("BLOCK_EXTERNAL_DMS", True) self.allow_room_creation = self._parse_bool("ALLOW_ROOM_CREATION", False) self.cache_ttl = int(os.getenv("CACHE_TTL_SECONDS", 604800)) self.rate_limit_per_minute = int(os.getenv("RATE_LIMIT_PER_MINUTE", 20)) self.http_timeout = int(os.getenv("HTTP_TIMEOUT", 5)) self.fail_open = self._parse_bool("FAIL_OPEN", True) self.debug = self._parse_bool("DEBUG", False) def validate(self): return [k for k in self.REQUIRED_KEYS if not os.getenv(k)] def _parse_list_env(self, key, default=None): raw = os.getenv(key) if not raw: return default or [] return [x.strip() for x in raw.split(",") if x.strip()] def _parse_bool(self, key, default): val = os.getenv(key) if val is None: return default return val.lower() in ("1", "true", "yes", "on") # ============================================================ # LOGGING # ============================================================ log_level = logging.DEBUG if os.getenv("DEBUG", "false").lower() == "true" else logging.INFO logging.basicConfig(level=log_level) logging.getLogger("werkzeug").setLevel(logging.ERROR) logger = logging.getLogger("matrix-interceptor") def now_iso(): return datetime.now(timezone.utc).isoformat() def log_event(event: str, **kwargs): base = f"{now_iso()} EVENT={event}" details = " ".join(f"{k}={v}" for k, v in kwargs.items()) logger.info(f"{base} {details}") def debug_log(title, data): if config.debug: logger.debug(f"{title}: {json.dumps(data, default=str)}") # ============================================================ # INIT # ============================================================ config = Config() missing = config.validate() if missing: logger.error(f"Missing env vars: {missing}") raise SystemExit(1) # ============================================================ # STATE + CACHE # ============================================================ KNOWN_EXTERNAL_USERS = {} RATE_LIMIT = defaultdict(list) CACHE_FILE = "/app/cache/known_users.json" def load_cache(): global KNOWN_EXTERNAL_USERS if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, "r") as f: KNOWN_EXTERNAL_USERS = json.load(f) logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users") except: KNOWN_EXTERNAL_USERS = {} def save_cache(): try: os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) with open(CACHE_FILE, "w") as f: json.dump(KNOWN_EXTERNAL_USERS, f) except Exception as e: logger.error(f"Failed to save cache: {e}") # ============================================================ # HELPERS # ============================================================ def extract_domain(user_id): try: return user_id.split(":")[1].lower().rstrip(".") except: return "unknown" def is_external(user_id): return extract_domain(user_id) != config.local_domain def is_known_user(user_id): ts = KNOWN_EXTERNAL_USERS.get(user_id) if not ts: return False if time.time() - ts > config.cache_ttl: del KNOWN_EXTERNAL_USERS[user_id] save_cache() return False return True def remember_user(user_id): KNOWN_EXTERNAL_USERS[user_id] = time.time() save_cache() def is_local_room(room_id): try: return room_id.split(":")[1] == config.local_domain except: return False def get_role(user_id): return "admin" if user_id in config.admin_users else "user" # ============================================================ # RATE LIMIT # ============================================================ def is_rate_limited(domain): now = time.time() RATE_LIMIT[domain] = [ t for t in RATE_LIMIT[domain] if now - t < 60 ] if len(RATE_LIMIT[domain]) >= config.rate_limit_per_minute: return True RATE_LIMIT[domain].append(now) return False # ============================================================ # SEED # ============================================================ def seed_known_users(): if not config.admin_token: return headers = {"Authorization": f"Bearer {config.admin_token}"} seeded = 0 try: rooms_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms", headers=headers, timeout=10 ) if rooms_res.status_code != 200: return for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue members_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members", headers=headers, timeout=10 ) if members_res.status_code != 200: continue for user_id in members_res.json().get("joined", {}).keys(): if is_external(user_id): remember_user(user_id) seeded += 1 logger.info(f"Seed refresh completed ({seeded} users)") except Exception as e: logger.error(f"Seed failed: {e}") def periodic_seed(): while True: seed_known_users() time.sleep(300) # ============================================================ # FALLBACK # ============================================================ def is_user_in_local_rooms(user_id): if not config.admin_token: return False try: headers = {"Authorization": f"Bearer {config.admin_token}"} rooms_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms", headers=headers, timeout=5 ) for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue members_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members", headers=headers, timeout=5 ) if user_id in members_res.json().get("joined", {}): return True except: return False return False # ============================================================ # DM DETECTION # ============================================================ def is_likely_dm_create(payload): if payload.get("is_direct") is True: return True invite = payload.get("invite", []) if isinstance(invite, list) and len(invite) == 1: return True if payload.get("visibility") == "private" and invite: return True return False def is_likely_dm_event(event): content = event.get("content", {}) if content.get("is_direct") is True: return True unsigned = event.get("unsigned", {}) state = unsigned.get("invite_room_state", []) members = [e for e in state if e.get("type") == "m.room.member"] return len(members) <= 2 # ============================================================ # ROUTES # ============================================================ @app.route("/healthz") def health(): return {"status": "ok"} @app.route('/_matrix/client/v3/createRoom', methods=['POST']) def create_room(): payload = request.get_json(silent=True) or {} user_id = request.headers.get("Authorization", "unknown") domain = extract_domain(user_id) role = get_role(user_id) is_dm = is_likely_dm_create(payload) allowed = ( user_id in config.admin_users or config.allow_room_creation or is_dm ) if not allowed: log_event( "create_room_blocked", actor=user_id, domain=domain, role=role, reason="room_creation_disabled" ) return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403) if is_dm: log_event( "create_room_allowed", actor=user_id, domain=domain, role=role, room_type="dm" ) return forward_request( "POST", f"{config.tuwunel_url}/_matrix/client/v3/createRoom", request.headers, payload ) @app.route('/_matrix/federation/v2/invite//', methods=['PUT']) def invite(room_id, event_id): payload = request.get_json(force=True) event = payload.get("event", {}) sender = event.get("sender", "") if not sender: return Response(status=400) domain = extract_domain(sender) if is_rate_limited(domain): return Response(status=429) if domain in config.domain_whitelist: remember_user(sender) return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) if config.block_external_dms and is_external(sender): if not is_known_user(sender): if is_user_in_local_rooms(sender): remember_user(sender) else: log_event( "invite_blocked", actor=sender, domain=domain, reason="unknown_external_user" ) return Response(status=403) remember_user(sender) return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) # ============================================================ # FORWARD # ============================================================ def forward_request(method, url, headers, body): try: proxy_headers = {"Content-Type": "application/json"} if "Authorization" in headers: proxy_headers["Authorization"] = headers["Authorization"] res = requests.request( method=method, url=url, headers=proxy_headers, json=body, timeout=config.http_timeout ) return Response(res.content, res.status_code) except Exception as e: logger.error(f"proxy error: {e}") if config.fail_open: logger.warning("FAIL_OPEN → allowing request") return Response(status=200) return Response(status=502) # ============================================================ # START # ============================================================ if __name__ == '__main__': load_cache() seed_known_users() threading.Thread(target=periodic_seed, daemon=True).start() app.run(host='0.0.0.0', port=5000)