baton/tests/test_db.py
2026-03-20 20:44:00 +02:00

248 lines
7.5 KiB
Python

"""
Tests for backend/db.py.
Uses a temporary file-based SQLite DB so all connections opened by
_get_conn() share the same database file (in-memory DBs are isolated
per-connection and cannot be shared across calls).
"""
from __future__ import annotations
import os
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
os.environ.setdefault("CHAT_ID", "-1001234567890")
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
import aiosqlite
def _safe_aiosqlite_await(self):
if not self._thread._started.is_set():
self._thread.start()
return self._connect().__await__()
aiosqlite.core.Connection.__await__ = _safe_aiosqlite_await # type: ignore[method-assign]
import tempfile
import pytest
from backend import config, db
def _tmpdb():
"""Return a fresh temp-file path and set config.DB_PATH."""
path = tempfile.mktemp(suffix=".db")
config.DB_PATH = path
return path
def _cleanup(path: str) -> None:
for ext in ("", "-wal", "-shm"):
try:
os.unlink(path + ext)
except FileNotFoundError:
pass
# ---------------------------------------------------------------------------
# init_db — schema / pragma verification
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_db_creates_tables():
"""init_db creates users, signals and telegram_batches tables."""
path = _tmpdb()
try:
await db.init_db()
# Verify by querying sqlite_master
async with aiosqlite.connect(path) as conn:
async with conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
) as cur:
rows = await cur.fetchall()
table_names = {r[0] for r in rows}
assert "users" in table_names
assert "signals" in table_names
assert "telegram_batches" in table_names
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_init_db_wal_mode():
"""PRAGMA journal_mode = wal after init_db."""
path = _tmpdb()
try:
await db.init_db()
async with aiosqlite.connect(path) as conn:
async with conn.execute("PRAGMA journal_mode") as cur:
row = await cur.fetchone()
assert row[0] == "wal"
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_init_db_busy_timeout():
"""PRAGMA busy_timeout = 5000 after init_db."""
path = _tmpdb()
try:
await db.init_db()
async with aiosqlite.connect(path) as conn:
async with conn.execute("PRAGMA busy_timeout") as cur:
row = await cur.fetchone()
assert row[0] == 5000
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_init_db_synchronous():
"""PRAGMA synchronous = 1 (NORMAL) on each connection opened by _get_conn().
The PRAGMA is per-connection (not file-level), so we must verify it via
a connection created by _get_conn() rather than a raw aiosqlite.connect().
"""
path = _tmpdb()
try:
await db.init_db()
# Check synchronous on a new connection via _get_conn()
from backend.db import _get_conn
conn = await _get_conn()
async with conn.execute("PRAGMA synchronous") as cur:
row = await cur.fetchone()
await conn.close()
# 1 == NORMAL
assert row[0] == 1
finally:
_cleanup(path)
# ---------------------------------------------------------------------------
# register_user
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_register_user_returns_id():
"""register_user returns a dict with a positive integer user_id."""
path = _tmpdb()
try:
await db.init_db()
result = await db.register_user(uuid="uuid-001", name="Alice")
assert isinstance(result["user_id"], int)
assert result["user_id"] > 0
assert result["uuid"] == "uuid-001"
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_register_user_idempotent():
"""Calling register_user twice with the same uuid returns the same id."""
path = _tmpdb()
try:
await db.init_db()
r1 = await db.register_user(uuid="uuid-002", name="Bob")
r2 = await db.register_user(uuid="uuid-002", name="Bob")
assert r1["user_id"] == r2["user_id"]
finally:
_cleanup(path)
# ---------------------------------------------------------------------------
# get_user_name
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_user_name_returns_name():
"""get_user_name returns the correct name for a registered user."""
path = _tmpdb()
try:
await db.init_db()
await db.register_user(uuid="uuid-003", name="Charlie")
name = await db.get_user_name("uuid-003")
assert name == "Charlie"
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_get_user_name_unknown_returns_none():
"""get_user_name returns None for an unregistered uuid."""
path = _tmpdb()
try:
await db.init_db()
name = await db.get_user_name("nonexistent-uuid")
assert name is None
finally:
_cleanup(path)
# ---------------------------------------------------------------------------
# save_signal
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_save_signal_returns_id():
"""save_signal returns a valid positive integer signal id."""
path = _tmpdb()
try:
await db.init_db()
await db.register_user(uuid="uuid-004", name="Dana")
signal_id = await db.save_signal(
user_uuid="uuid-004",
timestamp=1742478000000,
lat=55.7558,
lon=37.6173,
accuracy=15.0,
)
assert isinstance(signal_id, int)
assert signal_id > 0
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_save_signal_without_geo():
"""save_signal with geo=None stores NULL lat/lon/accuracy."""
path = _tmpdb()
try:
await db.init_db()
await db.register_user(uuid="uuid-005", name="Eve")
signal_id = await db.save_signal(
user_uuid="uuid-005",
timestamp=1742478000000,
lat=None,
lon=None,
accuracy=None,
)
assert isinstance(signal_id, int)
assert signal_id > 0
# Verify nulls in DB
async with aiosqlite.connect(path) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(
"SELECT lat, lon, accuracy FROM signals WHERE id = ?", (signal_id,)
) as cur:
row = await cur.fetchone()
assert row["lat"] is None
assert row["lon"] is None
assert row["accuracy"] is None
finally:
_cleanup(path)
@pytest.mark.asyncio
async def test_save_signal_increments_id():
"""Each call to save_signal returns a higher id."""
path = _tmpdb()
try:
await db.init_db()
await db.register_user(uuid="uuid-006", name="Frank")
id1 = await db.save_signal("uuid-006", 1742478000001, None, None, None)
id2 = await db.save_signal("uuid-006", 1742478000002, None, None, None)
assert id2 > id1
finally:
_cleanup(path)