core.sqlite: add helper SqliteTool to get table schemas
This commit is contained in:
parent
f1d23c5e96
commit
54df429f61
2 changed files with 45 additions and 2 deletions
|
@ -134,3 +134,46 @@ def select(cols: tuple[str, str, str, str, str, str, str, str], rest: str, *, db
|
|||
def select(cols, rest, *, db):
|
||||
# db arg is last cause that results in nicer code formatting..
|
||||
return db.execute('SELECT ' + ','.join(cols) + ' ' + rest)
|
||||
|
||||
|
||||
class SqliteTool:
|
||||
def __init__(self, connection: sqlite3.Connection) -> None:
|
||||
self.connection = connection
|
||||
|
||||
def _get_sqlite_master(self) -> dict[str, str]:
|
||||
res = {}
|
||||
for c in self.connection.execute('SELECT name, type FROM sqlite_master'):
|
||||
[name, type_] = c
|
||||
assert type_ in {'table', 'index', 'view', 'trigger'}, (name, type_) # just in case
|
||||
res[name] = type_
|
||||
return res
|
||||
|
||||
def get_table_names(self) -> list[str]:
|
||||
master = self._get_sqlite_master()
|
||||
res = []
|
||||
for name, type_ in master.items():
|
||||
if type_ != 'table':
|
||||
continue
|
||||
res.append(name)
|
||||
return res
|
||||
|
||||
def get_table_schema(self, name: str) -> dict[str, str]:
|
||||
"""
|
||||
Returns map from column name to column type
|
||||
|
||||
NOTE: Sometimes this doesn't work if the db has some extensions (e.g. happens for facebook apps)
|
||||
In this case you might still be able to use get_table_names
|
||||
"""
|
||||
schema: dict[str, str] = {}
|
||||
for row in self.connection.execute(f'PRAGMA table_info(`{name}`)'):
|
||||
col = row[1]
|
||||
type_ = row[2]
|
||||
# hmm, somewhere between 3.34.1 and 3.37.2, sqlite started normalising type names to uppercase
|
||||
# let's do this just in case since python < 3.10 are using the old version
|
||||
# e.g. it could have returned 'blob' and that would confuse blob check (see _check_allowed_blobs)
|
||||
type_ = type_.upper()
|
||||
schema[col] = type_
|
||||
return schema
|
||||
|
||||
def get_table_schemas(self) -> dict[str, dict[str, str]]:
|
||||
return {name: self.get_table_schema(name) for name in self.get_table_names()}
|
||||
|
|
|
@ -15,7 +15,7 @@ from my.core import LazyLogger, Paths, Res, datetime_aware, get_files, make_conf
|
|||
from my.core.common import unique_everseen
|
||||
from my.core.compat import assert_never
|
||||
from my.core.error import echain
|
||||
from my.core.sqlite import sqlite_connection
|
||||
from my.core.sqlite import sqlite_connection, SqliteTool
|
||||
|
||||
from my.config import fbmessenger as user_config # isort: skip
|
||||
|
||||
|
@ -86,8 +86,8 @@ def _entities() -> Iterator[Res[Entity]]:
|
|||
for idx, path in enumerate(paths):
|
||||
logger.info(f'processing [{idx:>{width}}/{total:>{width}}] {path}')
|
||||
with sqlite_connection(path, immutable=True, row_factory='row') as db:
|
||||
use_msys = "logging_events_v2" in SqliteTool(db).get_table_names()
|
||||
try:
|
||||
use_msys = len(list(db.execute('SELECT * FROM sqlite_master WHERE name = "logging_events_v2"'))) > 0
|
||||
if use_msys:
|
||||
yield from _process_db_msys(db)
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue