core.sqlite: add helper SqliteTool to get table schemas

This commit is contained in:
Dima Gerasimov 2024-12-29 15:06:49 +00:00 committed by karlicoss
parent f1d23c5e96
commit 54df429f61
2 changed files with 45 additions and 2 deletions

View file

@ -134,3 +134,46 @@ def select(cols: tuple[str, str, str, str, str, str, str, str], rest: str, *, db
def select(cols, rest, *, db): def select(cols, rest, *, db):
# db arg is last cause that results in nicer code formatting.. # db arg is last cause that results in nicer code formatting..
return db.execute('SELECT ' + ','.join(cols) + ' ' + rest) return db.execute('SELECT ' + ','.join(cols) + ' ' + rest)
class SqliteTool:
def __init__(self, connection: sqlite3.Connection) -> None:
self.connection = connection
def _get_sqlite_master(self) -> dict[str, str]:
res = {}
for c in self.connection.execute('SELECT name, type FROM sqlite_master'):
[name, type_] = c
assert type_ in {'table', 'index', 'view', 'trigger'}, (name, type_) # just in case
res[name] = type_
return res
def get_table_names(self) -> list[str]:
master = self._get_sqlite_master()
res = []
for name, type_ in master.items():
if type_ != 'table':
continue
res.append(name)
return res
def get_table_schema(self, name: str) -> dict[str, str]:
"""
Returns map from column name to column type
NOTE: Sometimes this doesn't work if the db has some extensions (e.g. happens for facebook apps)
In this case you might still be able to use get_table_names
"""
schema: dict[str, str] = {}
for row in self.connection.execute(f'PRAGMA table_info(`{name}`)'):
col = row[1]
type_ = row[2]
# hmm, somewhere between 3.34.1 and 3.37.2, sqlite started normalising type names to uppercase
# let's do this just in case since python < 3.10 are using the old version
# e.g. it could have returned 'blob' and that would confuse blob check (see _check_allowed_blobs)
type_ = type_.upper()
schema[col] = type_
return schema
def get_table_schemas(self) -> dict[str, dict[str, str]]:
return {name: self.get_table_schema(name) for name in self.get_table_names()}

View file

@ -15,7 +15,7 @@ from my.core import LazyLogger, Paths, Res, datetime_aware, get_files, make_conf
from my.core.common import unique_everseen from my.core.common import unique_everseen
from my.core.compat import assert_never from my.core.compat import assert_never
from my.core.error import echain from my.core.error import echain
from my.core.sqlite import sqlite_connection from my.core.sqlite import sqlite_connection, SqliteTool
from my.config import fbmessenger as user_config # isort: skip from my.config import fbmessenger as user_config # isort: skip
@ -86,8 +86,8 @@ def _entities() -> Iterator[Res[Entity]]:
for idx, path in enumerate(paths): for idx, path in enumerate(paths):
logger.info(f'processing [{idx:>{width}}/{total:>{width}}] {path}') logger.info(f'processing [{idx:>{width}}/{total:>{width}}] {path}')
with sqlite_connection(path, immutable=True, row_factory='row') as db: with sqlite_connection(path, immutable=True, row_factory='row') as db:
use_msys = "logging_events_v2" in SqliteTool(db).get_table_names()
try: try:
use_msys = len(list(db.execute('SELECT * FROM sqlite_master WHERE name = "logging_events_v2"'))) > 0
if use_msys: if use_msys:
yield from _process_db_msys(db) yield from _process_db_msys(db)
else: else: