diff --git a/jrnl/Journal.py b/jrnl/Journal.py index 4a2dae30..68808e18 100644 --- a/jrnl/Journal.py +++ b/jrnl/Journal.py @@ -380,7 +380,7 @@ class Journal: class PlainJournal(Journal): def _load(self, filename): - with open(filename, "r", encoding="utf-8") as f: + with open(filename, "r") as f: return f.read() def _store(self, filename, text): @@ -394,7 +394,7 @@ class LegacyJournal(Journal): by square brackets. You'll not be able to save these journals anymore.""" def _load(self, filename): - with open(filename, "r", encoding="utf-8") as f: + with open(filename, "rb") as f: return f.read() def _parse(self, journal_txt): diff --git a/jrnl/encryption/BaseEncryption.py b/jrnl/encryption/BaseEncryption.py index 6fcaa2b1..0bbec442 100644 --- a/jrnl/encryption/BaseEncryption.py +++ b/jrnl/encryption/BaseEncryption.py @@ -3,19 +3,31 @@ from abc import ABC from abc import abstractmethod +from jrnl.exception import JrnlException +from jrnl.messages import Message +from jrnl.messages import MsgStyle +from jrnl.messages import MsgText + class BaseEncryption(ABC): + _encoding: str + _journal_name: str + def __init__(self, journal_name, config): + self._encoding = "utf-8" self._journal_name = journal_name self._config = config - @abstractmethod def encrypt(self, text: str) -> str: - pass + return self._encrypt(text) - @abstractmethod - def decrypt(self, text: str) -> str | None: - pass + def decrypt(self, text: str) -> str: + if (result := self._decrypt(text)) is None: + raise JrnlException( + Message(MsgText.DecryptionFailedGeneric, MsgStyle.ERROR) + ) + + return result @abstractmethod def _encrypt(self, text: str) -> str: @@ -27,7 +39,7 @@ class BaseEncryption(ABC): pass @abstractmethod - def _decrypt(self, text: str) -> str: + def _decrypt(self, text: bytes) -> str | None: """ This is needed because self.decrypt might need to perform actions (e.g. prompt for password) diff --git a/jrnl/encryption/BasePasswordEncryption.py b/jrnl/encryption/BasePasswordEncryption.py index bca59093..0f09284a 100644 --- a/jrnl/encryption/BasePasswordEncryption.py +++ b/jrnl/encryption/BasePasswordEncryption.py @@ -8,17 +8,14 @@ from jrnl.prompt import prompt_password class BasePasswordEncryption(BaseEncryption): _attempts: int - _journal_name: str _max_attempts: int _password: str | None - _encoding: str def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._attempts = 0 self._max_attempts = 3 self._password = None - self._encoding = "utf-8" # Check keyring first for password. # That way we'll have it. @@ -38,7 +35,7 @@ class BasePasswordEncryption(BaseEncryption): self.password = create_password(self._journal_name) return self._encrypt(text) - def decrypt(self, text): + def decrypt(self, text: bytes) -> str: if self.password is None: self._prompt_password() while (result := self._decrypt(text)) is None: diff --git a/jrnl/encryption/Jrnlv1Encryption.py b/jrnl/encryption/Jrnlv1Encryption.py index 3ab70b4a..14190230 100644 --- a/jrnl/encryption/Jrnlv1Encryption.py +++ b/jrnl/encryption/Jrnlv1Encryption.py @@ -20,7 +20,8 @@ class Jrnlv1Encryption(BasePasswordEncryption): def _decrypt(self, text: bytes) -> str | None: iv, cipher = text[:16], text[16:] - decryption_key = hashlib.sha256(self._password.encode("utf-8")).digest() + password = self._password or "" + decryption_key = hashlib.sha256(password.encode(self._encoding)).digest() decryptor = Cipher( algorithms.AES(decryption_key), modes.CBC(iv), default_backend() ).decryptor() @@ -29,10 +30,10 @@ class Jrnlv1Encryption(BasePasswordEncryption): # self._password = password if plain_padded[-1] in (" ", 32): # Ancient versions of jrnl. Do not judge me. - return plain_padded.decode("utf-8").rstrip(" ") + return plain_padded.decode(self._encoding).rstrip(" ") else: unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() plain = unpadder.update(plain_padded) + unpadder.finalize() - return plain.decode("utf-8") + return plain.decode(self._encoding) except ValueError: return None diff --git a/jrnl/encryption/Jrnlv2Encryption.py b/jrnl/encryption/Jrnlv2Encryption.py index c27e6abe..11747841 100644 --- a/jrnl/encryption/Jrnlv2Encryption.py +++ b/jrnl/encryption/Jrnlv2Encryption.py @@ -49,12 +49,8 @@ class Jrnlv2Encryption(BasePasswordEncryption): .decode(self._encoding) ) - def _decrypt(self, text: str) -> str | None: + def _decrypt(self, text: bytes) -> str | None: try: - return ( - Fernet(self._key) - .decrypt(text.encode(self._encoding)) - .decode(self._encoding) - ) + return Fernet(self._key).decrypt(text).decode(self._encoding) except (InvalidToken, IndexError): return None diff --git a/jrnl/encryption/NoEncryption.py b/jrnl/encryption/NoEncryption.py index 9f199066..60475a7a 100644 --- a/jrnl/encryption/NoEncryption.py +++ b/jrnl/encryption/NoEncryption.py @@ -1,13 +1,14 @@ # Copyright © 2012-2022 jrnl contributors # License: https://www.gnu.org/licenses/gpl-3.0.html +from jrnl.encryption.BaseEncryption import BaseEncryption -class NoEncryption: +class NoEncryption(BaseEncryption): def __init__(self, *args, **kwargs): - pass + super().__init__(*args, **kwargs) - def encrypt(self, text: str) -> str: + def _encrypt(self, text: str) -> str: return text - def decrypt(self, text: str) -> str: + def _decrypt(self, text: str) -> str: return text diff --git a/jrnl/messages/MsgText.py b/jrnl/messages/MsgText.py index 0f509645..3695eada 100644 --- a/jrnl/messages/MsgText.py +++ b/jrnl/messages/MsgText.py @@ -93,6 +93,8 @@ class MsgText(Enum): of journal can't be encrypted. Please fix your config file. """ + DecryptionFailedGeneric = "The decryption of journal data failed." + KeyboardInterruptMsg = "Aborted by user" CantReadTemplate = """ diff --git a/jrnl/upgrade.py b/jrnl/upgrade.py index bb81266d..e9f1c136 100644 --- a/jrnl/upgrade.py +++ b/jrnl/upgrade.py @@ -1,6 +1,7 @@ # Copyright © 2012-2022 jrnl contributors # License: https://www.gnu.org/licenses/gpl-3.0.html +import logging import os from jrnl import Journal @@ -131,7 +132,10 @@ def upgrade_jrnl(config_path): old_journal = Journal.open_journal( journal_name, scope_config(config, journal_name), legacy=True ) - all_journals.append(Journal.PlainJournal.from_journal(old_journal)) + + logging.debug(f"Clearing encryption method for '{journal_name}' journal") + old_journal.encryption_method = None + all_journals.append(old_journal) for journal_name, path in plain_journals.items(): print_msg(