Do proper SQL escaping and warn about bad values.

This commit is contained in:
Chris Berkhout 2021-05-31 12:19:57 +02:00
parent 71ed878c2a
commit 229ea109ef
3 changed files with 68 additions and 22 deletions

View file

@ -41,23 +41,23 @@ def _cov_description(
start_uncovered = (a1 - r1).days
end_uncovered = (r2 - a2).days
def plural(n):
def s(n):
return "" if n == 1 else "s"
if start_uncovered == 0 and end_uncovered > 0:
return (
f"starting as requested and ending {end_uncovered} "
f"day{plural(end_uncovered)} earlier than requested"
f"day{s(end_uncovered)} earlier than requested"
)
elif start_uncovered > 0 and end_uncovered == 0:
return (
f"starting {start_uncovered} day{plural(start_uncovered)} later "
f"starting {start_uncovered} day{s(start_uncovered)} later "
"than requested and ending as requested"
)
elif start_uncovered > 0 and end_uncovered > 0:
return (
f"starting {start_uncovered} day{plural(start_uncovered)} later "
f"and ending {end_uncovered} day{plural(end_uncovered)} earlier "
f"starting {start_uncovered} day{s(start_uncovered)} later "
f"and ending {end_uncovered} day{s(end_uncovered)} earlier "
"than requested"
)
else:

View file

@ -1,4 +1,6 @@
import hashlib
import logging
import re
from datetime import datetime
from importlib.resources import read_text
@ -14,6 +16,16 @@ class GnuCashSQL(BaseOutput):
quote = fmt.quote or series.quote
src = f"pricehist:{source.id()}"
self._warn_about_backslashes(
{
"time": fmt.time,
"base": base,
"quote": quote,
"source": src,
"price type": series.type,
}
)
values_parts = []
for price in series.prices:
date = f"{fmt.format_date(price.date)} {fmt.time}"
@ -31,18 +43,23 @@ class GnuCashSQL(BaseOutput):
).encode("utf-8")
)
guid = m.hexdigest()[0:32]
value_num, value_denom = self._fractional(price.amount)
v = (
"("
f"'{guid}', "
f"'{date}', "
f"'{base}', "
f"'{quote}', "
f"'{src}', "
f"'{series.type}', "
f"{value_num}, "
f"{value_denom}"
")"
+ ", ".join(
[
self._sql_str(guid),
self._sql_str(date),
self._sql_str(base),
self._sql_str(quote),
self._sql_str(src),
self._sql_str(series.type),
str(value_num),
str(value_denom),
]
)
+ ")"
)
values_parts.append(v)
values = ",\n".join(values_parts)
@ -50,14 +67,43 @@ class GnuCashSQL(BaseOutput):
sql = read_text("pricehist.resources", "gnucash.sql").format(
version=__version__,
timestamp=datetime.utcnow().isoformat() + "Z",
base=base,
quote=quote,
base=self._sql_str(base),
quote=self._sql_str(quote),
values=values,
)
return sql
def _fractional(num):
num = str(num).replace(".", "")
denom = 10 ** len(f"{num}.".split(".")[1])
return (num, denom)
def _warn_about_backslashes(self, fields):
hits = [name for name, value in fields.items() if "\\" in value]
if hits:
logging.warn(
f"Before running this SQL, check the formatting of the "
f"{self._english_join(hits)} strings. "
f"SQLite treats backslahes in strings as plain characters, but "
f"MariaDB/MySQL and PostgreSQL may interpret them as escape "
f"codes."
)
def _english_join(self, strings):
if len(strings) == 0:
return ""
elif len(strings) == 1:
return str(strings[0])
else:
return f"{', '.join(strings[0:-1])} and {strings[-1]}"
def _sql_str(self, s):
# Documentation regarding SQL string literals
# - https://www.sqlite.org/lang_expr.html#literal_values_constants_
# - https://mariadb.com/kb/en/string-literals/
# - https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
# - https://www.postgresql.org/docs/devel/sql-syntax-lexical.html
escaped = re.sub("'", "''", s)
quoted = f"'{escaped}'"
return quoted
def _fractional(self, number):
numerator = str(number).replace(".", "")
denom = 10 ** len(f"{number}.".split(".")[1])
return (numerator, denom)

View file

@ -5,8 +5,8 @@ BEGIN;
-- The GnuCash database must already have entries for the relevant commodities.
-- These statements fail and later changes are skipped if that isn't the case.
CREATE TEMPORARY TABLE guids (mnemonic TEXT NOT NULL, guid TEXT NOT NULL);
INSERT INTO guids VALUES ('{base}', (SELECT guid FROM commodities WHERE mnemonic = '{base}' LIMIT 1));
INSERT INTO guids VALUES ('{quote}', (SELECT guid FROM commodities WHERE mnemonic = '{quote}' LIMIT 1));
INSERT INTO guids VALUES ({base}, (SELECT guid FROM commodities WHERE mnemonic = {base} LIMIT 1));
INSERT INTO guids VALUES ({quote}, (SELECT guid FROM commodities WHERE mnemonic = {quote} LIMIT 1));
-- Create a staging table for the new price data.
-- Doing this via a SELECT ensures the correct date type across databases.