misc fixes

- convert import_source to a decorator which
  wraps the function call in a try block
- fix protocol class when not TYPE_CHECKING
- add id properties to Protocols, remove attributes
  since protocol expects them to be settable but
  NT is read-only
- use id to merge comments
- remove type: ignore's from reddit config
  and just store as 'Any'
This commit is contained in:
Sean Breckenridge 2021-10-28 20:29:50 -07:00
parent 33f7f48ec5
commit 4492e00250
5 changed files with 118 additions and 71 deletions

View file

@ -1,26 +1,52 @@
from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable """
Decorator to gracefully handle importing a data source, or warning
and yielding nothing (or a default) when its not available
"""
from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable, Any
from my.core.warnings import warn from my.core.warnings import warn
from functools import wraps
T = TypeVar("T") # The factory function may produce something that has data
# similar to the shared model, but not exactly, so not
# making this a TypeVar, is just to make reading the
# type signature below a bit easier...
T = Any
# this is probably more generic and results in less code, but is not mypy-friendly # https://mypy.readthedocs.io/en/latest/generics.html?highlight=decorators#decorator-factories
def import_source(factory: Callable[[], Any], default: Any) -> Any: FactoryF = TypeVar("FactoryF", bound=Callable[..., Iterator[T]])
_DEFUALT_ITR = ()
# tried to use decorator module but it really doesn't work well
# with types and kw-arguments... :/
def import_source(
module_name: Optional[str] = None,
default: Optional[Iterable[T]] = _DEFUALT_ITR
) -> Callable[..., Callable[..., Iterator[T]]]:
"""
doesn't really play well with types, but is used to catch
ModuleNotFoundError's for when modules aren't installed in
all.py files, so the types don't particularly matter
this is meant to be used to wrap some function which imports
and then yields an iterator of objects
If the user doesn't have that module installed, it returns
nothing and warns instead
"""
def decorator(factory_func: FactoryF) -> Callable[..., Iterator[T]]:
@wraps(factory_func)
def wrapper(*args, **kwargs) -> Iterator[T]:
try: try:
res = factory() res = factory_func(**kwargs)
return res
except ModuleNotFoundError: # presumable means the user hasn't installed the module
warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly")
return default
# For an example of this, see the reddit.all file
def import_source_iter(factory: Callable[[], Iterator[T]], default: Optional[Iterable[T]] = None) -> Iterator[T]:
if default is None:
default = []
try:
res = factory()
yield from res yield from res
except ModuleNotFoundError: # presumable means the user hasn't installed the module except ModuleNotFoundError:
warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly") # TODO: check if module_name is disabled and don't send warning
warn(f"Module {factory_func.__qualname__} could not be imported, or isn't configured propertly")
yield from default yield from default
return wrapper
return decorator

View file

@ -1,6 +1,6 @@
from typing import Iterator, Any, Callable, TypeVar from typing import Iterator
from my.core.common import Stats from my.core.common import Stats
from my.core.source import import_source_iter as imp from my.core.source import import_source
from .common import Save, Upvote, Comment, Submission, _merge_comments from .common import Save, Upvote, Comment, Submission, _merge_comments
@ -8,22 +8,34 @@ from .common import Save, Upvote, Comment, Submission, _merge_comments
# reddit just feels like that much of a complicated source and # reddit just feels like that much of a complicated source and
# data acquired by different methods isn't the same # data acquired by different methods isn't the same
### import helpers ### 'safe importers' -- falls back to empty data if the module couldn't be found
rexport_src = import_source(module_name="my.reddit.rexport")
# this import error is caught in import_source_iter, if rexport isn't installed pushshift_src = import_source(module_name="my.reddit.pushshift")
def _rexport_import() -> Any:
from . import rexport as source
return source
@rexport_src
def _rexport_comments() -> Iterator[Comment]: def _rexport_comments() -> Iterator[Comment]:
yield from imp(lambda: _rexport_import().comments()) from . import rexport
yield from rexport.comments()
def _pushshift_import() -> Any: @rexport_src
from . import pushshift as source def _rexport_submissions() -> Iterator[Submission]:
return source from . import rexport
yield from rexport.submissions()
@rexport_src
def _rexport_saved() -> Iterator[Save]:
from . import rexport
yield from rexport.saved()
@rexport_src
def _rexport_upvoted() -> Iterator[Upvote]:
from . import rexport
yield from rexport.upvoted()
@pushshift_src
def _pushshift_comments() -> Iterator[Comment]: def _pushshift_comments() -> Iterator[Comment]:
yield from imp(lambda: _pushshift_import().comments()) from .pushshift import comments as pcomments
yield from pcomments()
# Merged functions # Merged functions
@ -33,13 +45,17 @@ def comments() -> Iterator[Comment]:
def submissions() -> Iterator[Submission]: def submissions() -> Iterator[Submission]:
# TODO: merge gdpr here # TODO: merge gdpr here
yield from imp(lambda: _rexport_import().submissions()) yield from _rexport_submissions()
@rexport_src
def saved() -> Iterator[Save]: def saved() -> Iterator[Save]:
yield from imp(lambda: _rexport_import().saved()) from .rexport import saved
yield from saved()
@rexport_src
def upvoted() -> Iterator[Upvote]: def upvoted() -> Iterator[Upvote]:
yield from imp(lambda: _rexport_import().upvoted()) from .rexport import upvoted
yield from upvoted()
def stats() -> Stats: def stats() -> Stats:
from my.core import stat from my.core import stat

View file

@ -1,28 +1,28 @@
""" """
This defines Protocol classes, which make sure that each different This defines Protocol classes, which make sure that each different
type of Comment/Save have a standard interface type of shared models have a standardized interface
""" """
from typing import Dict, Any, Set, Iterator from typing import Dict, Any, Set, Iterator, TYPE_CHECKING
from itertools import chain from itertools import chain
from datetime import datetime
from my.core.common import datetime_aware
Json = Dict[str, Any] Json = Dict[str, Any]
try: if TYPE_CHECKING:
try:
from typing import Protocol from typing import Protocol
except ImportError: except ImportError:
# hmm -- does this need to be installed on 3.6 or is it already here? # requirement of mypy
from typing_extensions import Protocol # type: ignore[misc] from typing_extensions import Protocol # type: ignore[misc]
else:
Protocol = object
# Note: doesn't include GDPR Save's since they don't have the same metadata # Note: doesn't include GDPR Save's since they don't have the same metadata
class Save(Protocol): class Save(Protocol):
created: datetime
title: str
raw: Json
@property @property
def sid(self) -> str: ... def id(self) -> str: ...
@property @property
def url(self) -> str: ... def url(self) -> str: ...
@property @property
@ -32,9 +32,10 @@ class Save(Protocol):
# Note: doesn't include GDPR Upvote's since they don't have the same metadata # Note: doesn't include GDPR Upvote's since they don't have the same metadata
class Upvote(Protocol): class Upvote(Protocol):
raw: Json
@property @property
def created(self) -> datetime: ... def id(self) -> str: ...
@property
def created(self) -> datetime_aware: ...
@property @property
def url(self) -> str: ... def url(self) -> str: ...
@property @property
@ -43,22 +44,24 @@ class Upvote(Protocol):
def title(self) -> str: ... def title(self) -> str: ...
# From rexport, pushshift and the reddit gdpr export # From rexport, pushshift and the reddit GDPR export
class Comment(Protocol): class Comment(Protocol):
raw: Json
@property @property
def created(self) -> datetime: ... def id(self) -> str: ...
@property
def created(self) -> datetime_aware: ...
@property @property
def url(self) -> str: ... def url(self) -> str: ...
@property @property
def text(self) -> str: ... def text(self) -> str: ...
# From rexport and the gdpr export # From rexport and the GDPR export
class Submission(Protocol): class Submission(Protocol):
raw: Json
@property @property
def created(self) -> datetime: ... def id(self) -> str: ...
@property
def created(self) -> datetime_aware: ...
@property @property
def url(self) -> str: ... def url(self) -> str: ...
@property @property
@ -70,14 +73,14 @@ class Submission(Protocol):
def _merge_comments(*sources: Iterator[Comment]) -> Iterator[Comment]: def _merge_comments(*sources: Iterator[Comment]) -> Iterator[Comment]:
#from .rexport import logger #from .rexport import logger
#ignored = 0 #ignored = 0
emitted: Set[int] = set() emitted: Set[str] = set()
for e in chain(*sources): for e in chain(*sources):
key = int(e.raw["created_utc"]) uid = e.id
if key in emitted: if uid in emitted:
#ignored += 1 #ignored += 1
#logger.info('ignoring %s: %s', key, e) #logger.info('ignoring %s: %s', uid, e)
continue continue
yield e yield e
emitted.add(key) emitted.add(uid)
#logger.info(f"Ignored {ignored} comments...") #logger.info(f"Ignored {ignored} comments...")

View file

@ -23,7 +23,6 @@ config = make_config(pushshift_config)
from my.core import get_files from my.core import get_files
from typing import Sequence, Iterator from typing import Sequence, Iterator
from pathlib import Path from pathlib import Path
from .common import Comment
from pushshift_comment_export.dal import read_file, PComment from pushshift_comment_export.dal import read_file, PComment

View file

@ -7,6 +7,7 @@ REQUIRES = [
from my.core.common import Paths from my.core.common import Paths
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from my.config import reddit as uconfig from my.config import reddit as uconfig
@ -18,21 +19,23 @@ from my.config import reddit as uconfig
# in the migration # in the migration
# need to check before we subclass # need to check before we subclass
conf: Any
if hasattr(uconfig, "rexport"): if hasattr(uconfig, "rexport"):
# sigh... backwards compatability conf = uconfig.rexport
uconfig = uconfig.rexport # type: ignore[attr-defined,misc,assignment]
else: else:
from my.core.warnings import high from my.core.warnings import high
high(f"""DEPRECATED! Please modify your reddit config to look like: high("""DEPRECATED! Please modify your reddit config to look like:
class reddit: class reddit:
class rexport: class rexport:
export_path: Paths = '/path/to/rexport/data' export_path: Paths = '/path/to/rexport/data'
""") """)
conf = uconfig
@dataclass @dataclass
class reddit(uconfig): class reddit(conf):
''' '''
Uses [[https://github.com/karlicoss/rexport][rexport]] output. Uses [[https://github.com/karlicoss/rexport][rexport]] output.
''' '''
@ -81,7 +84,7 @@ def inputs() -> Sequence[Path]:
return get_files(config.export_path) return get_files(config.export_path)
Sid = dal.Sid # str Uid = dal.Sid # str
Save = dal.Save Save = dal.Save
Comment = dal.Comment Comment = dal.Comment
Submission = dal.Submission Submission = dal.Submission
@ -161,7 +164,7 @@ def _get_bdate(bfile: Path) -> datetime:
return bdt return bdt
def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]: def _get_state(bfile: Path) -> Dict[Uid, SaveWithDt]:
logger.debug('handling %s', bfile) logger.debug('handling %s', bfile)
bdt = _get_bdate(bfile) bdt = _get_bdate(bfile)
@ -178,11 +181,11 @@ def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]:
def _get_events(backups: Sequence[Path], parallel: bool=True) -> Iterator[Event]: def _get_events(backups: Sequence[Path], parallel: bool=True) -> Iterator[Event]:
# todo cachew: let it transform return type? so you don't have to write a wrapper for lists? # todo cachew: let it transform return type? so you don't have to write a wrapper for lists?
prev_saves: Mapping[Sid, SaveWithDt] = {} prev_saves: Mapping[Uid, SaveWithDt] = {}
# TODO suppress first batch?? # TODO suppress first batch??
# TODO for initial batch, treat event time as creation time # TODO for initial batch, treat event time as creation time
states: Iterable[Mapping[Sid, SaveWithDt]] states: Iterable[Mapping[Uid, SaveWithDt]]
if parallel: if parallel:
with Pool() as p: with Pool() as p:
states = p.map(_get_state, backups) states = p.map(_get_state, backups)