standardize the load and store methods in journals

This commit is contained in:
Jonathan Wren 2022-09-30 17:25:14 -07:00
parent c6a564e722
commit 07d1db6b00
5 changed files with 14 additions and 30 deletions

View file

@ -153,11 +153,12 @@ class Journal:
return "\n".join([str(e) for e in self.entries]) return "\n".join([str(e) for e in self.entries])
def _load(self, filename): def _load(self, filename):
raise NotImplementedError with open(filename, "rb") as f:
return f.read()
@classmethod def _store(self, filename, text):
def _store(filename, text): with open(filename, "w", encoding="utf-8") as f:
raise NotImplementedError f.write(text)
def _parse(self, journal_txt): def _parse(self, journal_txt):
"""Parses a journal that's stored in a string and returns a list of entries""" """Parses a journal that's stored in a string and returns a list of entries"""
@ -379,13 +380,7 @@ class Journal:
class PlainJournal(Journal): class PlainJournal(Journal):
def _load(self, filename): pass
with open(filename, "r", encoding="utf-8") as f:
return f.read()
def _store(self, filename, text):
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
class LegacyJournal(Journal): class LegacyJournal(Journal):
@ -393,10 +388,6 @@ class LegacyJournal(Journal):
standard. Main difference here is that in 1.x, timestamps were not cuddled standard. Main difference here is that in 1.x, timestamps were not cuddled
by square brackets. You'll not be able to save these journals anymore.""" by square brackets. You'll not be able to save these journals anymore."""
def _load(self, filename):
with open(filename, "rb") as f:
return f.read()
def _parse(self, journal_txt): def _parse(self, journal_txt):
"""Parses a journal that's stored in a string and returns a list of entries""" """Parses a journal that's stored in a string and returns a list of entries"""
# Entries start with a line that looks like 'date title' - let's figure out how # Entries start with a line that looks like 'date title' - let's figure out how

View file

@ -21,7 +21,7 @@ class BaseEncryption(ABC):
def encrypt(self, text: str) -> str: def encrypt(self, text: str) -> str:
return self._encrypt(text) return self._encrypt(text)
def decrypt(self, text: str) -> str: def decrypt(self, text: bytes) -> str:
if (result := self._decrypt(text)) is None: if (result := self._decrypt(text)) is None:
raise JrnlException( raise JrnlException(
Message(MsgText.DecryptionFailedGeneric, MsgStyle.ERROR) Message(MsgText.DecryptionFailedGeneric, MsgStyle.ERROR)
@ -39,7 +39,7 @@ class BaseEncryption(ABC):
pass pass
@abstractmethod @abstractmethod
def _decrypt(self, text: bytes | str) -> str | None: def _decrypt(self, text: bytes) -> str | None:
""" """
This is needed because self.decrypt might need This is needed because self.decrypt might need
to perform actions (e.g. prompt for password) to perform actions (e.g. prompt for password)

View file

@ -30,12 +30,12 @@ class BasePasswordEncryption(BaseEncryption):
def password(self, value): def password(self, value):
self._password = value self._password = value
def encrypt(self, text): def encrypt(self, text: str) -> str:
if not self.password: if not self.password:
self.password = create_password(self._journal_name) self.password = create_password(self._journal_name)
return self._encrypt(text) return self._encrypt(text)
def decrypt(self, text: str) -> str: def decrypt(self, text: bytes) -> str:
if not self.password: if not self.password:
self._prompt_password() self._prompt_password()
while (result := self._decrypt(text)) is None: while (result := self._decrypt(text)) is None:

View file

@ -47,12 +47,8 @@ class Jrnlv2Encryption(BasePasswordEncryption):
.decode(self._encoding) .decode(self._encoding)
) )
def _decrypt(self, text: str) -> str | None: def _decrypt(self, text: bytes) -> str | None:
try: try:
return ( return Fernet(self._key).decrypt(text).decode(self._encoding)
Fernet(self._key)
.decrypt(text.encode(self._encoding))
.decode(self._encoding)
)
except (InvalidToken, IndexError): except (InvalidToken, IndexError):
return None return None

View file

@ -10,8 +10,5 @@ class NoEncryption(BaseEncryption):
def _encrypt(self, text: str) -> str: def _encrypt(self, text: str) -> str:
return text return text
def _decrypt(self, text: bytes | str) -> str: def _decrypt(self, text: bytes) -> str:
result = text return text.decode(self._encoding)
if isinstance(result, bytes):
result = result.decode(self._encoding)
return result