Merge branch 'BATON-SEC-006-backend_dev'
This commit is contained in:
commit
2d7b99618c
3 changed files with 40 additions and 20 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncGenerator, Optional
|
from typing import AsyncGenerator, Optional
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
@ -59,6 +60,12 @@ async def init_db() -> None:
|
||||||
ON signals(created_at);
|
ON signals(created_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_batches_status
|
CREATE INDEX IF NOT EXISTS idx_batches_status
|
||||||
ON telegram_batches(status);
|
ON telegram_batches(status);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS rate_limits (
|
||||||
|
ip TEXT NOT NULL PRIMARY KEY,
|
||||||
|
count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
window_start REAL NOT NULL DEFAULT 0
|
||||||
|
);
|
||||||
""")
|
""")
|
||||||
# Migrations for existing databases (silently ignore if columns already exist)
|
# Migrations for existing databases (silently ignore if columns already exist)
|
||||||
for stmt in [
|
for stmt in [
|
||||||
|
|
@ -228,6 +235,35 @@ async def admin_delete_user(user_id: int) -> bool:
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
|
|
||||||
|
async def rate_limit_increment(key: str, window: float) -> int:
|
||||||
|
"""Increment rate-limit counter for key within window. Returns current count.
|
||||||
|
|
||||||
|
Cleans up the stale record for this key before incrementing (TTL by window_start).
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
async with _get_conn() as conn:
|
||||||
|
# TTL cleanup: remove stale record for this key if window has expired
|
||||||
|
await conn.execute(
|
||||||
|
"DELETE FROM rate_limits WHERE ip = ? AND ? - window_start >= ?",
|
||||||
|
(key, now, window),
|
||||||
|
)
|
||||||
|
# Upsert: insert new record or increment existing
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO rate_limits (ip, count, window_start)
|
||||||
|
VALUES (?, 1, ?)
|
||||||
|
ON CONFLICT(ip) DO UPDATE SET count = count + 1
|
||||||
|
""",
|
||||||
|
(key, now),
|
||||||
|
)
|
||||||
|
await conn.commit()
|
||||||
|
async with conn.execute(
|
||||||
|
"SELECT count FROM rate_limits WHERE ip = ?", (key,)
|
||||||
|
) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
return row["count"] if row else 1
|
||||||
|
|
||||||
|
|
||||||
async def save_telegram_batch(
|
async def save_telegram_batch(
|
||||||
message_text: str,
|
message_text: str,
|
||||||
signals_count: int,
|
signals_count: int,
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,6 @@ async def _keep_alive_loop(app_url: str) -> None:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup
|
# Startup
|
||||||
app.state.rate_counters = {}
|
|
||||||
await db.init_db()
|
await db.init_db()
|
||||||
logger.info("Database initialized")
|
logger.info("Database initialized")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
import time
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends, Header, HTTPException, Request
|
from fastapi import Depends, Header, HTTPException, Request
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
from backend import config
|
from backend import config, db
|
||||||
|
|
||||||
_bearer = HTTPBearer(auto_error=False)
|
_bearer = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
@ -45,28 +44,14 @@ async def verify_admin_token(
|
||||||
|
|
||||||
|
|
||||||
async def rate_limit_register(request: Request) -> None:
|
async def rate_limit_register(request: Request) -> None:
|
||||||
counters = request.app.state.rate_counters
|
key = f"reg:{_get_client_ip(request)}"
|
||||||
client_ip = _get_client_ip(request)
|
count = await db.rate_limit_increment(key, _RATE_WINDOW)
|
||||||
now = time.time()
|
|
||||||
count, window_start = counters.get(client_ip, (0, now))
|
|
||||||
if now - window_start >= _RATE_WINDOW:
|
|
||||||
count = 0
|
|
||||||
window_start = now
|
|
||||||
count += 1
|
|
||||||
counters[client_ip] = (count, window_start)
|
|
||||||
if count > _RATE_LIMIT:
|
if count > _RATE_LIMIT:
|
||||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
|
||||||
|
|
||||||
async def rate_limit_signal(request: Request) -> None:
|
async def rate_limit_signal(request: Request) -> None:
|
||||||
counters = request.app.state.rate_counters
|
|
||||||
key = f"sig:{_get_client_ip(request)}"
|
key = f"sig:{_get_client_ip(request)}"
|
||||||
now = time.time()
|
count = await db.rate_limit_increment(key, _SIGNAL_RATE_WINDOW)
|
||||||
count, window_start = counters.get(key, (0, now))
|
|
||||||
if now - window_start >= _SIGNAL_RATE_WINDOW:
|
|
||||||
count = 0
|
|
||||||
window_start = now
|
|
||||||
count += 1
|
|
||||||
counters[key] = (count, window_start)
|
|
||||||
if count > _SIGNAL_RATE_LIMIT:
|
if count > _SIGNAL_RATE_LIMIT:
|
||||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue