misc fixes

- convert import_source to a decorator which
  wraps the function call in a try block
- fix protocol class when not TYPE_CHECKING
- add id properties to Protocols, remove attributes
  since protocol expects them to be settable but
  NT is read-only
- use id to merge comments
- remove type: ignore's from reddit config
  and just store as 'Any'
This commit is contained in:
Sean Breckenridge 2021-10-28 20:29:50 -07:00
parent 33f7f48ec5
commit 4492e00250
5 changed files with 118 additions and 71 deletions

View file

@ -1,26 +1,52 @@
from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable
"""
Decorator to gracefully handle importing a data source, or warning
and yielding nothing (or a default) when its not available
"""
from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable, Any
from my.core.warnings import warn
from functools import wraps
T = TypeVar("T")
# The factory function may produce something that has data
# similar to the shared model, but not exactly, so not
# making this a TypeVar, is just to make reading the
# type signature below a bit easier...
T = Any
# this is probably more generic and results in less code, but is not mypy-friendly
def import_source(factory: Callable[[], Any], default: Any) -> Any:
try:
res = factory()
return res
except ModuleNotFoundError: # presumable means the user hasn't installed the module
warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly")
return default
# https://mypy.readthedocs.io/en/latest/generics.html?highlight=decorators#decorator-factories
FactoryF = TypeVar("FactoryF", bound=Callable[..., Iterator[T]])
_DEFUALT_ITR = ()
# For an example of this, see the reddit.all file
def import_source_iter(factory: Callable[[], Iterator[T]], default: Optional[Iterable[T]] = None) -> Iterator[T]:
if default is None:
default = []
try:
res = factory()
yield from res
except ModuleNotFoundError: # presumable means the user hasn't installed the module
warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly")
yield from default
# tried to use decorator module but it really doesn't work well
# with types and kw-arguments... :/
def import_source(
module_name: Optional[str] = None,
default: Optional[Iterable[T]] = _DEFUALT_ITR
) -> Callable[..., Callable[..., Iterator[T]]]:
"""
doesn't really play well with types, but is used to catch
ModuleNotFoundError's for when modules aren't installed in
all.py files, so the types don't particularly matter
this is meant to be used to wrap some function which imports
and then yields an iterator of objects
If the user doesn't have that module installed, it returns
nothing and warns instead
"""
def decorator(factory_func: FactoryF) -> Callable[..., Iterator[T]]:
@wraps(factory_func)
def wrapper(*args, **kwargs) -> Iterator[T]:
try:
res = factory_func(**kwargs)
yield from res
except ModuleNotFoundError:
# TODO: check if module_name is disabled and don't send warning
warn(f"Module {factory_func.__qualname__} could not be imported, or isn't configured propertly")
yield from default
return wrapper
return decorator

View file

@ -1,6 +1,6 @@
from typing import Iterator, Any, Callable, TypeVar
from typing import Iterator
from my.core.common import Stats
from my.core.source import import_source_iter as imp
from my.core.source import import_source
from .common import Save, Upvote, Comment, Submission, _merge_comments
@ -8,22 +8,34 @@ from .common import Save, Upvote, Comment, Submission, _merge_comments
# reddit just feels like that much of a complicated source and
# data acquired by different methods isn't the same
### import helpers
# this import error is caught in import_source_iter, if rexport isn't installed
def _rexport_import() -> Any:
from . import rexport as source
return source
### 'safe importers' -- falls back to empty data if the module couldn't be found
rexport_src = import_source(module_name="my.reddit.rexport")
pushshift_src = import_source(module_name="my.reddit.pushshift")
@rexport_src
def _rexport_comments() -> Iterator[Comment]:
yield from imp(lambda: _rexport_import().comments())
from . import rexport
yield from rexport.comments()
def _pushshift_import() -> Any:
from . import pushshift as source
return source
@rexport_src
def _rexport_submissions() -> Iterator[Submission]:
from . import rexport
yield from rexport.submissions()
@rexport_src
def _rexport_saved() -> Iterator[Save]:
from . import rexport
yield from rexport.saved()
@rexport_src
def _rexport_upvoted() -> Iterator[Upvote]:
from . import rexport
yield from rexport.upvoted()
@pushshift_src
def _pushshift_comments() -> Iterator[Comment]:
yield from imp(lambda: _pushshift_import().comments())
from .pushshift import comments as pcomments
yield from pcomments()
# Merged functions
@ -33,13 +45,17 @@ def comments() -> Iterator[Comment]:
def submissions() -> Iterator[Submission]:
# TODO: merge gdpr here
yield from imp(lambda: _rexport_import().submissions())
yield from _rexport_submissions()
@rexport_src
def saved() -> Iterator[Save]:
yield from imp(lambda: _rexport_import().saved())
from .rexport import saved
yield from saved()
@rexport_src
def upvoted() -> Iterator[Upvote]:
yield from imp(lambda: _rexport_import().upvoted())
from .rexport import upvoted
yield from upvoted()
def stats() -> Stats:
from my.core import stat

View file

@ -1,28 +1,28 @@
"""
This defines Protocol classes, which make sure that each different
type of Comment/Save have a standard interface
type of shared models have a standardized interface
"""
from typing import Dict, Any, Set, Iterator
from typing import Dict, Any, Set, Iterator, TYPE_CHECKING
from itertools import chain
from datetime import datetime
from my.core.common import datetime_aware
Json = Dict[str, Any]
try:
from typing import Protocol
except ImportError:
# hmm -- does this need to be installed on 3.6 or is it already here?
from typing_extensions import Protocol # type: ignore[misc]
if TYPE_CHECKING:
try:
from typing import Protocol
except ImportError:
# requirement of mypy
from typing_extensions import Protocol # type: ignore[misc]
else:
Protocol = object
# Note: doesn't include GDPR Save's since they don't have the same metadata
class Save(Protocol):
created: datetime
title: str
raw: Json
@property
def sid(self) -> str: ...
def id(self) -> str: ...
@property
def url(self) -> str: ...
@property
@ -32,9 +32,10 @@ class Save(Protocol):
# Note: doesn't include GDPR Upvote's since they don't have the same metadata
class Upvote(Protocol):
raw: Json
@property
def created(self) -> datetime: ...
def id(self) -> str: ...
@property
def created(self) -> datetime_aware: ...
@property
def url(self) -> str: ...
@property
@ -43,22 +44,24 @@ class Upvote(Protocol):
def title(self) -> str: ...
# From rexport, pushshift and the reddit gdpr export
# From rexport, pushshift and the reddit GDPR export
class Comment(Protocol):
raw: Json
@property
def created(self) -> datetime: ...
def id(self) -> str: ...
@property
def created(self) -> datetime_aware: ...
@property
def url(self) -> str: ...
@property
def text(self) -> str: ...
# From rexport and the gdpr export
# From rexport and the GDPR export
class Submission(Protocol):
raw: Json
@property
def created(self) -> datetime: ...
def id(self) -> str: ...
@property
def created(self) -> datetime_aware: ...
@property
def url(self) -> str: ...
@property
@ -70,14 +73,14 @@ class Submission(Protocol):
def _merge_comments(*sources: Iterator[Comment]) -> Iterator[Comment]:
#from .rexport import logger
#ignored = 0
emitted: Set[int] = set()
emitted: Set[str] = set()
for e in chain(*sources):
key = int(e.raw["created_utc"])
if key in emitted:
uid = e.id
if uid in emitted:
#ignored += 1
#logger.info('ignoring %s: %s', key, e)
#logger.info('ignoring %s: %s', uid, e)
continue
yield e
emitted.add(key)
emitted.add(uid)
#logger.info(f"Ignored {ignored} comments...")

View file

@ -23,7 +23,6 @@ config = make_config(pushshift_config)
from my.core import get_files
from typing import Sequence, Iterator
from pathlib import Path
from .common import Comment
from pushshift_comment_export.dal import read_file, PComment

View file

@ -7,6 +7,7 @@ REQUIRES = [
from my.core.common import Paths
from dataclasses import dataclass
from typing import Any
from my.config import reddit as uconfig
@ -18,21 +19,23 @@ from my.config import reddit as uconfig
# in the migration
# need to check before we subclass
conf: Any
if hasattr(uconfig, "rexport"):
# sigh... backwards compatability
uconfig = uconfig.rexport # type: ignore[attr-defined,misc,assignment]
conf = uconfig.rexport
else:
from my.core.warnings import high
high(f"""DEPRECATED! Please modify your reddit config to look like:
high("""DEPRECATED! Please modify your reddit config to look like:
class reddit:
class rexport:
export_path: Paths = '/path/to/rexport/data'
""")
conf = uconfig
@dataclass
class reddit(uconfig):
class reddit(conf):
'''
Uses [[https://github.com/karlicoss/rexport][rexport]] output.
'''
@ -81,7 +84,7 @@ def inputs() -> Sequence[Path]:
return get_files(config.export_path)
Sid = dal.Sid # str
Uid = dal.Sid # str
Save = dal.Save
Comment = dal.Comment
Submission = dal.Submission
@ -161,7 +164,7 @@ def _get_bdate(bfile: Path) -> datetime:
return bdt
def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]:
def _get_state(bfile: Path) -> Dict[Uid, SaveWithDt]:
logger.debug('handling %s', bfile)
bdt = _get_bdate(bfile)
@ -178,11 +181,11 @@ def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]:
def _get_events(backups: Sequence[Path], parallel: bool=True) -> Iterator[Event]:
# todo cachew: let it transform return type? so you don't have to write a wrapper for lists?
prev_saves: Mapping[Sid, SaveWithDt] = {}
prev_saves: Mapping[Uid, SaveWithDt] = {}
# TODO suppress first batch??
# TODO for initial batch, treat event time as creation time
states: Iterable[Mapping[Sid, SaveWithDt]]
states: Iterable[Mapping[Uid, SaveWithDt]]
if parallel:
with Pool() as p:
states = p.map(_get_state, backups)