import json import time import os import logging import threading from collections import defaultdict from datetime import datetime, timezone from threading import Lock import requests from flask import Flask, request, Response app = Flask(__name__) # ============================================================ # CONFIG # ============================================================ class Config: REQUIRED_KEYS = ["TUWUNEL_URL", "LOCAL_DOMAIN"] def __init__(self): self.tuwunel_url = os.getenv("TUWUNEL_URL") self.local_domain = os.getenv("LOCAL_DOMAIN") self.admin_users = self._parse_list_env("ADMIN_USERS", []) self.admin_token = os.getenv("ADMIN_TOKEN", "") self.domain_whitelist = set( self._parse_list_env("DOMAIN_WHITELIST", []) ) self.block_external_dms = self._parse_bool("BLOCK_EXTERNAL_DMS", True) self.allow_room_creation = self._parse_bool("ALLOW_ROOM_CREATION", False) self.cache_ttl = int(os.getenv("CACHE_TTL_SECONDS", 604800)) self.rate_limit_per_minute = int(os.getenv("RATE_LIMIT_PER_MINUTE", 20)) self.http_timeout = int(os.getenv("HTTP_TIMEOUT", 5)) self.fail_open = self._parse_bool("FAIL_OPEN", True) self.debug = self._parse_bool("DEBUG", False) self.bot_webhook_url = os.getenv("BOT_WEBHOOK_URL", "") self.bot_webhook_secret = os.getenv("BOT_WEBHOOK_SECRET", "") def validate(self): return [k for k in self.REQUIRED_KEYS if not os.getenv(k)] def _parse_list_env(self, key, default=None): raw = os.getenv(key) if not raw: return default or [] return [x.strip() for x in raw.split(",") if x.strip()] def _parse_bool(self, key, default): val = os.getenv(key) if val is None: return default return val.lower() in ("1", "true", "yes", "on") # ============================================================ # LOGGING # ============================================================ log_level = logging.DEBUG if os.getenv("DEBUG", "false").lower() == "true" else logging.INFO logging.basicConfig(level=log_level, format="%(message)s") logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("gunicorn.access").setLevel(logging.WARNING) logger = logging.getLogger("matrix-interceptor") def now_iso(): return datetime.now(timezone.utc).isoformat() def log_event(event: str, **kwargs): base = f"{now_iso()} EVENT={event}" details = " ".join(f"{k}={v}" for k, v in kwargs.items()) logger.info(f"{base} {details}") # ============================================================ # INIT CONFIG # ============================================================ config = Config() missing = config.validate() if missing: logger.error(f"Missing env vars: {missing}") raise SystemExit(1) # ============================================================ # STATE & Cache # ============================================================ KNOWN_EXTERNAL_USERS = {} RATE_LIMIT = defaultdict(list) METRICS = defaultdict(int) DM_NOTIFY = defaultdict(list) METRICS_LOCK = Lock() CACHE_LOCK = Lock() CACHE_FILE = "/app/cache/known_users.json" CACHE_DIRTY = False # ============================================================ # CACHE Handling # ============================================================ def load_cache(): global KNOWN_EXTERNAL_USERS os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, "r") as f: KNOWN_EXTERNAL_USERS = json.load(f) logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users") except Exception: KNOWN_EXTERNAL_USERS = {} else: with open(CACHE_FILE, "w") as f: json.dump({}, f) logger.info("Initialized empty cache file") def save_cache(): global CACHE_DIRTY try: with CACHE_LOCK: os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) with open(CACHE_FILE, "w") as f: json.dump(KNOWN_EXTERNAL_USERS, f) CACHE_DIRTY = False except Exception as e: logger.error(f"Failed to save cache: {e}") def periodic_cache_save(): while True: save_cache() time.sleep(30) # ============================================================ # HELPERS # ============================================================ def extract_domain(user_id): try: parts = user_id.split(":") if len(parts) < 2: return "unknown" return parts[1].lower().rstrip(".") except Exception: return "unknown" def is_external(user_id): return extract_domain(user_id) != config.local_domain def is_known_user(user_id): ts = KNOWN_EXTERNAL_USERS.get(user_id) if not ts: return False if time.time() - ts > config.cache_ttl: del KNOWN_EXTERNAL_USERS[user_id] global CACHE_DIRTY CACHE_DIRTY = True return False return True def remember_user(user_id): global CACHE_DIRTY KNOWN_EXTERNAL_USERS[user_id] = time.time() CACHE_DIRTY = True save_cache() def is_local_room(room_id): try: return room_id.split(":")[1] == config.local_domain except Exception: return False def get_role(user_id): return "admin" if user_id in config.admin_users else "user" def notify_bot(event_type, sender, room_id): if not config.bot_webhook_url: return try: headers = {} if config.bot_webhook_secret: headers["Authorization"] = f"Bearer {config.bot_webhook_secret}" requests.post( config.bot_webhook_url, json={ "type": event_type, "sender": sender, "room_id": room_id, "timestamp": time.time() }, headers=headers, timeout=2 ) except Exception as e: logger.error(f"Failed to notify bot: {e}") # ============================================================ # RATE LIMIT # ============================================================ def is_rate_limited(domain, sender): key = f"{domain}:{sender}" now = time.time() RATE_LIMIT[key] = [t for t in RATE_LIMIT[key] if now - t < 60] if len(RATE_LIMIT[key]) >= config.rate_limit_per_minute: return True RATE_LIMIT[key].append(now) return False # ============================================================ # SEED # ============================================================ def seed_known_users(): if not config.admin_token: return headers = {"Authorization": f"Bearer {config.admin_token}"} seeded = 0 try: rooms_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms", headers=headers, timeout=10 ) if rooms_res.status_code != 200: return for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue members_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members", headers=headers, timeout=10 ) if members_res.status_code != 200: continue for user_id in members_res.json().get("joined", {}).keys(): if is_external(user_id): remember_user(user_id) seeded += 1 logger.info(f"Seed refresh completed ({seeded} users)") except Exception as e: logger.error(f"Seed failed: {e}") def periodic_seed(): while True: seed_known_users() time.sleep(60) # ============================================================ # FALLBACK # ============================================================ def is_user_in_local_rooms(user_id): if not config.admin_token: return False try: headers = {"Authorization": f"Bearer {config.admin_token}"} rooms_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms", headers=headers, timeout=5 ) for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue members_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members", headers=headers, timeout=5 ) if user_id in members_res.json().get("joined", {}): return True except Exception as e: logger.error(f"Fallback check failed: {e}") return False # ============================================================ # FALLBACK RETRY (CRITICAL FOR CONSISTENCY) # ============================================================ def fallback_check_with_retry(user_id): """ Retry membership check to handle Matrix eventual consistency. User join may not be immediately visible via API. """ for _ in range(2): # πŸ”₯ 2 attempts if is_user_in_local_rooms(user_id): return True time.sleep(0.5) # πŸ”₯ short delay return False # ============================================================ # DM DETECTION # ============================================================ def is_likely_dm_create(payload): if payload.get("is_direct") is True: return True invite = payload.get("invite", []) if isinstance(invite, list) and len(invite) == 1: return True if payload.get("visibility") == "private" and invite: return True return False def is_likely_dm_event(event): content = event.get("content", {}) if content.get("is_direct") is True: return True unsigned = event.get("unsigned", {}) state = unsigned.get("invite_room_state", []) members = [e for e in state if e.get("type") == "m.room.member"] return len(members) <= 2 # ============================================================ # INIT (πŸ”₯ CRITICAL FIX FOR GUNICORN) # ============================================================ def init_app(): logger.info("Initializing interceptor (gunicorn mode)...") load_cache() seed_known_users() threading.Thread(target=periodic_seed, daemon=True).start() threading.Thread(target=periodic_cache_save, daemon=True).start() logger.info("Initialization complete") init_app() # ============================================================ # ROUTES # ============================================================ @app.route("/healthz") def health(): return { "status": "ok", "known_users": len(KNOWN_EXTERNAL_USERS), "metrics": dict(METRICS) } @app.route("/metrics") def metrics(): return dict(METRICS) # ============================================================ # CREATE ROOM # ============================================================ @app.route('/_matrix/client/v3/createRoom', methods=['POST']) def create_room(): payload = request.get_json(silent=True) or {} user_id = request.headers.get("Authorization", "unknown") domain = extract_domain(user_id) role = get_role(user_id) is_dm = is_likely_dm_create(payload) allowed = ( user_id in config.admin_users or config.allow_room_creation or is_dm ) if not allowed: with METRICS_LOCK: METRICS["create_room_blocked"] += 1 log_event( "create_room_blocked", actor=user_id, domain=domain, role=role, reason="room_creation_disabled" ) return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403) if is_dm: with METRICS_LOCK: METRICS["create_room_allowed"] += 1 log_event( "create_room_allowed", actor=user_id, domain=domain, role=role, room_type="dm" ) return forward_request( "POST", f"{config.tuwunel_url}/_matrix/client/v3/createRoom", request.headers, payload ) # ============================================================ # INVITE # ============================================================ @app.route('/_matrix/federation/v2/invite//', methods=['PUT']) def invite(room_id, event_id): payload = request.get_json(force=True) event = payload.get("event", {}) sender = event.get("sender", "") if not sender: return Response(status=400) domain = extract_domain(sender) # πŸ”’ Rate Limit if is_rate_limited(domain, sender): log_event( "rate_limited", actor=sender, domain=domain ) return Response(status=429) # 🟒 Whitelist if domain in config.domain_whitelist: remember_user(sender) with METRICS_LOCK: METRICS["invite_allowed"] += 1 return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) is_dm = is_likely_dm_event(event) # πŸ”’ DM Protection if config.block_external_dms and is_dm and is_external(sender): if not is_known_user(sender): # πŸ”₯ Retry fallback (eventual consistency fix) if fallback_check_with_retry(sender): remember_user(sender) with METRICS_LOCK: METRICS["invite_allowed"] += 1 return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) key = f"{sender}:{room_id}" last_list = RATE_LIMIT.get(f"dm_notify:{key}", []) last = last_list[-1] if last_list else 0 # πŸ”₯ deduplicated notify if time.time() - last > 5: notify_bot("dm_spam", sender, room_id) DM_NOTIFY[key].append(time.time()) # πŸ”₯ heavy detection unabhΓ€ngig davon if is_rate_limited(domain, sender): notify_bot("dm_spam_heavy", sender, room_id) with METRICS_LOCK: METRICS["dm_detected"] += 1 log_event( "dm_detected", actor=sender, domain=domain, room_id=room_id ) # 🟒 DEFAULT (alles andere erlauben) remember_user(sender) with METRICS_LOCK: METRICS["invite_allowed"] += 1 return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) # ============================================================ # FORWARD # ============================================================ def forward_request(method, url, headers, body): try: proxy_headers = {"Content-Type": "application/json"} if "Authorization" in headers: proxy_headers["Authorization"] = headers["Authorization"] res = requests.request( method=method, url=url, headers=proxy_headers, json=body, timeout=config.http_timeout ) return Response(res.content, res.status_code) except Exception as e: logger.error(f"proxy error: {e}") if config.fail_open: logger.warning("FAIL_OPEN β†’ allowing request") return Response(status=200) return Response(status=502)