From d464b1e607d24aba247fe961293a534ed29e4648 Mon Sep 17 00:00:00 2001 From: Dima Gerasimov Date: Mon, 3 Apr 2023 22:30:58 +0100 Subject: [PATCH] core: implement more methods for ZipPath and better support for get_files --- my/core/common.py | 13 +++++++++++-- my/core/kompress.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/my/core/common.py b/my/core/common.py index 7adfd7a..090c564 100644 --- a/my/core/common.py +++ b/my/core/common.py @@ -161,6 +161,11 @@ from .logging import setup_logger, LazyLogger Paths = Union[Sequence[PathIsh], PathIsh] +def _is_zippath(p: Path) -> bool: + # weak type check here, don't want to depend on .kompress module in get_files + return type(p).__name__ == 'ZipPath' + + DEFAULT_GLOB = '*' def get_files( pp: Paths, @@ -183,7 +188,7 @@ def get_files( return () # early return to prevent warnings etc sources = [Path(pp)] else: - sources = [Path(p) for p in pp] + sources = [p if isinstance(p, Path) else Path(p) for p in pp] def caller() -> str: import traceback @@ -192,6 +197,10 @@ def get_files( paths: List[Path] = [] for src in sources: + if _is_zippath(src): + paths.append(src) + continue + if src.parts[0] == '~': src = src.expanduser() # note: glob handled first, because e.g. on Windows asterisk makes is_dir unhappy @@ -226,7 +235,7 @@ def get_files( if guess_compression: from .kompress import CPath, is_compressed - paths = [CPath(p) if is_compressed(p) else p for p in paths] + paths = [CPath(p) if is_compressed(p) and not _is_zippath(p) else p for p in paths] return tuple(paths) diff --git a/my/core/kompress.py b/my/core/kompress.py index a44b9d1..5ba32d3 100644 --- a/my/core/kompress.py +++ b/my/core/kompress.py @@ -3,6 +3,7 @@ Various helpers for compression """ from __future__ import annotations +from functools import total_ordering from datetime import datetime import pathlib from pathlib import Path @@ -155,6 +156,7 @@ else: zipfile_Path = object +@total_ordering class ZipPath(zipfile_Path): # NOTE: is_dir/is_file might not behave as expected, the base class checks it only based on the slash in path @@ -175,6 +177,9 @@ class ZipPath(zipfile_Path): def absolute(self) -> ZipPath: return ZipPath(self.filepath.absolute(), self.at) + def expanduser(self) -> ZipPath: + return ZipPath(self.filepath.expanduser(), self.at) + def exists(self) -> bool: if self.at == '': # special case, the base class returns False in this case for some reason @@ -224,6 +229,11 @@ class ZipPath(zipfile_Path): return False return (self.filepath, self.subpath) == (other.filepath, other.subpath) + def __lt__(self, other) -> bool: + if not isinstance(other, ZipPath): + return False + return (self.filepath, self.subpath) < (other.filepath, other.subpath) + def __hash__(self) -> int: return hash((self.filepath, self.subpath))