diff --git a/src/pricehist/cli.py b/src/pricehist/cli.py index feb97ca..f5a25fe 100644 --- a/src/pricehist/cli.py +++ b/src/pricehist/cli.py @@ -1,4 +1,5 @@ import argparse +import dataclasses import logging import shutil from datetime import datetime, timedelta @@ -6,6 +7,8 @@ from textwrap import TextWrapper from pricehist import __version__, outputs, sources from pricehist.format import Format +from pricehist.price import Price +from pricehist.series import Series def cli(args=None): @@ -95,21 +98,23 @@ def cmd_fetch(args): f"source start date of {source.start()}." ) - prices = source.fetch(args.pair, type, start, args.end) + base, quote = args.pair.split("/") + series = source.fetch(Series(base, quote, type, start, args.end)) + + if args.invert: + series = dataclasses.replace( + series, + base=series.quote, + quote=series.base, + prices=[Price(date=p.date, amount=(1 / p.amount)) for p in series.prices], + ) if args.renamebase or args.renamequote: - prices = [ - p._replace( - base=(args.renamebase or p.base), - quote=(args.renamequote or p.quote), - ) - for p in prices - ] - if args.invert: - prices = [ - p._replace(base=p.quote, quote=p.base, amount=(1 / p.amount)) - for p in prices - ] + series = dataclasses.replace( + series, + base=(args.renamebase or base), + quote=(args.renamequote or quote), + ) default = Format() @@ -128,7 +133,7 @@ def cmd_fetch(args): decimal_places=if_not_none(args.quantize, default.decimal_places), ) - print(output.format(prices, source, type, fmt=fmt), end="") + print(output.format(series, source, fmt=fmt), end="") def build_parser(): diff --git a/src/pricehist/outputs/beancount.py b/src/pricehist/outputs/beancount.py index f8ee615..e378a4f 100644 --- a/src/pricehist/outputs/beancount.py +++ b/src/pricehist/outputs/beancount.py @@ -2,9 +2,9 @@ from pricehist.format import Format class Beancount: - def format(self, prices, source=None, type=None, fmt=Format()): + def format(self, series, source=None, fmt=Format()): lines = [] - for price in prices: + for price in series.prices: amount_parts = f"{fmt.quantize(price.amount):,}".split(".") amount_parts[0] = amount_parts[0].replace(",", fmt.thousands) @@ -12,13 +12,13 @@ class Beancount: qa_parts = [amount] if fmt.symbol == "right": - qa_parts = qa_parts + [price.quote] + qa_parts = qa_parts + [series.quote] else: - qa_parts = qa_parts + [" ", price.quote] + qa_parts = qa_parts + [" ", series.quote] quote_amount = "".join(qa_parts) date = str(price.date).replace("-", fmt.datesep) - lines.append(f"{date} price {price.base} {quote_amount}") + lines.append(f"{date} price {series.base} {quote_amount}") return "\n".join(lines) + "\n" diff --git a/src/pricehist/outputs/csv.py b/src/pricehist/outputs/csv.py index 0a31354..5f27615 100644 --- a/src/pricehist/outputs/csv.py +++ b/src/pricehist/outputs/csv.py @@ -2,13 +2,15 @@ from pricehist.format import Format class CSV: - def format(self, prices, source=None, type=None, fmt=Format()): + def format(self, series, source=None, fmt=Format()): lines = ["date,base,quote,amount,source,type"] - for price in prices: + for price in series.prices: date = str(price.date).replace("-", fmt.datesep) amount_parts = f"{fmt.quantize(price.amount):,}".split(".") amount_parts[0] = amount_parts[0].replace(",", fmt.thousands) amount = fmt.decimal.join(amount_parts) - line = ",".join([date, price.base, price.quote, amount, source.id(), type]) + line = ",".join( + [date, series.base, series.quote, amount, source.id(), series.type] + ) lines.append(line) return "\n".join(lines) + "\n" diff --git a/src/pricehist/outputs/gnucashsql.py b/src/pricehist/outputs/gnucashsql.py index 4c0d5f9..9b3439b 100644 --- a/src/pricehist/outputs/gnucashsql.py +++ b/src/pricehist/outputs/gnucashsql.py @@ -7,18 +7,18 @@ from pricehist.format import Format class GnuCashSQL: - def format(self, prices, source=None, type=None, fmt=Format()): + def format(self, series, source=None, fmt=Format()): src = f"pricehist:{source.id()}" values_parts = [] - for price in prices: + for price in series.prices: date = f"{price.date} {fmt.time}" amount = fmt.quantize(price.amount) m = hashlib.sha256() m.update( - "".join([date, price.base, price.quote, src, type, str(amount)]).encode( - "utf-8" - ) + "".join( + [date, series.base, series.quote, src, series.type, str(amount)] + ).encode("utf-8") ) guid = m.hexdigest()[0:32] value_num = str(amount).replace(".", "") @@ -27,10 +27,10 @@ class GnuCashSQL: "(" f"'{guid}', " f"'{date}', " - f"'{price.base}', " - f"'{price.quote}', " + f"'{series.base}', " + f"'{series.quote}', " f"'{src}', " - f"'{type}', " + f"'{series.type}', " f"{value_num}, " f"{value_denom}" ")" @@ -41,8 +41,8 @@ class GnuCashSQL: sql = read_text("pricehist.resources", "gnucash.sql").format( version=__version__, timestamp=datetime.utcnow().isoformat() + "Z", - base=price.base, - quote=price.quote, + base=series.base, + quote=series.quote, values=values, ) diff --git a/src/pricehist/outputs/ledger.py b/src/pricehist/outputs/ledger.py index 1e56218..b41fd2a 100644 --- a/src/pricehist/outputs/ledger.py +++ b/src/pricehist/outputs/ledger.py @@ -2,9 +2,9 @@ from pricehist.format import Format class Ledger: - def format(self, prices, source=None, type=None, fmt=Format()): + def format(self, series, source=None, fmt=Format()): lines = [] - for price in prices: + for price in series.prices: date = str(price.date).replace("-", fmt.datesep) amount_parts = f"{fmt.quantize(price.amount):,}".split(".") @@ -13,16 +13,16 @@ class Ledger: qa_parts = [amount] if fmt.symbol == "left": - qa_parts = [price.quote] + qa_parts + qa_parts = [series.quote] + qa_parts elif fmt.symbol == "leftspace": - qa_parts = [price.quote, " "] + qa_parts + qa_parts = [series.quote, " "] + qa_parts elif fmt.symbol == "right": - qa_parts = qa_parts + [price.quote] + qa_parts = qa_parts + [series.quote] else: - qa_parts = qa_parts + [" ", price.quote] + qa_parts = qa_parts + [" ", series.quote] quote_amount = "".join(qa_parts) - lines.append(f"P {date} {fmt.time} {price.base} {quote_amount}") + lines.append(f"P {date} {fmt.time} {series.base} {quote_amount}") return "\n".join(lines) + "\n" # TODO support additional details of the format: diff --git a/src/pricehist/price.py b/src/pricehist/price.py index dafc2bc..d8c8a8a 100644 --- a/src/pricehist/price.py +++ b/src/pricehist/price.py @@ -4,7 +4,5 @@ from decimal import Decimal @dataclass(frozen=True) class Price: - base: str - quote: str date: str amount: Decimal diff --git a/src/pricehist/series.py b/src/pricehist/series.py new file mode 100644 index 0000000..20e8f67 --- /dev/null +++ b/src/pricehist/series.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass, field + +from pricehist.price import Price + + +@dataclass(frozen=True) +class Series: + base: str + quote: str + type: str + start: str + end: str + prices: list[Price] = field(default_factory=list) diff --git a/src/pricehist/sources/coindesk.py b/src/pricehist/sources/coindesk.py index 21f51de..e5b8a1e 100644 --- a/src/pricehist/sources/coindesk.py +++ b/src/pricehist/sources/coindesk.py @@ -1,3 +1,4 @@ +import dataclasses import json from decimal import Decimal @@ -41,19 +42,17 @@ class CoinDesk: ) return symbols - def fetch(self, pair, type, start, end): - base, quote = pair.split("/") - + def fetch(self, series): url = "https://api.coindesk.com/v1/bpi/historical/close.json" params = { - "currency": quote, - "start": start, - "end": end, + "currency": series.quote, + "start": series.start, + "end": series.end, } response = requests.get(url, params=params) data = json.loads(response.content) prices = [] for (d, v) in data["bpi"].items(): - prices.append(Price(base, quote, d, Decimal(str(v)))) + prices.append(Price(d, Decimal(str(v)))) - return prices + return dataclasses.replace(series, prices=prices) diff --git a/src/pricehist/sources/coinmarketcap.py b/src/pricehist/sources/coinmarketcap.py index 3318425..c446453 100644 --- a/src/pricehist/sources/coinmarketcap.py +++ b/src/pricehist/sources/coinmarketcap.py @@ -1,3 +1,4 @@ +import dataclasses import json from datetime import datetime from decimal import Decimal @@ -8,32 +9,25 @@ from pricehist.price import Price class CoinMarketCap: - @staticmethod - def id(): + def id(self): return "coinmarketcap" - @staticmethod - def name(): + def name(self): return "CoinMarketCap" - @staticmethod - def description(): + def description(self): return "The world's most-referenced price-tracking website for cryptoassets" - @staticmethod - def source_url(): + def source_url(self): return "https://coinmarketcap.com/" - @staticmethod - def start(): + def start(self): return "2013-04-28" - @staticmethod - def types(): + def types(self): return ["mid", "open", "high", "low", "close"] - @staticmethod - def notes(): + def notes(self): return ( "This source makes unoffical use of endpoints that power CoinMarketCap's " "public web interface. The price data comes from a public equivalent of " @@ -52,36 +46,36 @@ class CoinMarketCap: rows = [i.ljust(id_width + 4) + d for i, d in zip(ids, descriptions)] return rows - def fetch(self, pair, type, start, end): - base, quote = pair.split("/") - + def fetch(self, series): url = "https://web-api.coinmarketcap.com/v1/cryptocurrency/ohlcv/historical" params = {} - if base.startswith("id=") or quote.startswith("id="): + if series.base.startswith("id=") or series.quote.startswith("id="): symbols = {} for i in self._symbol_data(): symbols[str(i["id"])] = i["symbol"] or i["code"] - if base.startswith("id="): - params["id"] = base[3:] - output_base = symbols[base[3:]] + if series.base.startswith("id="): + params["id"] = series.base[3:] + output_base = symbols[series.base[3:]] else: - params["symbol"] = base - output_base = base + params["symbol"] = series.base + output_base = series.base - if quote.startswith("id="): - params["convert_id"] = quote[3:] - quote_key = quote[3:] - output_quote = symbols[quote[3:]] + if series.quote.startswith("id="): + params["convert_id"] = series.quote[3:] + quote_key = series.quote[3:] + output_quote = symbols[series.quote[3:]] else: - params["convert"] = quote - quote_key = quote - output_quote = quote + params["convert"] = series.quote + quote_key = series.quote + output_quote = series.quote - params["time_start"] = int(datetime.strptime(start, "%Y-%m-%d").timestamp()) + params["time_start"] = int( + datetime.strptime(series.start, "%Y-%m-%d").timestamp() + ) params["time_end"] = ( - int(datetime.strptime(end, "%Y-%m-%d").timestamp()) + 24 * 60 * 60 + int(datetime.strptime(series.end, "%Y-%m-%d").timestamp()) + 24 * 60 * 60 ) # round up to include the last day response = requests.get(url, params=params) @@ -90,10 +84,12 @@ class CoinMarketCap: prices = [] for item in data["data"]["quotes"]: d = item["time_open"][0:10] - amount = self._amount(item["quote"][quote_key], type) - prices.append(Price(output_base, output_quote, d, amount)) + amount = self._amount(item["quote"][quote_key], series.type) + prices.append(Price(d, amount)) - return prices + return dataclasses.replace( + series, base=output_base, quote=output_quote, prices=prices + ) def _symbol_data(self): fiat_url = "https://web-api.coinmarketcap.com/v1/fiat/map?include_metals=true" diff --git a/src/pricehist/sources/ecb.py b/src/pricehist/sources/ecb.py index 97e4260..a515022 100644 --- a/src/pricehist/sources/ecb.py +++ b/src/pricehist/sources/ecb.py @@ -1,3 +1,4 @@ +import dataclasses from datetime import datetime, timedelta from decimal import Decimal @@ -39,11 +40,9 @@ class ECB: pairs = [f"EUR/{c} Euro against {iso[c].name}" for c in currencies] return pairs - def fetch(self, pair, type, start, end): - base, quote = pair.split("/") - + def fetch(self, series): almost_90_days_ago = str(datetime.now().date() - timedelta(days=85)) - data = self._raw_data(start < almost_90_days_ago) + data = self._raw_data(series.start < almost_90_days_ago) root = etree.fromstring(data) all_rows = [] @@ -51,14 +50,14 @@ class ECB: date = day.attrib["time"] # TODO what if it's not found for that day? # (some quotes aren't in the earliest data) - for row in day.cssselect(f"[currency='{quote}']"): + for row in day.cssselect(f"[currency='{series.quote}']"): rate = Decimal(row.attrib["rate"]) all_rows.insert(0, (date, rate)) selected = [ - Price(base, quote, d, r) for d, r in all_rows if d >= start and d <= end + Price(d, r) for d, r in all_rows if d >= series.start and d <= series.end ] - return selected + return dataclasses.replace(series, prices=selected) def _raw_data(self, more_than_90_days=False): url_base = "https://www.ecb.europa.eu/stats/eurofxref"