fix: исправить RuntimeError в aiosqlite — _get_conn как async context manager
`async with await _get_conn()` запускал тред дважды: первый раз внутри `_get_conn` через `await aiosqlite.connect()`, второй раз в `__aenter__` через `await self`. Преобразован в `@asynccontextmanager` с `yield` и `finally: conn.close()`. Все вызывающие места обновлены. Тест `test_init_db_synchronous` обновлён под новый API. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ebb6e404e5
commit
284529dabe
2 changed files with 16 additions and 12 deletions
|
|
@ -1,22 +1,27 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator, Optional
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from backend import config
|
from backend import config
|
||||||
|
|
||||||
|
|
||||||
async def _get_conn() -> aiosqlite.Connection:
|
@asynccontextmanager
|
||||||
|
async def _get_conn() -> AsyncGenerator[aiosqlite.Connection, None]:
|
||||||
conn = await aiosqlite.connect(config.DB_PATH)
|
conn = await aiosqlite.connect(config.DB_PATH)
|
||||||
await conn.execute("PRAGMA journal_mode=WAL")
|
await conn.execute("PRAGMA journal_mode=WAL")
|
||||||
await conn.execute("PRAGMA busy_timeout=5000")
|
await conn.execute("PRAGMA busy_timeout=5000")
|
||||||
await conn.execute("PRAGMA synchronous=NORMAL")
|
await conn.execute("PRAGMA synchronous=NORMAL")
|
||||||
conn.row_factory = aiosqlite.Row
|
conn.row_factory = aiosqlite.Row
|
||||||
return conn
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
async def init_db() -> None:
|
async def init_db() -> None:
|
||||||
async with await _get_conn() as conn:
|
async with _get_conn() as conn:
|
||||||
await conn.executescript("""
|
await conn.executescript("""
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
|
@ -57,7 +62,7 @@ async def init_db() -> None:
|
||||||
|
|
||||||
|
|
||||||
async def register_user(uuid: str, name: str) -> dict:
|
async def register_user(uuid: str, name: str) -> dict:
|
||||||
async with await _get_conn() as conn:
|
async with _get_conn() as conn:
|
||||||
await conn.execute(
|
await conn.execute(
|
||||||
"INSERT OR IGNORE INTO users (uuid, name) VALUES (?, ?)",
|
"INSERT OR IGNORE INTO users (uuid, name) VALUES (?, ?)",
|
||||||
(uuid, name),
|
(uuid, name),
|
||||||
|
|
@ -77,7 +82,7 @@ async def save_signal(
|
||||||
lon: Optional[float],
|
lon: Optional[float],
|
||||||
accuracy: Optional[float],
|
accuracy: Optional[float],
|
||||||
) -> int:
|
) -> int:
|
||||||
async with await _get_conn() as conn:
|
async with _get_conn() as conn:
|
||||||
async with conn.execute(
|
async with conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO signals (user_uuid, timestamp, lat, lon, accuracy)
|
INSERT INTO signals (user_uuid, timestamp, lat, lon, accuracy)
|
||||||
|
|
@ -91,7 +96,7 @@ async def save_signal(
|
||||||
|
|
||||||
|
|
||||||
async def get_user_name(uuid: str) -> Optional[str]:
|
async def get_user_name(uuid: str) -> Optional[str]:
|
||||||
async with await _get_conn() as conn:
|
async with _get_conn() as conn:
|
||||||
async with conn.execute(
|
async with conn.execute(
|
||||||
"SELECT name FROM users WHERE uuid = ?", (uuid,)
|
"SELECT name FROM users WHERE uuid = ?", (uuid,)
|
||||||
) as cur:
|
) as cur:
|
||||||
|
|
@ -104,7 +109,7 @@ async def save_telegram_batch(
|
||||||
signals_count: int,
|
signals_count: int,
|
||||||
signal_ids: list[int],
|
signal_ids: list[int],
|
||||||
) -> int:
|
) -> int:
|
||||||
async with await _get_conn() as conn:
|
async with _get_conn() as conn:
|
||||||
async with conn.execute(
|
async with conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO telegram_batches (message_text, sent_at, signals_count, status)
|
INSERT INTO telegram_batches (message_text, sent_at, signals_count, status)
|
||||||
|
|
|
||||||
|
|
@ -109,10 +109,9 @@ async def test_init_db_synchronous():
|
||||||
await db.init_db()
|
await db.init_db()
|
||||||
# Check synchronous on a new connection via _get_conn()
|
# Check synchronous on a new connection via _get_conn()
|
||||||
from backend.db import _get_conn
|
from backend.db import _get_conn
|
||||||
conn = await _get_conn()
|
async with _get_conn() as conn:
|
||||||
async with conn.execute("PRAGMA synchronous") as cur:
|
async with conn.execute("PRAGMA synchronous") as cur:
|
||||||
row = await cur.fetchone()
|
row = await cur.fetchone()
|
||||||
await conn.close()
|
|
||||||
# 1 == NORMAL
|
# 1 == NORMAL
|
||||||
assert row[0] == 1
|
assert row[0] == 1
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue