491 lines
13 KiB
Python
491 lines
13 KiB
Python
import json
|
|
import time
|
|
import os
|
|
import logging
|
|
import threading
|
|
from collections import defaultdict
|
|
from datetime import datetime, timezone
|
|
from threading import Lock
|
|
|
|
import requests
|
|
from flask import Flask, request, Response
|
|
|
|
app = Flask(__name__)
|
|
|
|
# ============================================================
|
|
# CONFIG
|
|
# ============================================================
|
|
|
|
class Config:
|
|
REQUIRED_KEYS = ["TUWUNEL_URL", "LOCAL_DOMAIN"]
|
|
|
|
def __init__(self):
|
|
self.tuwunel_url = os.getenv("TUWUNEL_URL")
|
|
self.local_domain = os.getenv("LOCAL_DOMAIN")
|
|
|
|
self.admin_users = self._parse_list_env("ADMIN_USERS", [])
|
|
self.admin_token = os.getenv("ADMIN_TOKEN", "")
|
|
|
|
self.domain_whitelist = set(
|
|
self._parse_list_env("DOMAIN_WHITELIST", [])
|
|
)
|
|
|
|
self.block_external_dms = self._parse_bool("BLOCK_EXTERNAL_DMS", True)
|
|
self.allow_room_creation = self._parse_bool("ALLOW_ROOM_CREATION", False)
|
|
|
|
self.cache_ttl = int(os.getenv("CACHE_TTL_SECONDS", 604800))
|
|
self.rate_limit_per_minute = int(os.getenv("RATE_LIMIT_PER_MINUTE", 20))
|
|
self.http_timeout = int(os.getenv("HTTP_TIMEOUT", 5))
|
|
|
|
self.fail_open = self._parse_bool("FAIL_OPEN", True)
|
|
self.debug = self._parse_bool("DEBUG", False)
|
|
|
|
def validate(self):
|
|
return [k for k in self.REQUIRED_KEYS if not os.getenv(k)]
|
|
|
|
def _parse_list_env(self, key, default=None):
|
|
raw = os.getenv(key)
|
|
if not raw:
|
|
return default or []
|
|
return [x.strip() for x in raw.split(",") if x.strip()]
|
|
|
|
def _parse_bool(self, key, default):
|
|
val = os.getenv(key)
|
|
if val is None:
|
|
return default
|
|
return val.lower() in ("1", "true", "yes", "on")
|
|
|
|
# ============================================================
|
|
# LOGGING
|
|
# ============================================================
|
|
|
|
log_level = logging.DEBUG if os.getenv("DEBUG", "false").lower() == "true" else logging.INFO
|
|
|
|
logging.basicConfig(level=log_level, format="%(message)s")
|
|
|
|
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
|
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
|
logging.getLogger("gunicorn.access").setLevel(logging.WARNING)
|
|
|
|
logger = logging.getLogger("matrix-interceptor")
|
|
logger.propagate = False
|
|
|
|
def now_iso():
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
def log_event(event: str, **kwargs):
|
|
base = f"{now_iso()} EVENT={event}"
|
|
details = " ".join(f"{k}={v}" for k, v in kwargs.items())
|
|
logger.info(f"{base} {details}")
|
|
|
|
# ============================================================
|
|
# INIT
|
|
# ============================================================
|
|
|
|
config = Config()
|
|
missing = config.validate()
|
|
|
|
if missing:
|
|
logger.error(f"Missing env vars: {missing}")
|
|
raise SystemExit(1)
|
|
|
|
# ============================================================
|
|
# STATE + CACHE
|
|
# ============================================================
|
|
|
|
KNOWN_EXTERNAL_USERS = {}
|
|
RATE_LIMIT = defaultdict(list)
|
|
METRICS = defaultdict(int)
|
|
|
|
METRICS_LOCK = Lock()
|
|
CACHE_LOCK = Lock()
|
|
|
|
CACHE_FILE = "/app/cache/known_users.json"
|
|
CACHE_DIRTY = False
|
|
|
|
# ============================================================
|
|
# CACHE HANDLING
|
|
# ============================================================
|
|
|
|
def load_cache():
|
|
global KNOWN_EXTERNAL_USERS
|
|
|
|
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
|
|
|
|
if os.path.exists(CACHE_FILE):
|
|
try:
|
|
with open(CACHE_FILE, "r") as f:
|
|
KNOWN_EXTERNAL_USERS = json.load(f)
|
|
logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users")
|
|
except Exception:
|
|
KNOWN_EXTERNAL_USERS = {}
|
|
else:
|
|
with open(CACHE_FILE, "w") as f:
|
|
json.dump({}, f)
|
|
logger.info("Initialized empty cache file")
|
|
|
|
def save_cache():
|
|
global CACHE_DIRTY
|
|
|
|
if not CACHE_DIRTY:
|
|
return
|
|
|
|
try:
|
|
with CACHE_LOCK:
|
|
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
|
|
with open(CACHE_FILE, "w") as f:
|
|
json.dump(KNOWN_EXTERNAL_USERS, f)
|
|
CACHE_DIRTY = False
|
|
except Exception as e:
|
|
logger.error(f"Failed to save cache: {e}")
|
|
|
|
def periodic_cache_save():
|
|
while True:
|
|
save_cache()
|
|
time.sleep(30)
|
|
|
|
# ============================================================
|
|
# HELPERS
|
|
# ============================================================
|
|
|
|
def extract_domain(user_id):
|
|
try:
|
|
return user_id.split(":")[1].lower().rstrip(".")
|
|
except Exception:
|
|
return "unknown"
|
|
|
|
def is_external(user_id):
|
|
return extract_domain(user_id) != config.local_domain
|
|
|
|
def is_known_user(user_id):
|
|
ts = KNOWN_EXTERNAL_USERS.get(user_id)
|
|
if not ts:
|
|
return False
|
|
|
|
if time.time() - ts > config.cache_ttl:
|
|
del KNOWN_EXTERNAL_USERS[user_id]
|
|
global CACHE_DIRTY
|
|
CACHE_DIRTY = True
|
|
return False
|
|
|
|
return True
|
|
|
|
def remember_user(user_id):
|
|
global CACHE_DIRTY
|
|
KNOWN_EXTERNAL_USERS[user_id] = time.time()
|
|
CACHE_DIRTY = True
|
|
save_cache()
|
|
|
|
def is_local_room(room_id):
|
|
try:
|
|
return room_id.split(":")[1] == config.local_domain
|
|
except Exception:
|
|
return False
|
|
|
|
def get_role(user_id):
|
|
return "admin" if user_id in config.admin_users else "user"
|
|
|
|
# ============================================================
|
|
# RATE LIMIT (FIXED)
|
|
# ============================================================
|
|
|
|
def is_rate_limited(domain, sender):
|
|
key = f"{domain}:{sender}"
|
|
now = time.time()
|
|
|
|
RATE_LIMIT[key] = [t for t in RATE_LIMIT[key] if now - t < 60]
|
|
|
|
if len(RATE_LIMIT[key]) >= config.rate_limit_per_minute:
|
|
return True
|
|
|
|
RATE_LIMIT[key].append(now)
|
|
return False
|
|
|
|
# ============================================================
|
|
# SEED
|
|
# ============================================================
|
|
|
|
def seed_known_users():
|
|
if not config.admin_token:
|
|
return
|
|
|
|
headers = {"Authorization": f"Bearer {config.admin_token}"}
|
|
seeded = 0
|
|
|
|
try:
|
|
rooms_res = requests.get(
|
|
f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms",
|
|
headers=headers,
|
|
timeout=10
|
|
)
|
|
|
|
if rooms_res.status_code != 200:
|
|
return
|
|
|
|
for room_id in rooms_res.json().get("joined_rooms", []):
|
|
if not is_local_room(room_id):
|
|
continue
|
|
|
|
members_res = requests.get(
|
|
f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members",
|
|
headers=headers,
|
|
timeout=10
|
|
)
|
|
|
|
if members_res.status_code != 200:
|
|
continue
|
|
|
|
for user_id in members_res.json().get("joined", {}).keys():
|
|
if is_external(user_id):
|
|
remember_user(user_id)
|
|
seeded += 1
|
|
|
|
logger.info(f"Seed refresh completed ({seeded} users)")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Seed failed: {e}")
|
|
|
|
def periodic_seed():
|
|
while True:
|
|
seed_known_users()
|
|
time.sleep(300)
|
|
|
|
# ============================================================
|
|
# FALLBACK
|
|
# ============================================================
|
|
|
|
def is_user_in_local_rooms(user_id):
|
|
if not config.admin_token:
|
|
return False
|
|
|
|
try:
|
|
headers = {"Authorization": f"Bearer {config.admin_token}"}
|
|
|
|
rooms_res = requests.get(
|
|
f"{config.tuwunel_url}/_matrix/client/v3/joined_rooms",
|
|
headers=headers,
|
|
timeout=5
|
|
)
|
|
|
|
for room_id in rooms_res.json().get("joined_rooms", []):
|
|
if not is_local_room(room_id):
|
|
continue
|
|
|
|
members_res = requests.get(
|
|
f"{config.tuwunel_url}/_matrix/client/v3/rooms/{room_id}/joined_members",
|
|
headers=headers,
|
|
timeout=5
|
|
)
|
|
|
|
if user_id in members_res.json().get("joined", {}):
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Fallback check failed: {e}")
|
|
|
|
return False
|
|
|
|
# ============================================================
|
|
# DM DETECTION
|
|
# ============================================================
|
|
|
|
def is_likely_dm_create(payload):
|
|
if payload.get("is_direct") is True:
|
|
return True
|
|
|
|
invite = payload.get("invite", [])
|
|
|
|
if isinstance(invite, list) and len(invite) == 1:
|
|
return True
|
|
|
|
if payload.get("visibility") == "private" and invite:
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_likely_dm_event(event):
|
|
content = event.get("content", {})
|
|
|
|
if content.get("is_direct") is True:
|
|
return True
|
|
|
|
unsigned = event.get("unsigned", {})
|
|
state = unsigned.get("invite_room_state", [])
|
|
|
|
members = [e for e in state if e.get("type") == "m.room.member"]
|
|
return len(members) <= 2
|
|
|
|
# ============================================================
|
|
# ROUTES
|
|
# ============================================================
|
|
|
|
@app.route("/healthz")
|
|
def health():
|
|
return {
|
|
"status": "ok",
|
|
"known_users": len(KNOWN_EXTERNAL_USERS),
|
|
"metrics": dict(METRICS)
|
|
}
|
|
|
|
@app.route("/metrics")
|
|
def metrics():
|
|
return dict(METRICS)
|
|
|
|
@app.route('/_matrix/client/v3/createRoom', methods=['POST'])
|
|
def create_room():
|
|
payload = request.get_json(silent=True) or {}
|
|
user_id = request.headers.get("Authorization", "unknown")
|
|
|
|
domain = extract_domain(user_id)
|
|
role = get_role(user_id)
|
|
|
|
is_dm = is_likely_dm_create(payload)
|
|
|
|
allowed = (
|
|
user_id in config.admin_users
|
|
or config.allow_room_creation
|
|
or is_dm
|
|
)
|
|
|
|
if not allowed:
|
|
with METRICS_LOCK:
|
|
METRICS["create_room_blocked"] += 1
|
|
|
|
log_event(
|
|
"create_room_blocked",
|
|
actor=user_id,
|
|
domain=domain,
|
|
role=role,
|
|
reason="room_creation_disabled"
|
|
)
|
|
return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403)
|
|
|
|
if is_dm:
|
|
with METRICS_LOCK:
|
|
METRICS["create_room_allowed"] += 1
|
|
|
|
log_event(
|
|
"create_room_allowed",
|
|
actor=user_id,
|
|
domain=domain,
|
|
role=role,
|
|
room_type="dm"
|
|
)
|
|
|
|
return forward_request(
|
|
"POST",
|
|
f"{config.tuwunel_url}/_matrix/client/v3/createRoom",
|
|
request.headers,
|
|
payload
|
|
)
|
|
|
|
@app.route('/_matrix/federation/v2/invite/<room_id>/<event_id>', methods=['PUT'])
|
|
def invite(room_id, event_id):
|
|
payload = request.get_json(force=True)
|
|
|
|
event = payload.get("event", {})
|
|
sender = event.get("sender", "")
|
|
|
|
if not sender:
|
|
return Response(status=400)
|
|
|
|
domain = extract_domain(sender)
|
|
|
|
if is_rate_limited(domain, sender):
|
|
return Response(status=429)
|
|
|
|
if domain in config.domain_whitelist:
|
|
remember_user(sender)
|
|
with METRICS_LOCK:
|
|
METRICS["invite_allowed"] += 1
|
|
return forward_request(
|
|
"PUT",
|
|
f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}",
|
|
request.headers,
|
|
payload
|
|
)
|
|
|
|
is_dm = is_likely_dm_event(event)
|
|
|
|
if config.block_external_dms and is_dm and is_external(sender):
|
|
|
|
if not is_known_user(sender):
|
|
|
|
if is_user_in_local_rooms(sender):
|
|
remember_user(sender)
|
|
with METRICS_LOCK:
|
|
METRICS["invite_allowed"] += 1
|
|
return forward_request(
|
|
"PUT",
|
|
f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}",
|
|
request.headers,
|
|
payload
|
|
)
|
|
|
|
else:
|
|
with METRICS_LOCK:
|
|
METRICS["invite_blocked"] += 1
|
|
|
|
log_event(
|
|
"invite_blocked",
|
|
actor=sender,
|
|
domain=domain,
|
|
reason="unknown_external_user"
|
|
)
|
|
return Response(status=403)
|
|
|
|
remember_user(sender)
|
|
|
|
with METRICS_LOCK:
|
|
METRICS["invite_allowed"] += 1
|
|
|
|
return forward_request(
|
|
"PUT",
|
|
f"{config.tuwunel_url}/_matrix/federation/v2/invite/{room_id}/{event_id}",
|
|
request.headers,
|
|
payload
|
|
)
|
|
|
|
# ============================================================
|
|
# FORWARD
|
|
# ============================================================
|
|
|
|
def forward_request(method, url, headers, body):
|
|
try:
|
|
proxy_headers = {"Content-Type": "application/json"}
|
|
|
|
if "Authorization" in headers:
|
|
proxy_headers["Authorization"] = headers["Authorization"]
|
|
|
|
res = requests.request(
|
|
method=method,
|
|
url=url,
|
|
headers=proxy_headers,
|
|
json=body,
|
|
timeout=config.http_timeout
|
|
)
|
|
|
|
return Response(res.content, res.status_code)
|
|
|
|
except Exception as e:
|
|
logger.error(f"proxy error: {e}")
|
|
|
|
if config.fail_open:
|
|
logger.warning("FAIL_OPEN → allowing request")
|
|
return Response(status=200)
|
|
|
|
return Response(status=502)
|
|
|
|
# ============================================================
|
|
# START
|
|
# ============================================================
|
|
|
|
if __name__ == '__main__':
|
|
load_cache()
|
|
seed_known_users()
|
|
|
|
threading.Thread(target=periodic_seed, daemon=True).start()
|
|
threading.Thread(target=periodic_cache_save, daemon=True).start()
|
|
|
|
app.run(host='0.0.0.0', port=5000)
|