standardize the load and store methods in journals

This commit is contained in:
Jonathan Wren 2022-09-30 17:25:14 -07:00
parent c6a564e722
commit 07d1db6b00
5 changed files with 14 additions and 30 deletions

View file

@ -153,11 +153,12 @@ class Journal:
return "\n".join([str(e) for e in self.entries])
def _load(self, filename):
raise NotImplementedError
with open(filename, "rb") as f:
return f.read()
@classmethod
def _store(filename, text):
raise NotImplementedError
def _store(self, filename, text):
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
def _parse(self, journal_txt):
"""Parses a journal that's stored in a string and returns a list of entries"""
@ -379,13 +380,7 @@ class Journal:
class PlainJournal(Journal):
def _load(self, filename):
with open(filename, "r", encoding="utf-8") as f:
return f.read()
def _store(self, filename, text):
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
pass
class LegacyJournal(Journal):
@ -393,10 +388,6 @@ class LegacyJournal(Journal):
standard. Main difference here is that in 1.x, timestamps were not cuddled
by square brackets. You'll not be able to save these journals anymore."""
def _load(self, filename):
with open(filename, "rb") as f:
return f.read()
def _parse(self, journal_txt):
"""Parses a journal that's stored in a string and returns a list of entries"""
# Entries start with a line that looks like 'date title' - let's figure out how

View file

@ -21,7 +21,7 @@ class BaseEncryption(ABC):
def encrypt(self, text: str) -> str:
return self._encrypt(text)
def decrypt(self, text: str) -> str:
def decrypt(self, text: bytes) -> str:
if (result := self._decrypt(text)) is None:
raise JrnlException(
Message(MsgText.DecryptionFailedGeneric, MsgStyle.ERROR)
@ -39,7 +39,7 @@ class BaseEncryption(ABC):
pass
@abstractmethod
def _decrypt(self, text: bytes | str) -> str | None:
def _decrypt(self, text: bytes) -> str | None:
"""
This is needed because self.decrypt might need
to perform actions (e.g. prompt for password)

View file

@ -30,12 +30,12 @@ class BasePasswordEncryption(BaseEncryption):
def password(self, value):
self._password = value
def encrypt(self, text):
def encrypt(self, text: str) -> str:
if not self.password:
self.password = create_password(self._journal_name)
return self._encrypt(text)
def decrypt(self, text: str) -> str:
def decrypt(self, text: bytes) -> str:
if not self.password:
self._prompt_password()
while (result := self._decrypt(text)) is None:

View file

@ -47,12 +47,8 @@ class Jrnlv2Encryption(BasePasswordEncryption):
.decode(self._encoding)
)
def _decrypt(self, text: str) -> str | None:
def _decrypt(self, text: bytes) -> str | None:
try:
return (
Fernet(self._key)
.decrypt(text.encode(self._encoding))
.decode(self._encoding)
)
return Fernet(self._key).decrypt(text).decode(self._encoding)
except (InvalidToken, IndexError):
return None

View file

@ -10,8 +10,5 @@ class NoEncryption(BaseEncryption):
def _encrypt(self, text: str) -> str:
return text
def _decrypt(self, text: bytes | str) -> str:
result = text
if isinstance(result, bytes):
result = result.decode(self._encoding)
return result
def _decrypt(self, text: bytes) -> str:
return text.decode(self._encoding)