baton/backend/middleware.py
Gros Frumos 0562cb4e47 sec: server-side email domain check + IP block on violations
Only @tutlot.com emails allowed for registration (checked server-side,
invisible to frontend inspect). Wrong domain → scary message + IP
violation tracked. 5 violations → IP permanently blocked from login
and registration. Block screen with OK button on frontend.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-21 15:58:16 +02:00

155 lines
5 KiB
Python

from __future__ import annotations
import base64
import hashlib
import hmac
import json
import secrets
import time
from typing import Optional
from fastapi import Depends, Header, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from backend import config, db
# JWT secret: stable across restarts if JWT_SECRET env var is set; random per-process otherwise
_JWT_SECRET: str = config.JWT_SECRET or secrets.token_hex(32)
_JWT_HEADER_B64: str = (
base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}').rstrip(b"=").decode()
)
_bearer = HTTPBearer(auto_error=False)
_RATE_LIMIT = 5
_RATE_WINDOW = 600 # 10 minutes
_SIGNAL_RATE_LIMIT = 10
_SIGNAL_RATE_WINDOW = 60 # 1 minute
_AUTH_REGISTER_RATE_LIMIT = 3
_AUTH_REGISTER_RATE_WINDOW = 600 # 10 minutes
def _get_client_ip(request: Request) -> str:
return (
request.headers.get("X-Real-IP")
or request.headers.get("X-Forwarded-For", "").split(",")[0].strip()
or (request.client.host if request.client else "unknown")
)
async def check_ip_not_blocked(request: Request) -> None:
ip = _get_client_ip(request)
if await db.is_ip_blocked(ip):
raise HTTPException(status_code=403, detail="Доступ запрещён")
async def verify_webhook_secret(
x_telegram_bot_api_secret_token: str = Header(default=""),
) -> None:
if not secrets.compare_digest(
x_telegram_bot_api_secret_token, config.WEBHOOK_SECRET
):
raise HTTPException(status_code=403, detail="Forbidden")
async def verify_admin_token(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> None:
if credentials is None or not secrets.compare_digest(
credentials.credentials, config.ADMIN_TOKEN
):
raise HTTPException(status_code=401, detail="Unauthorized")
async def rate_limit_register(request: Request) -> None:
key = f"reg:{_get_client_ip(request)}"
count = await db.rate_limit_increment(key, _RATE_WINDOW)
if count > _RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests")
async def rate_limit_signal(request: Request) -> None:
key = f"sig:{_get_client_ip(request)}"
count = await db.rate_limit_increment(key, _SIGNAL_RATE_WINDOW)
if count > _SIGNAL_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests")
async def rate_limit_auth_register(request: Request) -> None:
key = f"authreg:{_get_client_ip(request)}"
count = await db.rate_limit_increment(key, _AUTH_REGISTER_RATE_WINDOW)
if count > _AUTH_REGISTER_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests")
_AUTH_LOGIN_RATE_LIMIT = 5
_AUTH_LOGIN_RATE_WINDOW = 900 # 15 minutes
def _b64url_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def _b64url_decode(s: str) -> bytes:
padding = 4 - len(s) % 4
if padding != 4:
s += "=" * padding
return base64.urlsafe_b64decode(s)
def create_auth_token(reg_id: int, login: str) -> str:
"""Create a signed HS256 JWT for an approved registration."""
now = int(time.time())
payload = {
"sub": str(reg_id),
"login": login,
"iat": now,
"exp": now + config.JWT_TOKEN_EXPIRE_SECONDS,
}
payload_b64 = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode())
signing_input = f"{_JWT_HEADER_B64}.{payload_b64}"
sig = hmac.new(
_JWT_SECRET.encode(), signing_input.encode(), hashlib.sha256
).digest()
return f"{signing_input}.{_b64url_encode(sig)}"
def _verify_jwt_token(token: str) -> dict:
"""Verify token signature and expiry. Returns payload dict on success."""
parts = token.split(".")
if len(parts) != 3:
raise ValueError("Invalid token format")
header_b64, payload_b64, sig_b64 = parts
signing_input = f"{header_b64}.{payload_b64}"
expected_sig = hmac.new(
_JWT_SECRET.encode(), signing_input.encode(), hashlib.sha256
).digest()
actual_sig = _b64url_decode(sig_b64)
if not hmac.compare_digest(expected_sig, actual_sig):
raise ValueError("Invalid signature")
payload = json.loads(_b64url_decode(payload_b64))
if payload.get("exp", 0) < time.time():
raise ValueError("Token expired")
return payload
async def rate_limit_auth_login(request: Request) -> None:
key = f"login:{_get_client_ip(request)}"
count = await db.rate_limit_increment(key, _AUTH_LOGIN_RATE_WINDOW)
if count > _AUTH_LOGIN_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests")
async def verify_auth_token(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> dict:
"""Dependency for protected endpoints — verifies Bearer JWT, returns payload."""
if credentials is None:
raise HTTPException(status_code=401, detail="Unauthorized")
try:
payload = _verify_jwt_token(credentials.credentials)
except Exception:
raise HTTPException(status_code=401, detail="Unauthorized")
return payload