initial pass through to rework encryption into separate module

This commit is contained in:
Jonathan Wren 2022-09-24 05:49:55 -07:00
parent 057f31407a
commit b3a662fd9f
12 changed files with 224 additions and 176 deletions

View file

@ -1,153 +1,15 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
import base64
import hashlib
import logging
import os
from typing import Callable
from typing import Optional
from cryptography.fernet import Fernet
from cryptography.fernet import InvalidToken
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers import modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from jrnl.exception import JrnlException
from jrnl.Journal import Journal
from jrnl.Journal import LegacyJournal
from jrnl.messages import Message
from jrnl.messages import MsgStyle
from jrnl.messages import MsgText
from jrnl.output import print_msg
from jrnl.prompt import create_password
def make_key(password):
password = password.encode("utf-8")
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
# Salt is hard-coded
salt=b"\xf2\xd5q\x0e\xc1\x8d.\xde\xdc\x8e6t\x89\x04\xce\xf8",
iterations=100_000,
backend=default_backend(),
)
key = kdf.derive(password)
return base64.urlsafe_b64encode(key)
def decrypt_content(
decrypt_func: Callable[[str], Optional[str]],
keychain: str = None,
max_attempts: int = 3,
) -> str:
def get_pw():
return print_msg(
Message(MsgText.Password, MsgStyle.PROMPT), get_input=True, hide_input=True
)
pwd_from_keychain = keychain and get_keychain(keychain)
password = pwd_from_keychain or get_pw()
result = decrypt_func(password)
# Password is bad:
if result is None and pwd_from_keychain:
set_keychain(keychain, None)
attempt = 1
while result is None and attempt < max_attempts:
print_msg(Message(MsgText.WrongPasswordTryAgain, MsgStyle.WARNING))
password = get_pw()
result = decrypt_func(password)
attempt += 1
if result is None:
raise JrnlException(Message(MsgText.PasswordMaxTriesExceeded, MsgStyle.ERROR))
return result
class EncryptedJournal(Journal):
def __init__(self, name="default", **kwargs):
super().__init__(name, **kwargs)
self.config["encrypt"] = True
self.password = None
def open(self, filename=None):
"""Opens the journal file defined in the config and parses it into a list of Entries.
Entries have the form (date, title, body)."""
filename = filename or self.config["journal"]
dirname = os.path.dirname(filename)
if not os.path.exists(filename):
if not os.path.isdir(dirname):
os.makedirs(dirname)
print_msg(
Message(
MsgText.DirectoryCreated,
MsgStyle.NORMAL,
{"directory_name": dirname},
)
)
self.create_file(filename)
self.password = create_password(self.name)
print_msg(
Message(
MsgText.JournalCreated,
MsgStyle.NORMAL,
{"journal_name": self.name, "filename": filename},
)
)
text = self._load(filename)
self.entries = self._parse(text)
self.sort()
logging.debug("opened %s with %d entries", self.__class__.__name__, len(self))
return self
def _load(self, filename):
"""Loads an encrypted journal from a file and tries to decrypt it.
If password is not provided, will look for password in the keychain
and otherwise ask the user to enter a password up to three times.
If the password is provided but wrong (or corrupt), this will simply
return None."""
with open(filename, "rb") as f:
journal_encrypted = f.read()
def decrypt_journal(password):
key = make_key(password)
try:
plain = Fernet(key).decrypt(journal_encrypted).decode("utf-8")
self.password = password
return plain
except (InvalidToken, IndexError):
return None
if self.password:
return decrypt_journal(self.password)
return decrypt_content(keychain=self.name, decrypt_func=decrypt_journal)
def _store(self, filename, text):
key = make_key(self.password)
journal = Fernet(key).encrypt(text.encode("utf-8"))
with open(filename, "wb") as f:
f.write(journal)
@classmethod
def from_journal(cls, other: Journal):
new_journal = super().from_journal(other)
new_journal.password = (
other.password
if hasattr(other, "password")
else create_password(other.name)
)
return new_journal
class LegacyEncryptedJournal(LegacyJournal):
@ -185,33 +47,3 @@ class LegacyEncryptedJournal(LegacyJournal):
if self.password:
return decrypt_journal(self.password)
return decrypt_content(keychain=self.name, decrypt_func=decrypt_journal)
def get_keychain(journal_name):
import keyring
try:
return keyring.get_password("jrnl", journal_name)
except keyring.errors.KeyringError as e:
if not isinstance(e, keyring.errors.NoKeyringError):
print_msg(Message(MsgText.KeyringRetrievalFailure, MsgStyle.ERROR))
return ""
def set_keychain(journal_name, password):
import keyring
if password is None:
try:
keyring.delete_password("jrnl", journal_name)
except keyring.errors.KeyringError:
pass
else:
try:
keyring.set_password("jrnl", journal_name, password)
except keyring.errors.KeyringError as e:
if isinstance(e, keyring.errors.NoKeyringError):
msg = Message(MsgText.KeyringBackendNotFound, MsgStyle.WARNING)
else:
msg = Message(MsgText.KeyringRetrievalFailure, MsgStyle.ERROR)
print_msg(msg)

View file

@ -5,9 +5,12 @@ import datetime
import logging
import os
import re
from types import ModuleType
from jrnl import Entry
from jrnl import time
from jrnl.encryption import BaseEncryption
from jrnl.encryption import determine_encryption_type
from jrnl.messages import Message
from jrnl.messages import MsgStyle
from jrnl.messages import MsgText
@ -46,6 +49,7 @@ class Journal:
self.search_tags = None # Store tags we're highlighting
self.name = name
self.entries = []
self._encryption_method = None
def __len__(self):
"""Returns the number of entries"""
@ -77,6 +81,23 @@ class Journal:
self.entries = list(frozenset(self.entries) | frozenset(imported_entries))
self.sort()
def _get_encryption_method(self):
self._encryption_method = determine_encryption_type(self.config["encrypt"])(
self.config
)
def _decrypt(self, text):
if not self._encryption_method:
self._get_encryption_method()
return self._encryption_method.decrypt(text)
def _encrypt(self, text):
if not self._encryption_method:
self._get_encryption_method()
return self._encryption_method.encrypt(text)
def open(self, filename=None):
"""Opens the journal file defined in the config and parses it into a list of Entries.
Entries have the form (date, title, body)."""
@ -105,6 +126,7 @@ class Journal:
)
text = self._load(filename)
text = self._decrypt(text)
self.entries = self._parse(text)
self.sort()
logging.debug("opened %s with %d entries", self.__class__.__name__, len(self))
@ -114,6 +136,7 @@ class Journal:
"""Dumps the journal into the config file, overwriting it"""
filename = filename or self.config["journal"]
text = self._to_text()
text = self._encrypt(text)
self._store(filename, text)
def validate_parsing(self):
@ -344,7 +367,7 @@ class Journal:
def editable_str(self):
"""Turns the journal into a string of entries that can be edited
manually and later be parsed with eslf.parse_editable_str."""
manually and later be parsed with self.parse_editable_str."""
return "\n".join([str(e) for e in self.entries])
def parse_editable_str(self, edited):
@ -465,8 +488,6 @@ def open_journal(journal_name, config, legacy=False):
return FolderJournal.Folder(journal_name, **config).open()
return PlainJournal(journal_name, **config).open()
from jrnl import EncryptedJournal
if legacy:
return EncryptedJournal.LegacyEncryptedJournal(journal_name, **config).open()
return EncryptedJournal.EncryptedJournal(journal_name, **config).open()
return PlainJournal(journal_name, **config).open()

View file

@ -78,8 +78,8 @@ def postconfig_encrypt(args, config, original_config, **kwargs):
Encrypt a journal in place, or optionally to a new file
"""
from jrnl.config import update_config
from jrnl.EncryptedJournal import EncryptedJournal
from jrnl.install import save_config
from jrnl.Journal import PlainJournal
from jrnl.Journal import open_journal
# Open the journal
@ -97,7 +97,7 @@ def postconfig_encrypt(args, config, original_config, **kwargs):
)
)
new_journal = EncryptedJournal.from_journal(journal)
new_journal = PlainJournal.from_journal(journal)
# If journal is encrypted, create new password
if journal.config["encrypt"] is True:

View file

@ -0,0 +1,25 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
from abc import ABC
from abc import abstractmethod
class BaseEncryption(ABC):
def __init__(self, config):
self._config = config
@abstractmethod
def encrypt(self, text: str) -> str:
pass
@abstractmethod
def decrypt(self, text: str) -> str | None:
pass
@abstractmethod
def _decrypt(self, text: str) -> str:
"""
This is needed because self.decrypt needs
to get a password on decryption failures
"""
pass

View file

@ -0,0 +1,7 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
from .BaseEncryption import BaseEncryption
class BaseKeyEncryption(BaseEncryption):
pass

View file

@ -0,0 +1,52 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
from jrnl.encryption.BaseEncryption import BaseEncryption
from jrnl.exception import JrnlException
from jrnl.keyring import get_keyring_password
from jrnl.messages import Message
from jrnl.messages import MsgStyle
from jrnl.messages import MsgText
from jrnl.output import print_msg
class BasePasswordEncryption(BaseEncryption):
_attempts: int
_journal_name: str
_max_attempts: int
_password: str | None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._attempts = 0
self._max_attempts = 3
self._password = None
# Check keyring first to be ready for decryption
get_keyring_password(self._config["journal"])
# Prompt for password if keyring didn't work
if self._password is None:
self._prompt_password()
def decrypt(self, text: str) -> str:
encoded_text = text.encode(self._encoding)
while (result := self._decrypt(encoded_text)) is None:
self._prompt_password()
return result
def _prompt_password(self):
if self._attempts >= self._max_attempts:
raise JrnlException(
Message(MsgText.PasswordMaxTriesExceeded, MsgStyle.ERROR)
)
if self._attempts > 0:
print_msg(Message(MsgText.WrongPasswordTryAgain, MsgStyle.WARNING))
self._attempts += 1
self._password = print_msg(
Message(MsgText.Password, MsgStyle.PROMPT),
get_input=True,
hide_input=True,
)

View file

@ -0,0 +1,7 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
from jrnl.encryption.BaseEncryption import BaseEncryption
class Jrnlv1Encryption(BaseEncryption):
pass

View file

@ -0,0 +1,46 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
import base64
from cryptography.fernet import Fernet
from cryptography.fernet import InvalidToken
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from .BasePasswordEncryption import BasePasswordEncryption
class Jrnlv2Encryption(BasePasswordEncryption):
_salt: bytes
_encoding: str
_key: bytes
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Salt is hard-coded
self._salt = b"\xf2\xd5q\x0e\xc1\x8d.\xde\xdc\x8e6t\x89\x04\xce\xf8"
self._encoding = "utf-8"
self._make_key()
def _make_key(self) -> None:
password = self._password.encode(self._encoding)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self._salt,
iterations=100_000,
backend=default_backend(),
)
key = kdf.derive(password)
self._key = base64.urlsafe_b64encode(key)
def encrypt(self, text: str) -> bytes:
return Fernet(self._key).encrypt(text.encode(self._encoding))
def _decrypt(self, text: bytes) -> str | None:
try:
return Fernet(self._key).decrypt((text)).decode(self._encoding)
except (InvalidToken, IndexError):
return None

View file

@ -0,0 +1,13 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
class NoEncryption:
def __init__(self, *args, **kwargs):
pass
def encrypt(self, text: str) -> str:
return text
def decrypt(self, text: str) -> str:
return text

View file

@ -0,0 +1,18 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
from jrnl.encryption.NoEncryption import NoEncryption
def determine_encryption_type(config):
encryption_method = NoEncryption
if config is True:
# Default encryption method
from jrnl.encryption.Jrnlv2Encryption import Jrnlv2Encryption
encryption_method = Jrnlv2Encryption
elif config == "jrnlv1":
from jrnl.encryption.Jrnlv1Encryption import Jrnlv1Encryption
encryption_method = Jrnlv1Encryption
return encryption_method

28
jrnl/keyring.py Normal file
View file

@ -0,0 +1,28 @@
# Copyright © 2012-2022 jrnl contributors
# License: https://www.gnu.org/licenses/gpl-3.0.html
import keyring
from jrnl.messages import Message
from jrnl.messages import MsgStyle
from jrnl.messages import MsgText
from jrnl.output import print_msg
def get_keyring_password(journal_name: str = "default") -> str | None:
try:
return keyring.get_password("jrnl", journal_name)
except keyring.errors.KeyringError as e:
if not isinstance(e, keyring.errors.NoKeyringError):
print_msg(Message(MsgText.KeyringRetrievalFailure, MsgStyle.ERROR))
return None
def set_keyring_password(password, journal_name: str = "default"):
try:
return keyring.set_password("jrnl", journal_name, password)
except keyring.errors.KeyringError as e:
if isinstance(e, keyring.errors.NoKeyringError):
msg = Message(MsgText.KeyringBackendNotFound, MsgStyle.WARNING)
else:
msg = Message(MsgText.KeyringRetrievalFailure, MsgStyle.ERROR)
print_msg(msg)

View file

@ -8,7 +8,6 @@ from jrnl import __version__
from jrnl.config import is_config_json
from jrnl.config import load_config
from jrnl.config import scope_config
from jrnl.EncryptedJournal import EncryptedJournal
from jrnl.exception import JrnlException
from jrnl.messages import Message
from jrnl.messages import MsgStyle
@ -132,7 +131,7 @@ def upgrade_jrnl(config_path):
old_journal = Journal.open_journal(
journal_name, scope_config(config, journal_name), legacy=True
)
all_journals.append(EncryptedJournal.from_journal(old_journal))
all_journals.append(PlainJournal.from_journal(old_journal))
for journal_name, path in plain_journals.items():
print_msg(