import json import time import os import logging import threading from collections import defaultdict from datetime import datetime, timezone from threading import Lock import requests from flask import Flask, request, Response app = Flask(__name__) # ============================================================ # CONFIG # ============================================================ class Config: REQUIRED_KEYS = ["TUWUNEL_URL", "LOCAL_DOMAIN"] def __init__(self): self.tuwunel_url = os.getenv("TUWUNEL_URL") self.local_domain = os.getenv("LOCAL_DOMAIN") self.admin_users = self._parse_list_env("ADMIN_USERS", []) self.admin_token = os.getenv("ADMIN_TOKEN", "") self.domain_whitelist = set( self._parse_list_env("DOMAIN_WHITELIST", []) ) self.block_external_dms = self._parse_bool("BLOCK_EXTERNAL_DMS", True) self.allow_room_creation = self._parse_bool("ALLOW_ROOM_CREATION", False) self.cache_ttl = int(os.getenv("CACHE_TTL_SECONDS", 604800)) self.rate_limit_per_minute = int(os.getenv("RATE_LIMIT_PER_MINUTE", 20)) self.http_timeout = int(os.getenv("HTTP_TIMEOUT", 5)) self.fail_open = self._parse_bool("FAIL_OPEN", True) self.debug = self._parse_bool("DEBUG", False) def validate(self): return [k for k in self.REQUIRED_KEYS if not os.getenv(k)] def _parse_list_env(self, key, default=None): raw = os.getenv(key) if not raw: return default or [] return [x.strip() for x in raw.split(",") if x.strip()] def _parse_bool(self, key, default): val = os.getenv(key) if val is None: return default return val.lower() in ("1", "true", "yes", "on") # ============================================================ # LOGGING # ============================================================ log_level = logging.DEBUG if os.getenv("DEBUG", "false").lower() == "true" else logging.INFO logging.basicConfig(level=log_level, format="%(message)s") logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("gunicorn.access").setLevel(logging.WARNING) logger = logging.getLogger("matrix-interceptor") logger.propagate = False def now_iso(): return datetime.now(timezone.utc).isoformat() def log_event(event: str, **kwargs): base = f"{now_iso()} EVENT={event}" details = " ".join(f"{k}={v}" for k, v in kwargs.items()) logger.info(f"{base} {details}") # ============================================================ # INIT # ============================================================ config = Config() missing = config.validate() if missing: logger.error(f"Missing env vars: {missing}") raise SystemExit(1) # ============================================================ # STATE + CACHE # ============================================================ KNOWN_EXTERNAL_USERS = {} RATE_LIMIT = defaultdict(list) METRICS = defaultdict(int) METRICS_LOCK = Lock() CACHE_LOCK = Lock() CACHE_FILE = "/app/cache/known_users.json" CACHE_DIRTY = False # ============================================================ # CACHE HANDLING # ============================================================ def load_cache(): global KNOWN_EXTERNAL_USERS os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, "r") as f: KNOWN_EXTERNAL_USERS = json.load(f) logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users") except Exception: KNOWN_EXTERNAL_USERS = {} else: with open(CACHE_FILE, "w") as f: json.dump({}, f) logger.info("Initialized empty cache file") def save_cache(): global CACHE_DIRTY if not CACHE_DIRTY: return try: with CACHE_LOCK: os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) with open(CACHE_FILE, "w") as f: json.dump(KNOWN_EXTERNAL_USERS, f) CACHE_DIRTY = False except Exception as e: logger.error(f"Failed to save cache: {e}") def periodic_cache_save(): while True: save_cache() time.sleep(30) # ============================================================ # HELPERS # ============================================================ def extract_domain(user_id): try: return user_id.split(":")[1].lower().rstrip(".") except Exception: return "unknown" def is_external(user_id): return extract_domain(user_id) != config.local_domain def is_known_user(user_id): ts = KNOWN_EXTERNAL_USERS.get(user_id) if not ts: return False if time.time() - ts > config.cache_ttl: del KNOWN_EXTERNAL_USERS[user_id] global CACHE_DIRTY CACHE_DIRTY = True return False return True def remember_user(user_id): global CACHE_DIRTY KNOWN_EXTERNAL_USERS[user_id] = time.time() CACHE_DIRTY = True save_cache() def is_local_room(room_id): try: return room_id.split(":")[1] == config.local_domain except Exception: return False def get_role(user_id): return "admin" if user_id in config.admin_users else "user" # ============================================================ # RATE LIMIT (FIXED) # ============================================================ def is_rate_limited(domain, sender): key = f"{domain}:{sender}" now = time.time() RATE_LIMIT[key] = [t for t in RATE_LIMIT[key] if now - t < 60] if len(RATE_LIMIT[key]) >= config.rate_limit_per_minute: return True RATE_LIMIT[key].append(now) return False # ============================================================ # SEED # ============================================================ def seed_known_users(): if not config.admin_token: return headers = {"Authorization": f"Bearer {config.admin_token}"} seeded = 0 try: rooms_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms", headers=headers, timeout=10 ) if rooms_res.status_code != 200: return for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue members_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members", headers=headers, timeout=10 ) if members_res.status_code != 200: continue for user_id in members_res.json().get("joined", {}).keys(): if is_external(user_id): remember_user(user_id) seeded += 1 logger.info(f"Seed refresh completed ({seeded} users)") except Exception as e: logger.error(f"Seed failed: {e}") def periodic_seed(): while True: seed_known_users() time.sleep(300) # ============================================================ # FALLBACK # ============================================================ def is_user_in_local_rooms(user_id): if not config.admin_token: return False try: headers = {"Authorization": f"Bearer {config.admin_token}"} rooms_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms", headers=headers, timeout=5 ) for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue members_res = requests.get( f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members", headers=headers, timeout=5 ) if user_id in members_res.json().get("joined", {}): return True except Exception as e: logger.error(f"Fallback check failed: {e}") return False # ============================================================ # DM DETECTION # ============================================================ def is_likely_dm_create(payload): if payload.get("is_direct") is True: return True invite = payload.get("invite", []) if isinstance(invite, list) and len(invite) == 1: return True if payload.get("visibility") == "private" and invite: return True return False def is_likely_dm_event(event): content = event.get("content", {}) if content.get("is_direct") is True: return True unsigned = event.get("unsigned", {}) state = unsigned.get("invite_room_state", []) members = [e for e in state if e.get("type") == "m.room.member"] return len(members) <= 2 # ============================================================ # ROUTES # ============================================================ @app.route("/healthz") def health(): return { "status": "ok", "known_users": len(KNOWN_EXTERNAL_USERS), "metrics": dict(METRICS) } @app.route("/metrics") def metrics(): return dict(METRICS) @app.route('/_matrix/client/v3/createRoom', methods=['POST']) def create_room(): payload = request.get_json(silent=True) or {} user_id = request.headers.get("Authorization", "unknown") domain = extract_domain(user_id) role = get_role(user_id) is_dm = is_likely_dm_create(payload) allowed = ( user_id in config.admin_users or config.allow_room_creation or is_dm ) if not allowed: with METRICS_LOCK: METRICS["create_room_blocked"] += 1 log_event( "create_room_blocked", actor=user_id, domain=domain, role=role, reason="room_creation_disabled" ) return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403) if is_dm: with METRICS_LOCK: METRICS["create_room_allowed"] += 1 log_event( "create_room_allowed", actor=user_id, domain=domain, role=role, room_type="dm" ) return forward_request( "POST", f"{config.tuwunel_url}/_matrix/client/v3/createRoom", request.headers, payload ) @app.route('/_matrix/federation/v2/invite//', methods=['PUT']) def invite(room_id, event_id): payload = request.get_json(force=True) event = payload.get("event", {}) sender = event.get("sender", "") if not sender: return Response(status=400) domain = extract_domain(sender) if is_rate_limited(domain, sender): return Response(status=429) if domain in config.domain_whitelist: remember_user(sender) with METRICS_LOCK: METRICS["invite_allowed"] += 1 return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) is_dm = is_likely_dm_event(event) if config.block_external_dms and is_dm and is_external(sender): if not is_known_user(sender): if is_user_in_local_rooms(sender): remember_user(sender) with METRICS_LOCK: METRICS["invite_allowed"] += 1 return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) else: with METRICS_LOCK: METRICS["invite_blocked"] += 1 log_event( "invite_blocked", actor=sender, domain=domain, reason="unknown_external_user" ) return Response(status=403) remember_user(sender) with METRICS_LOCK: METRICS["invite_allowed"] += 1 return forward_request( "PUT", f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}", request.headers, payload ) # ============================================================ # FORWARD # ============================================================ def forward_request(method, url, headers, body): try: proxy_headers = {"Content-Type": "application/json"} if "Authorization" in headers: proxy_headers["Authorization"] = headers["Authorization"] res = requests.request( method=method, url=url, headers=proxy_headers, json=body, timeout=config.http_timeout ) return Response(res.content, res.status_code) except Exception as e: logger.error(f"proxy error: {e}") if config.fail_open: logger.warning("FAIL_OPEN → allowing request") return Response(status=200) return Response(status=502) # ============================================================ # START # ============================================================ if __name__ == '__main__': load_cache() seed_known_users() threading.Thread(target=periodic_seed, daemon=True).start() threading.Thread(target=periodic_cache_save, daemon=True).start() app.run(host='0.0.0.0', port=5000)