From 4492e0025047bff0010442d981537ab6ac13113d Mon Sep 17 00:00:00 2001 From: Sean Breckenridge Date: Thu, 28 Oct 2021 20:29:50 -0700 Subject: [PATCH] misc fixes - convert import_source to a decorator which wraps the function call in a try block - fix protocol class when not TYPE_CHECKING - add id properties to Protocols, remove attributes since protocol expects them to be settable but NT is read-only - use id to merge comments - remove type: ignore's from reddit config and just store as 'Any' --- my/core/source.py | 66 +++++++++++++++++++++++++++++------------- my/reddit/all.py | 48 ++++++++++++++++++++---------- my/reddit/common.py | 55 ++++++++++++++++++----------------- my/reddit/pushshift.py | 1 - my/reddit/rexport.py | 19 +++++++----- 5 files changed, 118 insertions(+), 71 deletions(-) diff --git a/my/core/source.py b/my/core/source.py index 6f62754..bcf4965 100644 --- a/my/core/source.py +++ b/my/core/source.py @@ -1,26 +1,52 @@ -from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable +""" +Decorator to gracefully handle importing a data source, or warning +and yielding nothing (or a default) when its not available +""" + +from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable, Any from my.core.warnings import warn +from functools import wraps -T = TypeVar("T") +# The factory function may produce something that has data +# similar to the shared model, but not exactly, so not +# making this a TypeVar, is just to make reading the +# type signature below a bit easier... +T = Any -# this is probably more generic and results in less code, but is not mypy-friendly -def import_source(factory: Callable[[], Any], default: Any) -> Any: - try: - res = factory() - return res - except ModuleNotFoundError: # presumable means the user hasn't installed the module - warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly") - return default +# https://mypy.readthedocs.io/en/latest/generics.html?highlight=decorators#decorator-factories +FactoryF = TypeVar("FactoryF", bound=Callable[..., Iterator[T]]) + +_DEFUALT_ITR = () -# For an example of this, see the reddit.all file -def import_source_iter(factory: Callable[[], Iterator[T]], default: Optional[Iterable[T]] = None) -> Iterator[T]: - if default is None: - default = [] - try: - res = factory() - yield from res - except ModuleNotFoundError: # presumable means the user hasn't installed the module - warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly") - yield from default +# tried to use decorator module but it really doesn't work well +# with types and kw-arguments... :/ +def import_source( + module_name: Optional[str] = None, + default: Optional[Iterable[T]] = _DEFUALT_ITR +) -> Callable[..., Callable[..., Iterator[T]]]: + """ + doesn't really play well with types, but is used to catch + ModuleNotFoundError's for when modules aren't installed in + all.py files, so the types don't particularly matter + + this is meant to be used to wrap some function which imports + and then yields an iterator of objects + + If the user doesn't have that module installed, it returns + nothing and warns instead + """ + + def decorator(factory_func: FactoryF) -> Callable[..., Iterator[T]]: + @wraps(factory_func) + def wrapper(*args, **kwargs) -> Iterator[T]: + try: + res = factory_func(**kwargs) + yield from res + except ModuleNotFoundError: + # TODO: check if module_name is disabled and don't send warning + warn(f"Module {factory_func.__qualname__} could not be imported, or isn't configured propertly") + yield from default + return wrapper + return decorator diff --git a/my/reddit/all.py b/my/reddit/all.py index d4cf137..a668081 100644 --- a/my/reddit/all.py +++ b/my/reddit/all.py @@ -1,6 +1,6 @@ -from typing import Iterator, Any, Callable, TypeVar +from typing import Iterator from my.core.common import Stats -from my.core.source import import_source_iter as imp +from my.core.source import import_source from .common import Save, Upvote, Comment, Submission, _merge_comments @@ -8,22 +8,34 @@ from .common import Save, Upvote, Comment, Submission, _merge_comments # reddit just feels like that much of a complicated source and # data acquired by different methods isn't the same -### import helpers - -# this import error is caught in import_source_iter, if rexport isn't installed -def _rexport_import() -> Any: - from . import rexport as source - return source +### 'safe importers' -- falls back to empty data if the module couldn't be found +rexport_src = import_source(module_name="my.reddit.rexport") +pushshift_src = import_source(module_name="my.reddit.pushshift") +@rexport_src def _rexport_comments() -> Iterator[Comment]: - yield from imp(lambda: _rexport_import().comments()) + from . import rexport + yield from rexport.comments() -def _pushshift_import() -> Any: - from . import pushshift as source - return source +@rexport_src +def _rexport_submissions() -> Iterator[Submission]: + from . import rexport + yield from rexport.submissions() +@rexport_src +def _rexport_saved() -> Iterator[Save]: + from . import rexport + yield from rexport.saved() + +@rexport_src +def _rexport_upvoted() -> Iterator[Upvote]: + from . import rexport + yield from rexport.upvoted() + +@pushshift_src def _pushshift_comments() -> Iterator[Comment]: - yield from imp(lambda: _pushshift_import().comments()) + from .pushshift import comments as pcomments + yield from pcomments() # Merged functions @@ -33,13 +45,17 @@ def comments() -> Iterator[Comment]: def submissions() -> Iterator[Submission]: # TODO: merge gdpr here - yield from imp(lambda: _rexport_import().submissions()) + yield from _rexport_submissions() +@rexport_src def saved() -> Iterator[Save]: - yield from imp(lambda: _rexport_import().saved()) + from .rexport import saved + yield from saved() +@rexport_src def upvoted() -> Iterator[Upvote]: - yield from imp(lambda: _rexport_import().upvoted()) + from .rexport import upvoted + yield from upvoted() def stats() -> Stats: from my.core import stat diff --git a/my/reddit/common.py b/my/reddit/common.py index 767d3a1..915455e 100644 --- a/my/reddit/common.py +++ b/my/reddit/common.py @@ -1,28 +1,28 @@ """ This defines Protocol classes, which make sure that each different -type of Comment/Save have a standard interface +type of shared models have a standardized interface """ -from typing import Dict, Any, Set, Iterator +from typing import Dict, Any, Set, Iterator, TYPE_CHECKING from itertools import chain -from datetime import datetime + +from my.core.common import datetime_aware Json = Dict[str, Any] -try: - from typing import Protocol -except ImportError: - # hmm -- does this need to be installed on 3.6 or is it already here? - from typing_extensions import Protocol # type: ignore[misc] +if TYPE_CHECKING: + try: + from typing import Protocol + except ImportError: + # requirement of mypy + from typing_extensions import Protocol # type: ignore[misc] +else: + Protocol = object # Note: doesn't include GDPR Save's since they don't have the same metadata class Save(Protocol): - created: datetime - title: str - raw: Json - @property - def sid(self) -> str: ... + def id(self) -> str: ... @property def url(self) -> str: ... @property @@ -32,9 +32,10 @@ class Save(Protocol): # Note: doesn't include GDPR Upvote's since they don't have the same metadata class Upvote(Protocol): - raw: Json @property - def created(self) -> datetime: ... + def id(self) -> str: ... + @property + def created(self) -> datetime_aware: ... @property def url(self) -> str: ... @property @@ -43,22 +44,24 @@ class Upvote(Protocol): def title(self) -> str: ... -# From rexport, pushshift and the reddit gdpr export +# From rexport, pushshift and the reddit GDPR export class Comment(Protocol): - raw: Json @property - def created(self) -> datetime: ... + def id(self) -> str: ... + @property + def created(self) -> datetime_aware: ... @property def url(self) -> str: ... @property def text(self) -> str: ... -# From rexport and the gdpr export +# From rexport and the GDPR export class Submission(Protocol): - raw: Json @property - def created(self) -> datetime: ... + def id(self) -> str: ... + @property + def created(self) -> datetime_aware: ... @property def url(self) -> str: ... @property @@ -70,14 +73,14 @@ class Submission(Protocol): def _merge_comments(*sources: Iterator[Comment]) -> Iterator[Comment]: #from .rexport import logger #ignored = 0 - emitted: Set[int] = set() + emitted: Set[str] = set() for e in chain(*sources): - key = int(e.raw["created_utc"]) - if key in emitted: + uid = e.id + if uid in emitted: #ignored += 1 - #logger.info('ignoring %s: %s', key, e) + #logger.info('ignoring %s: %s', uid, e) continue yield e - emitted.add(key) + emitted.add(uid) #logger.info(f"Ignored {ignored} comments...") diff --git a/my/reddit/pushshift.py b/my/reddit/pushshift.py index df1fd1e..b7a8277 100644 --- a/my/reddit/pushshift.py +++ b/my/reddit/pushshift.py @@ -23,7 +23,6 @@ config = make_config(pushshift_config) from my.core import get_files from typing import Sequence, Iterator from pathlib import Path -from .common import Comment from pushshift_comment_export.dal import read_file, PComment diff --git a/my/reddit/rexport.py b/my/reddit/rexport.py index 32b1f6f..8cc3173 100755 --- a/my/reddit/rexport.py +++ b/my/reddit/rexport.py @@ -7,6 +7,7 @@ REQUIRES = [ from my.core.common import Paths from dataclasses import dataclass +from typing import Any from my.config import reddit as uconfig @@ -18,21 +19,23 @@ from my.config import reddit as uconfig # in the migration # need to check before we subclass +conf: Any + if hasattr(uconfig, "rexport"): - # sigh... backwards compatability - uconfig = uconfig.rexport # type: ignore[attr-defined,misc,assignment] + conf = uconfig.rexport else: from my.core.warnings import high - high(f"""DEPRECATED! Please modify your reddit config to look like: + high("""DEPRECATED! Please modify your reddit config to look like: class reddit: class rexport: export_path: Paths = '/path/to/rexport/data' """) + conf = uconfig @dataclass -class reddit(uconfig): +class reddit(conf): ''' Uses [[https://github.com/karlicoss/rexport][rexport]] output. ''' @@ -81,7 +84,7 @@ def inputs() -> Sequence[Path]: return get_files(config.export_path) -Sid = dal.Sid # str +Uid = dal.Sid # str Save = dal.Save Comment = dal.Comment Submission = dal.Submission @@ -161,7 +164,7 @@ def _get_bdate(bfile: Path) -> datetime: return bdt -def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]: +def _get_state(bfile: Path) -> Dict[Uid, SaveWithDt]: logger.debug('handling %s', bfile) bdt = _get_bdate(bfile) @@ -178,11 +181,11 @@ def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]: def _get_events(backups: Sequence[Path], parallel: bool=True) -> Iterator[Event]: # todo cachew: let it transform return type? so you don't have to write a wrapper for lists? - prev_saves: Mapping[Sid, SaveWithDt] = {} + prev_saves: Mapping[Uid, SaveWithDt] = {} # TODO suppress first batch?? # TODO for initial batch, treat event time as creation time - states: Iterable[Mapping[Sid, SaveWithDt]] + states: Iterable[Mapping[Uid, SaveWithDt]] if parallel: with Pool() as p: states = p.map(_get_state, backups)