From c912b676b4c28d35cd9f779f94a541cabf8ed3f3 Mon Sep 17 00:00:00 2001 From: Chris Berkhout Date: Thu, 29 Jul 2021 16:22:06 +0200 Subject: [PATCH] Use a common exceptions handler when interacting with a source. --- src/pricehist/exceptions.py | 15 ++++++++++++++ src/pricehist/fetch.py | 7 +------ src/pricehist/sources/basesource.py | 32 +++++++++++++++++++---------- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/pricehist/exceptions.py b/src/pricehist/exceptions.py index d7fb6f4..5ac7aa7 100644 --- a/src/pricehist/exceptions.py +++ b/src/pricehist/exceptions.py @@ -1,3 +1,18 @@ +import logging +import sys +from contextlib import contextmanager + + +@contextmanager +def handler(): + try: + yield + except SourceError as e: + logging.debug("Critical exception encountered", exc_info=e) + logging.critical(str(e)) + sys.exit(1) + + class SourceError(Exception): """Base exception for errors rased by sources""" diff --git a/src/pricehist/fetch.py b/src/pricehist/fetch.py index e1cc341..71bcdf3 100644 --- a/src/pricehist/fetch.py +++ b/src/pricehist/fetch.py @@ -1,5 +1,4 @@ import logging -import sys from datetime import date, datetime, timedelta from pricehist import exceptions @@ -12,12 +11,8 @@ def fetch(series, source, output, invert: bool, quantize: int, fmt) -> str: f"source start date of {source.start()}." ) - try: + with exceptions.handler(): series = source.fetch(series) - except exceptions.SourceError as e: - logging.debug("Critical exception encountered", exc_info=e) - logging.critical(str(e)) - sys.exit(1) if len(series.prices) == 0: logging.warn(f"No data found for the interval [{series.start}--{series.end}].") diff --git a/src/pricehist/sources/basesource.py b/src/pricehist/sources/basesource.py index 3794522..01fc6cd 100644 --- a/src/pricehist/sources/basesource.py +++ b/src/pricehist/sources/basesource.py @@ -1,9 +1,11 @@ import logging +import sys from abc import ABC, abstractmethod from textwrap import TextWrapper import curlify +from pricehist import exceptions from pricehist.series import Series @@ -56,13 +58,18 @@ class BaseSource(ABC): return response def format_symbols(self) -> str: - symbols = self.symbols() + with exceptions.handler(): + symbols = self.symbols() + width = max([len(sym) for sym, desc in symbols] + [0]) lines = [sym.ljust(width + 4) + desc + "\n" for sym, desc in symbols] return "".join(lines) def format_search(self, query) -> str: - if (symbols := self.search(query)) is None: + with exceptions.handler(): + symbols = self.search(query) + + if symbols is None: logging.error(f"Symbol search is not possible for the {self.id()} source.") exit(1) elif symbols == []: @@ -75,15 +82,18 @@ class BaseSource(ABC): def format_info(self, total_width=80) -> str: k_width = 11 - parts = [ - self._fmt_field("ID", self.id(), k_width, total_width), - self._fmt_field("Name", self.name(), k_width, total_width), - self._fmt_field("Description", self.description(), k_width, total_width), - self._fmt_field("URL", self.source_url(), k_width, total_width, False), - self._fmt_field("Start", self.start(), k_width, total_width), - self._fmt_field("Types", ", ".join(self.types()), k_width, total_width), - self._fmt_field("Notes", self.notes(), k_width, total_width), - ] + with exceptions.handler(): + parts = [ + self._fmt_field("ID", self.id(), k_width, total_width), + self._fmt_field("Name", self.name(), k_width, total_width), + self._fmt_field( + "Description", self.description(), k_width, total_width + ), + self._fmt_field("URL", self.source_url(), k_width, total_width, False), + self._fmt_field("Start", self.start(), k_width, total_width), + self._fmt_field("Types", ", ".join(self.types()), k_width, total_width), + self._fmt_field("Notes", self.notes(), k_width, total_width), + ] return "\n".join(filter(None, parts)) def _fmt_field(self, key, value, key_width, total_width, force=True):