Improve yahoo error handling and add tests.

This commit is contained in:
Chris Berkhout 2021-07-18 18:09:42 +02:00
parent cdd78f0445
commit 46db6e9a6f
5 changed files with 486 additions and 22 deletions

View file

@ -7,7 +7,7 @@ from decimal import Decimal
import requests
from pricehist import __version__
from pricehist import __version__, exceptions
from pricehist.price import Price
from .basesource import BaseSource
@ -30,7 +30,10 @@ class Yahoo(BaseSource):
return "https://finance.yahoo.com/"
def start(self):
return "1970-01-01"
# The "Download historical data in Yahoo Finance" page says
# "Historical prices usually don't go back earlier than 1970", but
# several do. Examples going back to 1962-01-02 include ED and IBM.
return "1962-01-02"
def types(self):
return ["adjclose", "open", "high", "low", "close", "mid"]
@ -55,7 +58,7 @@ class Yahoo(BaseSource):
return (
"Find the symbol of interest on https://finance.yahoo.com/ and use "
"that as the PAIR in your pricehist command. Prices for each symbol "
"are given in its native currency."
"are quoted in its native currency."
)
def symbols(self):
@ -63,10 +66,12 @@ class Yahoo(BaseSource):
return []
def fetch(self, series):
# TODO fail if quote isn't empty - yahoo symbols don't have a slash
spark, history = self._data(series)
if series.quote:
raise exceptions.InvalidPair(
series.base, series.quote, self, "Don't specify the quote currency."
)
output_quote = spark["spark"]["result"][0]["response"][0]["meta"]["currency"]
quote, history = self._data(series)
prices = [
Price(row["date"], amount)
@ -74,15 +79,13 @@ class Yahoo(BaseSource):
if (amount := self._amount(row, series.type))
]
return dataclasses.replace(series, quote=output_quote, prices=prices)
return dataclasses.replace(series, quote=quote, prices=prices)
def _amount(self, row, type):
if type != "mid" and row[type] != "null":
return Decimal(row[type])
elif type == "mid" and row["high"] != "null" and row["low"] != "null":
if type == "mid" and row["high"] != "null" and row["low"] != "null":
return sum([Decimal(row["high"]), Decimal(row["low"])]) / 2
else:
return None
return Decimal(row[type])
def _data(self, series) -> (dict, csv.DictReader):
base_url = "https://query1.finance.yahoo.com/v7/finance"
@ -97,10 +100,32 @@ class Yahoo(BaseSource):
"includeTimestamps": "false",
"includePrePost": "false",
}
spark_response = self.log_curl(
requests.get(spark_url, params=spark_params, headers=headers)
)
spark = json.loads(spark_response.content)
try:
spark_response = self.log_curl(
requests.get(spark_url, params=spark_params, headers=headers)
)
except Exception as e:
raise exceptions.RequestError(str(e)) from e
code = spark_response.status_code
text = spark_response.text
if code == 404 and "No data found for spark symbols" in text:
raise exceptions.InvalidPair(
series.base, series.quote, self, "Symbol not found."
)
try:
spark_response.raise_for_status()
except Exception as e:
raise exceptions.BadResponse(str(e)) from e
try:
spark = json.loads(spark_response.content)
quote = spark["spark"]["result"][0]["response"][0]["meta"]["currency"]
except Exception as e:
raise exceptions.ResponseParsingError(
"The spark data couldn't be parsed. "
) from e
start_ts = int(
datetime.strptime(series.start, "%Y-%m-%d")
@ -123,11 +148,45 @@ class Yahoo(BaseSource):
"events": "history",
"includeAdjustedClose": "true",
}
history_response = self.log_curl(
requests.get(history_url, params=history_params, headers=headers)
)
history_lines = history_response.content.decode("utf-8").splitlines()
history_lines[0] = history_lines[0].lower().replace(" ", "")
history = csv.DictReader(history_lines, delimiter=",")
return (spark, history)
try:
history_response = self.log_curl(
requests.get(history_url, params=history_params, headers=headers)
)
except Exception as e:
raise exceptions.RequestError(str(e)) from e
code = history_response.status_code
text = history_response.text
if code == 404 and "No data found, symbol may be delisted" in text:
raise exceptions.InvalidPair(
series.base, series.quote, self, "Symbol not found."
)
if code == 400 and "Data doesn't exist" in text:
raise exceptions.BadResponse(
"No data for the given interval. Try requesting a larger interval."
)
elif code == 404 and "Timestamp data missing" in text:
raise exceptions.BadResponse(
"Data missing. The given interval may be for a gap in the data "
"such as a weekend or holiday. Try requesting a larger interval."
)
try:
history_response.raise_for_status()
except Exception as e:
raise exceptions.BadResponse(str(e)) from e
try:
history_lines = history_response.content.decode("utf-8").splitlines()
history_lines[0] = history_lines[0].lower().replace(" ", "")
history = csv.DictReader(history_lines, delimiter=",")
except Exception as e:
raise exceptions.ResponseParsingError(str(e)) from e
if history_lines[0] != "date,open,high,low,close,adjclose,volume":
raise exceptions.ResponseParsingError("Unexpected CSV format")
return (quote, history)

View file

@ -0,0 +1,311 @@
import logging
import os
from datetime import datetime, timezone
from decimal import Decimal
from pathlib import Path
import pytest
import requests
import responses
from pricehist import exceptions
from pricehist.price import Price
from pricehist.series import Series
from pricehist.sources.yahoo import Yahoo
def timestamp(date):
return int(
datetime.strptime(date, "%Y-%m-%d").replace(tzinfo=timezone.utc).timestamp()
)
@pytest.fixture
def src():
return Yahoo()
@pytest.fixture
def type(src):
return src.types()[0]
@pytest.fixture
def requests_mock():
with responses.RequestsMock() as mock:
yield mock
spark_url = "https://query1.finance.yahoo.com/v7/finance/spark"
def history_url(base):
return f"https://query1.finance.yahoo.com/v7/finance/download/{base}"
@pytest.fixture
def spark_ok(requests_mock):
json = (Path(os.path.splitext(__file__)[0]) / "tsla-spark.json").read_text()
requests_mock.add(responses.GET, spark_url, body=json, status=200)
yield requests_mock
@pytest.fixture
def recent_ok(requests_mock):
json = (Path(os.path.splitext(__file__)[0]) / "tsla-recent.csv").read_text()
requests_mock.add(responses.GET, history_url("TSLA"), body=json, status=200)
yield requests_mock
@pytest.fixture
def long_ok(requests_mock):
json = (Path(os.path.splitext(__file__)[0]) / "ibm-long-partial.csv").read_text()
requests_mock.add(responses.GET, history_url("IBM"), body=json, status=200)
yield requests_mock
def test_normalizesymbol(src):
assert src.normalizesymbol("tsla") == "TSLA"
def test_metadata(src):
assert isinstance(src.id(), str)
assert len(src.id()) > 0
assert isinstance(src.name(), str)
assert len(src.name()) > 0
assert isinstance(src.description(), str)
assert len(src.description()) > 0
assert isinstance(src.source_url(), str)
assert src.source_url().startswith("http")
assert datetime.strptime(src.start(), "%Y-%m-%d")
assert isinstance(src.types(), list)
assert len(src.types()) > 0
assert isinstance(src.types()[0], str)
assert len(src.types()[0]) > 0
assert isinstance(src.notes(), str)
def test_symbols(src, caplog):
with caplog.at_level(logging.INFO):
symbols = src.symbols()
assert symbols == []
assert any(["Find the symbol of interest on" in r.message for r in caplog.records])
def test_fetch_known(src, type, spark_ok, recent_ok):
series = src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
spark_req = recent_ok.calls[0].request
hist_req = recent_ok.calls[1].request
assert spark_req.params["symbols"] == "TSLA"
assert hist_req.params["events"] == "history"
assert hist_req.params["includeAdjustedClose"] == "true"
assert (series.base, series.quote) == ("TSLA", "USD")
assert len(series.prices) == 5
def test_fetch_requests_and_receives_correct_times(src, type, spark_ok, recent_ok):
series = src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
hist_req = recent_ok.calls[1].request
assert hist_req.params["period1"] == str(timestamp("2021-01-04"))
assert hist_req.params["period2"] == str(timestamp("2021-01-09")) # rounded up one
assert hist_req.params["interval"] == "1d"
assert series.prices[0] == Price("2021-01-04", Decimal("729.770020"))
assert series.prices[-1] == Price("2021-01-08", Decimal("880.020020"))
def test_fetch_requests_logged(src, type, spark_ok, recent_ok, caplog):
with caplog.at_level(logging.DEBUG):
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
logged_requests = 0
for r in caplog.records:
if r.levelname == "DEBUG" and " curl " in r.message:
logged_requests += 1
assert logged_requests == 2
def test_fetch_types_all_available(src, spark_ok, recent_ok):
adj = src.fetch(Series("TSLA", "", "adjclose", "2021-01-04", "2021-01-08"))
opn = src.fetch(Series("TSLA", "", "open", "2021-01-04", "2021-01-08"))
hgh = src.fetch(Series("TSLA", "", "high", "2021-01-04", "2021-01-08"))
low = src.fetch(Series("TSLA", "", "low", "2021-01-04", "2021-01-08"))
cls = src.fetch(Series("TSLA", "", "close", "2021-01-04", "2021-01-08"))
mid = src.fetch(Series("TSLA", "", "mid", "2021-01-04", "2021-01-08"))
assert adj.prices[0].amount == Decimal("729.770020")
assert opn.prices[0].amount == Decimal("719.460022")
assert hgh.prices[0].amount == Decimal("744.489990")
assert low.prices[0].amount == Decimal("717.190002")
assert cls.prices[0].amount == Decimal("729.770020")
assert mid.prices[0].amount == Decimal("730.839996")
def test_fetch_type_mid_is_mean_of_low_and_high(src, spark_ok, recent_ok):
mid = src.fetch(Series("TSLA", "", "mid", "2021-01-04", "2021-01-08")).prices
hgh = src.fetch(Series("TSLA", "", "high", "2021-01-04", "2021-01-08")).prices
low = src.fetch(Series("TSLA", "", "low", "2021-01-04", "2021-01-08")).prices
assert all(
[
mid[i].amount == (sum([low[i].amount, hgh[i].amount]) / 2)
for i in range(0, 5)
]
)
def test_fetch_from_before_start(src, type, spark_ok, long_ok):
series = src.fetch(Series("IBM", "", type, "1900-01-01", "2021-01-08"))
assert series.prices[0] == Price("1962-01-02", Decimal("1.837710"))
assert series.prices[-1] == Price("2021-01-08", Decimal("125.433624"))
assert len(series.prices) > 9
def test_fetch_to_future(src, type, spark_ok, recent_ok):
series = src.fetch(Series("TSLA", "", type, "2021-01-04", "2100-01-08"))
assert len(series.prices) > 0
def test_fetch_no_data_in_past(src, type, spark_ok, requests_mock):
requests_mock.add(
responses.GET,
history_url("TSLA"),
status=400,
body=(
"400 Bad Request: Data doesn't exist for "
"startDate = 1262304000, endDate = 1262995200"
),
)
with pytest.raises(exceptions.BadResponse) as e:
src.fetch(Series("TSLA", "", type, "2010-01-04", "2010-01-08"))
assert "No data for the given interval" in str(e.value)
def test_fetch_no_data_in_future(src, type, spark_ok, requests_mock):
requests_mock.add(
responses.GET,
history_url("TSLA"),
status=400,
body=(
"400 Bad Request: Data doesn't exist for "
"startDate = 1893715200, endDate = 1894147200"
),
)
with pytest.raises(exceptions.BadResponse) as e:
src.fetch(Series("TSLA", "", type, "2030-01-04", "2030-01-08"))
assert "No data for the given interval" in str(e.value)
def test_fetch_no_data_on_weekend(src, type, spark_ok, requests_mock):
requests_mock.add(
responses.GET,
history_url("TSLA"),
status=404,
body="404 Not Found: Timestamp data missing.",
)
with pytest.raises(exceptions.BadResponse) as e:
src.fetch(Series("TSLA", "", type, "2021-01-09", "2021-01-10"))
assert "may be for a gap in the data" in str(e.value)
def test_fetch_bad_sym(src, type, requests_mock):
requests_mock.add(
responses.GET,
spark_url,
status=404,
body="""{
"spark": {
"result": null,
"error": {
"code": "Not Found",
"description": "No data found for spark symbols"
}
}
}""",
)
with pytest.raises(exceptions.InvalidPair) as e:
src.fetch(Series("NOTABASE", "", type, "2021-01-04", "2021-01-08"))
assert "Symbol not found" in str(e.value)
def test_fetch_bad_sym_history(src, type, spark_ok, requests_mock):
# In practice the spark history requests should succeed or fail together.
# This extra test ensures that a failure of the the history part is handled
# correctly even if the spark part succeeds.
requests_mock.add(
responses.GET,
history_url("NOTABASE"),
status=404,
body="404 Not Found: No data found, symbol may be delisted",
)
with pytest.raises(exceptions.InvalidPair) as e:
src.fetch(Series("NOTABASE", "", type, "2021-01-04", "2021-01-08"))
assert "Symbol not found" in str(e.value)
def test_fetch_giving_quote(src, type):
with pytest.raises(exceptions.InvalidPair) as e:
src.fetch(Series("TSLA", "USD", type, "2021-01-04", "2021-01-08"))
assert "quote currency" in str(e.value)
def test_fetch_spark_network_issue(src, type, requests_mock):
body = requests.exceptions.ConnectionError("Network issue")
requests_mock.add(responses.GET, spark_url, body=body)
with pytest.raises(exceptions.RequestError) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "Network issue" in str(e.value)
def test_fetch_spark_bad_status(src, type, requests_mock):
requests_mock.add(responses.GET, spark_url, status=500, body="Some other reason")
with pytest.raises(exceptions.BadResponse) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "Internal Server Error" in str(e.value)
def test_fetch_spark_parsing_error(src, type, requests_mock):
requests_mock.add(responses.GET, spark_url, body="NOT JSON")
with pytest.raises(exceptions.ResponseParsingError) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "spark data couldn't be parsed" in str(e.value)
def test_fetch_spark_unexpected_json(src, type, requests_mock):
requests_mock.add(responses.GET, spark_url, body='{"notdata": []}')
with pytest.raises(exceptions.ResponseParsingError) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "spark data couldn't be parsed" in str(e.value)
def test_fetch_history_network_issue(src, type, spark_ok, requests_mock):
body = requests.exceptions.ConnectionError("Network issue")
requests_mock.add(responses.GET, history_url("TSLA"), body=body)
with pytest.raises(exceptions.RequestError) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "Network issue" in str(e.value)
def test_fetch_history_bad_status(src, type, spark_ok, requests_mock):
requests_mock.add(
responses.GET, history_url("TSLA"), status=500, body="Some other reason"
)
with pytest.raises(exceptions.BadResponse) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "Internal Server Error" in str(e.value)
def test_fetch_history_parsing_error(src, type, spark_ok, requests_mock):
requests_mock.add(responses.GET, history_url("TSLA"), body="")
with pytest.raises(exceptions.ResponseParsingError) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "error occurred while parsing data from the source" in str(e.value)
def test_fetch_history_unexpected_csv_format(src, type, spark_ok, requests_mock):
requests_mock.add(responses.GET, history_url("TSLA"), body="BAD HEADER\nBAD DATA")
with pytest.raises(exceptions.ResponseParsingError) as e:
src.fetch(Series("TSLA", "", type, "2021-01-04", "2021-01-08"))
assert "Unexpected CSV format" in str(e.value)

View file

@ -0,0 +1,11 @@
Date,Open,High,Low,Close,Adj Close,Volume
1962-01-02,7.713333,7.713333,7.626667,7.626667,1.837710,390000
1962-01-03,7.626667,7.693333,7.626667,7.693333,1.853774,292500
1962-01-04,7.693333,7.693333,7.613333,7.616667,1.835299,262500
1962-01-05,7.606667,7.606667,7.453333,7.466667,1.799155,367500
1962-01-08,7.460000,7.460000,7.266667,7.326667,1.765422,547500
2021-01-04,125.849998,125.919998,123.040001,123.940002,120.954201,5179200
2021-01-05,125.010002,126.680000,124.610001,126.139999,123.101204,6114600
2021-01-06,126.900002,131.880005,126.720001,129.289993,126.175316,7956700
2021-01-07,130.039993,130.460007,128.259995,128.990005,125.882545,4507400
2021-01-08,128.570007,129.320007,126.980003,128.529999,125.433624,4676200
1 Date Open High Low Close Adj Close Volume
2 1962-01-02 7.713333 7.713333 7.626667 7.626667 1.837710 390000
3 1962-01-03 7.626667 7.693333 7.626667 7.693333 1.853774 292500
4 1962-01-04 7.693333 7.693333 7.613333 7.616667 1.835299 262500
5 1962-01-05 7.606667 7.606667 7.453333 7.466667 1.799155 367500
6 1962-01-08 7.460000 7.460000 7.266667 7.326667 1.765422 547500
7 2021-01-04 125.849998 125.919998 123.040001 123.940002 120.954201 5179200
8 2021-01-05 125.010002 126.680000 124.610001 126.139999 123.101204 6114600
9 2021-01-06 126.900002 131.880005 126.720001 129.289993 126.175316 7956700
10 2021-01-07 130.039993 130.460007 128.259995 128.990005 125.882545 4507400
11 2021-01-08 128.570007 129.320007 126.980003 128.529999 125.433624 4676200

View file

@ -0,0 +1,6 @@
Date,Open,High,Low,Close,Adj Close,Volume
2021-01-04,719.460022,744.489990,717.190002,729.770020,729.770020,48638200
2021-01-05,723.659973,740.840027,719.200012,735.109985,735.109985,32245200
2021-01-06,758.489990,774.000000,749.099976,755.979980,755.979980,44700000
2021-01-07,777.630005,816.989990,775.200012,816.039978,816.039978,51498900
2021-01-08,856.000000,884.489990,838.390015,880.020020,880.020020,75055500
1 Date Open High Low Close Adj Close Volume
2 2021-01-04 719.460022 744.489990 717.190002 729.770020 729.770020 48638200
3 2021-01-05 723.659973 740.840027 719.200012 735.109985 735.109985 32245200
4 2021-01-06 758.489990 774.000000 749.099976 755.979980 755.979980 44700000
5 2021-01-07 777.630005 816.989990 775.200012 816.039978 816.039978 51498900
6 2021-01-08 856.000000 884.489990 838.390015 880.020020 880.020020 75055500

View file

@ -0,0 +1,77 @@
{
"spark": {
"result": [
{
"symbol": "TSLA",
"response": [
{
"meta": {
"currency": "USD",
"symbol": "TSLA",
"exchangeName": "NMS",
"instrumentType": "EQUITY",
"firstTradeDate": 1277818200,
"regularMarketTime": 1626465603,
"gmtoffset": -14400,
"timezone": "EDT",
"exchangeTimezoneName": "America/New_York",
"regularMarketPrice": 644.22,
"chartPreviousClose": 650.6,
"priceHint": 2,
"currentTradingPeriod": {
"pre": {
"timezone": "EDT",
"start": 1626422400,
"end": 1626442200,
"gmtoffset": -14400
},
"regular": {
"timezone": "EDT",
"start": 1626442200,
"end": 1626465600,
"gmtoffset": -14400
},
"post": {
"timezone": "EDT",
"start": 1626465600,
"end": 1626480000,
"gmtoffset": -14400
}
},
"dataGranularity": "1d",
"range": "1d",
"validRanges": [
"1d",
"5d",
"1mo",
"3mo",
"6mo",
"1y",
"2y",
"5y",
"10y",
"ytd",
"max"
]
},
"timestamp": [
1626442200,
1626465603
],
"indicators": {
"quote": [
{
"close": [
644.22,
644.22
]
}
]
}
}
]
}
],
"error": null
}
}