From 4b7e59d78d2a0c7ded154fd0c1195b26638fc657 Mon Sep 17 00:00:00 2001 From: Gros Frumos Date: Sat, 21 Mar 2026 07:56:44 +0200 Subject: [PATCH] kin: BATON-SEC-006-backend_dev --- backend/db.py | 36 ++++++++++++++++++++++++++++++++++++ backend/main.py | 1 - backend/middleware.py | 23 ++++------------------- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/backend/db.py b/backend/db.py index e0aca18..bb1df49 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional import aiosqlite @@ -59,6 +60,12 @@ async def init_db() -> None: ON signals(created_at); CREATE INDEX IF NOT EXISTS idx_batches_status ON telegram_batches(status); + + CREATE TABLE IF NOT EXISTS rate_limits ( + ip TEXT NOT NULL PRIMARY KEY, + count INTEGER NOT NULL DEFAULT 0, + window_start REAL NOT NULL DEFAULT 0 + ); """) # Migrations for existing databases (silently ignore if columns already exist) for stmt in [ @@ -228,6 +235,35 @@ async def admin_delete_user(user_id: int) -> bool: return changed +async def rate_limit_increment(key: str, window: float) -> int: + """Increment rate-limit counter for key within window. Returns current count. + + Cleans up the stale record for this key before incrementing (TTL by window_start). + """ + now = time.time() + async with _get_conn() as conn: + # TTL cleanup: remove stale record for this key if window has expired + await conn.execute( + "DELETE FROM rate_limits WHERE ip = ? AND ? - window_start >= ?", + (key, now, window), + ) + # Upsert: insert new record or increment existing + await conn.execute( + """ + INSERT INTO rate_limits (ip, count, window_start) + VALUES (?, 1, ?) + ON CONFLICT(ip) DO UPDATE SET count = count + 1 + """, + (key, now), + ) + await conn.commit() + async with conn.execute( + "SELECT count FROM rate_limits WHERE ip = ?", (key,) + ) as cur: + row = await cur.fetchone() + return row["count"] if row else 1 + + async def save_telegram_batch( message_text: str, signals_count: int, diff --git a/backend/main.py b/backend/main.py index 7c267d8..ed8ab90 100644 --- a/backend/main.py +++ b/backend/main.py @@ -59,7 +59,6 @@ async def _keep_alive_loop(app_url: str) -> None: @asynccontextmanager async def lifespan(app: FastAPI): # Startup - app.state.rate_counters = {} await db.init_db() logger.info("Database initialized") diff --git a/backend/middleware.py b/backend/middleware.py index 1a3aa39..b91b83e 100644 --- a/backend/middleware.py +++ b/backend/middleware.py @@ -1,13 +1,12 @@ from __future__ import annotations import secrets -import time from typing import Optional from fastapi import Depends, Header, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from backend import config +from backend import config, db _bearer = HTTPBearer(auto_error=False) @@ -45,28 +44,14 @@ async def verify_admin_token( async def rate_limit_register(request: Request) -> None: - counters = request.app.state.rate_counters - client_ip = _get_client_ip(request) - now = time.time() - count, window_start = counters.get(client_ip, (0, now)) - if now - window_start >= _RATE_WINDOW: - count = 0 - window_start = now - count += 1 - counters[client_ip] = (count, window_start) + key = f"reg:{_get_client_ip(request)}" + count = await db.rate_limit_increment(key, _RATE_WINDOW) if count > _RATE_LIMIT: raise HTTPException(status_code=429, detail="Too Many Requests") async def rate_limit_signal(request: Request) -> None: - counters = request.app.state.rate_counters key = f"sig:{_get_client_ip(request)}" - now = time.time() - count, window_start = counters.get(key, (0, now)) - if now - window_start >= _SIGNAL_RATE_WINDOW: - count = 0 - window_start = now - count += 1 - counters[key] = (count, window_start) + count = await db.rate_limit_increment(key, _SIGNAL_RATE_WINDOW) if count > _SIGNAL_RATE_LIMIT: raise HTTPException(status_code=429, detail="Too Many Requests")