diff --git a/tests/test_sec_006.py b/tests/test_sec_006.py new file mode 100644 index 0000000..8f4221d --- /dev/null +++ b/tests/test_sec_006.py @@ -0,0 +1,329 @@ +""" +Tests for BATON-SEC-006: Персистентное хранение rate-limit счётчиков. + +Acceptance criteria: +1. Счётчики сохраняются между пересозданием экземпляра приложения (симуляция рестарта). +2. TTL-очистка корректно сбрасывает устаревшие записи после истечения окна. +3. Превышение лимита возвращает HTTP 429. +4. X-Real-IP и X-Forwarded-For корректно парсятся для подсчёта. + +UUID note: All UUIDs below satisfy the v4 pattern validated since BATON-SEC-005. +""" +from __future__ import annotations + +import os + +os.environ.setdefault("BOT_TOKEN", "test-bot-token") +os.environ.setdefault("CHAT_ID", "-1001234567890") +os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret") +os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram") +os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000") +os.environ.setdefault("ADMIN_TOKEN", "test-admin-token") + +import tempfile +import unittest.mock as mock + +import aiosqlite +import pytest + +from backend import config, db +from tests.conftest import make_app_client + +# ── Valid UUID v4 constants ────────────────────────────────────────────────── + +_UUID_XREALIP_A = "c0000001-0000-4000-8000-000000000001" # X-Real-IP exhaustion +_UUID_XREALIP_B = "c0000002-0000-4000-8000-000000000002" # IP-B (independent counter) +_UUID_XFWD = "c0000003-0000-4000-8000-000000000003" # X-Forwarded-For test +_UUID_REG_RL = "c0000004-0000-4000-8000-000000000004" # register 429 test + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _tmpdb() -> str: + """Set config.DB_PATH to a fresh temp file and return the path.""" + path = tempfile.mktemp(suffix=".db") + config.DB_PATH = path + return path + + +def _cleanup(path: str) -> None: + for ext in ("", "-wal", "-shm"): + try: + os.unlink(path + ext) + except FileNotFoundError: + pass + + +# ── Criterion 1: Persistence across restart ─────────────────────────────────── + + +@pytest.mark.asyncio +async def test_rate_limits_table_created_by_init_db(): + """init_db() creates the rate_limits table in SQLite.""" + path = _tmpdb() + try: + await db.init_db() + async with aiosqlite.connect(path) as conn: + async with conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='rate_limits'" + ) as cur: + row = await cur.fetchone() + assert row is not None, "rate_limits table not found after init_db()" + finally: + _cleanup(path) + + +@pytest.mark.asyncio +async def test_rate_limit_counter_persists_after_db_reinit(): + """Counter survives re-initialization of the DB (simulates app restart). + + Before: in-memory app.state.rate_counters was lost on restart. + After: SQLite-backed rate_limits table persists across init_db() calls. + """ + path = _tmpdb() + try: + await db.init_db() + + c1 = await db.rate_limit_increment("persist:test", 600) + c2 = await db.rate_limit_increment("persist:test", 600) + c3 = await db.rate_limit_increment("persist:test", 600) + assert c3 == 3, f"Expected 3 after 3 increments, got {c3}" + + # Simulate restart: re-initialize DB against the same file + await db.init_db() + + # Counter must continue from 3, not reset to 0 + c4 = await db.rate_limit_increment("persist:test", 600) + assert c4 == 4, ( + f"Expected 4 after reinit + 1 more increment (counter must persist), got {c4}" + ) + finally: + _cleanup(path) + + +@pytest.mark.asyncio +async def test_rate_limit_increment_returns_sequential_counts(): + """rate_limit_increment returns 1, 2, 3 on successive calls within window.""" + path = _tmpdb() + try: + await db.init_db() + c1 = await db.rate_limit_increment("seq:test", 600) + c2 = await db.rate_limit_increment("seq:test", 600) + c3 = await db.rate_limit_increment("seq:test", 600) + assert (c1, c2, c3) == (1, 2, 3), f"Expected (1,2,3), got ({c1},{c2},{c3})" + finally: + _cleanup(path) + + +# ── Criterion 2: TTL cleanup resets stale entries ──────────────────────────── + + +@pytest.mark.asyncio +async def test_rate_limit_ttl_resets_counter_after_window_expires(): + """Counter resets to 1 when the time window has expired (TTL cleanup). + + time.time() is mocked — no real sleep required. + """ + path = _tmpdb() + try: + await db.init_db() + + with mock.patch("backend.db.time") as mock_time: + mock_time.time.return_value = 1000.0 # window_start = t0 + + c1 = await db.rate_limit_increment("ttl:test", 10) + c2 = await db.rate_limit_increment("ttl:test", 10) + c3 = await db.rate_limit_increment("ttl:test", 10) + assert c3 == 3 + + # Jump 11 seconds ahead (window = 10s → expired) + mock_time.time.return_value = 1011.0 + + c4 = await db.rate_limit_increment("ttl:test", 10) + + assert c4 == 1, ( + f"Expected counter reset to 1 after window expired, got {c4}" + ) + finally: + _cleanup(path) + + +@pytest.mark.asyncio +async def test_rate_limit_ttl_does_not_reset_within_window(): + """Counter is NOT reset when the window has NOT expired yet.""" + path = _tmpdb() + try: + await db.init_db() + + with mock.patch("backend.db.time") as mock_time: + mock_time.time.return_value = 1000.0 + + await db.rate_limit_increment("ttl:within", 10) + await db.rate_limit_increment("ttl:within", 10) + c3 = await db.rate_limit_increment("ttl:within", 10) + assert c3 == 3 + + # Only 5 seconds passed (window = 10s, still active) + mock_time.time.return_value = 1005.0 + + c4 = await db.rate_limit_increment("ttl:within", 10) + + assert c4 == 4, ( + f"Expected 4 (counter continues inside window), got {c4}" + ) + finally: + _cleanup(path) + + +@pytest.mark.asyncio +async def test_rate_limit_ttl_boundary_exactly_at_window_end(): + """Counter resets when elapsed time equals exactly the window duration.""" + path = _tmpdb() + try: + await db.init_db() + + with mock.patch("backend.db.time") as mock_time: + mock_time.time.return_value = 1000.0 + + await db.rate_limit_increment("ttl:boundary", 10) + await db.rate_limit_increment("ttl:boundary", 10) + + # Exactly at window boundary (elapsed == window → stale) + mock_time.time.return_value = 1010.0 + + c = await db.rate_limit_increment("ttl:boundary", 10) + + assert c == 1, ( + f"Expected reset at exact window boundary (elapsed == window), got {c}" + ) + finally: + _cleanup(path) + + +# ── Criterion 3: HTTP 429 when rate limit exceeded ──────────────────────────── + + +@pytest.mark.asyncio +async def test_register_returns_429_after_rate_limit_exceeded(): + """POST /api/register returns 429 on the 6th request from the same IP. + + Register limit = 5 requests per 600s window. + """ + async with make_app_client() as client: + ip_hdrs = {"X-Real-IP": "192.0.2.10"} + statuses = [] + for _ in range(6): + r = await client.post( + "/api/register", + json={"uuid": _UUID_REG_RL, "name": "RateLimitTest"}, + headers=ip_hdrs, + ) + statuses.append(r.status_code) + + assert statuses[-1] == 429, ( + f"Expected 429 on 6th register request, got statuses: {statuses}" + ) + + +@pytest.mark.asyncio +async def test_register_first_5_requests_are_allowed(): + """First 5 POST /api/register requests from the same IP must all return 200.""" + async with make_app_client() as client: + ip_hdrs = {"X-Real-IP": "192.0.2.11"} + statuses = [] + for _ in range(5): + r = await client.post( + "/api/register", + json={"uuid": _UUID_REG_RL, "name": "RateLimitTest"}, + headers=ip_hdrs, + ) + statuses.append(r.status_code) + + assert all(s == 200 for s in statuses), ( + f"Expected all 5 register requests to return 200, got: {statuses}" + ) + + +# ── Criterion 4: X-Real-IP and X-Forwarded-For for rate counting ────────────── + + +@pytest.mark.asyncio +async def test_x_real_ip_header_is_used_for_rate_counting(): + """Rate counter keys are derived from X-Real-IP: two requests sharing + the same X-Real-IP share the same counter and collectively hit the 429 limit. + """ + async with make_app_client() as client: + await client.post( + "/api/register", json={"uuid": _UUID_XREALIP_A, "name": "RealIPUser"} + ) + + ip_hdrs = {"X-Real-IP": "203.0.113.10"} + payload = {"user_id": _UUID_XREALIP_A, "timestamp": 1742478000000} + + statuses = [] + for _ in range(11): + r = await client.post("/api/signal", json=payload, headers=ip_hdrs) + statuses.append(r.status_code) + + assert statuses[-1] == 429, ( + f"Expected 429 on 11th signal with same X-Real-IP, got: {statuses}" + ) + + +@pytest.mark.asyncio +async def test_x_forwarded_for_header_is_used_for_rate_counting(): + """Rate counter keys are derived from X-Forwarded-For (first IP) when + X-Real-IP is absent: requests sharing the same forwarded IP hit the limit. + """ + async with make_app_client() as client: + await client.post( + "/api/register", json={"uuid": _UUID_XFWD, "name": "FwdUser"} + ) + + # Chain: first IP is the original client (only that one is used) + fwd_hdrs = {"X-Forwarded-For": "198.51.100.5, 10.0.0.1, 172.16.0.1"} + payload = {"user_id": _UUID_XFWD, "timestamp": 1742478000000} + + statuses = [] + for _ in range(11): + r = await client.post("/api/signal", json=payload, headers=fwd_hdrs) + statuses.append(r.status_code) + + assert statuses[-1] == 429, ( + f"Expected 429 on 11th request with same X-Forwarded-For first IP, got: {statuses}" + ) + + +@pytest.mark.asyncio +async def test_different_x_real_ip_values_have_independent_counters(): + """Exhausting the rate limit for IP-A must not block IP-B. + + Verifies that rate-limit keys are truly per-IP. + """ + async with make_app_client() as client: + await client.post( + "/api/register", json={"uuid": _UUID_XREALIP_A, "name": "IPA"} + ) + await client.post( + "/api/register", json={"uuid": _UUID_XREALIP_B, "name": "IPB"} + ) + + # Exhaust limit for IP-A + for _ in range(11): + await client.post( + "/api/signal", + json={"user_id": _UUID_XREALIP_A, "timestamp": 1742478000000}, + headers={"X-Real-IP": "198.51.100.100"}, + ) + + # IP-B has its own independent counter — must not be blocked + r = await client.post( + "/api/signal", + json={"user_id": _UUID_XREALIP_B, "timestamp": 1742478000000}, + headers={"X-Real-IP": "198.51.100.200"}, + ) + + assert r.status_code == 200, ( + f"IP-B was incorrectly blocked after IP-A exhausted its counter: {r.status_code}" + )