249 lines
7.5 KiB
Python
249 lines
7.5 KiB
Python
|
|
"""
|
||
|
|
Tests for backend/db.py.
|
||
|
|
|
||
|
|
Uses a temporary file-based SQLite DB so all connections opened by
|
||
|
|
_get_conn() share the same database file (in-memory DBs are isolated
|
||
|
|
per-connection and cannot be shared across calls).
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
|
||
|
|
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||
|
|
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||
|
|
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||
|
|
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||
|
|
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||
|
|
|
||
|
|
import aiosqlite
|
||
|
|
|
||
|
|
def _safe_aiosqlite_await(self):
|
||
|
|
if not self._thread._started.is_set():
|
||
|
|
self._thread.start()
|
||
|
|
return self._connect().__await__()
|
||
|
|
|
||
|
|
aiosqlite.core.Connection.__await__ = _safe_aiosqlite_await # type: ignore[method-assign]
|
||
|
|
|
||
|
|
import tempfile
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from backend import config, db
|
||
|
|
|
||
|
|
|
||
|
|
def _tmpdb():
|
||
|
|
"""Return a fresh temp-file path and set config.DB_PATH."""
|
||
|
|
path = tempfile.mktemp(suffix=".db")
|
||
|
|
config.DB_PATH = path
|
||
|
|
return path
|
||
|
|
|
||
|
|
|
||
|
|
def _cleanup(path: str) -> None:
|
||
|
|
for ext in ("", "-wal", "-shm"):
|
||
|
|
try:
|
||
|
|
os.unlink(path + ext)
|
||
|
|
except FileNotFoundError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# init_db — schema / pragma verification
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_init_db_creates_tables():
|
||
|
|
"""init_db creates users, signals and telegram_batches tables."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
# Verify by querying sqlite_master
|
||
|
|
async with aiosqlite.connect(path) as conn:
|
||
|
|
async with conn.execute(
|
||
|
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||
|
|
) as cur:
|
||
|
|
rows = await cur.fetchall()
|
||
|
|
table_names = {r[0] for r in rows}
|
||
|
|
assert "users" in table_names
|
||
|
|
assert "signals" in table_names
|
||
|
|
assert "telegram_batches" in table_names
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_init_db_wal_mode():
|
||
|
|
"""PRAGMA journal_mode = wal after init_db."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
async with aiosqlite.connect(path) as conn:
|
||
|
|
async with conn.execute("PRAGMA journal_mode") as cur:
|
||
|
|
row = await cur.fetchone()
|
||
|
|
assert row[0] == "wal"
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_init_db_busy_timeout():
|
||
|
|
"""PRAGMA busy_timeout = 5000 after init_db."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
async with aiosqlite.connect(path) as conn:
|
||
|
|
async with conn.execute("PRAGMA busy_timeout") as cur:
|
||
|
|
row = await cur.fetchone()
|
||
|
|
assert row[0] == 5000
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_init_db_synchronous():
|
||
|
|
"""PRAGMA synchronous = 1 (NORMAL) on each connection opened by _get_conn().
|
||
|
|
|
||
|
|
The PRAGMA is per-connection (not file-level), so we must verify it via
|
||
|
|
a connection created by _get_conn() rather than a raw aiosqlite.connect().
|
||
|
|
"""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
# Check synchronous on a new connection via _get_conn()
|
||
|
|
from backend.db import _get_conn
|
||
|
|
conn = await _get_conn()
|
||
|
|
async with conn.execute("PRAGMA synchronous") as cur:
|
||
|
|
row = await cur.fetchone()
|
||
|
|
await conn.close()
|
||
|
|
# 1 == NORMAL
|
||
|
|
assert row[0] == 1
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# register_user
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_register_user_returns_id():
|
||
|
|
"""register_user returns a dict with a positive integer user_id."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
result = await db.register_user(uuid="uuid-001", name="Alice")
|
||
|
|
assert isinstance(result["user_id"], int)
|
||
|
|
assert result["user_id"] > 0
|
||
|
|
assert result["uuid"] == "uuid-001"
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_register_user_idempotent():
|
||
|
|
"""Calling register_user twice with the same uuid returns the same id."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
r1 = await db.register_user(uuid="uuid-002", name="Bob")
|
||
|
|
r2 = await db.register_user(uuid="uuid-002", name="Bob")
|
||
|
|
assert r1["user_id"] == r2["user_id"]
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# get_user_name
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_user_name_returns_name():
|
||
|
|
"""get_user_name returns the correct name for a registered user."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
await db.register_user(uuid="uuid-003", name="Charlie")
|
||
|
|
name = await db.get_user_name("uuid-003")
|
||
|
|
assert name == "Charlie"
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_get_user_name_unknown_returns_none():
|
||
|
|
"""get_user_name returns None for an unregistered uuid."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
name = await db.get_user_name("nonexistent-uuid")
|
||
|
|
assert name is None
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# save_signal
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_save_signal_returns_id():
|
||
|
|
"""save_signal returns a valid positive integer signal id."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
await db.register_user(uuid="uuid-004", name="Dana")
|
||
|
|
signal_id = await db.save_signal(
|
||
|
|
user_uuid="uuid-004",
|
||
|
|
timestamp=1742478000000,
|
||
|
|
lat=55.7558,
|
||
|
|
lon=37.6173,
|
||
|
|
accuracy=15.0,
|
||
|
|
)
|
||
|
|
assert isinstance(signal_id, int)
|
||
|
|
assert signal_id > 0
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_save_signal_without_geo():
|
||
|
|
"""save_signal with geo=None stores NULL lat/lon/accuracy."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
await db.register_user(uuid="uuid-005", name="Eve")
|
||
|
|
signal_id = await db.save_signal(
|
||
|
|
user_uuid="uuid-005",
|
||
|
|
timestamp=1742478000000,
|
||
|
|
lat=None,
|
||
|
|
lon=None,
|
||
|
|
accuracy=None,
|
||
|
|
)
|
||
|
|
assert isinstance(signal_id, int)
|
||
|
|
assert signal_id > 0
|
||
|
|
|
||
|
|
# Verify nulls in DB
|
||
|
|
async with aiosqlite.connect(path) as conn:
|
||
|
|
conn.row_factory = aiosqlite.Row
|
||
|
|
async with conn.execute(
|
||
|
|
"SELECT lat, lon, accuracy FROM signals WHERE id = ?", (signal_id,)
|
||
|
|
) as cur:
|
||
|
|
row = await cur.fetchone()
|
||
|
|
assert row["lat"] is None
|
||
|
|
assert row["lon"] is None
|
||
|
|
assert row["accuracy"] is None
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_save_signal_increments_id():
|
||
|
|
"""Each call to save_signal returns a higher id."""
|
||
|
|
path = _tmpdb()
|
||
|
|
try:
|
||
|
|
await db.init_db()
|
||
|
|
await db.register_user(uuid="uuid-006", name="Frank")
|
||
|
|
id1 = await db.save_signal("uuid-006", 1742478000001, None, None, None)
|
||
|
|
id2 = await db.save_signal("uuid-006", 1742478000002, None, None, None)
|
||
|
|
assert id2 > id1
|
||
|
|
finally:
|
||
|
|
_cleanup(path)
|