kin: BATON-BIZ-001-backend_dev

This commit is contained in:
Gros Frumos 2026-03-21 12:42:13 +02:00
parent e266b6506e
commit ea06309a6e
5 changed files with 160 additions and 1 deletions

View file

@ -1,6 +1,11 @@
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import secrets
import time
from typing import Optional
from fastapi import Depends, Header, HTTPException, Request
@ -8,6 +13,12 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from backend import config, db
# JWT secret: stable across restarts if JWT_SECRET env var is set; random per-process otherwise
_JWT_SECRET: str = config.JWT_SECRET or secrets.token_hex(32)
_JWT_HEADER_B64: str = (
base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}').rstrip(b"=").decode()
)
_bearer = HTTPBearer(auto_error=False)
_RATE_LIMIT = 5
@ -65,3 +76,74 @@ async def rate_limit_auth_register(request: Request) -> None:
count = await db.rate_limit_increment(key, _AUTH_REGISTER_RATE_WINDOW)
if count > _AUTH_REGISTER_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests")
_AUTH_LOGIN_RATE_LIMIT = 5
_AUTH_LOGIN_RATE_WINDOW = 900 # 15 minutes
def _b64url_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def _b64url_decode(s: str) -> bytes:
padding = 4 - len(s) % 4
if padding != 4:
s += "=" * padding
return base64.urlsafe_b64decode(s)
def create_auth_token(reg_id: int, login: str) -> str:
"""Create a signed HS256 JWT for an approved registration."""
now = int(time.time())
payload = {
"sub": str(reg_id),
"login": login,
"iat": now,
"exp": now + config.JWT_TOKEN_EXPIRE_SECONDS,
}
payload_b64 = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode())
signing_input = f"{_JWT_HEADER_B64}.{payload_b64}"
sig = hmac.new(
_JWT_SECRET.encode(), signing_input.encode(), hashlib.sha256
).digest()
return f"{signing_input}.{_b64url_encode(sig)}"
def _verify_jwt_token(token: str) -> dict:
"""Verify token signature and expiry. Returns payload dict on success."""
parts = token.split(".")
if len(parts) != 3:
raise ValueError("Invalid token format")
header_b64, payload_b64, sig_b64 = parts
signing_input = f"{header_b64}.{payload_b64}"
expected_sig = hmac.new(
_JWT_SECRET.encode(), signing_input.encode(), hashlib.sha256
).digest()
actual_sig = _b64url_decode(sig_b64)
if not hmac.compare_digest(expected_sig, actual_sig):
raise ValueError("Invalid signature")
payload = json.loads(_b64url_decode(payload_b64))
if payload.get("exp", 0) < time.time():
raise ValueError("Token expired")
return payload
async def rate_limit_auth_login(request: Request) -> None:
key = f"login:{_get_client_ip(request)}"
count = await db.rate_limit_increment(key, _AUTH_LOGIN_RATE_WINDOW)
if count > _AUTH_LOGIN_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests")
async def verify_auth_token(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
) -> dict:
"""Dependency for protected endpoints — verifies Bearer JWT, returns payload."""
if credentials is None:
raise HTTPException(status_code=401, detail="Unauthorized")
try:
payload = _verify_jwt_token(credentials.credentials)
except Exception:
raise HTTPException(status_code=401, detail="Unauthorized")
return payload