130 lines
5.1 KiB
Python
130 lines
5.1 KiB
Python
from .common import assert_subpackage; assert_subpackage(__name__)
|
|
|
|
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
import shutil
|
|
import sqlite3
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Tuple, Any, Iterator, Callable, Optional, Union
|
|
|
|
|
|
from .common import PathIsh, assert_never
|
|
from .compat import Literal
|
|
|
|
|
|
def sqlite_connect_immutable(db: PathIsh) -> sqlite3.Connection:
|
|
return sqlite3.connect(f'file:{db}?immutable=1', uri=True)
|
|
|
|
|
|
def test_sqlite_connect_immutable(tmp_path: Path) -> None:
|
|
db = str(tmp_path / 'db.sqlite')
|
|
with sqlite3.connect(db) as conn:
|
|
conn.execute('CREATE TABLE testtable (col)')
|
|
|
|
import pytest # type: ignore
|
|
with pytest.raises(sqlite3.OperationalError, match='readonly database'):
|
|
with sqlite_connect_immutable(db) as conn:
|
|
conn.execute('DROP TABLE testtable')
|
|
|
|
# succeeds without immutable
|
|
with sqlite3.connect(db) as conn:
|
|
conn.execute('DROP TABLE testtable')
|
|
|
|
|
|
SqliteRowFactory = Callable[[sqlite3.Cursor, sqlite3.Row], Any]
|
|
|
|
def dict_factory(cursor, row):
|
|
fields = [column[0] for column in cursor.description]
|
|
return {key: value for key, value in zip(fields, row)}
|
|
|
|
|
|
Factory = Union[SqliteRowFactory, Literal['row', 'dict']]
|
|
|
|
@contextmanager
|
|
def sqlite_connection(db: PathIsh, *, immutable: bool=False, row_factory: Optional[Factory]=None) -> Iterator[sqlite3.Connection]:
|
|
dbp = f'file:{db}'
|
|
# https://www.sqlite.org/draft/uri.html#uriimmutable
|
|
if immutable:
|
|
# assert results in nicer error than sqlite3.OperationalError
|
|
assert Path(db).exists(), db
|
|
dbp = f'{dbp}?immutable=1'
|
|
row_factory_: Any = None
|
|
if row_factory is not None:
|
|
if callable(row_factory):
|
|
row_factory_ = row_factory
|
|
elif row_factory == 'row':
|
|
row_factory_ = sqlite3.Row
|
|
elif row_factory == 'dict':
|
|
row_factory_ = dict_factory
|
|
else:
|
|
assert_never()
|
|
|
|
conn = sqlite3.connect(dbp, uri=True)
|
|
try:
|
|
conn.row_factory = row_factory_
|
|
with conn:
|
|
yield conn
|
|
finally:
|
|
# Connection context manager isn't actually closing the connection, only keeps transaction
|
|
conn.close()
|
|
|
|
|
|
# TODO come up with a better name?
|
|
# NOTE: this is tested by tests/sqlite.py::test_sqlite_read_with_wal
|
|
def sqlite_copy_and_open(db: PathIsh) -> sqlite3.Connection:
|
|
"""
|
|
'Snapshots' database and opens by making a deep copy of it including journal/WAL files
|
|
"""
|
|
dp = Path(db)
|
|
# TODO make atomic/check mtimes or something
|
|
dest = sqlite3.connect(':memory:')
|
|
with TemporaryDirectory() as td:
|
|
tdir = Path(td)
|
|
# shm should be recreated from scratch -- safer not to copy perhaps
|
|
tocopy = [dp] + [p for p in dp.parent.glob(dp.name + '-*') if not p.name.endswith('-shm')]
|
|
for p in tocopy:
|
|
shutil.copy(p, tdir / p.name)
|
|
with sqlite3.connect(str(tdir / dp.name)) as conn:
|
|
from .compat import sqlite_backup
|
|
sqlite_backup(source=conn, dest=dest)
|
|
conn.close()
|
|
return dest
|
|
|
|
|
|
# NOTE hmm, so this kinda works
|
|
# V = TypeVar('V', bound=Tuple[Any, ...])
|
|
# def select(cols: V, rest: str, *, db: sqlite3.Connetion) -> Iterator[V]:
|
|
# but sadly when we pass columns (Tuple[str, ...]), it seems to bind this type to V?
|
|
# and then the return type ends up as Iterator[Tuple[str, ...]], which isn't desirable :(
|
|
# a bit annoying to have this copy-pasting, but hopefully not a big issue
|
|
|
|
from typing import overload
|
|
@overload
|
|
def select(cols: Tuple[str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any, Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any, Any, Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any, Any, Any, Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str, str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any, Any, Any, Any, Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str, str, str, str, str, str ], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any, Any, Any, Any, Any, Any ]]: ...
|
|
@overload
|
|
def select(cols: Tuple[str, str, str, str, str, str, str, str], rest: str, *, db: sqlite3.Connection) -> \
|
|
Iterator[Tuple[Any, Any, Any, Any, Any, Any, Any, Any]]: ...
|
|
|
|
def select(cols, rest, *, db):
|
|
# db arg is last cause that results in nicer code formatting..
|
|
return db.execute('SELECT ' + ','.join(cols) + ' ' + rest)
|