baton/backend/db.py

376 lines
13 KiB
Python
Raw Permalink Normal View History

2026-03-20 20:44:00 +02:00
from __future__ import annotations
2026-03-21 07:56:44 +02:00
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
2026-03-20 20:44:00 +02:00
import aiosqlite
from backend import config
@asynccontextmanager
async def _get_conn() -> AsyncGenerator[aiosqlite.Connection, None]:
2026-03-20 20:44:00 +02:00
conn = await aiosqlite.connect(config.DB_PATH)
await conn.execute("PRAGMA journal_mode=WAL")
await conn.execute("PRAGMA busy_timeout=5000")
await conn.execute("PRAGMA synchronous=NORMAL")
conn.row_factory = aiosqlite.Row
try:
yield conn
finally:
await conn.close()
2026-03-20 20:44:00 +02:00
async def init_db() -> None:
async with _get_conn() as conn:
2026-03-20 20:44:00 +02:00
await conn.executescript("""
CREATE TABLE IF NOT EXISTS users (
2026-03-20 23:39:28 +02:00
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
is_blocked INTEGER NOT NULL DEFAULT 0,
password_hash TEXT DEFAULT NULL,
2026-03-21 08:12:01 +02:00
api_key_hash TEXT DEFAULT NULL,
2026-03-20 23:39:28 +02:00
created_at TEXT DEFAULT (datetime('now'))
2026-03-20 20:44:00 +02:00
);
CREATE TABLE IF NOT EXISTS signals (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_uuid TEXT NOT NULL REFERENCES users(uuid),
timestamp INTEGER NOT NULL,
lat REAL DEFAULT NULL,
lon REAL DEFAULT NULL,
accuracy REAL DEFAULT NULL,
created_at TEXT DEFAULT (datetime('now')),
telegram_batch_id INTEGER DEFAULT NULL
);
CREATE TABLE IF NOT EXISTS telegram_batches (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_text TEXT DEFAULT NULL,
sent_at TEXT DEFAULT NULL,
signals_count INTEGER DEFAULT 0,
status TEXT DEFAULT 'pending'
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_uuid
ON users(uuid);
CREATE INDEX IF NOT EXISTS idx_signals_user_uuid
ON signals(user_uuid);
CREATE INDEX IF NOT EXISTS idx_signals_created_at
ON signals(created_at);
CREATE INDEX IF NOT EXISTS idx_batches_status
ON telegram_batches(status);
2026-03-21 07:56:44 +02:00
CREATE TABLE IF NOT EXISTS rate_limits (
ip TEXT NOT NULL PRIMARY KEY,
count INTEGER NOT NULL DEFAULT 0,
window_start REAL NOT NULL DEFAULT 0
);
2026-03-21 09:19:50 +02:00
CREATE TABLE IF NOT EXISTS registrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL,
login TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
push_subscription TEXT DEFAULT NULL,
created_at TEXT DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_registrations_status
ON registrations(status);
CREATE INDEX IF NOT EXISTS idx_registrations_email
ON registrations(email);
CREATE INDEX IF NOT EXISTS idx_registrations_login
ON registrations(login);
2026-03-20 20:44:00 +02:00
""")
2026-03-20 23:39:28 +02:00
# Migrations for existing databases (silently ignore if columns already exist)
for stmt in [
"ALTER TABLE users ADD COLUMN is_blocked INTEGER NOT NULL DEFAULT 0",
"ALTER TABLE users ADD COLUMN password_hash TEXT DEFAULT NULL",
2026-03-21 08:12:01 +02:00
"ALTER TABLE users ADD COLUMN api_key_hash TEXT DEFAULT NULL",
2026-03-20 23:39:28 +02:00
]:
try:
await conn.execute(stmt)
await conn.commit()
except Exception:
pass # Column already exists
2026-03-20 20:44:00 +02:00
await conn.commit()
2026-03-21 08:12:01 +02:00
async def register_user(uuid: str, name: str, api_key_hash: Optional[str] = None) -> dict:
async with _get_conn() as conn:
2026-03-21 08:12:01 +02:00
if api_key_hash is not None:
await conn.execute(
"""
INSERT INTO users (uuid, name, api_key_hash) VALUES (?, ?, ?)
ON CONFLICT(uuid) DO UPDATE SET api_key_hash = excluded.api_key_hash
""",
(uuid, name, api_key_hash),
)
else:
await conn.execute(
"INSERT OR IGNORE INTO users (uuid, name) VALUES (?, ?)",
(uuid, name),
)
2026-03-20 20:44:00 +02:00
await conn.commit()
async with conn.execute(
"SELECT id, uuid FROM users WHERE uuid = ?", (uuid,)
) as cur:
row = await cur.fetchone()
return {"user_id": row["id"], "uuid": row["uuid"]}
2026-03-21 08:12:01 +02:00
async def get_api_key_hash_by_uuid(uuid: str) -> Optional[str]:
async with _get_conn() as conn:
async with conn.execute(
"SELECT api_key_hash FROM users WHERE uuid = ?", (uuid,)
) as cur:
row = await cur.fetchone()
return row["api_key_hash"] if row else None
2026-03-20 20:44:00 +02:00
async def save_signal(
user_uuid: str,
timestamp: int,
lat: Optional[float],
lon: Optional[float],
accuracy: Optional[float],
) -> int:
async with _get_conn() as conn:
2026-03-20 20:44:00 +02:00
async with conn.execute(
"""
INSERT INTO signals (user_uuid, timestamp, lat, lon, accuracy)
VALUES (?, ?, ?, ?, ?)
""",
(user_uuid, timestamp, lat, lon, accuracy),
) as cur:
signal_id = cur.lastrowid
await conn.commit()
return signal_id
async def get_user_name(uuid: str) -> Optional[str]:
async with _get_conn() as conn:
2026-03-20 20:44:00 +02:00
async with conn.execute(
"SELECT name FROM users WHERE uuid = ?", (uuid,)
) as cur:
row = await cur.fetchone()
return row["name"] if row else None
2026-03-20 23:39:28 +02:00
async def is_user_blocked(uuid: str) -> bool:
async with _get_conn() as conn:
async with conn.execute(
"SELECT is_blocked FROM users WHERE uuid = ?", (uuid,)
) as cur:
row = await cur.fetchone()
return bool(row["is_blocked"]) if row else False
async def admin_list_users() -> list[dict]:
async with _get_conn() as conn:
async with conn.execute(
"SELECT id, uuid, name, is_blocked, created_at FROM users ORDER BY id"
) as cur:
rows = await cur.fetchall()
return [
{
"id": row["id"],
"uuid": row["uuid"],
"name": row["name"],
"is_blocked": bool(row["is_blocked"]),
"created_at": row["created_at"],
}
for row in rows
]
async def admin_get_user_by_id(user_id: int) -> Optional[dict]:
async with _get_conn() as conn:
async with conn.execute(
"SELECT id, uuid, name, is_blocked, created_at FROM users WHERE id = ?",
(user_id,),
) as cur:
row = await cur.fetchone()
if row is None:
return None
return {
"id": row["id"],
"uuid": row["uuid"],
"name": row["name"],
"is_blocked": bool(row["is_blocked"]),
"created_at": row["created_at"],
}
async def admin_create_user(
uuid: str, name: str, password_hash: Optional[str] = None
) -> Optional[dict]:
"""Returns None if UUID already exists."""
async with _get_conn() as conn:
try:
async with conn.execute(
"INSERT INTO users (uuid, name, password_hash) VALUES (?, ?, ?)",
(uuid, name, password_hash),
) as cur:
new_id = cur.lastrowid
except Exception:
return None # UNIQUE constraint violation — UUID already exists
await conn.commit()
async with conn.execute(
"SELECT id, uuid, name, is_blocked, created_at FROM users WHERE id = ?",
(new_id,),
) as cur:
row = await cur.fetchone()
return {
"id": row["id"],
"uuid": row["uuid"],
"name": row["name"],
"is_blocked": bool(row["is_blocked"]),
"created_at": row["created_at"],
}
async def admin_set_password(user_id: int, password_hash: str) -> bool:
async with _get_conn() as conn:
async with conn.execute(
"UPDATE users SET password_hash = ? WHERE id = ?",
(password_hash, user_id),
) as cur:
changed = cur.rowcount > 0
await conn.commit()
return changed
async def admin_set_blocked(user_id: int, is_blocked: bool) -> bool:
async with _get_conn() as conn:
async with conn.execute(
"UPDATE users SET is_blocked = ? WHERE id = ?",
(1 if is_blocked else 0, user_id),
) as cur:
changed = cur.rowcount > 0
await conn.commit()
return changed
async def admin_delete_user(user_id: int) -> bool:
async with _get_conn() as conn:
# Delete signals first (no FK cascade in SQLite by default)
async with conn.execute(
"DELETE FROM signals WHERE user_uuid = (SELECT uuid FROM users WHERE id = ?)",
(user_id,),
):
pass
async with conn.execute(
"DELETE FROM users WHERE id = ?",
(user_id,),
) as cur:
changed = cur.rowcount > 0
await conn.commit()
return changed
2026-03-21 07:56:44 +02:00
async def rate_limit_increment(key: str, window: float) -> int:
"""Increment rate-limit counter for key within window. Returns current count.
Cleans up the stale record for this key before incrementing (TTL by window_start).
"""
now = time.time()
async with _get_conn() as conn:
# TTL cleanup: remove stale record for this key if window has expired
await conn.execute(
"DELETE FROM rate_limits WHERE ip = ? AND ? - window_start >= ?",
(key, now, window),
)
# Upsert: insert new record or increment existing
await conn.execute(
"""
INSERT INTO rate_limits (ip, count, window_start)
VALUES (?, 1, ?)
ON CONFLICT(ip) DO UPDATE SET count = count + 1
""",
(key, now),
)
await conn.commit()
async with conn.execute(
"SELECT count FROM rate_limits WHERE ip = ?", (key,)
) as cur:
row = await cur.fetchone()
return row["count"] if row else 1
2026-03-21 09:19:50 +02:00
async def create_registration(
email: str,
login: str,
password_hash: str,
push_subscription: Optional[str] = None,
) -> int:
"""Insert a new registration. Raises aiosqlite.IntegrityError on email/login conflict."""
async with _get_conn() as conn:
async with conn.execute(
"""
INSERT INTO registrations (email, login, password_hash, push_subscription)
VALUES (?, ?, ?, ?)
""",
(email, login, password_hash, push_subscription),
) as cur:
reg_id = cur.lastrowid
await conn.commit()
return reg_id # type: ignore[return-value]
async def get_registration(reg_id: int) -> Optional[dict]:
async with _get_conn() as conn:
async with conn.execute(
"SELECT id, email, login, status, push_subscription, created_at FROM registrations WHERE id = ?",
(reg_id,),
) as cur:
row = await cur.fetchone()
if row is None:
return None
return {
"id": row["id"],
"email": row["email"],
"login": row["login"],
"status": row["status"],
"push_subscription": row["push_subscription"],
"created_at": row["created_at"],
}
async def update_registration_status(reg_id: int, status: str) -> bool:
async with _get_conn() as conn:
async with conn.execute(
"UPDATE registrations SET status = ? WHERE id = ?",
(status, reg_id),
) as cur:
changed = cur.rowcount > 0
await conn.commit()
return changed
2026-03-20 20:44:00 +02:00
async def save_telegram_batch(
message_text: str,
signals_count: int,
signal_ids: list[int],
) -> int:
async with _get_conn() as conn:
2026-03-20 20:44:00 +02:00
async with conn.execute(
"""
INSERT INTO telegram_batches (message_text, sent_at, signals_count, status)
VALUES (?, datetime('now'), ?, 'sent')
""",
(message_text, signals_count),
) as cur:
batch_id = cur.lastrowid
if signal_ids:
placeholders = ",".join("?" * len(signal_ids))
await conn.execute(
f"UPDATE signals SET telegram_batch_id = ? WHERE id IN ({placeholders})",
[batch_id, *signal_ids],
)
await conn.commit()
return batch_id