kin: BATON-SEC-006-backend_dev

This commit is contained in:
Gros Frumos 2026-03-21 07:56:44 +02:00
parent 8629f3e40b
commit ee966dd148
3 changed files with 40 additions and 20 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
import aiosqlite import aiosqlite
@ -59,6 +60,12 @@ async def init_db() -> None:
ON signals(created_at); ON signals(created_at);
CREATE INDEX IF NOT EXISTS idx_batches_status CREATE INDEX IF NOT EXISTS idx_batches_status
ON telegram_batches(status); ON telegram_batches(status);
CREATE TABLE IF NOT EXISTS rate_limits (
ip TEXT NOT NULL PRIMARY KEY,
count INTEGER NOT NULL DEFAULT 0,
window_start REAL NOT NULL DEFAULT 0
);
""") """)
# Migrations for existing databases (silently ignore if columns already exist) # Migrations for existing databases (silently ignore if columns already exist)
for stmt in [ for stmt in [
@ -228,6 +235,35 @@ async def admin_delete_user(user_id: int) -> bool:
return changed return changed
async def rate_limit_increment(key: str, window: float) -> int:
"""Increment rate-limit counter for key within window. Returns current count.
Cleans up the stale record for this key before incrementing (TTL by window_start).
"""
now = time.time()
async with _get_conn() as conn:
# TTL cleanup: remove stale record for this key if window has expired
await conn.execute(
"DELETE FROM rate_limits WHERE ip = ? AND ? - window_start >= ?",
(key, now, window),
)
# Upsert: insert new record or increment existing
await conn.execute(
"""
INSERT INTO rate_limits (ip, count, window_start)
VALUES (?, 1, ?)
ON CONFLICT(ip) DO UPDATE SET count = count + 1
""",
(key, now),
)
await conn.commit()
async with conn.execute(
"SELECT count FROM rate_limits WHERE ip = ?", (key,)
) as cur:
row = await cur.fetchone()
return row["count"] if row else 1
async def save_telegram_batch( async def save_telegram_batch(
message_text: str, message_text: str,
signals_count: int, signals_count: int,

View file

@ -59,7 +59,6 @@ async def _keep_alive_loop(app_url: str) -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
app.state.rate_counters = {}
await db.init_db() await db.init_db()
logger.info("Database initialized") logger.info("Database initialized")

View file

@ -1,13 +1,12 @@
from __future__ import annotations from __future__ import annotations
import secrets import secrets
import time
from typing import Optional from typing import Optional
from fastapi import Depends, Header, HTTPException, Request from fastapi import Depends, Header, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from backend import config from backend import config, db
_bearer = HTTPBearer(auto_error=False) _bearer = HTTPBearer(auto_error=False)
@ -45,28 +44,14 @@ async def verify_admin_token(
async def rate_limit_register(request: Request) -> None: async def rate_limit_register(request: Request) -> None:
counters = request.app.state.rate_counters key = f"reg:{_get_client_ip(request)}"
client_ip = _get_client_ip(request) count = await db.rate_limit_increment(key, _RATE_WINDOW)
now = time.time()
count, window_start = counters.get(client_ip, (0, now))
if now - window_start >= _RATE_WINDOW:
count = 0
window_start = now
count += 1
counters[client_ip] = (count, window_start)
if count > _RATE_LIMIT: if count > _RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests") raise HTTPException(status_code=429, detail="Too Many Requests")
async def rate_limit_signal(request: Request) -> None: async def rate_limit_signal(request: Request) -> None:
counters = request.app.state.rate_counters
key = f"sig:{_get_client_ip(request)}" key = f"sig:{_get_client_ip(request)}"
now = time.time() count = await db.rate_limit_increment(key, _SIGNAL_RATE_WINDOW)
count, window_start = counters.get(key, (0, now))
if now - window_start >= _SIGNAL_RATE_WINDOW:
count = 0
window_start = now
count += 1
counters[key] = (count, window_start)
if count > _SIGNAL_RATE_LIMIT: if count > _SIGNAL_RATE_LIMIT:
raise HTTPException(status_code=429, detail="Too Many Requests") raise HTTPException(status_code=429, detail="Too Many Requests")