diff --git a/backend/db.py b/backend/db.py index 4a5ea4b..e52a95f 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,22 +1,27 @@ from __future__ import annotations -from typing import Optional +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Optional import aiosqlite from backend import config -async def _get_conn() -> aiosqlite.Connection: +@asynccontextmanager +async def _get_conn() -> AsyncGenerator[aiosqlite.Connection, None]: conn = await aiosqlite.connect(config.DB_PATH) await conn.execute("PRAGMA journal_mode=WAL") await conn.execute("PRAGMA busy_timeout=5000") await conn.execute("PRAGMA synchronous=NORMAL") conn.row_factory = aiosqlite.Row - return conn + try: + yield conn + finally: + await conn.close() async def init_db() -> None: - async with await _get_conn() as conn: + async with _get_conn() as conn: await conn.executescript(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -57,7 +62,7 @@ async def init_db() -> None: async def register_user(uuid: str, name: str) -> dict: - async with await _get_conn() as conn: + async with _get_conn() as conn: await conn.execute( "INSERT OR IGNORE INTO users (uuid, name) VALUES (?, ?)", (uuid, name), @@ -77,7 +82,7 @@ async def save_signal( lon: Optional[float], accuracy: Optional[float], ) -> int: - async with await _get_conn() as conn: + async with _get_conn() as conn: async with conn.execute( """ INSERT INTO signals (user_uuid, timestamp, lat, lon, accuracy) @@ -91,7 +96,7 @@ async def save_signal( async def get_user_name(uuid: str) -> Optional[str]: - async with await _get_conn() as conn: + async with _get_conn() as conn: async with conn.execute( "SELECT name FROM users WHERE uuid = ?", (uuid,) ) as cur: @@ -104,7 +109,7 @@ async def save_telegram_batch( signals_count: int, signal_ids: list[int], ) -> int: - async with await _get_conn() as conn: + async with _get_conn() as conn: async with conn.execute( """ INSERT INTO telegram_batches (message_text, sent_at, signals_count, status) diff --git a/tests/test_db.py b/tests/test_db.py index 6a9aabd..e823fc4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -109,10 +109,9 @@ async def test_init_db_synchronous(): await db.init_db() # Check synchronous on a new connection via _get_conn() from backend.db import _get_conn - conn = await _get_conn() - async with conn.execute("PRAGMA synchronous") as cur: - row = await cur.fetchone() - await conn.close() + async with _get_conn() as conn: + async with conn.execute("PRAGMA synchronous") as cur: + row = await cur.fetchone() # 1 == NORMAL assert row[0] == 1 finally: