Update app.py
This commit is contained in:
@@ -72,6 +72,7 @@ def log_event(event: str, **kwargs):
|
||||
details = " ".join(f"{k}={v}" for k, v in kwargs.items())
|
||||
logger.info(f"{base} {details}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# INIT
|
||||
# ============================================================
|
||||
@@ -84,12 +85,32 @@ if missing:
|
||||
raise SystemExit(1)
|
||||
|
||||
# ============================================================
|
||||
# STATE
|
||||
# STATE + CACHE
|
||||
# ============================================================
|
||||
|
||||
KNOWN_EXTERNAL_USERS = {}
|
||||
RATE_LIMIT = defaultdict(list)
|
||||
|
||||
CACHE_FILE = "/app/cache/known_users.json"
|
||||
|
||||
def load_cache():
|
||||
global KNOWN_EXTERNAL_USERS
|
||||
if os.path.exists(CACHE_FILE):
|
||||
try:
|
||||
with open(CACHE_FILE, "r") as f:
|
||||
KNOWN_EXTERNAL_USERS = json.load(f)
|
||||
logger.info(f"Loaded cache with {len(KNOWN_EXTERNAL_USERS)} users")
|
||||
except:
|
||||
KNOWN_EXTERNAL_USERS = {}
|
||||
|
||||
def save_cache():
|
||||
try:
|
||||
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
|
||||
with open(CACHE_FILE, "w") as f:
|
||||
json.dump(KNOWN_EXTERNAL_USERS, f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save cache: {e}")
|
||||
|
||||
# ============================================================
|
||||
# HELPERS
|
||||
# ============================================================
|
||||
@@ -107,13 +128,17 @@ def is_known_user(user_id):
|
||||
ts = KNOWN_EXTERNAL_USERS.get(user_id)
|
||||
if not ts:
|
||||
return False
|
||||
|
||||
if time.time() - ts > config.cache_ttl:
|
||||
del KNOWN_EXTERNAL_USERS[user_id]
|
||||
save_cache()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def remember_user(user_id):
|
||||
KNOWN_EXTERNAL_USERS[user_id] = time.time()
|
||||
save_cache()
|
||||
|
||||
def is_local_room(room_id):
|
||||
try:
|
||||
@@ -121,6 +146,27 @@ def is_local_room(room_id):
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_role(user_id):
|
||||
return "admin" if user_id in config.admin_users else "user"
|
||||
|
||||
# ============================================================
|
||||
# RATE LIMIT
|
||||
# ============================================================
|
||||
|
||||
def is_rate_limited(domain):
|
||||
now = time.time()
|
||||
|
||||
RATE_LIMIT[domain] = [
|
||||
t for t in RATE_LIMIT[domain]
|
||||
if now - t < 60
|
||||
]
|
||||
|
||||
if len(RATE_LIMIT[domain]) >= config.rate_limit_per_minute:
|
||||
return True
|
||||
|
||||
RATE_LIMIT[domain].append(now)
|
||||
return False
|
||||
|
||||
# ============================================================
|
||||
# SEED
|
||||
# ============================================================
|
||||
@@ -155,32 +201,27 @@ def seed_known_users():
|
||||
if members_res.status_code != 200:
|
||||
continue
|
||||
|
||||
members = members_res.json().get("joined", {})
|
||||
|
||||
for user_id in members.keys():
|
||||
if is_external(user_id):
|
||||
remember_user(user_id)
|
||||
seeded += 1
|
||||
|
||||
logger.info(f"Seed refreshed: {seeded} users")
|
||||
logger.info(f"Seed refresh completed ({seeded} users)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Seed failed: {e}")
|
||||
|
||||
# ============================================================
|
||||
# PERIODIC REFRESH
|
||||
# ============================================================
|
||||
|
||||
def periodic_seed():
|
||||
while True:
|
||||
seed_known_users()
|
||||
time.sleep(300)
|
||||
|
||||
# ============================================================
|
||||
# FALLBACK CHECK
|
||||
# FALLBACK
|
||||
# ============================================================
|
||||
|
||||
def is_user_in_local_rooms(user_id: str) -> bool:
|
||||
def is_user_in_local_rooms(user_id):
|
||||
if not config.admin_token:
|
||||
return False
|
||||
|
||||
@@ -193,9 +234,6 @@ def is_user_in_local_rooms(user_id: str) -> bool:
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if rooms_res.status_code != 200:
|
||||
return False
|
||||
|
||||
for room_id in rooms_res.json().get("joined_rooms", []):
|
||||
if not is_local_room(room_id):
|
||||
continue
|
||||
@@ -206,11 +244,7 @@ def is_user_in_local_rooms(user_id: str) -> bool:
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if members_res.status_code != 200:
|
||||
continue
|
||||
|
||||
members = members_res.json().get("joined", {})
|
||||
if user_id in members:
|
||||
if user_id in members_res.json().get("joined", {}):
|
||||
return True
|
||||
|
||||
except:
|
||||
@@ -219,25 +253,34 @@ def is_user_in_local_rooms(user_id: str) -> bool:
|
||||
return False
|
||||
|
||||
# ============================================================
|
||||
# 🔥 HYBRID DM DETECTION
|
||||
# DM DETECTION
|
||||
# ============================================================
|
||||
|
||||
def is_likely_dm_create(payload):
|
||||
if payload.get("is_direct") is True:
|
||||
return True
|
||||
|
||||
invite = payload.get("invite", [])
|
||||
|
||||
if isinstance(invite, list) and len(invite) == 1:
|
||||
return True
|
||||
|
||||
if payload.get("visibility") == "private" and invite:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_likely_dm_event(event):
|
||||
content = event.get("content", {})
|
||||
|
||||
# 1️⃣ offizielles Flag
|
||||
if content.get("is_direct") is True:
|
||||
return True
|
||||
|
||||
# 2️⃣ Heuristik
|
||||
unsigned = event.get("unsigned", {})
|
||||
state = unsigned.get("invite_room_state", [])
|
||||
|
||||
member_events = [
|
||||
e for e in state if e.get("type") == "m.room.member"
|
||||
]
|
||||
|
||||
return len(member_events) <= 2
|
||||
members = [e for e in state if e.get("type") == "m.room.member"]
|
||||
return len(members) <= 2
|
||||
|
||||
# ============================================================
|
||||
# ROUTES
|
||||
@@ -247,6 +290,48 @@ def is_likely_dm_event(event):
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.route('/_matrix/client/v3/createRoom', methods=['POST'])
|
||||
def create_room():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
user_id = request.headers.get("Authorization", "unknown")
|
||||
|
||||
domain = extract_domain(user_id)
|
||||
role = get_role(user_id)
|
||||
|
||||
is_dm = is_likely_dm_create(payload)
|
||||
|
||||
allowed = (
|
||||
user_id in config.admin_users
|
||||
or config.allow_room_creation
|
||||
or is_dm
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
log_event(
|
||||
"create_room_blocked",
|
||||
actor=user_id,
|
||||
domain=domain,
|
||||
role=role,
|
||||
reason="room_creation_disabled"
|
||||
)
|
||||
return Response(json.dumps({"errcode": "M_FORBIDDEN"}), status=403)
|
||||
|
||||
if is_dm:
|
||||
log_event(
|
||||
"create_room_allowed",
|
||||
actor=user_id,
|
||||
domain=domain,
|
||||
role=role,
|
||||
room_type="dm"
|
||||
)
|
||||
|
||||
return forward_request(
|
||||
"POST",
|
||||
f"{config.tuwunel_url}/_matrix/client/v3/createRoom",
|
||||
request.headers,
|
||||
payload
|
||||
)
|
||||
|
||||
@app.route('/_matrix/federation/v2/invite/<room_id>/<event_id>', methods=['PUT'])
|
||||
def invite(room_id, event_id):
|
||||
payload = request.get_json(force=True)
|
||||
@@ -259,7 +344,9 @@ def invite(room_id, event_id):
|
||||
|
||||
domain = extract_domain(sender)
|
||||
|
||||
# whitelist
|
||||
if is_rate_limited(domain):
|
||||
return Response(status=429)
|
||||
|
||||
if domain in config.domain_whitelist:
|
||||
remember_user(sender)
|
||||
return forward_request(
|
||||
@@ -273,7 +360,6 @@ def invite(room_id, event_id):
|
||||
|
||||
if not is_known_user(sender):
|
||||
|
||||
# 🔥 fallback check
|
||||
if is_user_in_local_rooms(sender):
|
||||
remember_user(sender)
|
||||
else:
|
||||
@@ -301,6 +387,7 @@ def invite(room_id, event_id):
|
||||
def forward_request(method, url, headers, body):
|
||||
try:
|
||||
proxy_headers = {"Content-Type": "application/json"}
|
||||
|
||||
if "Authorization" in headers:
|
||||
proxy_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
@@ -316,6 +403,11 @@ def forward_request(method, url, headers, body):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"proxy error: {e}")
|
||||
|
||||
if config.fail_open:
|
||||
logger.warning("FAIL_OPEN → allowing request")
|
||||
return Response(status=200)
|
||||
|
||||
return Response(status=502)
|
||||
|
||||
# ============================================================
|
||||
@@ -323,11 +415,9 @@ def forward_request(method, url, headers, body):
|
||||
# ============================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
load_cache()
|
||||
seed_known_users()
|
||||
|
||||
threading.Thread(
|
||||
target=periodic_seed,
|
||||
daemon=True
|
||||
).start()
|
||||
threading.Thread(target=periodic_seed, daemon=True).start()
|
||||
|
||||
app.run(host='0.0.0.0', port=5000)
|
||||
|
||||
Reference in New Issue
Block a user