""" Tests for backend/db.py. Uses a temporary file-based SQLite DB so all connections opened by _get_conn() share the same database file (in-memory DBs are isolated per-connection and cannot be shared across calls). """ from __future__ import annotations import os os.environ.setdefault("BOT_TOKEN", "test-bot-token") os.environ.setdefault("CHAT_ID", "-1001234567890") os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret") os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram") os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000") import aiosqlite def _safe_aiosqlite_await(self): if not self._thread._started.is_set(): self._thread.start() return self._connect().__await__() aiosqlite.core.Connection.__await__ = _safe_aiosqlite_await # type: ignore[method-assign] import tempfile import pytest from backend import config, db def _tmpdb(): """Return a fresh temp-file path and set config.DB_PATH.""" path = tempfile.mktemp(suffix=".db") config.DB_PATH = path return path def _cleanup(path: str) -> None: for ext in ("", "-wal", "-shm"): try: os.unlink(path + ext) except FileNotFoundError: pass # --------------------------------------------------------------------------- # init_db — schema / pragma verification # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_init_db_creates_tables(): """init_db creates users, signals and telegram_batches tables.""" path = _tmpdb() try: await db.init_db() # Verify by querying sqlite_master async with aiosqlite.connect(path) as conn: async with conn.execute( "SELECT name FROM sqlite_master WHERE type='table'" ) as cur: rows = await cur.fetchall() table_names = {r[0] for r in rows} assert "users" in table_names assert "signals" in table_names assert "telegram_batches" in table_names finally: _cleanup(path) @pytest.mark.asyncio async def test_init_db_wal_mode(): """PRAGMA journal_mode = wal after init_db.""" path = _tmpdb() try: await db.init_db() async with aiosqlite.connect(path) as conn: async with conn.execute("PRAGMA journal_mode") as cur: row = await cur.fetchone() assert row[0] == "wal" finally: _cleanup(path) @pytest.mark.asyncio async def test_init_db_busy_timeout(): """PRAGMA busy_timeout = 5000 after init_db.""" path = _tmpdb() try: await db.init_db() async with aiosqlite.connect(path) as conn: async with conn.execute("PRAGMA busy_timeout") as cur: row = await cur.fetchone() assert row[0] == 5000 finally: _cleanup(path) @pytest.mark.asyncio async def test_init_db_synchronous(): """PRAGMA synchronous = 1 (NORMAL) on each connection opened by _get_conn(). The PRAGMA is per-connection (not file-level), so we must verify it via a connection created by _get_conn() rather than a raw aiosqlite.connect(). """ path = _tmpdb() try: await db.init_db() # Check synchronous on a new connection via _get_conn() from backend.db import _get_conn async with _get_conn() as conn: async with conn.execute("PRAGMA synchronous") as cur: row = await cur.fetchone() # 1 == NORMAL assert row[0] == 1 finally: _cleanup(path) # --------------------------------------------------------------------------- # register_user # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_register_user_returns_id(): """register_user returns a dict with a positive integer user_id.""" path = _tmpdb() try: await db.init_db() result = await db.register_user(uuid="uuid-001", name="Alice") assert isinstance(result["user_id"], int) assert result["user_id"] > 0 assert result["uuid"] == "uuid-001" finally: _cleanup(path) @pytest.mark.asyncio async def test_register_user_idempotent(): """Calling register_user twice with the same uuid returns the same id.""" path = _tmpdb() try: await db.init_db() r1 = await db.register_user(uuid="uuid-002", name="Bob") r2 = await db.register_user(uuid="uuid-002", name="Bob") assert r1["user_id"] == r2["user_id"] finally: _cleanup(path) # --------------------------------------------------------------------------- # get_user_name # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_get_user_name_returns_name(): """get_user_name returns the correct name for a registered user.""" path = _tmpdb() try: await db.init_db() await db.register_user(uuid="uuid-003", name="Charlie") name = await db.get_user_name("uuid-003") assert name == "Charlie" finally: _cleanup(path) @pytest.mark.asyncio async def test_get_user_name_unknown_returns_none(): """get_user_name returns None for an unregistered uuid.""" path = _tmpdb() try: await db.init_db() name = await db.get_user_name("nonexistent-uuid") assert name is None finally: _cleanup(path) # --------------------------------------------------------------------------- # save_signal # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_save_signal_returns_id(): """save_signal returns a valid positive integer signal id.""" path = _tmpdb() try: await db.init_db() await db.register_user(uuid="uuid-004", name="Dana") signal_id = await db.save_signal( user_uuid="uuid-004", timestamp=1742478000000, lat=55.7558, lon=37.6173, accuracy=15.0, ) assert isinstance(signal_id, int) assert signal_id > 0 finally: _cleanup(path) @pytest.mark.asyncio async def test_save_signal_without_geo(): """save_signal with geo=None stores NULL lat/lon/accuracy.""" path = _tmpdb() try: await db.init_db() await db.register_user(uuid="uuid-005", name="Eve") signal_id = await db.save_signal( user_uuid="uuid-005", timestamp=1742478000000, lat=None, lon=None, accuracy=None, ) assert isinstance(signal_id, int) assert signal_id > 0 # Verify nulls in DB async with aiosqlite.connect(path) as conn: conn.row_factory = aiosqlite.Row async with conn.execute( "SELECT lat, lon, accuracy FROM signals WHERE id = ?", (signal_id,) ) as cur: row = await cur.fetchone() assert row["lat"] is None assert row["lon"] is None assert row["accuracy"] is None finally: _cleanup(path) @pytest.mark.asyncio async def test_save_signal_increments_id(): """Each call to save_signal returns a higher id.""" path = _tmpdb() try: await db.init_db() await db.register_user(uuid="uuid-006", name="Frank") id1 = await db.save_signal("uuid-006", 1742478000001, None, None, None) id2 = await db.save_signal("uuid-006", 1742478000002, None, None, None) assert id2 > id1 finally: _cleanup(path)