diff --git a/stegano/lsbset/lsbset.py b/stegano/lsbset/lsbset.py index 038abc1..d22fcbc 100644 --- a/stegano/lsbset/lsbset.py +++ b/stegano/lsbset/lsbset.py @@ -46,13 +46,14 @@ def hide(input_image_file, message, generator, auto_convert_rgb=False): assert message_length != 0, "message length is zero" img = Image.open(input_image_file) - if img.mode != 'RGB': - print('The mode of the image is not RGB. Mode is {}'.format(img.mode)) + + if img.mode not in ['RGB', 'RGBA']: if not auto_convert_rgb: + print('The mode of the image is not RGB. Mode is {}'.\ + format(img.mode)) answer = input('Convert the image to RGB ? [Y / n]\n') or 'Y' if answer.lower() == 'n': raise Exception('Not a RGB image.') - img = img.convert('RGB') img_list = list(img.getdata()) @@ -72,7 +73,7 @@ def hide(input_image_file, message, generator, auto_convert_rgb=False): while index + 3 <= len_message_bits : generated_number = next(generator) - (r, g, b) = img_list[generated_number] + r, g, b, *a = img_list[generated_number] # Change the Least Significant Bit of each colour component. r = tools.setlsb(r, message_bits[index]) @@ -80,7 +81,10 @@ def hide(input_image_file, message, generator, auto_convert_rgb=False): b = tools.setlsb(b, message_bits[index+2]) # Save the new pixel - img_list[generated_number] = (r, g , b) + if img.mode == 'RGBA': + img_list[generated_number] = (r, g , b, a[0]) + else: + img_list[generated_number] = (r, g , b) index += 3 diff --git a/tests/test_lsbset.py b/tests/test_lsbset.py index ae8aa2b..d5475a1 100644 --- a/tests/test_lsbset.py +++ b/tests/test_lsbset.py @@ -53,6 +53,19 @@ class TestLSBSet(unittest.TestCase): self.assertEqual(message, clear_message) + def test_with_transparent_png(self): + messages_to_hide = ["a", "foo", "Hello World!", ":Python:"] + + for message in messages_to_hide: + secret = lsbset.hide("./tests/sample-files/transparent.png", + message, generators.eratosthenes()) + secret.save("./image.png") + + clear_message = lsbset.reveal("./image.png", + generators.eratosthenes()) + + self.assertEqual(message, clear_message) + def test_with_too_long_message(self): with open("./tests/sample-files/lorem_ipsum.txt") as f: message = f.read()