more cleanup

This commit is contained in:
Jonathan Wren 2022-09-25 14:57:15 -07:00
parent dbeb2c2128
commit 358e6c7f68
8 changed files with 39 additions and 26 deletions

View file

@ -380,7 +380,7 @@ class Journal:
class PlainJournal(Journal): class PlainJournal(Journal):
def _load(self, filename): def _load(self, filename):
with open(filename, "r", encoding="utf-8") as f: with open(filename, "r") as f:
return f.read() return f.read()
def _store(self, filename, text): def _store(self, filename, text):
@ -394,7 +394,7 @@ class LegacyJournal(Journal):
by square brackets. You'll not be able to save these journals anymore.""" by square brackets. You'll not be able to save these journals anymore."""
def _load(self, filename): def _load(self, filename):
with open(filename, "r", encoding="utf-8") as f: with open(filename, "rb") as f:
return f.read() return f.read()
def _parse(self, journal_txt): def _parse(self, journal_txt):

View file

@ -3,19 +3,31 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from jrnl.exception import JrnlException
from jrnl.messages import Message
from jrnl.messages import MsgStyle
from jrnl.messages import MsgText
class BaseEncryption(ABC): class BaseEncryption(ABC):
_encoding: str
_journal_name: str
def __init__(self, journal_name, config): def __init__(self, journal_name, config):
self._encoding = "utf-8"
self._journal_name = journal_name self._journal_name = journal_name
self._config = config self._config = config
@abstractmethod
def encrypt(self, text: str) -> str: def encrypt(self, text: str) -> str:
pass return self._encrypt(text)
@abstractmethod def decrypt(self, text: str) -> str:
def decrypt(self, text: str) -> str | None: if (result := self._decrypt(text)) is None:
pass raise JrnlException(
Message(MsgText.DecryptionFailedGeneric, MsgStyle.ERROR)
)
return result
@abstractmethod @abstractmethod
def _encrypt(self, text: str) -> str: def _encrypt(self, text: str) -> str:
@ -27,7 +39,7 @@ class BaseEncryption(ABC):
pass pass
@abstractmethod @abstractmethod
def _decrypt(self, text: str) -> str: def _decrypt(self, text: bytes) -> str | None:
""" """
This is needed because self.decrypt might need This is needed because self.decrypt might need
to perform actions (e.g. prompt for password) to perform actions (e.g. prompt for password)

View file

@ -8,17 +8,14 @@ from jrnl.prompt import prompt_password
class BasePasswordEncryption(BaseEncryption): class BasePasswordEncryption(BaseEncryption):
_attempts: int _attempts: int
_journal_name: str
_max_attempts: int _max_attempts: int
_password: str | None _password: str | None
_encoding: str
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._attempts = 0 self._attempts = 0
self._max_attempts = 3 self._max_attempts = 3
self._password = None self._password = None
self._encoding = "utf-8"
# Check keyring first for password. # Check keyring first for password.
# That way we'll have it. # That way we'll have it.
@ -38,7 +35,7 @@ class BasePasswordEncryption(BaseEncryption):
self.password = create_password(self._journal_name) self.password = create_password(self._journal_name)
return self._encrypt(text) return self._encrypt(text)
def decrypt(self, text): def decrypt(self, text: bytes) -> str:
if self.password is None: if self.password is None:
self._prompt_password() self._prompt_password()
while (result := self._decrypt(text)) is None: while (result := self._decrypt(text)) is None:

View file

@ -20,7 +20,8 @@ class Jrnlv1Encryption(BasePasswordEncryption):
def _decrypt(self, text: bytes) -> str | None: def _decrypt(self, text: bytes) -> str | None:
iv, cipher = text[:16], text[16:] iv, cipher = text[:16], text[16:]
decryption_key = hashlib.sha256(self._password.encode("utf-8")).digest() password = self._password or ""
decryption_key = hashlib.sha256(password.encode(self._encoding)).digest()
decryptor = Cipher( decryptor = Cipher(
algorithms.AES(decryption_key), modes.CBC(iv), default_backend() algorithms.AES(decryption_key), modes.CBC(iv), default_backend()
).decryptor() ).decryptor()
@ -29,10 +30,10 @@ class Jrnlv1Encryption(BasePasswordEncryption):
# self._password = password # self._password = password
if plain_padded[-1] in (" ", 32): if plain_padded[-1] in (" ", 32):
# Ancient versions of jrnl. Do not judge me. # Ancient versions of jrnl. Do not judge me.
return plain_padded.decode("utf-8").rstrip(" ") return plain_padded.decode(self._encoding).rstrip(" ")
else: else:
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
plain = unpadder.update(plain_padded) + unpadder.finalize() plain = unpadder.update(plain_padded) + unpadder.finalize()
return plain.decode("utf-8") return plain.decode(self._encoding)
except ValueError: except ValueError:
return None return None

View file

@ -49,12 +49,8 @@ class Jrnlv2Encryption(BasePasswordEncryption):
.decode(self._encoding) .decode(self._encoding)
) )
def _decrypt(self, text: str) -> str | None: def _decrypt(self, text: bytes) -> str | None:
try: try:
return ( return Fernet(self._key).decrypt(text).decode(self._encoding)
Fernet(self._key)
.decrypt(text.encode(self._encoding))
.decode(self._encoding)
)
except (InvalidToken, IndexError): except (InvalidToken, IndexError):
return None return None

View file

@ -1,13 +1,14 @@
# Copyright © 2012-2022 jrnl contributors # Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html # License: https://www.gnu.org/licenses/gpl-3.0.html
from jrnl.encryption.BaseEncryption import BaseEncryption
class NoEncryption: class NoEncryption(BaseEncryption):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass super().__init__(*args, **kwargs)
def encrypt(self, text: str) -> str: def _encrypt(self, text: str) -> str:
return text return text
def decrypt(self, text: str) -> str: def _decrypt(self, text: str) -> str:
return text return text

View file

@ -93,6 +93,8 @@ class MsgText(Enum):
of journal can't be encrypted. Please fix your config file. of journal can't be encrypted. Please fix your config file.
""" """
DecryptionFailedGeneric = "The decryption of journal data failed."
KeyboardInterruptMsg = "Aborted by user" KeyboardInterruptMsg = "Aborted by user"
CantReadTemplate = """ CantReadTemplate = """

View file

@ -1,6 +1,7 @@
# Copyright © 2012-2022 jrnl contributors # Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html # License: https://www.gnu.org/licenses/gpl-3.0.html
import logging
import os import os
from jrnl import Journal from jrnl import Journal
@ -131,7 +132,10 @@ def upgrade_jrnl(config_path):
old_journal = Journal.open_journal( old_journal = Journal.open_journal(
journal_name, scope_config(config, journal_name), legacy=True journal_name, scope_config(config, journal_name), legacy=True
) )
all_journals.append(Journal.PlainJournal.from_journal(old_journal))
logging.debug(f"Clearing encryption method for '{journal_name}' journal")
old_journal.encryption_method = None
all_journals.append(old_journal)
for journal_name, path in plain_journals.items(): for journal_name, path in plain_journals.items():
print_msg( print_msg(