From 46db6e9a6faf56499e7677cc02645cf4d4e03ba2 Mon Sep 17 00:00:00 2001 From: Chris Berkhout Date: Sun, 18 Jul 2021 18:09:42 +0200 Subject: [PATCH] Improve yahoo error handling and add tests. --- src/pricehist/sources/yahoo.py | 103 ++++-- tests/pricehist/sources/test_yahoo.py | 311 ++++++++++++++++++ .../sources/test_yahoo/ibm-long-partial.csv | 11 + .../sources/test_yahoo/tsla-recent.csv | 6 + .../sources/test_yahoo/tsla-spark.json | 77 +++++ 5 files changed, 486 insertions(+), 22 deletions(-) create mode 100644 tests/pricehist/sources/test_yahoo.py create mode 100644 tests/pricehist/sources/test_yahoo/ibm-long-partial.csv create mode 100644 tests/pricehist/sources/test_yahoo/tsla-recent.csv create mode 100644 tests/pricehist/sources/test_yahoo/tsla-spark.json diff --git a/src/pricehist/sources/yahoo.py b/src/pricehist/sources/yahoo.py index 0a61cb5..ae3179a 100644 --- a/src/pricehist/sources/yahoo.py +++ b/src/pricehist/sources/yahoo.py @@ -7,7 +7,7 @@ from decimal import Decimal import requests -from pricehist import __version__ +from pricehist import __version__, exceptions from pricehist.price import Price from .basesource import BaseSource @@ -30,7 +30,10 @@ class Yahoo(BaseSource): return "https://finance.yahoo.com/" def start(self): - return "1970-01-01" + # The "Download historical data in Yahoo Finance" page says + # "Historical prices usually don't go back earlier than 1970", but + # several do. Examples going back to 1962-01-02 include ED and IBM. + return "1962-01-02" def types(self): return ["adjclose", "open", "high", "low", "close", "mid"] @@ -55,7 +58,7 @@ class Yahoo(BaseSource): return ( "Find the symbol of interest on https://finance.yahoo.com/ and use " "that as the PAIR in your pricehist command. Prices for each symbol " - "are given in its native currency." + "are quoted in its native currency." ) def symbols(self): @@ -63,10 +66,12 @@ class Yahoo(BaseSource): return [] def fetch(self, series): - # TODO fail if quote isn't empty - yahoo symbols don't have a slash - spark, history = self._data(series) + if series.quote: + raise exceptions.InvalidPair( + series.base, series.quote, self, "Don't specify the quote currency." + ) - output_quote = spark["spark"]["result"][0]["response"][0]["meta"]["currency"] + quote, history = self._data(series) prices = [ Price(row["date"], amount) @@ -74,15 +79,13 @@ class Yahoo(BaseSource): if (amount := self._amount(row, series.type)) ] - return dataclasses.replace(series, quote=output_quote, prices=prices) + return dataclasses.replace(series, quote=quote, prices=prices) def _amount(self, row, type): - if type != "mid" and row[type] != "null": - return Decimal(row[type]) - elif type == "mid" and row["high"] != "null" and row["low"] != "null": + if type == "mid" and row["high"] != "null" and row["low"] != "null": return sum([Decimal(row["high"]), Decimal(row["low"])]) / 2 else: - return None + return Decimal(row[type]) def _data(self, series) -> (dict, csv.DictReader): base_url = "https://query1.finance.yahoo.com/v7/finance" @@ -97,10 +100,32 @@ class Yahoo(BaseSource): "includeTimestamps": "false", "includePrePost": "false", } - spark_response = self.log_curl( - requests.get(spark_url, params=spark_params, headers=headers) - ) - spark = json.loads(spark_response.content) + try: + spark_response = self.log_curl( + requests.get(spark_url, params=spark_params, headers=headers) + ) + except Exception as e: + raise exceptions.RequestError(str(e)) from e + + code = spark_response.status_code + text = spark_response.text + if code == 404 and "No data found for spark symbols" in text: + raise exceptions.InvalidPair( + series.base, series.quote, self, "Symbol not found." + ) + + try: + spark_response.raise_for_status() + except Exception as e: + raise exceptions.BadResponse(str(e)) from e + + try: + spark = json.loads(spark_response.content) + quote = spark["spark"]["result"][0]["response"][0]["meta"]["currency"] + except Exception as e: + raise exceptions.ResponseParsingError( + "The spark data couldn't be parsed. " + ) from e start_ts = int( datetime.strptime(series.start, "%Y-%m-%d") @@ -123,11 +148,45 @@ class Yahoo(BaseSource): "events": "history", "includeAdjustedClose": "true", } - history_response = self.log_curl( - requests.get(history_url, params=history_params, headers=headers) - ) - history_lines = history_response.content.decode("utf-8").splitlines() - history_lines[0] = history_lines[0].lower().replace(" ", "") - history = csv.DictReader(history_lines, delimiter=",") - return (spark, history) + try: + history_response = self.log_curl( + requests.get(history_url, params=history_params, headers=headers) + ) + except Exception as e: + raise exceptions.RequestError(str(e)) from e + + code = history_response.status_code + text = history_response.text + + if code == 404 and "No data found, symbol may be delisted" in text: + raise exceptions.InvalidPair( + series.base, series.quote, self, "Symbol not found." + ) + if code == 400 and "Data doesn't exist" in text: + raise exceptions.BadResponse( + "No data for the given interval. Try requesting a larger interval." + ) + + elif code == 404 and "Timestamp data missing" in text: + raise exceptions.BadResponse( + "Data missing. The given interval may be for a gap in the data " + "such as a weekend or holiday. Try requesting a larger interval." + ) + + try: + history_response.raise_for_status() + except Exception as e: + raise exceptions.BadResponse(str(e)) from e + + try: + history_lines = history_response.content.decode("utf-8").splitlines() + history_lines[0] = history_lines[0].lower().replace(" ", "") + history = csv.DictReader(history_lines, delimiter=",") + except Exception as e: + raise exceptions.ResponseParsingError(str(e)) from e + + if history_lines[0] != "date,open,high,low,close,adjclose,volume": + raise exceptions.ResponseParsingError("Unexpected CSV format") + + return (quote, history) diff --git a/tests/pricehist/sources/test_yahoo.py b/tests/pricehist/sources/test_yahoo.py new file mode 100644 index 0000000..322e9a5 --- /dev/null +++ b/tests/pricehist/sources/test_yahoo.py @@ -0,0 +1,311 @@ +import logging +import os +from datetime import datetime, timezone +from decimal import Decimal +from pathlib import Path + +import pytest +import requests +import responses + +from pricehist import exceptions +from pricehist.price import Price +from pricehist.series import Series +from pricehist.sources.yahoo import Yahoo + + +def timestamp(date): + return int( + datetime.strptime(date, "%Y-%m-%d").replace(tzinfo=timezone.utc).timestamp() + ) + + +@pytest.fixture +def src(): + return Yahoo() + + +@pytest.fixture +def type(src): + return src.types()[0] + + +@pytest.fixture +def requests_mock(): + with responses.RequestsMock() as mock: + yield mock + + +spark_url = "https://query1.finance.yahoo.com/v7/finance/spark" + + +def history_url(base): + return f"https://query1.finance.yahoo.com/v7/finance/download/{base}" + + +@pytest.fixture +def spark_ok(requests_mock): + json = (Path(os.path.splitext(__file__)[0]) / "tsla-spark.json").read_text() + requests_mock.add(responses.GET, spark_url, body=json, status=200) + yield requests_mock + + +@pytest.fixture +def recent_ok(requests_mock): + json = (Path(os.path.splitext(__file__)[0]) / "tsla-recent.csv").read_text() + requests_mock.add(responses.GET, history_url("TSLA"), body=json, status=200) + yield requests_mock + + +@pytest.fixture +def long_ok(requests_mock): + json = (Path(os.path.splitext(__file__)[0]) / "ibm-long-partial.csv").read_text() + requests_mock.add(responses.GET, history_url("IBM"), body=json, status=200) + yield requests_mock + + +def test_normalizesymbol(src): + assert src.normalizesymbol("tsla") == "TSLA" + + +def test_metadata(src): + assert isinstance(src.id(), str) + assert len(src.id()) > 0 + + assert isinstance(src.name(), str) + assert len(src.name()) > 0 + + assert isinstance(src.description(), str) + assert len(src.description()) > 0 + + assert isinstance(src.source_url(), str) + assert src.source_url().startswith("http") + + assert datetime.strptime(src.start(), "%Y-%m-%d") + + assert isinstance(src.types(), list) + assert len(src.types()) > 0 + assert isinstance(src.types()[0], str) + assert len(src.types()[0]) > 0 + + assert isinstance(src.notes(), str) + + +def test_symbols(src, caplog): + with caplog.at_level(logging.INFO): + symbols = src.symbols() + assert symbols == [] + assert any(["Find the symbol of interest on" in r.message for r in caplog.records]) + + +def test_fetch_known(src, type, spark_ok, recent_ok): + series = src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + spark_req = recent_ok.calls[0].request + hist_req = recent_ok.calls[1].request + assert spark_req.params["symbols"] == "TSLA" + assert hist_req.params["events"] == "history" + assert hist_req.params["includeAdjustedClose"] == "true" + assert (series.base, series.quote) == ("TSLA", "USD") + assert len(series.prices) == 5 + + +def test_fetch_requests_and_receives_correct_times(src, type, spark_ok, recent_ok): + series = src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + hist_req = recent_ok.calls[1].request + assert hist_req.params["period1"] == str(timestamp("2021-01-04")) + assert hist_req.params["period2"] == str(timestamp("2021-01-09")) # rounded up one + assert hist_req.params["interval"] == "1d" + assert series.prices[0] == Price("2021-01-04", Decimal("729.770020")) + assert series.prices[-1] == Price("2021-01-08", Decimal("880.020020")) + + +def test_fetch_requests_logged(src, type, spark_ok, recent_ok, caplog): + with caplog.at_level(logging.DEBUG): + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + logged_requests = 0 + for r in caplog.records: + if r.levelname == "DEBUG" and " curl " in r.message: + logged_requests += 1 + assert logged_requests == 2 + + +def test_fetch_types_all_available(src, spark_ok, recent_ok): + adj = src.fetch(Series("TSLA", "", "adjclose", "2021-01-04", "2021-01-08")) + opn = src.fetch(Series("TSLA", "", "open", "2021-01-04", "2021-01-08")) + hgh = src.fetch(Series("TSLA", "", "high", "2021-01-04", "2021-01-08")) + low = src.fetch(Series("TSLA", "", "low", "2021-01-04", "2021-01-08")) + cls = src.fetch(Series("TSLA", "", "close", "2021-01-04", "2021-01-08")) + mid = src.fetch(Series("TSLA", "", "mid", "2021-01-04", "2021-01-08")) + assert adj.prices[0].amount == Decimal("729.770020") + assert opn.prices[0].amount == Decimal("719.460022") + assert hgh.prices[0].amount == Decimal("744.489990") + assert low.prices[0].amount == Decimal("717.190002") + assert cls.prices[0].amount == Decimal("729.770020") + assert mid.prices[0].amount == Decimal("730.839996") + + +def test_fetch_type_mid_is_mean_of_low_and_high(src, spark_ok, recent_ok): + mid = src.fetch(Series("TSLA", "", "mid", "2021-01-04", "2021-01-08")).prices + hgh = src.fetch(Series("TSLA", "", "high", "2021-01-04", "2021-01-08")).prices + low = src.fetch(Series("TSLA", "", "low", "2021-01-04", "2021-01-08")).prices + assert all( + [ + mid[i].amount == (sum([low[i].amount, hgh[i].amount]) / 2) + for i in range(0, 5) + ] + ) + + +def test_fetch_from_before_start(src, type, spark_ok, long_ok): + series = src.fetch(Series("IBM", "", type, "1900-01-01", "2021-01-08")) + assert series.prices[0] == Price("1962-01-02", Decimal("1.837710")) + assert series.prices[-1] == Price("2021-01-08", Decimal("125.433624")) + assert len(series.prices) > 9 + + +def test_fetch_to_future(src, type, spark_ok, recent_ok): + series = src.fetch(Series("TSLA", "", type, "2021-01-04", "2100-01-08")) + assert len(series.prices) > 0 + + +def test_fetch_no_data_in_past(src, type, spark_ok, requests_mock): + requests_mock.add( + responses.GET, + history_url("TSLA"), + status=400, + body=( + "400 Bad Request: Data doesn't exist for " + "startDate = 1262304000, endDate = 1262995200" + ), + ) + with pytest.raises(exceptions.BadResponse) as e: + src.fetch(Series("TSLA", "", type, "2010-01-04", "2010-01-08")) + assert "No data for the given interval" in str(e.value) + + +def test_fetch_no_data_in_future(src, type, spark_ok, requests_mock): + requests_mock.add( + responses.GET, + history_url("TSLA"), + status=400, + body=( + "400 Bad Request: Data doesn't exist for " + "startDate = 1893715200, endDate = 1894147200" + ), + ) + with pytest.raises(exceptions.BadResponse) as e: + src.fetch(Series("TSLA", "", type, "2030-01-04", "2030-01-08")) + assert "No data for the given interval" in str(e.value) + + +def test_fetch_no_data_on_weekend(src, type, spark_ok, requests_mock): + requests_mock.add( + responses.GET, + history_url("TSLA"), + status=404, + body="404 Not Found: Timestamp data missing.", + ) + with pytest.raises(exceptions.BadResponse) as e: + src.fetch(Series("TSLA", "", type, "2021-01-09", "2021-01-10")) + assert "may be for a gap in the data" in str(e.value) + + +def test_fetch_bad_sym(src, type, requests_mock): + requests_mock.add( + responses.GET, + spark_url, + status=404, + body="""{ + "spark": { + "result": null, + "error": { + "code": "Not Found", + "description": "No data found for spark symbols" + } + } + }""", + ) + with pytest.raises(exceptions.InvalidPair) as e: + src.fetch(Series("NOTABASE", "", type, "2021-01-04", "2021-01-08")) + assert "Symbol not found" in str(e.value) + + +def test_fetch_bad_sym_history(src, type, spark_ok, requests_mock): + # In practice the spark history requests should succeed or fail together. + # This extra test ensures that a failure of the the history part is handled + # correctly even if the spark part succeeds. + requests_mock.add( + responses.GET, + history_url("NOTABASE"), + status=404, + body="404 Not Found: No data found, symbol may be delisted", + ) + with pytest.raises(exceptions.InvalidPair) as e: + src.fetch(Series("NOTABASE", "", type, "2021-01-04", "2021-01-08")) + assert "Symbol not found" in str(e.value) + + +def test_fetch_giving_quote(src, type): + with pytest.raises(exceptions.InvalidPair) as e: + src.fetch(Series("TSLA", "USD", type, "2021-01-04", "2021-01-08")) + assert "quote currency" in str(e.value) + + +def test_fetch_spark_network_issue(src, type, requests_mock): + body = requests.exceptions.ConnectionError("Network issue") + requests_mock.add(responses.GET, spark_url, body=body) + with pytest.raises(exceptions.RequestError) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "Network issue" in str(e.value) + + +def test_fetch_spark_bad_status(src, type, requests_mock): + requests_mock.add(responses.GET, spark_url, status=500, body="Some other reason") + with pytest.raises(exceptions.BadResponse) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "Internal Server Error" in str(e.value) + + +def test_fetch_spark_parsing_error(src, type, requests_mock): + requests_mock.add(responses.GET, spark_url, body="NOT JSON") + with pytest.raises(exceptions.ResponseParsingError) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "spark data couldn't be parsed" in str(e.value) + + +def test_fetch_spark_unexpected_json(src, type, requests_mock): + requests_mock.add(responses.GET, spark_url, body='{"notdata": []}') + with pytest.raises(exceptions.ResponseParsingError) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "spark data couldn't be parsed" in str(e.value) + + +def test_fetch_history_network_issue(src, type, spark_ok, requests_mock): + body = requests.exceptions.ConnectionError("Network issue") + requests_mock.add(responses.GET, history_url("TSLA"), body=body) + with pytest.raises(exceptions.RequestError) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "Network issue" in str(e.value) + + +def test_fetch_history_bad_status(src, type, spark_ok, requests_mock): + requests_mock.add( + responses.GET, history_url("TSLA"), status=500, body="Some other reason" + ) + with pytest.raises(exceptions.BadResponse) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "Internal Server Error" in str(e.value) + + +def test_fetch_history_parsing_error(src, type, spark_ok, requests_mock): + requests_mock.add(responses.GET, history_url("TSLA"), body="") + with pytest.raises(exceptions.ResponseParsingError) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "error occurred while parsing data from the source" in str(e.value) + + +def test_fetch_history_unexpected_csv_format(src, type, spark_ok, requests_mock): + requests_mock.add(responses.GET, history_url("TSLA"), body="BAD HEADER\nBAD DATA") + with pytest.raises(exceptions.ResponseParsingError) as e: + src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08")) + assert "Unexpected CSV format" in str(e.value) diff --git a/tests/pricehist/sources/test_yahoo/ibm-long-partial.csv b/tests/pricehist/sources/test_yahoo/ibm-long-partial.csv new file mode 100644 index 0000000..98149ad --- /dev/null +++ b/tests/pricehist/sources/test_yahoo/ibm-long-partial.csv @@ -0,0 +1,11 @@ +Date,Open,High,Low,Close,Adj Close,Volume +1962-01-02,7.713333,7.713333,7.626667,7.626667,1.837710,390000 +1962-01-03,7.626667,7.693333,7.626667,7.693333,1.853774,292500 +1962-01-04,7.693333,7.693333,7.613333,7.616667,1.835299,262500 +1962-01-05,7.606667,7.606667,7.453333,7.466667,1.799155,367500 +1962-01-08,7.460000,7.460000,7.266667,7.326667,1.765422,547500 +2021-01-04,125.849998,125.919998,123.040001,123.940002,120.954201,5179200 +2021-01-05,125.010002,126.680000,124.610001,126.139999,123.101204,6114600 +2021-01-06,126.900002,131.880005,126.720001,129.289993,126.175316,7956700 +2021-01-07,130.039993,130.460007,128.259995,128.990005,125.882545,4507400 +2021-01-08,128.570007,129.320007,126.980003,128.529999,125.433624,4676200 diff --git a/tests/pricehist/sources/test_yahoo/tsla-recent.csv b/tests/pricehist/sources/test_yahoo/tsla-recent.csv new file mode 100644 index 0000000..48b5692 --- /dev/null +++ b/tests/pricehist/sources/test_yahoo/tsla-recent.csv @@ -0,0 +1,6 @@ +Date,Open,High,Low,Close,Adj Close,Volume +2021-01-04,719.460022,744.489990,717.190002,729.770020,729.770020,48638200 +2021-01-05,723.659973,740.840027,719.200012,735.109985,735.109985,32245200 +2021-01-06,758.489990,774.000000,749.099976,755.979980,755.979980,44700000 +2021-01-07,777.630005,816.989990,775.200012,816.039978,816.039978,51498900 +2021-01-08,856.000000,884.489990,838.390015,880.020020,880.020020,75055500 \ No newline at end of file diff --git a/tests/pricehist/sources/test_yahoo/tsla-spark.json b/tests/pricehist/sources/test_yahoo/tsla-spark.json new file mode 100644 index 0000000..53e7585 --- /dev/null +++ b/tests/pricehist/sources/test_yahoo/tsla-spark.json @@ -0,0 +1,77 @@ +{ + "spark": { + "result": [ + { + "symbol": "TSLA", + "response": [ + { + "meta": { + "currency": "USD", + "symbol": "TSLA", + "exchangeName": "NMS", + "instrumentType": "EQUITY", + "firstTradeDate": 1277818200, + "regularMarketTime": 1626465603, + "gmtoffset": -14400, + "timezone": "EDT", + "exchangeTimezoneName": "America/New_York", + "regularMarketPrice": 644.22, + "chartPreviousClose": 650.6, + "priceHint": 2, + "currentTradingPeriod": { + "pre": { + "timezone": "EDT", + "start": 1626422400, + "end": 1626442200, + "gmtoffset": -14400 + }, + "regular": { + "timezone": "EDT", + "start": 1626442200, + "end": 1626465600, + "gmtoffset": -14400 + }, + "post": { + "timezone": "EDT", + "start": 1626465600, + "end": 1626480000, + "gmtoffset": -14400 + } + }, + "dataGranularity": "1d", + "range": "1d", + "validRanges": [ + "1d", + "5d", + "1mo", + "3mo", + "6mo", + "1y", + "2y", + "5y", + "10y", + "ytd", + "max" + ] + }, + "timestamp": [ + 1626442200, + 1626465603 + ], + "indicators": { + "quote": [ + { + "close": [ + 644.22, + 644.22 + ] + } + ] + } + } + ] + } + ], + "error": null + } +}