""" Tests for BATON-SEC-006: Персистентное хранение rate-limit счётчиков. Acceptance criteria: 1. Счётчики сохраняются между пересозданием экземпляра приложения (симуляция рестарта). 2. TTL-очистка корректно сбрасывает устаревшие записи после истечения окна. 3. Превышение лимита возвращает HTTP 429. 4. X-Real-IP и X-Forwarded-For корректно парсятся для подсчёта. UUID note: All UUIDs below satisfy the v4 pattern validated since BATON-SEC-005. """ from __future__ import annotations import os os.environ.setdefault("BOT_TOKEN", "test-bot-token") os.environ.setdefault("CHAT_ID", "-1001234567890") os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret") os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram") os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000") os.environ.setdefault("ADMIN_TOKEN", "test-admin-token") import tempfile import unittest.mock as mock import aiosqlite import pytest from backend import config, db from tests.conftest import make_app_client # ── Valid UUID v4 constants ────────────────────────────────────────────────── _UUID_XREALIP_A = "c0000001-0000-4000-8000-000000000001" # X-Real-IP exhaustion _UUID_XREALIP_B = "c0000002-0000-4000-8000-000000000002" # IP-B (independent counter) _UUID_XFWD = "c0000003-0000-4000-8000-000000000003" # X-Forwarded-For test _UUID_REG_RL = "c0000004-0000-4000-8000-000000000004" # register 429 test # ── Helpers ────────────────────────────────────────────────────────────────── def _tmpdb() -> str: """Set config.DB_PATH to a fresh temp file and return the path.""" path = tempfile.mktemp(suffix=".db") config.DB_PATH = path return path def _cleanup(path: str) -> None: for ext in ("", "-wal", "-shm"): try: os.unlink(path + ext) except FileNotFoundError: pass # ── Criterion 1: Persistence across restart ─────────────────────────────────── @pytest.mark.asyncio async def test_rate_limits_table_created_by_init_db(): """init_db() creates the rate_limits table in SQLite.""" path = _tmpdb() try: await db.init_db() async with aiosqlite.connect(path) as conn: async with conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='rate_limits'" ) as cur: row = await cur.fetchone() assert row is not None, "rate_limits table not found after init_db()" finally: _cleanup(path) @pytest.mark.asyncio async def test_rate_limit_counter_persists_after_db_reinit(): """Counter survives re-initialization of the DB (simulates app restart). Before: in-memory app.state.rate_counters was lost on restart. After: SQLite-backed rate_limits table persists across init_db() calls. """ path = _tmpdb() try: await db.init_db() c1 = await db.rate_limit_increment("persist:test", 600) c2 = await db.rate_limit_increment("persist:test", 600) c3 = await db.rate_limit_increment("persist:test", 600) assert c3 == 3, f"Expected 3 after 3 increments, got {c3}" # Simulate restart: re-initialize DB against the same file await db.init_db() # Counter must continue from 3, not reset to 0 c4 = await db.rate_limit_increment("persist:test", 600) assert c4 == 4, ( f"Expected 4 after reinit + 1 more increment (counter must persist), got {c4}" ) finally: _cleanup(path) @pytest.mark.asyncio async def test_rate_limit_increment_returns_sequential_counts(): """rate_limit_increment returns 1, 2, 3 on successive calls within window.""" path = _tmpdb() try: await db.init_db() c1 = await db.rate_limit_increment("seq:test", 600) c2 = await db.rate_limit_increment("seq:test", 600) c3 = await db.rate_limit_increment("seq:test", 600) assert (c1, c2, c3) == (1, 2, 3), f"Expected (1,2,3), got ({c1},{c2},{c3})" finally: _cleanup(path) # ── Criterion 2: TTL cleanup resets stale entries ──────────────────────────── @pytest.mark.asyncio async def test_rate_limit_ttl_resets_counter_after_window_expires(): """Counter resets to 1 when the time window has expired (TTL cleanup). time.time() is mocked — no real sleep required. """ path = _tmpdb() try: await db.init_db() with mock.patch("backend.db.time") as mock_time: mock_time.time.return_value = 1000.0 # window_start = t0 c1 = await db.rate_limit_increment("ttl:test", 10) c2 = await db.rate_limit_increment("ttl:test", 10) c3 = await db.rate_limit_increment("ttl:test", 10) assert c3 == 3 # Jump 11 seconds ahead (window = 10s → expired) mock_time.time.return_value = 1011.0 c4 = await db.rate_limit_increment("ttl:test", 10) assert c4 == 1, ( f"Expected counter reset to 1 after window expired, got {c4}" ) finally: _cleanup(path) @pytest.mark.asyncio async def test_rate_limit_ttl_does_not_reset_within_window(): """Counter is NOT reset when the window has NOT expired yet.""" path = _tmpdb() try: await db.init_db() with mock.patch("backend.db.time") as mock_time: mock_time.time.return_value = 1000.0 await db.rate_limit_increment("ttl:within", 10) await db.rate_limit_increment("ttl:within", 10) c3 = await db.rate_limit_increment("ttl:within", 10) assert c3 == 3 # Only 5 seconds passed (window = 10s, still active) mock_time.time.return_value = 1005.0 c4 = await db.rate_limit_increment("ttl:within", 10) assert c4 == 4, ( f"Expected 4 (counter continues inside window), got {c4}" ) finally: _cleanup(path) @pytest.mark.asyncio async def test_rate_limit_ttl_boundary_exactly_at_window_end(): """Counter resets when elapsed time equals exactly the window duration.""" path = _tmpdb() try: await db.init_db() with mock.patch("backend.db.time") as mock_time: mock_time.time.return_value = 1000.0 await db.rate_limit_increment("ttl:boundary", 10) await db.rate_limit_increment("ttl:boundary", 10) # Exactly at window boundary (elapsed == window → stale) mock_time.time.return_value = 1010.0 c = await db.rate_limit_increment("ttl:boundary", 10) assert c == 1, ( f"Expected reset at exact window boundary (elapsed == window), got {c}" ) finally: _cleanup(path) # ── Criterion 3: HTTP 429 when rate limit exceeded ──────────────────────────── @pytest.mark.asyncio async def test_register_returns_429_after_rate_limit_exceeded(): """POST /api/register returns 429 on the 6th request from the same IP. Register limit = 5 requests per 600s window. """ async with make_app_client() as client: ip_hdrs = {"X-Real-IP": "192.0.2.10"} statuses = [] for _ in range(6): r = await client.post( "/api/register", json={"uuid": _UUID_REG_RL, "name": "RateLimitTest"}, headers=ip_hdrs, ) statuses.append(r.status_code) assert statuses[-1] == 429, ( f"Expected 429 on 6th register request, got statuses: {statuses}" ) @pytest.mark.asyncio async def test_register_first_5_requests_are_allowed(): """First 5 POST /api/register requests from the same IP must all return 200.""" async with make_app_client() as client: ip_hdrs = {"X-Real-IP": "192.0.2.11"} statuses = [] for _ in range(5): r = await client.post( "/api/register", json={"uuid": _UUID_REG_RL, "name": "RateLimitTest"}, headers=ip_hdrs, ) statuses.append(r.status_code) assert all(s == 200 for s in statuses), ( f"Expected all 5 register requests to return 200, got: {statuses}" ) # ── Criterion 4: X-Real-IP and X-Forwarded-For for rate counting ────────────── @pytest.mark.asyncio async def test_x_real_ip_header_is_used_for_rate_counting(): """Rate counter keys are derived from X-Real-IP: two requests sharing the same X-Real-IP share the same counter and collectively hit the 429 limit. """ async with make_app_client() as client: await client.post( "/api/register", json={"uuid": _UUID_XREALIP_A, "name": "RealIPUser"} ) ip_hdrs = {"X-Real-IP": "203.0.113.10"} payload = {"user_id": _UUID_XREALIP_A, "timestamp": 1742478000000} statuses = [] for _ in range(11): r = await client.post("/api/signal", json=payload, headers=ip_hdrs) statuses.append(r.status_code) assert statuses[-1] == 429, ( f"Expected 429 on 11th signal with same X-Real-IP, got: {statuses}" ) @pytest.mark.asyncio async def test_x_forwarded_for_header_is_used_for_rate_counting(): """Rate counter keys are derived from X-Forwarded-For (first IP) when X-Real-IP is absent: requests sharing the same forwarded IP hit the limit. """ async with make_app_client() as client: await client.post( "/api/register", json={"uuid": _UUID_XFWD, "name": "FwdUser"} ) # Chain: first IP is the original client (only that one is used) fwd_hdrs = {"X-Forwarded-For": "198.51.100.5, 10.0.0.1, 172.16.0.1"} payload = {"user_id": _UUID_XFWD, "timestamp": 1742478000000} statuses = [] for _ in range(11): r = await client.post("/api/signal", json=payload, headers=fwd_hdrs) statuses.append(r.status_code) assert statuses[-1] == 429, ( f"Expected 429 on 11th request with same X-Forwarded-For first IP, got: {statuses}" ) @pytest.mark.asyncio async def test_different_x_real_ip_values_have_independent_counters(): """Exhausting the rate limit for IP-A must not block IP-B. Verifies that rate-limit keys are truly per-IP. """ async with make_app_client() as client: r_a = await client.post( "/api/register", json={"uuid": _UUID_XREALIP_A, "name": "IPA"} ) r_b = await client.post( "/api/register", json={"uuid": _UUID_XREALIP_B, "name": "IPB"} ) api_key_a = r_a.json()["api_key"] api_key_b = r_b.json()["api_key"] # Exhaust limit for IP-A (with valid auth so requests reach the rate limiter) for _ in range(11): await client.post( "/api/signal", json={"user_id": _UUID_XREALIP_A, "timestamp": 1742478000000}, headers={ "X-Real-IP": "198.51.100.100", "Authorization": f"Bearer {api_key_a}", }, ) # IP-B has its own independent counter — must not be blocked r = await client.post( "/api/signal", json={"user_id": _UUID_XREALIP_B, "timestamp": 1742478000000}, headers={ "X-Real-IP": "198.51.100.200", "Authorization": f"Bearer {api_key_b}", }, ) assert r.status_code == 200, ( f"IP-B was incorrectly blocked after IP-A exhausted its counter: {r.status_code}" )