From 54df429f614a5e5d0617dcd196bf8566608e987c Mon Sep 17 00:00:00 2001 From: Dima Gerasimov Date: Sun, 29 Dec 2024 15:06:49 +0000 Subject: [PATCH] core.sqlite: add helper SqliteTool to get table schemas --- my/core/sqlite.py | 43 +++++++++++++++++++++++++++++++++++++++ my/fbmessenger/android.py | 4 ++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/my/core/sqlite.py b/my/core/sqlite.py index aa41ab3..6167d2e 100644 --- a/my/core/sqlite.py +++ b/my/core/sqlite.py @@ -134,3 +134,46 @@ def select(cols: tuple[str, str, str, str, str, str, str, str], rest: str, *, db def select(cols, rest, *, db): # db arg is last cause that results in nicer code formatting.. return db.execute('SELECT ' + ','.join(cols) + ' ' + rest) + + +class SqliteTool: + def __init__(self, connection: sqlite3.Connection) -> None: + self.connection = connection + + def _get_sqlite_master(self) -> dict[str, str]: + res = {} + for c in self.connection.execute('SELECT name, type FROM sqlite_master'): + [name, type_] = c + assert type_ in {'table', 'index', 'view', 'trigger'}, (name, type_) # just in case + res[name] = type_ + return res + + def get_table_names(self) -> list[str]: + master = self._get_sqlite_master() + res = [] + for name, type_ in master.items(): + if type_ != 'table': + continue + res.append(name) + return res + + def get_table_schema(self, name: str) -> dict[str, str]: + """ + Returns map from column name to column type + + NOTE: Sometimes this doesn't work if the db has some extensions (e.g. happens for facebook apps) + In this case you might still be able to use get_table_names + """ + schema: dict[str, str] = {} + for row in self.connection.execute(f'PRAGMA table_info(`{name}`)'): + col = row[1] + type_ = row[2] + # hmm, somewhere between 3.34.1 and 3.37.2, sqlite started normalising type names to uppercase + # let's do this just in case since python < 3.10 are using the old version + # e.g. it could have returned 'blob' and that would confuse blob check (see _check_allowed_blobs) + type_ = type_.upper() + schema[col] = type_ + return schema + + def get_table_schemas(self) -> dict[str, dict[str, str]]: + return {name: self.get_table_schema(name) for name in self.get_table_names()} diff --git a/my/fbmessenger/android.py b/my/fbmessenger/android.py index db4cc54..f6fdb82 100644 --- a/my/fbmessenger/android.py +++ b/my/fbmessenger/android.py @@ -15,7 +15,7 @@ from my.core import LazyLogger, Paths, Res, datetime_aware, get_files, make_conf from my.core.common import unique_everseen from my.core.compat import assert_never from my.core.error import echain -from my.core.sqlite import sqlite_connection +from my.core.sqlite import sqlite_connection, SqliteTool from my.config import fbmessenger as user_config # isort: skip @@ -86,8 +86,8 @@ def _entities() -> Iterator[Res[Entity]]: for idx, path in enumerate(paths): logger.info(f'processing [{idx:>{width}}/{total:>{width}}] {path}') with sqlite_connection(path, immutable=True, row_factory='row') as db: + use_msys = "logging_events_v2" in SqliteTool(db).get_table_names() try: - use_msys = len(list(db.execute('SELECT * FROM sqlite_master WHERE name = "logging_events_v2"'))) > 0 if use_msys: yield from _process_db_msys(db) else: