baton/tests/test_sec_006.py

338 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Tests for BATON-SEC-006: Персистентное хранение rate-limit счётчиков.
Acceptance criteria:
1. Счётчики сохраняются между пересозданием экземпляра приложения (симуляция рестарта).
2. TTL-очистка корректно сбрасывает устаревшие записи после истечения окна.
3. Превышение лимита возвращает HTTP 429.
4. X-Real-IP и X-Forwarded-For корректно парсятся для подсчёта.
UUID note: All UUIDs below satisfy the v4 pattern validated since BATON-SEC-005.
"""
from __future__ import annotations
import os
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
os.environ.setdefault("CHAT_ID", "-1001234567890")
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
os.environ.setdefault("ADMIN_TOKEN", "test-admin-token")
import tempfile
import unittest.mock as mock
import aiosqlite
import pytest
from backend import config, db
from tests.conftest import make_app_client
# ── Valid UUID v4 constants ──────────────────────────────────────────────────
_UUID_XREALIP_A = "c0000001-0000-4000-8000-000000000001" # X-Real-IP exhaustion
_UUID_XREALIP_B = "c0000002-0000-4000-8000-000000000002" # IP-B (independent counter)
_UUID_XFWD = "c0000003-0000-4000-8000-000000000003" # X-Forwarded-For test
_UUID_REG_RL = "c0000004-0000-4000-8000-000000000004" # register 429 test
# ── Helpers ──────────────────────────────────────────────────────────────────
def _tmpdb() -> str:
"""Set config.DB_PATH to a fresh temp file and return the path."""
path = tempfile.mktemp(suffix=".db")
config.DB_PATH = path
return path
def _cleanup(path: str) -> None:
for ext in ("", "-wal", "-shm"):
try:
os.unlink(path + ext)
except FileNotFoundError:
pass
# ── Criterion 1: Persistence across restart ───────────────────────────────────
@pytest.mark.asyncio
async def test_rate_limits_table_created_by_init_db():
"""init_db() creates the rate_limits table in SQLite."""
path = _tmpdb()
try:
await db.init_db()
async with aiosqlite.connect(path) as conn:
async with conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='rate_limits'"
) as cur:
row = await cur.fetchone()
assert row is not None, "rate_limits table not found after init_db()"
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_rate_limit_counter_persists_after_db_reinit():
"""Counter survives re-initialization of the DB (simulates app restart).
Before: in-memory app.state.rate_counters was lost on restart.
After: SQLite-backed rate_limits table persists across init_db() calls.
"""
path = _tmpdb()
try:
await db.init_db()
c1 = await db.rate_limit_increment("persist:test", 600)
c2 = await db.rate_limit_increment("persist:test", 600)
c3 = await db.rate_limit_increment("persist:test", 600)
assert c3 == 3, f"Expected 3 after 3 increments, got {c3}"
# Simulate restart: re-initialize DB against the same file
await db.init_db()
# Counter must continue from 3, not reset to 0
c4 = await db.rate_limit_increment("persist:test", 600)
assert c4 == 4, (
f"Expected 4 after reinit + 1 more increment (counter must persist), got {c4}"
)
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_rate_limit_increment_returns_sequential_counts():
"""rate_limit_increment returns 1, 2, 3 on successive calls within window."""
path = _tmpdb()
try:
await db.init_db()
c1 = await db.rate_limit_increment("seq:test", 600)
c2 = await db.rate_limit_increment("seq:test", 600)
c3 = await db.rate_limit_increment("seq:test", 600)
assert (c1, c2, c3) == (1, 2, 3), f"Expected (1,2,3), got ({c1},{c2},{c3})"
finally:
_cleanup(path)
# ── Criterion 2: TTL cleanup resets stale entries ────────────────────────────
@pytest.mark.asyncio
async def test_rate_limit_ttl_resets_counter_after_window_expires():
"""Counter resets to 1 when the time window has expired (TTL cleanup).
time.time() is mocked no real sleep required.
"""
path = _tmpdb()
try:
await db.init_db()
with mock.patch("backend.db.time") as mock_time:
mock_time.time.return_value = 1000.0 # window_start = t0
c1 = await db.rate_limit_increment("ttl:test", 10)
c2 = await db.rate_limit_increment("ttl:test", 10)
c3 = await db.rate_limit_increment("ttl:test", 10)
assert c3 == 3
# Jump 11 seconds ahead (window = 10s → expired)
mock_time.time.return_value = 1011.0
c4 = await db.rate_limit_increment("ttl:test", 10)
assert c4 == 1, (
f"Expected counter reset to 1 after window expired, got {c4}"
)
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_rate_limit_ttl_does_not_reset_within_window():
"""Counter is NOT reset when the window has NOT expired yet."""
path = _tmpdb()
try:
await db.init_db()
with mock.patch("backend.db.time") as mock_time:
mock_time.time.return_value = 1000.0
await db.rate_limit_increment("ttl:within", 10)
await db.rate_limit_increment("ttl:within", 10)
c3 = await db.rate_limit_increment("ttl:within", 10)
assert c3 == 3
# Only 5 seconds passed (window = 10s, still active)
mock_time.time.return_value = 1005.0
c4 = await db.rate_limit_increment("ttl:within", 10)
assert c4 == 4, (
f"Expected 4 (counter continues inside window), got {c4}"
)
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_rate_limit_ttl_boundary_exactly_at_window_end():
"""Counter resets when elapsed time equals exactly the window duration."""
path = _tmpdb()
try:
await db.init_db()
with mock.patch("backend.db.time") as mock_time:
mock_time.time.return_value = 1000.0
await db.rate_limit_increment("ttl:boundary", 10)
await db.rate_limit_increment("ttl:boundary", 10)
# Exactly at window boundary (elapsed == window → stale)
mock_time.time.return_value = 1010.0
c = await db.rate_limit_increment("ttl:boundary", 10)
assert c == 1, (
f"Expected reset at exact window boundary (elapsed == window), got {c}"
)
finally:
_cleanup(path)
# ── Criterion 3: HTTP 429 when rate limit exceeded ────────────────────────────
@pytest.mark.asyncio
async def test_register_returns_429_after_rate_limit_exceeded():
"""POST /api/register returns 429 on the 6th request from the same IP.
Register limit = 5 requests per 600s window.
"""
async with make_app_client() as client:
ip_hdrs = {"X-Real-IP": "192.0.2.10"}
statuses = []
for _ in range(6):
r = await client.post(
"/api/register",
json={"uuid": _UUID_REG_RL, "name": "RateLimitTest"},
headers=ip_hdrs,
)
statuses.append(r.status_code)
assert statuses[-1] == 429, (
f"Expected 429 on 6th register request, got statuses: {statuses}"
)
@pytest.mark.asyncio
async def test_register_first_5_requests_are_allowed():
"""First 5 POST /api/register requests from the same IP must all return 200."""
async with make_app_client() as client:
ip_hdrs = {"X-Real-IP": "192.0.2.11"}
statuses = []
for _ in range(5):
r = await client.post(
"/api/register",
json={"uuid": _UUID_REG_RL, "name": "RateLimitTest"},
headers=ip_hdrs,
)
statuses.append(r.status_code)
assert all(s == 200 for s in statuses), (
f"Expected all 5 register requests to return 200, got: {statuses}"
)
# ── Criterion 4: X-Real-IP and X-Forwarded-For for rate counting ──────────────
@pytest.mark.asyncio
async def test_x_real_ip_header_is_used_for_rate_counting():
"""Rate counter keys are derived from X-Real-IP: two requests sharing
the same X-Real-IP share the same counter and collectively hit the 429 limit.
"""
async with make_app_client() as client:
await client.post(
"/api/register", json={"uuid": _UUID_XREALIP_A, "name": "RealIPUser"}
)
ip_hdrs = {"X-Real-IP": "203.0.113.10"}
payload = {"user_id": _UUID_XREALIP_A, "timestamp": 1742478000000}
statuses = []
for _ in range(11):
r = await client.post("/api/signal", json=payload, headers=ip_hdrs)
statuses.append(r.status_code)
assert statuses[-1] == 429, (
f"Expected 429 on 11th signal with same X-Real-IP, got: {statuses}"
)
@pytest.mark.asyncio
async def test_x_forwarded_for_header_is_used_for_rate_counting():
"""Rate counter keys are derived from X-Forwarded-For (first IP) when
X-Real-IP is absent: requests sharing the same forwarded IP hit the limit.
"""
async with make_app_client() as client:
await client.post(
"/api/register", json={"uuid": _UUID_XFWD, "name": "FwdUser"}
)
# Chain: first IP is the original client (only that one is used)
fwd_hdrs = {"X-Forwarded-For": "198.51.100.5, 10.0.0.1, 172.16.0.1"}
payload = {"user_id": _UUID_XFWD, "timestamp": 1742478000000}
statuses = []
for _ in range(11):
r = await client.post("/api/signal", json=payload, headers=fwd_hdrs)
statuses.append(r.status_code)
assert statuses[-1] == 429, (
f"Expected 429 on 11th request with same X-Forwarded-For first IP, got: {statuses}"
)
@pytest.mark.asyncio
async def test_different_x_real_ip_values_have_independent_counters():
"""Exhausting the rate limit for IP-A must not block IP-B.
Verifies that rate-limit keys are truly per-IP.
"""
async with make_app_client() as client:
r_a = await client.post(
"/api/register", json={"uuid": _UUID_XREALIP_A, "name": "IPA"}
)
r_b = await client.post(
"/api/register", json={"uuid": _UUID_XREALIP_B, "name": "IPB"}
)
api_key_a = r_a.json()["api_key"]
api_key_b = r_b.json()["api_key"]
# Exhaust limit for IP-A (with valid auth so requests reach the rate limiter)
for _ in range(11):
await client.post(
"/api/signal",
json={"user_id": _UUID_XREALIP_A, "timestamp": 1742478000000},
headers={
"X-Real-IP": "198.51.100.100",
"Authorization": f"Bearer {api_key_a}",
},
)
# IP-B has its own independent counter — must not be blocked
r = await client.post(
"/api/signal",
json={"user_id": _UUID_XREALIP_B, "timestamp": 1742478000000},
headers={
"X-Real-IP": "198.51.100.200",
"Authorization": f"Bearer {api_key_b}",
},
)
assert r.status_code == 200, (
f"IP-B was incorrectly blocked after IP-A exhausted its counter: {r.status_code}"
)