misc fixes
- convert import_source to a decorator which wraps the function call in a try block - fix protocol class when not TYPE_CHECKING - add id properties to Protocols, remove attributes since protocol expects them to be settable but NT is read-only - use id to merge comments - remove type: ignore's from reddit config and just store as 'Any'
This commit is contained in:
parent
33f7f48ec5
commit
4492e00250
5 changed files with 118 additions and 71 deletions
|
@ -1,26 +1,52 @@
|
||||||
from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable
|
"""
|
||||||
|
Decorator to gracefully handle importing a data source, or warning
|
||||||
|
and yielding nothing (or a default) when its not available
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Iterator, TypeVar, Callable, Optional, Iterable, Any
|
||||||
from my.core.warnings import warn
|
from my.core.warnings import warn
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
T = TypeVar("T")
|
# The factory function may produce something that has data
|
||||||
|
# similar to the shared model, but not exactly, so not
|
||||||
|
# making this a TypeVar, is just to make reading the
|
||||||
|
# type signature below a bit easier...
|
||||||
|
T = Any
|
||||||
|
|
||||||
# this is probably more generic and results in less code, but is not mypy-friendly
|
# https://mypy.readthedocs.io/en/latest/generics.html?highlight=decorators#decorator-factories
|
||||||
def import_source(factory: Callable[[], Any], default: Any) -> Any:
|
FactoryF = TypeVar("FactoryF", bound=Callable[..., Iterator[T]])
|
||||||
try:
|
|
||||||
res = factory()
|
_DEFUALT_ITR = ()
|
||||||
return res
|
|
||||||
except ModuleNotFoundError: # presumable means the user hasn't installed the module
|
|
||||||
warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly")
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
# For an example of this, see the reddit.all file
|
# tried to use decorator module but it really doesn't work well
|
||||||
def import_source_iter(factory: Callable[[], Iterator[T]], default: Optional[Iterable[T]] = None) -> Iterator[T]:
|
# with types and kw-arguments... :/
|
||||||
if default is None:
|
def import_source(
|
||||||
default = []
|
module_name: Optional[str] = None,
|
||||||
try:
|
default: Optional[Iterable[T]] = _DEFUALT_ITR
|
||||||
res = factory()
|
) -> Callable[..., Callable[..., Iterator[T]]]:
|
||||||
yield from res
|
"""
|
||||||
except ModuleNotFoundError: # presumable means the user hasn't installed the module
|
doesn't really play well with types, but is used to catch
|
||||||
warn(f"Module {factory.__qualname__} could not be imported, or isn't configured propertly")
|
ModuleNotFoundError's for when modules aren't installed in
|
||||||
yield from default
|
all.py files, so the types don't particularly matter
|
||||||
|
|
||||||
|
this is meant to be used to wrap some function which imports
|
||||||
|
and then yields an iterator of objects
|
||||||
|
|
||||||
|
If the user doesn't have that module installed, it returns
|
||||||
|
nothing and warns instead
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(factory_func: FactoryF) -> Callable[..., Iterator[T]]:
|
||||||
|
@wraps(factory_func)
|
||||||
|
def wrapper(*args, **kwargs) -> Iterator[T]:
|
||||||
|
try:
|
||||||
|
res = factory_func(**kwargs)
|
||||||
|
yield from res
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# TODO: check if module_name is disabled and don't send warning
|
||||||
|
warn(f"Module {factory_func.__qualname__} could not be imported, or isn't configured propertly")
|
||||||
|
yield from default
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Iterator, Any, Callable, TypeVar
|
from typing import Iterator
|
||||||
from my.core.common import Stats
|
from my.core.common import Stats
|
||||||
from my.core.source import import_source_iter as imp
|
from my.core.source import import_source
|
||||||
|
|
||||||
from .common import Save, Upvote, Comment, Submission, _merge_comments
|
from .common import Save, Upvote, Comment, Submission, _merge_comments
|
||||||
|
|
||||||
|
@ -8,22 +8,34 @@ from .common import Save, Upvote, Comment, Submission, _merge_comments
|
||||||
# reddit just feels like that much of a complicated source and
|
# reddit just feels like that much of a complicated source and
|
||||||
# data acquired by different methods isn't the same
|
# data acquired by different methods isn't the same
|
||||||
|
|
||||||
### import helpers
|
### 'safe importers' -- falls back to empty data if the module couldn't be found
|
||||||
|
rexport_src = import_source(module_name="my.reddit.rexport")
|
||||||
# this import error is caught in import_source_iter, if rexport isn't installed
|
pushshift_src = import_source(module_name="my.reddit.pushshift")
|
||||||
def _rexport_import() -> Any:
|
|
||||||
from . import rexport as source
|
|
||||||
return source
|
|
||||||
|
|
||||||
|
@rexport_src
|
||||||
def _rexport_comments() -> Iterator[Comment]:
|
def _rexport_comments() -> Iterator[Comment]:
|
||||||
yield from imp(lambda: _rexport_import().comments())
|
from . import rexport
|
||||||
|
yield from rexport.comments()
|
||||||
|
|
||||||
def _pushshift_import() -> Any:
|
@rexport_src
|
||||||
from . import pushshift as source
|
def _rexport_submissions() -> Iterator[Submission]:
|
||||||
return source
|
from . import rexport
|
||||||
|
yield from rexport.submissions()
|
||||||
|
|
||||||
|
@rexport_src
|
||||||
|
def _rexport_saved() -> Iterator[Save]:
|
||||||
|
from . import rexport
|
||||||
|
yield from rexport.saved()
|
||||||
|
|
||||||
|
@rexport_src
|
||||||
|
def _rexport_upvoted() -> Iterator[Upvote]:
|
||||||
|
from . import rexport
|
||||||
|
yield from rexport.upvoted()
|
||||||
|
|
||||||
|
@pushshift_src
|
||||||
def _pushshift_comments() -> Iterator[Comment]:
|
def _pushshift_comments() -> Iterator[Comment]:
|
||||||
yield from imp(lambda: _pushshift_import().comments())
|
from .pushshift import comments as pcomments
|
||||||
|
yield from pcomments()
|
||||||
|
|
||||||
# Merged functions
|
# Merged functions
|
||||||
|
|
||||||
|
@ -33,13 +45,17 @@ def comments() -> Iterator[Comment]:
|
||||||
|
|
||||||
def submissions() -> Iterator[Submission]:
|
def submissions() -> Iterator[Submission]:
|
||||||
# TODO: merge gdpr here
|
# TODO: merge gdpr here
|
||||||
yield from imp(lambda: _rexport_import().submissions())
|
yield from _rexport_submissions()
|
||||||
|
|
||||||
|
@rexport_src
|
||||||
def saved() -> Iterator[Save]:
|
def saved() -> Iterator[Save]:
|
||||||
yield from imp(lambda: _rexport_import().saved())
|
from .rexport import saved
|
||||||
|
yield from saved()
|
||||||
|
|
||||||
|
@rexport_src
|
||||||
def upvoted() -> Iterator[Upvote]:
|
def upvoted() -> Iterator[Upvote]:
|
||||||
yield from imp(lambda: _rexport_import().upvoted())
|
from .rexport import upvoted
|
||||||
|
yield from upvoted()
|
||||||
|
|
||||||
def stats() -> Stats:
|
def stats() -> Stats:
|
||||||
from my.core import stat
|
from my.core import stat
|
||||||
|
|
|
@ -1,28 +1,28 @@
|
||||||
"""
|
"""
|
||||||
This defines Protocol classes, which make sure that each different
|
This defines Protocol classes, which make sure that each different
|
||||||
type of Comment/Save have a standard interface
|
type of shared models have a standardized interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any, Set, Iterator
|
from typing import Dict, Any, Set, Iterator, TYPE_CHECKING
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from datetime import datetime
|
|
||||||
|
from my.core.common import datetime_aware
|
||||||
|
|
||||||
Json = Dict[str, Any]
|
Json = Dict[str, Any]
|
||||||
|
|
||||||
try:
|
if TYPE_CHECKING:
|
||||||
from typing import Protocol
|
try:
|
||||||
except ImportError:
|
from typing import Protocol
|
||||||
# hmm -- does this need to be installed on 3.6 or is it already here?
|
except ImportError:
|
||||||
from typing_extensions import Protocol # type: ignore[misc]
|
# requirement of mypy
|
||||||
|
from typing_extensions import Protocol # type: ignore[misc]
|
||||||
|
else:
|
||||||
|
Protocol = object
|
||||||
|
|
||||||
# Note: doesn't include GDPR Save's since they don't have the same metadata
|
# Note: doesn't include GDPR Save's since they don't have the same metadata
|
||||||
class Save(Protocol):
|
class Save(Protocol):
|
||||||
created: datetime
|
|
||||||
title: str
|
|
||||||
raw: Json
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sid(self) -> str: ...
|
def id(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
def url(self) -> str: ...
|
def url(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
|
@ -32,9 +32,10 @@ class Save(Protocol):
|
||||||
|
|
||||||
# Note: doesn't include GDPR Upvote's since they don't have the same metadata
|
# Note: doesn't include GDPR Upvote's since they don't have the same metadata
|
||||||
class Upvote(Protocol):
|
class Upvote(Protocol):
|
||||||
raw: Json
|
|
||||||
@property
|
@property
|
||||||
def created(self) -> datetime: ...
|
def id(self) -> str: ...
|
||||||
|
@property
|
||||||
|
def created(self) -> datetime_aware: ...
|
||||||
@property
|
@property
|
||||||
def url(self) -> str: ...
|
def url(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
|
@ -43,22 +44,24 @@ class Upvote(Protocol):
|
||||||
def title(self) -> str: ...
|
def title(self) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
# From rexport, pushshift and the reddit gdpr export
|
# From rexport, pushshift and the reddit GDPR export
|
||||||
class Comment(Protocol):
|
class Comment(Protocol):
|
||||||
raw: Json
|
|
||||||
@property
|
@property
|
||||||
def created(self) -> datetime: ...
|
def id(self) -> str: ...
|
||||||
|
@property
|
||||||
|
def created(self) -> datetime_aware: ...
|
||||||
@property
|
@property
|
||||||
def url(self) -> str: ...
|
def url(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
def text(self) -> str: ...
|
def text(self) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
# From rexport and the gdpr export
|
# From rexport and the GDPR export
|
||||||
class Submission(Protocol):
|
class Submission(Protocol):
|
||||||
raw: Json
|
|
||||||
@property
|
@property
|
||||||
def created(self) -> datetime: ...
|
def id(self) -> str: ...
|
||||||
|
@property
|
||||||
|
def created(self) -> datetime_aware: ...
|
||||||
@property
|
@property
|
||||||
def url(self) -> str: ...
|
def url(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
|
@ -70,14 +73,14 @@ class Submission(Protocol):
|
||||||
def _merge_comments(*sources: Iterator[Comment]) -> Iterator[Comment]:
|
def _merge_comments(*sources: Iterator[Comment]) -> Iterator[Comment]:
|
||||||
#from .rexport import logger
|
#from .rexport import logger
|
||||||
#ignored = 0
|
#ignored = 0
|
||||||
emitted: Set[int] = set()
|
emitted: Set[str] = set()
|
||||||
for e in chain(*sources):
|
for e in chain(*sources):
|
||||||
key = int(e.raw["created_utc"])
|
uid = e.id
|
||||||
if key in emitted:
|
if uid in emitted:
|
||||||
#ignored += 1
|
#ignored += 1
|
||||||
#logger.info('ignoring %s: %s', key, e)
|
#logger.info('ignoring %s: %s', uid, e)
|
||||||
continue
|
continue
|
||||||
yield e
|
yield e
|
||||||
emitted.add(key)
|
emitted.add(uid)
|
||||||
#logger.info(f"Ignored {ignored} comments...")
|
#logger.info(f"Ignored {ignored} comments...")
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ config = make_config(pushshift_config)
|
||||||
from my.core import get_files
|
from my.core import get_files
|
||||||
from typing import Sequence, Iterator
|
from typing import Sequence, Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .common import Comment
|
|
||||||
|
|
||||||
from pushshift_comment_export.dal import read_file, PComment
|
from pushshift_comment_export.dal import read_file, PComment
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ REQUIRES = [
|
||||||
|
|
||||||
from my.core.common import Paths
|
from my.core.common import Paths
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from my.config import reddit as uconfig
|
from my.config import reddit as uconfig
|
||||||
|
|
||||||
|
@ -18,21 +19,23 @@ from my.config import reddit as uconfig
|
||||||
# in the migration
|
# in the migration
|
||||||
|
|
||||||
# need to check before we subclass
|
# need to check before we subclass
|
||||||
|
conf: Any
|
||||||
|
|
||||||
if hasattr(uconfig, "rexport"):
|
if hasattr(uconfig, "rexport"):
|
||||||
# sigh... backwards compatability
|
conf = uconfig.rexport
|
||||||
uconfig = uconfig.rexport # type: ignore[attr-defined,misc,assignment]
|
|
||||||
else:
|
else:
|
||||||
from my.core.warnings import high
|
from my.core.warnings import high
|
||||||
high(f"""DEPRECATED! Please modify your reddit config to look like:
|
high("""DEPRECATED! Please modify your reddit config to look like:
|
||||||
|
|
||||||
class reddit:
|
class reddit:
|
||||||
class rexport:
|
class rexport:
|
||||||
export_path: Paths = '/path/to/rexport/data'
|
export_path: Paths = '/path/to/rexport/data'
|
||||||
""")
|
""")
|
||||||
|
conf = uconfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class reddit(uconfig):
|
class reddit(conf):
|
||||||
'''
|
'''
|
||||||
Uses [[https://github.com/karlicoss/rexport][rexport]] output.
|
Uses [[https://github.com/karlicoss/rexport][rexport]] output.
|
||||||
'''
|
'''
|
||||||
|
@ -81,7 +84,7 @@ def inputs() -> Sequence[Path]:
|
||||||
return get_files(config.export_path)
|
return get_files(config.export_path)
|
||||||
|
|
||||||
|
|
||||||
Sid = dal.Sid # str
|
Uid = dal.Sid # str
|
||||||
Save = dal.Save
|
Save = dal.Save
|
||||||
Comment = dal.Comment
|
Comment = dal.Comment
|
||||||
Submission = dal.Submission
|
Submission = dal.Submission
|
||||||
|
@ -161,7 +164,7 @@ def _get_bdate(bfile: Path) -> datetime:
|
||||||
return bdt
|
return bdt
|
||||||
|
|
||||||
|
|
||||||
def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]:
|
def _get_state(bfile: Path) -> Dict[Uid, SaveWithDt]:
|
||||||
logger.debug('handling %s', bfile)
|
logger.debug('handling %s', bfile)
|
||||||
|
|
||||||
bdt = _get_bdate(bfile)
|
bdt = _get_bdate(bfile)
|
||||||
|
@ -178,11 +181,11 @@ def _get_state(bfile: Path) -> Dict[Sid, SaveWithDt]:
|
||||||
def _get_events(backups: Sequence[Path], parallel: bool=True) -> Iterator[Event]:
|
def _get_events(backups: Sequence[Path], parallel: bool=True) -> Iterator[Event]:
|
||||||
# todo cachew: let it transform return type? so you don't have to write a wrapper for lists?
|
# todo cachew: let it transform return type? so you don't have to write a wrapper for lists?
|
||||||
|
|
||||||
prev_saves: Mapping[Sid, SaveWithDt] = {}
|
prev_saves: Mapping[Uid, SaveWithDt] = {}
|
||||||
# TODO suppress first batch??
|
# TODO suppress first batch??
|
||||||
# TODO for initial batch, treat event time as creation time
|
# TODO for initial batch, treat event time as creation time
|
||||||
|
|
||||||
states: Iterable[Mapping[Sid, SaveWithDt]]
|
states: Iterable[Mapping[Uid, SaveWithDt]]
|
||||||
if parallel:
|
if parallel:
|
||||||
with Pool() as p:
|
with Pool() as p:
|
||||||
states = p.map(_get_state, backups)
|
states = p.map(_get_state, backups)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue