Do proper SQL escaping and warn about bad values.
This commit is contained in:
parent
71ed878c2a
commit
229ea109ef
3 changed files with 68 additions and 22 deletions
|
@ -41,23 +41,23 @@ def _cov_description(
|
|||
start_uncovered = (a1 - r1).days
|
||||
end_uncovered = (r2 - a2).days
|
||||
|
||||
def plural(n):
|
||||
def s(n):
|
||||
return "" if n == 1 else "s"
|
||||
|
||||
if start_uncovered == 0 and end_uncovered > 0:
|
||||
return (
|
||||
f"starting as requested and ending {end_uncovered} "
|
||||
f"day{plural(end_uncovered)} earlier than requested"
|
||||
f"day{s(end_uncovered)} earlier than requested"
|
||||
)
|
||||
elif start_uncovered > 0 and end_uncovered == 0:
|
||||
return (
|
||||
f"starting {start_uncovered} day{plural(start_uncovered)} later "
|
||||
f"starting {start_uncovered} day{s(start_uncovered)} later "
|
||||
"than requested and ending as requested"
|
||||
)
|
||||
elif start_uncovered > 0 and end_uncovered > 0:
|
||||
return (
|
||||
f"starting {start_uncovered} day{plural(start_uncovered)} later "
|
||||
f"and ending {end_uncovered} day{plural(end_uncovered)} earlier "
|
||||
f"starting {start_uncovered} day{s(start_uncovered)} later "
|
||||
f"and ending {end_uncovered} day{s(end_uncovered)} earlier "
|
||||
"than requested"
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from importlib.resources import read_text
|
||||
|
||||
|
@ -14,6 +16,16 @@ class GnuCashSQL(BaseOutput):
|
|||
quote = fmt.quote or series.quote
|
||||
src = f"pricehist:{source.id()}"
|
||||
|
||||
self._warn_about_backslashes(
|
||||
{
|
||||
"time": fmt.time,
|
||||
"base": base,
|
||||
"quote": quote,
|
||||
"source": src,
|
||||
"price type": series.type,
|
||||
}
|
||||
)
|
||||
|
||||
values_parts = []
|
||||
for price in series.prices:
|
||||
date = f"{fmt.format_date(price.date)} {fmt.time}"
|
||||
|
@ -31,18 +43,23 @@ class GnuCashSQL(BaseOutput):
|
|||
).encode("utf-8")
|
||||
)
|
||||
guid = m.hexdigest()[0:32]
|
||||
|
||||
value_num, value_denom = self._fractional(price.amount)
|
||||
v = (
|
||||
"("
|
||||
f"'{guid}', "
|
||||
f"'{date}', "
|
||||
f"'{base}', "
|
||||
f"'{quote}', "
|
||||
f"'{src}', "
|
||||
f"'{series.type}', "
|
||||
f"{value_num}, "
|
||||
f"{value_denom}"
|
||||
")"
|
||||
+ ", ".join(
|
||||
[
|
||||
self._sql_str(guid),
|
||||
self._sql_str(date),
|
||||
self._sql_str(base),
|
||||
self._sql_str(quote),
|
||||
self._sql_str(src),
|
||||
self._sql_str(series.type),
|
||||
str(value_num),
|
||||
str(value_denom),
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
values_parts.append(v)
|
||||
values = ",\n".join(values_parts)
|
||||
|
@ -50,14 +67,43 @@ class GnuCashSQL(BaseOutput):
|
|||
sql = read_text("pricehist.resources", "gnucash.sql").format(
|
||||
version=__version__,
|
||||
timestamp=datetime.utcnow().isoformat() + "Z",
|
||||
base=base,
|
||||
quote=quote,
|
||||
base=self._sql_str(base),
|
||||
quote=self._sql_str(quote),
|
||||
values=values,
|
||||
)
|
||||
|
||||
return sql
|
||||
|
||||
def _fractional(num):
|
||||
num = str(num).replace(".", "")
|
||||
denom = 10 ** len(f"{num}.".split(".")[1])
|
||||
return (num, denom)
|
||||
def _warn_about_backslashes(self, fields):
|
||||
hits = [name for name, value in fields.items() if "\\" in value]
|
||||
if hits:
|
||||
logging.warn(
|
||||
f"Before running this SQL, check the formatting of the "
|
||||
f"{self._english_join(hits)} strings. "
|
||||
f"SQLite treats backslahes in strings as plain characters, but "
|
||||
f"MariaDB/MySQL and PostgreSQL may interpret them as escape "
|
||||
f"codes."
|
||||
)
|
||||
|
||||
def _english_join(self, strings):
|
||||
if len(strings) == 0:
|
||||
return ""
|
||||
elif len(strings) == 1:
|
||||
return str(strings[0])
|
||||
else:
|
||||
return f"{', '.join(strings[0:-1])} and {strings[-1]}"
|
||||
|
||||
def _sql_str(self, s):
|
||||
# Documentation regarding SQL string literals
|
||||
# - https://www.sqlite.org/lang_expr.html#literal_values_constants_
|
||||
# - https://mariadb.com/kb/en/string-literals/
|
||||
# - https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
|
||||
# - https://www.postgresql.org/docs/devel/sql-syntax-lexical.html
|
||||
escaped = re.sub("'", "''", s)
|
||||
quoted = f"'{escaped}'"
|
||||
return quoted
|
||||
|
||||
def _fractional(self, number):
|
||||
numerator = str(number).replace(".", "")
|
||||
denom = 10 ** len(f"{number}.".split(".")[1])
|
||||
return (numerator, denom)
|
||||
|
|
|
@ -5,8 +5,8 @@ BEGIN;
|
|||
-- The GnuCash database must already have entries for the relevant commodities.
|
||||
-- These statements fail and later changes are skipped if that isn't the case.
|
||||
CREATE TEMPORARY TABLE guids (mnemonic TEXT NOT NULL, guid TEXT NOT NULL);
|
||||
INSERT INTO guids VALUES ('{base}', (SELECT guid FROM commodities WHERE mnemonic = '{base}' LIMIT 1));
|
||||
INSERT INTO guids VALUES ('{quote}', (SELECT guid FROM commodities WHERE mnemonic = '{quote}' LIMIT 1));
|
||||
INSERT INTO guids VALUES ({base}, (SELECT guid FROM commodities WHERE mnemonic = {base} LIMIT 1));
|
||||
INSERT INTO guids VALUES ({quote}, (SELECT guid FROM commodities WHERE mnemonic = {quote} LIMIT 1));
|
||||
|
||||
-- Create a staging table for the new price data.
|
||||
-- Doing this via a SELECT ensures the correct date type across databases.
|
||||
|
|
Loading…
Add table
Reference in a new issue