from __future__ import annotations import base64 import hashlib import hmac import json import secrets import time from typing import Optional from fastapi import Depends, Header, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from backend import config, db # JWT secret: stable across restarts if JWT_SECRET env var is set; random per-process otherwise _JWT_SECRET: str = config.JWT_SECRET or secrets.token_hex(32) _JWT_HEADER_B64: str = ( base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}').rstrip(b"=").decode() ) _bearer = HTTPBearer(auto_error=False) _RATE_LIMIT = 5 _RATE_WINDOW = 600 # 10 minutes _SIGNAL_RATE_LIMIT = 10 _SIGNAL_RATE_WINDOW = 60 # 1 minute _AUTH_REGISTER_RATE_LIMIT = 3 _AUTH_REGISTER_RATE_WINDOW = 600 # 10 minutes def _get_client_ip(request: Request) -> str: return ( request.headers.get("X-Real-IP") or request.headers.get("X-Forwarded-For", "").split(",")[0].strip() or (request.client.host if request.client else "unknown") ) async def check_ip_not_blocked(request: Request) -> None: ip = _get_client_ip(request) if await db.is_ip_blocked(ip): raise HTTPException(status_code=403, detail="Доступ запрещён") async def verify_webhook_secret( x_telegram_bot_api_secret_token: str = Header(default=""), ) -> None: if not secrets.compare_digest( x_telegram_bot_api_secret_token, config.WEBHOOK_SECRET ): raise HTTPException(status_code=403, detail="Forbidden") async def verify_admin_token( credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer), ) -> None: if credentials is None or not secrets.compare_digest( credentials.credentials, config.ADMIN_TOKEN ): raise HTTPException(status_code=401, detail="Unauthorized") async def rate_limit_register(request: Request) -> None: key = f"reg:{_get_client_ip(request)}" count = await db.rate_limit_increment(key, _RATE_WINDOW) if count > _RATE_LIMIT: raise HTTPException(status_code=429, detail="Too Many Requests") async def rate_limit_signal(request: Request) -> None: key = f"sig:{_get_client_ip(request)}" count = await db.rate_limit_increment(key, _SIGNAL_RATE_WINDOW) if count > _SIGNAL_RATE_LIMIT: raise HTTPException(status_code=429, detail="Too Many Requests") async def rate_limit_auth_register(request: Request) -> None: key = f"authreg:{_get_client_ip(request)}" count = await db.rate_limit_increment(key, _AUTH_REGISTER_RATE_WINDOW) if count > _AUTH_REGISTER_RATE_LIMIT: raise HTTPException(status_code=429, detail="Too Many Requests") _AUTH_LOGIN_RATE_LIMIT = 5 _AUTH_LOGIN_RATE_WINDOW = 900 # 15 minutes def _b64url_encode(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() def _b64url_decode(s: str) -> bytes: padding = 4 - len(s) % 4 if padding != 4: s += "=" * padding return base64.urlsafe_b64decode(s) def create_auth_token(reg_id: int, login: str) -> str: """Create a signed HS256 JWT for an approved registration.""" now = int(time.time()) payload = { "sub": str(reg_id), "login": login, "iat": now, "exp": now + config.JWT_TOKEN_EXPIRE_SECONDS, } payload_b64 = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode()) signing_input = f"{_JWT_HEADER_B64}.{payload_b64}" sig = hmac.new( _JWT_SECRET.encode(), signing_input.encode(), hashlib.sha256 ).digest() return f"{signing_input}.{_b64url_encode(sig)}" def _verify_jwt_token(token: str) -> dict: """Verify token signature and expiry. Returns payload dict on success.""" parts = token.split(".") if len(parts) != 3: raise ValueError("Invalid token format") header_b64, payload_b64, sig_b64 = parts signing_input = f"{header_b64}.{payload_b64}" expected_sig = hmac.new( _JWT_SECRET.encode(), signing_input.encode(), hashlib.sha256 ).digest() actual_sig = _b64url_decode(sig_b64) if not hmac.compare_digest(expected_sig, actual_sig): raise ValueError("Invalid signature") payload = json.loads(_b64url_decode(payload_b64)) if payload.get("exp", 0) < time.time(): raise ValueError("Token expired") return payload async def rate_limit_auth_login(request: Request) -> None: key = f"login:{_get_client_ip(request)}" count = await db.rate_limit_increment(key, _AUTH_LOGIN_RATE_WINDOW) if count > _AUTH_LOGIN_RATE_LIMIT: raise HTTPException(status_code=429, detail="Too Many Requests") async def verify_auth_token( credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer), ) -> dict: """Dependency for protected endpoints — verifies Bearer JWT, returns payload.""" if credentials is None: raise HTTPException(status_code=401, detail="Unauthorized") try: payload = _verify_jwt_token(credentials.credentials) except Exception: raise HTTPException(status_code=401, detail="Unauthorized") return payload