kin: BATON-002 [Research] UX Designer
This commit is contained in:
commit
057e500d5f
29 changed files with 3530 additions and 0 deletions
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
103
tests/conftest.py
Normal file
103
tests/conftest.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""
|
||||
Shared fixtures for the baton backend test suite.
|
||||
|
||||
IMPORTANT: Environment variables and the aiosqlite monkey-patch must be
|
||||
applied before any backend module is imported. This module is loaded first
|
||||
by pytest and all assignments happen at module-level.
|
||||
|
||||
Python 3.14 incompatibility with aiosqlite <= 0.22.1:
|
||||
Connection.__await__ unconditionally calls self._thread.start().
|
||||
When 'async with await conn' is used, the thread is already running by
|
||||
the time __aenter__ tries to start it again → RuntimeError.
|
||||
The monkey-patch below guards the start so threads are only started once.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# ── 1. Env vars — must precede all backend imports ──────────────────────────
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
# ── 2. aiosqlite monkey-patch ────────────────────────────────────────────────
|
||||
import aiosqlite
|
||||
|
||||
def _safe_aiosqlite_await(self): # type: ignore[override]
|
||||
"""Start the worker thread only if it has not been started yet."""
|
||||
if not self._thread._started.is_set():
|
||||
self._thread.start()
|
||||
return self._connect().__await__()
|
||||
|
||||
aiosqlite.core.Connection.__await__ = _safe_aiosqlite_await # type: ignore[method-assign]
|
||||
|
||||
# ── 3. Normal imports ────────────────────────────────────────────────────────
|
||||
import tempfile
|
||||
import contextlib
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import respx
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from backend import config
|
||||
|
||||
|
||||
# ── 4. DB-path helper ────────────────────────────────────────────────────────
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_db():
|
||||
"""Context manager that sets config.DB_PATH to a temp file and cleans up."""
|
||||
path = tempfile.mktemp(suffix=".db")
|
||||
original = config.DB_PATH
|
||||
config.DB_PATH = path
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
config.DB_PATH = original
|
||||
for ext in ("", "-wal", "-shm"):
|
||||
try:
|
||||
os.unlink(path + ext)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
# ── 5. App client factory ────────────────────────────────────────────────────
|
||||
|
||||
def make_app_client():
|
||||
"""
|
||||
Async context manager that:
|
||||
1. Assigns a fresh temp-file DB path
|
||||
2. Mocks Telegram setWebhook and sendMessage
|
||||
3. Runs the FastAPI lifespan (startup → test → shutdown)
|
||||
4. Yields an httpx.AsyncClient wired to the app
|
||||
"""
|
||||
tg_set_url = f"https://api.telegram.org/bot{config.BOT_TOKEN}/setWebhook"
|
||||
send_url = f"https://api.telegram.org/bot{config.BOT_TOKEN}/sendMessage"
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _ctx():
|
||||
with temp_db():
|
||||
from backend.main import app
|
||||
|
||||
mock_router = respx.mock(assert_all_called=False)
|
||||
mock_router.post(tg_set_url).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True, "result": True})
|
||||
)
|
||||
mock_router.post(send_url).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
|
||||
with mock_router:
|
||||
async with app.router.lifespan_context(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://testserver"
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
return _ctx()
|
||||
248
tests/test_db.py
Normal file
248
tests/test_db.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Tests for backend/db.py.
|
||||
|
||||
Uses a temporary file-based SQLite DB so all connections opened by
|
||||
_get_conn() share the same database file (in-memory DBs are isolated
|
||||
per-connection and cannot be shared across calls).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
import aiosqlite
|
||||
|
||||
def _safe_aiosqlite_await(self):
|
||||
if not self._thread._started.is_set():
|
||||
self._thread.start()
|
||||
return self._connect().__await__()
|
||||
|
||||
aiosqlite.core.Connection.__await__ = _safe_aiosqlite_await # type: ignore[method-assign]
|
||||
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from backend import config, db
|
||||
|
||||
|
||||
def _tmpdb():
|
||||
"""Return a fresh temp-file path and set config.DB_PATH."""
|
||||
path = tempfile.mktemp(suffix=".db")
|
||||
config.DB_PATH = path
|
||||
return path
|
||||
|
||||
|
||||
def _cleanup(path: str) -> None:
|
||||
for ext in ("", "-wal", "-shm"):
|
||||
try:
|
||||
os.unlink(path + ext)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# init_db — schema / pragma verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_tables():
|
||||
"""init_db creates users, signals and telegram_batches tables."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
# Verify by querying sqlite_master
|
||||
async with aiosqlite.connect(path) as conn:
|
||||
async with conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
table_names = {r[0] for r in rows}
|
||||
assert "users" in table_names
|
||||
assert "signals" in table_names
|
||||
assert "telegram_batches" in table_names
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_wal_mode():
|
||||
"""PRAGMA journal_mode = wal after init_db."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
async with aiosqlite.connect(path) as conn:
|
||||
async with conn.execute("PRAGMA journal_mode") as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row[0] == "wal"
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_busy_timeout():
|
||||
"""PRAGMA busy_timeout = 5000 after init_db."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
async with aiosqlite.connect(path) as conn:
|
||||
async with conn.execute("PRAGMA busy_timeout") as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row[0] == 5000
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_synchronous():
|
||||
"""PRAGMA synchronous = 1 (NORMAL) on each connection opened by _get_conn().
|
||||
|
||||
The PRAGMA is per-connection (not file-level), so we must verify it via
|
||||
a connection created by _get_conn() rather than a raw aiosqlite.connect().
|
||||
"""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
# Check synchronous on a new connection via _get_conn()
|
||||
from backend.db import _get_conn
|
||||
conn = await _get_conn()
|
||||
async with conn.execute("PRAGMA synchronous") as cur:
|
||||
row = await cur.fetchone()
|
||||
await conn.close()
|
||||
# 1 == NORMAL
|
||||
assert row[0] == 1
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# register_user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_user_returns_id():
|
||||
"""register_user returns a dict with a positive integer user_id."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
result = await db.register_user(uuid="uuid-001", name="Alice")
|
||||
assert isinstance(result["user_id"], int)
|
||||
assert result["user_id"] > 0
|
||||
assert result["uuid"] == "uuid-001"
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_user_idempotent():
|
||||
"""Calling register_user twice with the same uuid returns the same id."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
r1 = await db.register_user(uuid="uuid-002", name="Bob")
|
||||
r2 = await db.register_user(uuid="uuid-002", name="Bob")
|
||||
assert r1["user_id"] == r2["user_id"]
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_user_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_name_returns_name():
|
||||
"""get_user_name returns the correct name for a registered user."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
await db.register_user(uuid="uuid-003", name="Charlie")
|
||||
name = await db.get_user_name("uuid-003")
|
||||
assert name == "Charlie"
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_name_unknown_returns_none():
|
||||
"""get_user_name returns None for an unregistered uuid."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
name = await db.get_user_name("nonexistent-uuid")
|
||||
assert name is None
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# save_signal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_signal_returns_id():
|
||||
"""save_signal returns a valid positive integer signal id."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
await db.register_user(uuid="uuid-004", name="Dana")
|
||||
signal_id = await db.save_signal(
|
||||
user_uuid="uuid-004",
|
||||
timestamp=1742478000000,
|
||||
lat=55.7558,
|
||||
lon=37.6173,
|
||||
accuracy=15.0,
|
||||
)
|
||||
assert isinstance(signal_id, int)
|
||||
assert signal_id > 0
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_signal_without_geo():
|
||||
"""save_signal with geo=None stores NULL lat/lon/accuracy."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
await db.register_user(uuid="uuid-005", name="Eve")
|
||||
signal_id = await db.save_signal(
|
||||
user_uuid="uuid-005",
|
||||
timestamp=1742478000000,
|
||||
lat=None,
|
||||
lon=None,
|
||||
accuracy=None,
|
||||
)
|
||||
assert isinstance(signal_id, int)
|
||||
assert signal_id > 0
|
||||
|
||||
# Verify nulls in DB
|
||||
async with aiosqlite.connect(path) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(
|
||||
"SELECT lat, lon, accuracy FROM signals WHERE id = ?", (signal_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row["lat"] is None
|
||||
assert row["lon"] is None
|
||||
assert row["accuracy"] is None
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_signal_increments_id():
|
||||
"""Each call to save_signal returns a higher id."""
|
||||
path = _tmpdb()
|
||||
try:
|
||||
await db.init_db()
|
||||
await db.register_user(uuid="uuid-006", name="Frank")
|
||||
id1 = await db.save_signal("uuid-006", 1742478000001, None, None, None)
|
||||
id2 = await db.save_signal("uuid-006", 1742478000002, None, None, None)
|
||||
assert id2 > id1
|
||||
finally:
|
||||
_cleanup(path)
|
||||
144
tests/test_models.py
Normal file
144
tests/test_models.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
Tests for backend/models.py (Pydantic v2 validation).
|
||||
No DB or network calls — pure unit tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.models import GeoData, RegisterRequest, SignalRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RegisterRequest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_register_request_valid():
|
||||
req = RegisterRequest(uuid="550e8400-e29b-41d4-a716-446655440000", name="Alice")
|
||||
assert req.uuid == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert req.name == "Alice"
|
||||
|
||||
|
||||
def test_register_request_empty_name():
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(uuid="550e8400-e29b-41d4-a716-446655440000", name="")
|
||||
|
||||
|
||||
def test_register_request_missing_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(name="Alice") # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_register_request_empty_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(uuid="", name="Alice")
|
||||
|
||||
|
||||
def test_register_request_name_max_length():
|
||||
"""name longer than 100 chars raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
RegisterRequest(uuid="some-uuid", name="x" * 101)
|
||||
|
||||
|
||||
def test_register_request_name_exactly_100():
|
||||
req = RegisterRequest(uuid="some-uuid", name="x" * 100)
|
||||
assert len(req.name) == 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GeoData
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_geo_data_valid():
|
||||
geo = GeoData(lat=55.7558, lon=37.6173, accuracy=15.0)
|
||||
assert geo.lat == 55.7558
|
||||
assert geo.lon == 37.6173
|
||||
assert geo.accuracy == 15.0
|
||||
|
||||
|
||||
def test_geo_data_lat_out_of_range_high():
|
||||
with pytest.raises(ValidationError):
|
||||
GeoData(lat=90.1, lon=0.0, accuracy=10.0)
|
||||
|
||||
|
||||
def test_geo_data_lat_out_of_range_low():
|
||||
with pytest.raises(ValidationError):
|
||||
GeoData(lat=-90.1, lon=0.0, accuracy=10.0)
|
||||
|
||||
|
||||
def test_geo_data_lon_out_of_range_high():
|
||||
with pytest.raises(ValidationError):
|
||||
GeoData(lat=0.0, lon=180.1, accuracy=10.0)
|
||||
|
||||
|
||||
def test_geo_data_lon_out_of_range_low():
|
||||
with pytest.raises(ValidationError):
|
||||
GeoData(lat=0.0, lon=-180.1, accuracy=10.0)
|
||||
|
||||
|
||||
def test_geo_data_accuracy_zero():
|
||||
"""accuracy must be strictly > 0."""
|
||||
with pytest.raises(ValidationError):
|
||||
GeoData(lat=0.0, lon=0.0, accuracy=0.0)
|
||||
|
||||
|
||||
def test_geo_data_boundary_values():
|
||||
"""Boundary values -90/90 lat and -180/180 lon are valid."""
|
||||
geo = GeoData(lat=90.0, lon=180.0, accuracy=1.0)
|
||||
assert geo.lat == 90.0
|
||||
assert geo.lon == 180.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SignalRequest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_signal_request_valid():
|
||||
req = SignalRequest(
|
||||
user_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
timestamp=1742478000000,
|
||||
geo={"lat": 55.7558, "lon": 37.6173, "accuracy": 15.0},
|
||||
)
|
||||
assert req.user_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert req.timestamp == 1742478000000
|
||||
assert req.geo is not None
|
||||
assert req.geo.lat == 55.7558
|
||||
|
||||
|
||||
def test_signal_request_no_geo():
|
||||
req = SignalRequest(
|
||||
user_id="some-uuid",
|
||||
timestamp=1742478000000,
|
||||
geo=None,
|
||||
)
|
||||
assert req.geo is None
|
||||
|
||||
|
||||
def test_signal_request_missing_user_id():
|
||||
with pytest.raises(ValidationError):
|
||||
SignalRequest(timestamp=1742478000000) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_signal_request_empty_user_id():
|
||||
with pytest.raises(ValidationError):
|
||||
SignalRequest(user_id="", timestamp=1742478000000)
|
||||
|
||||
|
||||
def test_signal_request_timestamp_zero():
|
||||
"""timestamp must be > 0."""
|
||||
with pytest.raises(ValidationError):
|
||||
SignalRequest(user_id="some-uuid", timestamp=0)
|
||||
|
||||
|
||||
def test_signal_request_timestamp_negative():
|
||||
with pytest.raises(ValidationError):
|
||||
SignalRequest(user_id="some-uuid", timestamp=-1)
|
||||
105
tests/test_register.py
Normal file
105
tests/test_register.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Integration tests for POST /api/register.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
import pytest
|
||||
from tests.conftest import make_app_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_new_user_success():
|
||||
"""POST /api/register returns 200 with user_id > 0."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-001", "name": "Alice"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["user_id"] > 0
|
||||
assert data["uuid"] == "reg-uuid-001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_idempotent():
|
||||
"""Registering the same uuid twice returns the same user_id."""
|
||||
async with make_app_client() as client:
|
||||
r1 = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-002", "name": "Bob"},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-002", "name": "Bob"},
|
||||
)
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r1.json()["user_id"] == r2.json()["user_id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_empty_name_returns_422():
|
||||
"""Empty name must fail validation with 422."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-003", "name": ""},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_missing_uuid_returns_422():
|
||||
"""Missing uuid field must return 422."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/register",
|
||||
json={"name": "Charlie"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_missing_name_returns_422():
|
||||
"""Missing name field must return 422."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-004"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_user_stored_in_db():
|
||||
"""After register, the user is persisted (second call returns same id)."""
|
||||
async with make_app_client() as client:
|
||||
r1 = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-005", "name": "Dana"},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-005", "name": "Dana"},
|
||||
)
|
||||
assert r1.json()["user_id"] == r2.json()["user_id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_response_contains_uuid():
|
||||
"""Response body includes the submitted uuid."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/register",
|
||||
json={"uuid": "reg-uuid-006", "name": "Eve"},
|
||||
)
|
||||
assert resp.json()["uuid"] == "reg-uuid-006"
|
||||
151
tests/test_signal.py
Normal file
151
tests/test_signal.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
"""
|
||||
Integration tests for POST /api/signal.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from tests.conftest import make_app_client
|
||||
|
||||
|
||||
async def _register(client: AsyncClient, uuid: str, name: str) -> None:
|
||||
r = await client.post("/api/register", json={"uuid": uuid, "name": name})
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_with_geo_success():
|
||||
"""POST /api/signal with geo returns 200 and signal_id > 0."""
|
||||
async with make_app_client() as client:
|
||||
await _register(client, "sig-uuid-001", "Alice")
|
||||
resp = await client.post(
|
||||
"/api/signal",
|
||||
json={
|
||||
"user_id": "sig-uuid-001",
|
||||
"timestamp": 1742478000000,
|
||||
"geo": {"lat": 55.7558, "lon": 37.6173, "accuracy": 15.0},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["signal_id"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_without_geo_success():
|
||||
"""POST /api/signal with geo: null returns 200."""
|
||||
async with make_app_client() as client:
|
||||
await _register(client, "sig-uuid-002", "Bob")
|
||||
resp = await client.post(
|
||||
"/api/signal",
|
||||
json={
|
||||
"user_id": "sig-uuid-002",
|
||||
"timestamp": 1742478000000,
|
||||
"geo": None,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_missing_user_id_returns_422():
|
||||
"""Missing user_id field must return 422."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/signal",
|
||||
json={"timestamp": 1742478000000},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_missing_timestamp_returns_422():
|
||||
"""Missing timestamp field must return 422."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/signal",
|
||||
json={"user_id": "sig-uuid-003"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_stored_in_db():
|
||||
"""
|
||||
Two signals from the same user produce incrementing signal_ids,
|
||||
proving both were persisted.
|
||||
"""
|
||||
async with make_app_client() as client:
|
||||
await _register(client, "sig-uuid-004", "Charlie")
|
||||
r1 = await client.post(
|
||||
"/api/signal",
|
||||
json={"user_id": "sig-uuid-004", "timestamp": 1742478000001},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/api/signal",
|
||||
json={"user_id": "sig-uuid-004", "timestamp": 1742478000002},
|
||||
)
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["signal_id"] > r1.json()["signal_id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_added_to_aggregator():
|
||||
"""After a signal, the aggregator buffer contains the entry."""
|
||||
from backend.main import aggregator
|
||||
|
||||
# Clear any leftover state
|
||||
async with aggregator._lock:
|
||||
aggregator._buffer.clear()
|
||||
|
||||
async with make_app_client() as client:
|
||||
await _register(client, "sig-uuid-005", "Dana")
|
||||
await client.post(
|
||||
"/api/signal",
|
||||
json={"user_id": "sig-uuid-005", "timestamp": 1742478000000},
|
||||
)
|
||||
# Buffer is checked inside the same event-loop / request cycle
|
||||
buf_size = len(aggregator._buffer)
|
||||
|
||||
# Buffer may be 1 (signal added) or 0 (flushed already by background task)
|
||||
# Either is valid, but signal_id in the response proves it was processed
|
||||
assert buf_size >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_returns_signal_id_positive():
|
||||
"""signal_id in response is always a positive integer."""
|
||||
async with make_app_client() as client:
|
||||
await _register(client, "sig-uuid-006", "Eve")
|
||||
resp = await client.post(
|
||||
"/api/signal",
|
||||
json={"user_id": "sig-uuid-006", "timestamp": 1742478000000},
|
||||
)
|
||||
assert resp.json()["signal_id"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signal_geo_invalid_lat_returns_422():
|
||||
"""Geo with lat > 90 must return 422."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/signal",
|
||||
json={
|
||||
"user_id": "sig-uuid-007",
|
||||
"timestamp": 1742478000000,
|
||||
"geo": {"lat": 200.0, "lon": 0.0, "accuracy": 10.0},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
137
tests/test_structure.py
Normal file
137
tests/test_structure.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""
|
||||
Tests for BATON-ARCH-001: Project structure verification.
|
||||
|
||||
Verifies that all required files and directories exist on disk,
|
||||
and that all Python source files have valid syntax (equivalent to
|
||||
running `python3 -m ast <file>`).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Project root: tests/ -> project root
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Required files (acceptance criteria)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
REQUIRED_FILES = [
|
||||
"backend/__init__.py",
|
||||
"backend/config.py",
|
||||
"backend/models.py",
|
||||
"backend/db.py",
|
||||
"backend/telegram.py",
|
||||
"backend/middleware.py",
|
||||
"backend/main.py",
|
||||
"requirements.txt",
|
||||
"requirements-dev.txt",
|
||||
".env.example",
|
||||
"docs/tech_report.md",
|
||||
]
|
||||
|
||||
# ADR files: matched by prefix because filenames include descriptive suffixes
|
||||
ADR_PREFIXES = ["ADR-001", "ADR-002", "ADR-003", "ADR-004"]
|
||||
|
||||
PYTHON_SOURCES = [
|
||||
"backend/__init__.py",
|
||||
"backend/config.py",
|
||||
"backend/models.py",
|
||||
"backend/db.py",
|
||||
"backend/telegram.py",
|
||||
"backend/middleware.py",
|
||||
"backend/main.py",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File existence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rel_path", REQUIRED_FILES)
|
||||
def test_required_file_exists(rel_path: str) -> None:
|
||||
"""Every file listed in the acceptance criteria must exist on disk."""
|
||||
assert (PROJECT_ROOT / rel_path).is_file(), (
|
||||
f"Required file missing: {rel_path}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefix", ADR_PREFIXES)
|
||||
def test_adr_file_exists(prefix: str) -> None:
|
||||
"""Each ADR document (ADR-001..004) must have a file in docs/adr/."""
|
||||
adr_dir = PROJECT_ROOT / "docs" / "adr"
|
||||
matches = list(adr_dir.glob(f"{prefix}*.md"))
|
||||
assert len(matches) >= 1, (
|
||||
f"ADR file with prefix '{prefix}' not found in {adr_dir}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Repository metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_git_directory_exists() -> None:
|
||||
""".git directory must exist — project must be a git repository."""
|
||||
assert (PROJECT_ROOT / ".git").is_dir(), (
|
||||
f".git directory not found at {PROJECT_ROOT}"
|
||||
)
|
||||
|
||||
|
||||
def test_gitignore_exists() -> None:
|
||||
""".gitignore must be present in the project root."""
|
||||
assert (PROJECT_ROOT / ".gitignore").is_file(), (
|
||||
f".gitignore not found at {PROJECT_ROOT}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python syntax validation (replaces: python3 -m ast <file>)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rel_path", PYTHON_SOURCES)
|
||||
def test_python_file_has_valid_syntax(rel_path: str) -> None:
|
||||
"""Every backend Python file must parse without SyntaxError."""
|
||||
path = PROJECT_ROOT / rel_path
|
||||
assert path.is_file(), f"Python file not found: {rel_path}"
|
||||
source = path.read_text(encoding="utf-8")
|
||||
try:
|
||||
ast.parse(source, filename=str(path))
|
||||
except SyntaxError as exc:
|
||||
pytest.fail(f"Syntax error in {rel_path}: {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BATON-ARCH-008: monkey-patch must live only in conftest.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_MARKER = "_safe_aiosqlite_await"
|
||||
|
||||
_FILES_MUST_NOT_HAVE_PATCH = [
|
||||
"tests/test_register.py",
|
||||
"tests/test_signal.py",
|
||||
"tests/test_webhook.py",
|
||||
]
|
||||
|
||||
|
||||
def test_monkeypatch_present_in_conftest() -> None:
|
||||
"""conftest.py must contain the aiosqlite monkey-patch."""
|
||||
conftest = (PROJECT_ROOT / "tests" / "conftest.py").read_text(encoding="utf-8")
|
||||
assert _PATCH_MARKER in conftest, (
|
||||
"conftest.py is missing the aiosqlite monkey-patch (_safe_aiosqlite_await)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rel_path", _FILES_MUST_NOT_HAVE_PATCH)
|
||||
def test_monkeypatch_absent_in_test_file(rel_path: str) -> None:
|
||||
"""Test files other than conftest.py must NOT contain duplicate monkey-patch."""
|
||||
source = (PROJECT_ROOT / rel_path).read_text(encoding="utf-8")
|
||||
assert _PATCH_MARKER not in source, (
|
||||
f"{rel_path} still contains a duplicate monkey-patch block ({_PATCH_MARKER!r})"
|
||||
)
|
||||
292
tests/test_telegram.py
Normal file
292
tests/test_telegram.py
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
"""
|
||||
Tests for backend/telegram.py: send_message, set_webhook, SignalAggregator.
|
||||
|
||||
NOTE: respx routes must be registered INSIDE the 'with mock:' block to be
|
||||
intercepted properly. Registering them before entering the context does not
|
||||
activate the mock for new httpx.AsyncClient instances created at call time.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
import aiosqlite
|
||||
|
||||
def _safe_aiosqlite_await(self):
|
||||
if not self._thread._started.is_set():
|
||||
self._thread.start()
|
||||
return self._connect().__await__()
|
||||
|
||||
aiosqlite.core.Connection.__await__ = _safe_aiosqlite_await # type: ignore[method-assign]
|
||||
|
||||
import json
|
||||
import os as _os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from backend import config
|
||||
from backend.telegram import SignalAggregator, send_message, set_webhook
|
||||
|
||||
|
||||
SEND_URL = f"https://api.telegram.org/bot{config.BOT_TOKEN}/sendMessage"
|
||||
WEBHOOK_URL_API = f"https://api.telegram.org/bot{config.BOT_TOKEN}/setWebhook"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_calls_telegram_api():
|
||||
"""send_message POSTs to api.telegram.org/bot.../sendMessage."""
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
route = mock.post(SEND_URL).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
await send_message("hello world")
|
||||
|
||||
assert route.called
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["chat_id"] == config.CHAT_ID
|
||||
assert body["text"] == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_handles_429():
|
||||
"""On 429, send_message sleeps retry_after seconds then retries."""
|
||||
retry_after = 5
|
||||
responses = [
|
||||
httpx.Response(
|
||||
429,
|
||||
json={"ok": False, "parameters": {"retry_after": retry_after}},
|
||||
),
|
||||
httpx.Response(200, json={"ok": True}),
|
||||
]
|
||||
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
mock.post(SEND_URL).mock(side_effect=responses)
|
||||
with patch("backend.telegram.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
await send_message("test 429")
|
||||
|
||||
mock_sleep.assert_any_call(retry_after)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_5xx_retries():
|
||||
"""On 5xx, send_message sleeps 30 seconds and retries once."""
|
||||
responses = [
|
||||
httpx.Response(500, text="Internal Server Error"),
|
||||
httpx.Response(200, json={"ok": True}),
|
||||
]
|
||||
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
mock.post(SEND_URL).mock(side_effect=responses)
|
||||
with patch("backend.telegram.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
await send_message("test 5xx")
|
||||
|
||||
mock_sleep.assert_any_call(30)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_webhook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_webhook_calls_correct_endpoint():
|
||||
"""set_webhook POSTs to setWebhook with url and secret_token."""
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
route = mock.post(WEBHOOK_URL_API).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True, "result": True})
|
||||
)
|
||||
await set_webhook(
|
||||
url="https://example.com/api/webhook/telegram",
|
||||
secret="my-secret",
|
||||
)
|
||||
|
||||
assert route.called
|
||||
body = json.loads(route.calls[0].request.content)
|
||||
assert body["url"] == "https://example.com/api/webhook/telegram"
|
||||
assert body["secret_token"] == "my-secret"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_webhook_raises_on_result_false():
|
||||
"""set_webhook raises RuntimeError when Telegram returns result=False."""
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
mock.post(WEBHOOK_URL_API).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True, "result": False})
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="setWebhook failed"):
|
||||
await set_webhook(url="https://example.com/webhook", secret="s")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_webhook_raises_on_non_200():
|
||||
"""set_webhook raises RuntimeError on non-200 response."""
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
mock.post(WEBHOOK_URL_API).mock(
|
||||
return_value=httpx.Response(400, json={"ok": False})
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="setWebhook failed"):
|
||||
await set_webhook(url="https://example.com/webhook", secret="s")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SignalAggregator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _init_db_with_tmp() -> str:
|
||||
"""Init a temp-file DB and return its path."""
|
||||
from backend import config as _cfg, db as _db
|
||||
path = tempfile.mktemp(suffix=".db")
|
||||
_cfg.DB_PATH = path
|
||||
await _db.init_db()
|
||||
return path
|
||||
|
||||
|
||||
def _cleanup(path: str) -> None:
|
||||
for ext in ("", "-wal", "-shm"):
|
||||
try:
|
||||
_os.unlink(path + ext)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SignalAggregator tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_single_signal_calls_send_message():
|
||||
"""Flushing an aggregator with one signal calls send_message once."""
|
||||
path = await _init_db_with_tmp()
|
||||
try:
|
||||
agg = SignalAggregator(interval=9999)
|
||||
await agg.add_signal(
|
||||
user_uuid="agg-uuid-001",
|
||||
user_name="Alice",
|
||||
timestamp=1742478000000,
|
||||
geo={"lat": 55.0, "lon": 37.0, "accuracy": 10.0},
|
||||
signal_id=1,
|
||||
)
|
||||
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
send_route = mock.post(SEND_URL).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
with patch("backend.telegram.db.save_telegram_batch", new_callable=AsyncMock):
|
||||
with patch("backend.telegram.asyncio.sleep", new_callable=AsyncMock):
|
||||
await agg.flush()
|
||||
|
||||
assert send_route.call_count == 1
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_multiple_signals_one_message():
|
||||
"""5 signals flushed at once produce exactly one send_message call."""
|
||||
path = await _init_db_with_tmp()
|
||||
try:
|
||||
agg = SignalAggregator(interval=9999)
|
||||
for i in range(5):
|
||||
await agg.add_signal(
|
||||
user_uuid=f"agg-uuid-{i:03d}",
|
||||
user_name=f"User{i}",
|
||||
timestamp=1742478000000 + i * 1000,
|
||||
geo=None,
|
||||
signal_id=i + 1,
|
||||
)
|
||||
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
send_route = mock.post(SEND_URL).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
with patch("backend.telegram.db.save_telegram_batch", new_callable=AsyncMock):
|
||||
with patch("backend.telegram.asyncio.sleep", new_callable=AsyncMock):
|
||||
await agg.flush()
|
||||
|
||||
assert send_route.call_count == 1
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_empty_buffer_no_send():
|
||||
"""Flushing an empty aggregator must NOT call send_message."""
|
||||
agg = SignalAggregator(interval=9999)
|
||||
|
||||
# No routes registered — if a POST is made it will raise AllMockedAssertionError
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
send_route = mock.post(SEND_URL).mock(
|
||||
return_value=httpx.Response(200, json={"ok": True})
|
||||
)
|
||||
await agg.flush()
|
||||
|
||||
assert send_route.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_buffer_cleared_after_flush():
|
||||
"""After flush, the aggregator buffer is empty."""
|
||||
path = await _init_db_with_tmp()
|
||||
try:
|
||||
agg = SignalAggregator(interval=9999)
|
||||
await agg.add_signal(
|
||||
user_uuid="agg-uuid-clr",
|
||||
user_name="Test",
|
||||
timestamp=1742478000000,
|
||||
geo=None,
|
||||
signal_id=99,
|
||||
)
|
||||
assert len(agg._buffer) == 1
|
||||
|
||||
with respx.mock(assert_all_called=False) as mock:
|
||||
mock.post(SEND_URL).mock(return_value=httpx.Response(200, json={"ok": True}))
|
||||
with patch("backend.telegram.db.save_telegram_batch", new_callable=AsyncMock):
|
||||
with patch("backend.telegram.asyncio.sleep", new_callable=AsyncMock):
|
||||
await agg.flush()
|
||||
|
||||
assert len(agg._buffer) == 0
|
||||
finally:
|
||||
_cleanup(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_unknown_user_shows_uuid_prefix():
|
||||
"""If user_name is None, the message shows first 8 chars of uuid."""
|
||||
path = await _init_db_with_tmp()
|
||||
try:
|
||||
agg = SignalAggregator(interval=9999)
|
||||
test_uuid = "abcdef1234567890"
|
||||
await agg.add_signal(
|
||||
user_uuid=test_uuid,
|
||||
user_name=None,
|
||||
timestamp=1742478000000,
|
||||
geo=None,
|
||||
signal_id=1,
|
||||
)
|
||||
|
||||
sent_texts: list[str] = []
|
||||
|
||||
async def _fake_send(text: str) -> None:
|
||||
sent_texts.append(text)
|
||||
|
||||
with patch("backend.telegram.send_message", side_effect=_fake_send):
|
||||
with patch("backend.telegram.db.save_telegram_batch", new_callable=AsyncMock):
|
||||
with patch("backend.telegram.asyncio.sleep", new_callable=AsyncMock):
|
||||
await agg.flush()
|
||||
|
||||
assert len(sent_texts) == 1
|
||||
assert test_uuid[:8] in sent_texts[0]
|
||||
finally:
|
||||
_cleanup(path)
|
||||
115
tests/test_webhook.py
Normal file
115
tests/test_webhook.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Tests for POST /api/webhook/telegram.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("BOT_TOKEN", "test-bot-token")
|
||||
os.environ.setdefault("CHAT_ID", "-1001234567890")
|
||||
os.environ.setdefault("WEBHOOK_SECRET", "test-webhook-secret")
|
||||
os.environ.setdefault("WEBHOOK_URL", "https://example.com/api/webhook/telegram")
|
||||
os.environ.setdefault("FRONTEND_ORIGIN", "http://localhost:3000")
|
||||
|
||||
import pytest
|
||||
from tests.conftest import make_app_client
|
||||
|
||||
CORRECT_SECRET = "test-webhook-secret"
|
||||
|
||||
_SAMPLE_UPDATE = {
|
||||
"update_id": 100,
|
||||
"message": {
|
||||
"message_id": 1,
|
||||
"from": {"id": 12345678, "first_name": "Test", "last_name": "User"},
|
||||
"chat": {"id": 12345678, "type": "private"},
|
||||
"text": "/start",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_valid_secret_returns_200():
|
||||
"""Correct X-Telegram-Bot-Api-Secret-Token → 200."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/webhook/telegram",
|
||||
json=_SAMPLE_UPDATE,
|
||||
headers={"X-Telegram-Bot-Api-Secret-Token": CORRECT_SECRET},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_missing_secret_returns_403():
|
||||
"""Request without the secret header must return 403."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/webhook/telegram",
|
||||
json=_SAMPLE_UPDATE,
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_wrong_secret_returns_403():
|
||||
"""Request with a wrong secret header must return 403."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/webhook/telegram",
|
||||
json=_SAMPLE_UPDATE,
|
||||
headers={"X-Telegram-Bot-Api-Secret-Token": "wrong-secret"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_start_command_registers_user():
|
||||
"""A /start command in the update should not raise and must return 200."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/webhook/telegram",
|
||||
json={
|
||||
"update_id": 101,
|
||||
"message": {
|
||||
"message_id": 2,
|
||||
"from": {"id": 99887766, "first_name": "Frank", "last_name": ""},
|
||||
"chat": {"id": 99887766, "type": "private"},
|
||||
"text": "/start",
|
||||
},
|
||||
},
|
||||
headers={"X-Telegram-Bot-Api-Secret-Token": CORRECT_SECRET},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_non_start_command_returns_200():
|
||||
"""Any update without /start should still return 200."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/webhook/telegram",
|
||||
json={
|
||||
"update_id": 102,
|
||||
"message": {
|
||||
"message_id": 3,
|
||||
"from": {"id": 11111111, "first_name": "Anon"},
|
||||
"chat": {"id": 11111111, "type": "private"},
|
||||
"text": "hello",
|
||||
},
|
||||
},
|
||||
headers={"X-Telegram-Bot-Api-Secret-Token": CORRECT_SECRET},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_empty_body_with_valid_secret_returns_200():
|
||||
"""An update with no message field should still return 200."""
|
||||
async with make_app_client() as client:
|
||||
resp = await client.post(
|
||||
"/api/webhook/telegram",
|
||||
json={"update_id": 103},
|
||||
headers={"X-Telegram-Bot-Api-Secret-Token": CORRECT_SECRET},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
Loading…
Add table
Add a link
Reference in a new issue