from __future__ import annotations from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional import aiosqlite from backend import config @asynccontextmanager async def _get_conn() -> AsyncGenerator[aiosqlite.Connection, None]: conn = await aiosqlite.connect(config.DB_PATH) await conn.execute("PRAGMA journal_mode=WAL") await conn.execute("PRAGMA busy_timeout=5000") await conn.execute("PRAGMA synchronous=NORMAL") conn.row_factory = aiosqlite.Row try: yield conn finally: await conn.close() async def init_db() -> None: async with _get_conn() as conn: await conn.executescript(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, uuid TEXT UNIQUE NOT NULL, name TEXT NOT NULL, is_blocked INTEGER NOT NULL DEFAULT 0, password_hash TEXT DEFAULT NULL, api_key_hash TEXT DEFAULT NULL, created_at TEXT DEFAULT (datetime('now')) ); CREATE TABLE IF NOT EXISTS signals ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_uuid TEXT NOT NULL REFERENCES users(uuid), timestamp INTEGER NOT NULL, lat REAL DEFAULT NULL, lon REAL DEFAULT NULL, accuracy REAL DEFAULT NULL, created_at TEXT DEFAULT (datetime('now')), telegram_batch_id INTEGER DEFAULT NULL ); CREATE TABLE IF NOT EXISTS telegram_batches ( id INTEGER PRIMARY KEY AUTOINCREMENT, message_text TEXT DEFAULT NULL, sent_at TEXT DEFAULT NULL, signals_count INTEGER DEFAULT 0, status TEXT DEFAULT 'pending' ); CREATE UNIQUE INDEX IF NOT EXISTS idx_users_uuid ON users(uuid); CREATE INDEX IF NOT EXISTS idx_signals_user_uuid ON signals(user_uuid); CREATE INDEX IF NOT EXISTS idx_signals_created_at ON signals(created_at); CREATE INDEX IF NOT EXISTS idx_batches_status ON telegram_batches(status); """) # Migrations for existing databases (silently ignore if columns already exist) for stmt in [ "ALTER TABLE users ADD COLUMN is_blocked INTEGER NOT NULL DEFAULT 0", "ALTER TABLE users ADD COLUMN password_hash TEXT DEFAULT NULL", "ALTER TABLE users ADD COLUMN api_key_hash TEXT DEFAULT NULL", ]: try: await conn.execute(stmt) await conn.commit() except Exception: pass # Column already exists await conn.commit() async def register_user(uuid: str, name: str, api_key_hash: Optional[str] = None) -> dict: async with _get_conn() as conn: if api_key_hash is not None: await conn.execute( """ INSERT INTO users (uuid, name, api_key_hash) VALUES (?, ?, ?) ON CONFLICT(uuid) DO UPDATE SET api_key_hash = excluded.api_key_hash """, (uuid, name, api_key_hash), ) else: await conn.execute( "INSERT OR IGNORE INTO users (uuid, name) VALUES (?, ?)", (uuid, name), ) await conn.commit() async with conn.execute( "SELECT id, uuid FROM users WHERE uuid = ?", (uuid,) ) as cur: row = await cur.fetchone() return {"user_id": row["id"], "uuid": row["uuid"]} async def get_api_key_hash_by_uuid(uuid: str) -> Optional[str]: async with _get_conn() as conn: async with conn.execute( "SELECT api_key_hash FROM users WHERE uuid = ?", (uuid,) ) as cur: row = await cur.fetchone() return row["api_key_hash"] if row else None async def save_signal( user_uuid: str, timestamp: int, lat: Optional[float], lon: Optional[float], accuracy: Optional[float], ) -> int: async with _get_conn() as conn: async with conn.execute( """ INSERT INTO signals (user_uuid, timestamp, lat, lon, accuracy) VALUES (?, ?, ?, ?, ?) """, (user_uuid, timestamp, lat, lon, accuracy), ) as cur: signal_id = cur.lastrowid await conn.commit() return signal_id async def get_user_name(uuid: str) -> Optional[str]: async with _get_conn() as conn: async with conn.execute( "SELECT name FROM users WHERE uuid = ?", (uuid,) ) as cur: row = await cur.fetchone() return row["name"] if row else None async def is_user_blocked(uuid: str) -> bool: async with _get_conn() as conn: async with conn.execute( "SELECT is_blocked FROM users WHERE uuid = ?", (uuid,) ) as cur: row = await cur.fetchone() return bool(row["is_blocked"]) if row else False async def admin_list_users() -> list[dict]: async with _get_conn() as conn: async with conn.execute( "SELECT id, uuid, name, is_blocked, created_at FROM users ORDER BY id" ) as cur: rows = await cur.fetchall() return [ { "id": row["id"], "uuid": row["uuid"], "name": row["name"], "is_blocked": bool(row["is_blocked"]), "created_at": row["created_at"], } for row in rows ] async def admin_get_user_by_id(user_id: int) -> Optional[dict]: async with _get_conn() as conn: async with conn.execute( "SELECT id, uuid, name, is_blocked, created_at FROM users WHERE id = ?", (user_id,), ) as cur: row = await cur.fetchone() if row is None: return None return { "id": row["id"], "uuid": row["uuid"], "name": row["name"], "is_blocked": bool(row["is_blocked"]), "created_at": row["created_at"], } async def admin_create_user( uuid: str, name: str, password_hash: Optional[str] = None ) -> Optional[dict]: """Returns None if UUID already exists.""" async with _get_conn() as conn: try: async with conn.execute( "INSERT INTO users (uuid, name, password_hash) VALUES (?, ?, ?)", (uuid, name, password_hash), ) as cur: new_id = cur.lastrowid except Exception: return None # UNIQUE constraint violation — UUID already exists await conn.commit() async with conn.execute( "SELECT id, uuid, name, is_blocked, created_at FROM users WHERE id = ?", (new_id,), ) as cur: row = await cur.fetchone() return { "id": row["id"], "uuid": row["uuid"], "name": row["name"], "is_blocked": bool(row["is_blocked"]), "created_at": row["created_at"], } async def admin_set_password(user_id: int, password_hash: str) -> bool: async with _get_conn() as conn: async with conn.execute( "UPDATE users SET password_hash = ? WHERE id = ?", (password_hash, user_id), ) as cur: changed = cur.rowcount > 0 await conn.commit() return changed async def admin_set_blocked(user_id: int, is_blocked: bool) -> bool: async with _get_conn() as conn: async with conn.execute( "UPDATE users SET is_blocked = ? WHERE id = ?", (1 if is_blocked else 0, user_id), ) as cur: changed = cur.rowcount > 0 await conn.commit() return changed async def admin_delete_user(user_id: int) -> bool: async with _get_conn() as conn: # Delete signals first (no FK cascade in SQLite by default) async with conn.execute( "DELETE FROM signals WHERE user_uuid = (SELECT uuid FROM users WHERE id = ?)", (user_id,), ): pass async with conn.execute( "DELETE FROM users WHERE id = ?", (user_id,), ) as cur: changed = cur.rowcount > 0 await conn.commit() return changed async def save_telegram_batch( message_text: str, signals_count: int, signal_ids: list[int], ) -> int: async with _get_conn() as conn: async with conn.execute( """ INSERT INTO telegram_batches (message_text, sent_at, signals_count, status) VALUES (?, datetime('now'), ?, 'sent') """, (message_text, signals_count), ) as cur: batch_id = cur.lastrowid if signal_ids: placeholders = ",".join("?" * len(signal_ids)) await conn.execute( f"UPDATE signals SET telegram_batch_id = ? WHERE id IN ({placeholders})", [batch_id, *signal_ids], ) await conn.commit() return batch_id