diff --git a/app.py b/app.py index b2483f6..7c6ebfd 100644 --- a/app.py +++ b/app.py @@ -72,6 +72,7 @@ def log_event(event: str, **kwargs): details = " ".join(f"{k}={v}" for k, v in kwargs.items()) logger.info(f"{base} {details}") + # ============================================================ # INIT # ============================================================ @@ -84,12 +85,32 @@ if missing: raise SystemExit(1) # ============================================================ -# STATE +# STATE + CACHE # ============================================================ KNOWN_EXTERNAL_USERS = {} RATE_LIMIT = defaultdict(list) +CACHE_FILE = "/app/cache/known_users.json" + +def load_cache(): + global KNOWN_EXTERNAL_USERS + if os.path.exists(CACHE_FILE): + try: + with open(CACHE_FILE, "r") as f: + KNOWN_EXTERNAL_USERS = json.load(f) + logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users") + except: + KNOWN_EXTERNAL_USERS = {} + +def save_cache(): + try: + os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True) + with open(CACHE_FILE, "w") as f: + json.dump(KNOWN_EXTERNAL_USERS, f) + except Exception as e: + logger.error(f"Failed to save cache: {e}") + # ============================================================ # HELPERS # ============================================================ @@ -107,13 +128,17 @@ def is_known_user(user_id): ts = KNOWN_EXTERNAL_USERS.get(user_id) if not ts: return False + if time.time() - ts > config.cache_ttl: del KNOWN_EXTERNAL_USERS[user_id] + save_cache() return False + return True def remember_user(user_id): KNOWN_EXTERNAL_USERS[user_id] = time.time() + save_cache() def is_local_room(room_id): try: @@ -121,6 +146,27 @@ def is_local_room(room_id): except: return False +def get_role(user_id): + return "admin" if user_id in config.admin_users else "user" + +# ============================================================ +# RATE LIMIT +# ============================================================ + +def is_rate_limited(domain): + now = time.time() + + RATE_LIMIT[domain] = [ + t for t in RATE_LIMIT[domain] + if now - t < 60 + ] + + if len(RATE_LIMIT[domain]) >= config.rate_limit_per_minute: + return True + + RATE_LIMIT[domain].append(now) + return False + # ============================================================ # SEED # ============================================================ @@ -155,32 +201,27 @@ def seed_known_users(): if members_res.status_code != 200: continue - members = members_res.json().get("joined", {}) for user_id in members.keys(): if is_external(user_id): remember_user(user_id) seeded += 1 - logger.info(f"Seed refreshed: {seeded} users") + logger.info(f"Seed refresh completed ({seeded} users)") except Exception as e: logger.error(f"Seed failed: {e}") -# ============================================================ -# PERIODIC REFRESH -# ============================================================ - def periodic_seed(): while True: seed_known_users() time.sleep(300) # ============================================================ -# FALLBACK CHECK +# FALLBACK # ============================================================ -def is_user_in_local_rooms(user_id: str) -> bool: +def is_user_in_local_rooms(user_id): if not config.admin_token: return False @@ -193,9 +234,6 @@ def is_user_in_local_rooms(user_id: str) -> bool: timeout=5 ) - if rooms_res.status_code != 200: - return False - for room_id in rooms_res.json().get("joined_rooms", []): if not is_local_room(room_id): continue @@ -206,11 +244,7 @@ def is_user_in_local_rooms(user_id: str) -> bool: timeout=5 ) - if members_res.status_code != 200: - continue - - members = members_res.json().get("joined", {}) - if user_id in members: + if user_id in members_res.json().get("joined", {}): return True except: @@ -219,25 +253,34 @@ def is_user_in_local_rooms(user_id: str) -> bool: return False # ============================================================ -# 🔥 HYBRID DM DETECTION +# DM DETECTION # ============================================================ +def is_likely_dm_create(payload): + if payload.get("is_direct") is True: + return True + + invite = payload.get("invite", []) + + if isinstance(invite, list) and len(invite) == 1: + return True + + if payload.get("visibility") == "private" and invite: + return True + + return False + def is_likely_dm_event(event): content = event.get("content", {}) - # 1️⃣ offizielles Flag if content.get("is_direct") is True: return True - # 2️⃣ Heuristik unsigned = event.get("unsigned", {}) state = unsigned.get("invite_room_state", []) - member_events = [ - e for e in state if e.get("type") == "m.room.member" - ] - - return len(member_events) <= 2 + members = [e for e in state if e.get("type") == "m.room.member"] + return len(members) <= 2 # ============================================================ # ROUTES @@ -247,6 +290,48 @@ def is_likely_dm_event(event): def health(): return {"status": "ok"} +@app.route('/_matrix/client/v3/createRoom', methods=['POST']) +def create_room(): + payload = request.get_json(silent=True) or {} + user_id = request.headers.get("Authorization", "unknown") + + domain = extract_domain(user_id) + role = get_role(user_id) + + is_dm = is_likely_dm_create(payload) + + allowed = ( + user_id in config.admin_users + or config.allow_room_creation + or is_dm + ) + + if not allowed: + log_event( + "create_room_blocked", + actor=user_id, + domain=domain, + role=role, + reason="room_creation_disabled" + ) + return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403) + + if is_dm: + log_event( + "create_room_allowed", + actor=user_id, + domain=domain, + role=role, + room_type="dm" + ) + + return forward_request( + "POST", + f"{config.tuwunel_url}/_matrix/client/v3/createRoom", + request.headers, + payload + ) + @app.route('/_matrix/federation/v2/invite//', methods=['PUT']) def invite(room_id, event_id): payload = request.get_json(force=True) @@ -259,7 +344,9 @@ def invite(room_id, event_id): domain = extract_domain(sender) - # whitelist + if is_rate_limited(domain): + return Response(status=429) + if domain in config.domain_whitelist: remember_user(sender) return forward_request( @@ -273,7 +360,6 @@ def invite(room_id, event_id): if not is_known_user(sender): - # 🔥 fallback check if is_user_in_local_rooms(sender): remember_user(sender) else: @@ -301,6 +387,7 @@ def invite(room_id, event_id): def forward_request(method, url, headers, body): try: proxy_headers = {"Content-Type": "application/json"} + if "Authorization" in headers: proxy_headers["Authorization"] = headers["Authorization"] @@ -316,6 +403,11 @@ def forward_request(method, url, headers, body): except Exception as e: logger.error(f"proxy error: {e}") + + if config.fail_open: + logger.warning("FAIL_OPEN → allowing request") + return Response(status=200) + return Response(status=502) # ============================================================ @@ -323,11 +415,9 @@ def forward_request(method, url, headers, body): # ============================================================ if __name__ == '__main__': + load_cache() seed_known_users() - threading.Thread( - target=periodic_seed, - daemon=True - ).start() + threading.Thread(target=periodic_seed, daemon=True).start() app.run(host='0.0.0.0', port=5000)