Files
Tuwunel-Interceptor/app.py
T
2026-05-05 20:47:44 +02:00

489 lines
13 KiB
Python

import json
import time
import os
import logging
import threading
from collections import defaultdict
from datetime import datetime, timezone
from threading import Lock
import requests
from flask import Flask, request, Response
app = Flask(__name__)
# ============================================================
# CONFIG
# ============================================================
class Config:
REQUIRED_KEYS = ["TUWUNEL_URL", "LOCAL_DOMAIN"]
def __init__(self):
self.tuwunel_url = os.getenv("TUWUNEL_URL")
self.local_domain = os.getenv("LOCAL_DOMAIN")
self.admin_users = self._parse_list_env("ADMIN_USERS", [])
self.admin_token = os.getenv("ADMIN_TOKEN", "")
self.domain_whitelist = set(
self._parse_list_env("DOMAIN_WHITELIST", [])
)
self.block_external_dms = self._parse_bool("BLOCK_EXTERNAL_DMS", True)
self.allow_room_creation = self._parse_bool("ALLOW_ROOM_CREATION", False)
self.cache_ttl = int(os.getenv("CACHE_TTL_SECONDS", 604800))
self.rate_limit_per_minute = int(os.getenv("RATE_LIMIT_PER_MINUTE", 20))
self.http_timeout = int(os.getenv("HTTP_TIMEOUT", 5))
self.fail_open = self._parse_bool("FAIL_OPEN", True)
self.debug = self._parse_bool("DEBUG", False)
def validate(self):
return [k for k in self.REQUIRED_KEYS if not os.getenv(k)]
def _parse_list_env(self, key, default=None):
raw = os.getenv(key)
if not raw:
return default or []
return [x.strip() for x in raw.split(",") if x.strip()]
def _parse_bool(self, key, default):
val = os.getenv(key)
if val is None:
return default
return val.lower() in ("1", "true", "yes", "on")
# ============================================================
# LOGGING
# ============================================================
log_level = logging.DEBUG if os.getenv("DEBUG", "false").lower() == "true" else logging.INFO
logging.basicConfig(
level=log_level,
format="%(message)s"
)
logging.getLogger("werkzeug").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("gunicorn.access").setLevel(logging.WARNING)
logger = logging.getLogger("matrix-interceptor")
def now_iso():
return datetime.now(timezone.utc).isoformat()
def log_event(event: str, **kwargs):
base = f"{now_iso()} EVENT={event}"
details = " ".join(f"{k}={v}" for k, v in kwargs.items())
logger.info(f"{base} {details}")
def debug_log(title, data):
if config.debug:
logger.debug(f"{title}: {json.dumps(data, default=str)}")
# ============================================================
# INIT
# ============================================================
config = Config()
missing = config.validate()
if missing:
logger.error(f"Missing env vars: {missing}")
raise SystemExit(1)
# ============================================================
# STATE + CACHE
# ============================================================
KNOWN_EXTERNAL_USERS = {}
RATE_LIMIT = defaultdict(list)
METRICS = defaultdict(int)
METRICS_LOCK = Lock()
CACHE_FILE = "/app/cache/known_users.json"
CACHE_DIRTY = False
# ============================================================
# CACHE HANDLING
# ============================================================
def load_cache():
global KNOWN_EXTERNAL_USERS
if os.path.exists(CACHE_FILE):
try:
with open(CACHE_FILE, "r") as f:
KNOWN_EXTERNAL_USERS = json.load(f)
logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users")
except Exception as e:
logger.error(f"Cache load failed: {e}")
KNOWN_EXTERNAL_USERS = {}
else:
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
with open(CACHE_FILE, "w") as f:
json.dump({}, f)
logger.info("Initialized empty cache")
def save_cache():
global CACHE_DIRTY
if not CACHE_DIRTY:
return
try:
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
with open(CACHE_FILE, "w") as f:
json.dump(KNOWN_EXTERNAL_USERS, f)
CACHE_DIRTY = False
except Exception as e:
logger.error(f"Failed to save cache: {e}")
def periodic_cache_save():
while True:
save_cache()
time.sleep(30)
# ============================================================
# HELPERS
# ============================================================
def extract_domain(user_id):
try:
return user_id.split(":")[1].lower().rstrip(".")
except Exception:
return "unknown"
def is_external(user_id):
return extract_domain(user_id) != config.local_domain
def is_known_user(user_id):
ts = KNOWN_EXTERNAL_USERS.get(user_id)
if not ts:
return False
if time.time() - ts > config.cache_ttl:
del KNOWN_EXTERNAL_USERS[user_id]
return False
return True
def remember_user(user_id):
global CACHE_DIRTY
KNOWN_EXTERNAL_USERS[user_id] = time.time()
CACHE_DIRTY = True
save_cache() # 🔥 critical safety write
def is_local_room(room_id):
try:
return room_id.split(":")[1] == config.local_domain
except Exception:
return False
def get_role(user_id):
return "admin" if user_id in config.admin_users else "user"
# ============================================================
# RATE LIMIT
# ============================================================
def is_rate_limited(domain):
now = time.time()
RATE_LIMIT[domain] = [
t for t in RATE_LIMIT[domain] if now - t < 60
]
if len(RATE_LIMIT[domain]) >= config.rate_limit_per_minute:
return True
RATE_LIMIT[domain].append(now)
return False
# ============================================================
# SEED
# ============================================================
def seed_known_users():
if not config.admin_token:
return
headers = {"Authorization": f"Bearer {config.admin_token}"}
seeded = 0
try:
rooms_res = requests.get(
f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms",
headers=headers,
timeout=10
)
if rooms_res.status_code != 200:
return
for room_id in rooms_res.json().get("joined_rooms", []):
if not is_local_room(room_id):
continue
members_res = requests.get(
f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members",
headers=headers,
timeout=10
)
if members_res.status_code != 200:
continue
for user_id in members_res.json().get("joined", {}).keys():
if is_external(user_id):
remember_user(user_id)
seeded += 1
logger.info(f"Seed refresh completed ({seeded} users)")
except Exception as e:
logger.error(f"Seed failed: {e}")
def periodic_seed():
while True:
seed_known_users()
time.sleep(300)
# ============================================================
# FALLBACK
# ============================================================
def is_user_in_local_rooms(user_id):
if not config.admin_token:
return False
try:
headers = {"Authorization": f"Bearer {config.admin_token}"}
rooms_res = requests.get(
f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms",
headers=headers,
timeout=5
)
for room_id in rooms_res.json().get("joined_rooms", []):
if not is_local_room(room_id):
continue
members_res = requests.get(
f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members",
headers=headers,
timeout=5
)
if user_id in members_res.json().get("joined", {}):
return True
except Exception as e:
logger.error(f"Fallback check failed: {e}")
return False
# ============================================================
# DM DETECTION
# ============================================================
def is_likely_dm_create(payload):
if payload.get("is_direct") is True:
return True
invite = payload.get("invite", [])
if isinstance(invite, list) and len(invite) == 1:
return True
if payload.get("visibility") == "private" and invite:
return True
return False
def is_likely_dm_event(event):
content = event.get("content", {})
if content.get("is_direct") is True:
return True
unsigned = event.get("unsigned", {})
state = unsigned.get("invite_room_state", [])
members = [e for e in state if e.get("type") == "m.room.member"]
return len(members) <= 2
# ============================================================
# ROUTES
# ============================================================
@app.route("/healthz")
def health():
return {
"status": "ok",
"known_users": len(KNOWN_EXTERNAL_USERS),
"metrics": dict(METRICS)
}
@app.route("/metrics")
def metrics():
return dict(METRICS)
@app.route('/_matrix/client/v3/createRoom', methods=['POST'])
def create_room():
payload = request.get_json(silent=True) or {}
user_id = request.headers.get("Authorization", "unknown")
domain = extract_domain(user_id)
role = get_role(user_id)
is_dm = is_likely_dm_create(payload)
allowed = (
user_id in config.admin_users
or config.allow_room_creation
or is_dm
)
if not allowed:
with METRICS_LOCK:
METRICS["create_room_blocked"] += 1
log_event(
"create_room_blocked",
actor=user_id,
domain=domain,
role=role,
reason="room_creation_disabled"
)
return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403)
if is_dm:
with METRICS_LOCK:
METRICS["create_room_allowed"] += 1
log_event(
"create_room_allowed",
actor=user_id,
domain=domain,
role=role,
room_type="dm"
)
return forward_request(
"POST",
f"{config.tuwunel_url}/_matrix/client/v3/createRoom",
request.headers,
payload
)
@app.route('/_matrix/federation/v2/invite/<room_id>/<event_id>', methods=['PUT'])
def invite(room_id, event_id):
payload = request.get_json(force=True)
event = payload.get("event", {})
sender = event.get("sender", "")
if not sender:
return Response(status=400)
domain = extract_domain(sender)
if is_rate_limited(domain):
return Response(status=429)
if domain in config.domain_whitelist:
remember_user(sender)
return forward_request(
"PUT",
f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}",
request.headers,
payload
)
is_dm = is_likely_dm_event(event)
if config.block_external_dms and is_dm and is_external(sender):
if not is_known_user(sender):
if is_user_in_local_rooms(sender):
remember_user(sender)
return forward_request(
"PUT",
f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}",
request.headers,
payload
)
else:
with METRICS_LOCK:
METRICS["invite_blocked"] += 1
log_event(
"invite_blocked",
actor=sender,
domain=domain,
reason="unknown_external_user"
)
return Response(status=403)
remember_user(sender)
with METRICS_LOCK:
METRICS["invite_allowed"] += 1
return forward_request(
"PUT",
f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}",
request.headers,
payload
)
# ============================================================
# FORWARD
# ============================================================
def forward_request(method, url, headers, body):
try:
proxy_headers = {"Content-Type": "application/json"}
if "Authorization" in headers:
proxy_headers["Authorization"] = headers["Authorization"]
res = requests.request(
method=method,
url=url,
headers=proxy_headers,
json=body,
timeout=config.http_timeout
)
return Response(res.content, res.status_code)
except Exception as e:
logger.error(f"proxy error: {e}")
if config.fail_open:
logger.warning("FAIL_OPEN → allowing request")
return Response(status=200)
return Response(status=502)
# ============================================================
# START
# ============================================================
if __name__ == '__main__':
load_cache()
seed_known_users()
threading.Thread(target=periodic_seed, daemon=True).start()
threading.Thread(target=periodic_cache_save, daemon=True).start()
app.run(host='0.0.0.0', port=5000)