diff --git a/stegano/wav/wav.py b/stegano/wav/wav.py index b804944..029257f 100644 --- a/stegano/wav/wav.py +++ b/stegano/wav/wav.py @@ -23,11 +23,16 @@ __date__ = "$Date: 2010/10/01 $" __revision__ = "$Date: 2017/02/06 $" __license__ = "GPLv3" -from typing import IO, Union - import wave +from typing import IO, Union +from stegano import tools -def hide(input_file: Union[str, IO[bytes]], message: str, output_file: Union[str, IO[bytes]]): +def hide( + input_file: Union[str, IO[bytes]], + message: str, + output_file: Union[str, IO[bytes]], + encoding: str = "UTF-8" +): """ Hide a message (string) in a .wav audio file. @@ -36,7 +41,6 @@ def hide(input_file: Union[str, IO[bytes]], message: str, output_file: Union[str """ message_length = len(message) assert message_length != 0, "message message_length is zero" - # TODO messages in audio files could likely be much longer in most cases assert message_length < 255, "message is too long" output = wave.open(output_file, "wb") @@ -44,20 +48,27 @@ def hide(input_file: Union[str, IO[bytes]], message: str, output_file: Union[str nchannels, sampwidth, framerate, nframes, comptype, _ = input.getparams() assert comptype == "NONE", "only uncompressed files are supported" - nsamples = nchannels * nframes + nsamples = nframes * nchannels - # TODO get message bits and check length + message_bits = f"{message_length:08b}" + "".join(tools.a2bits_list(message, encoding)) + assert len(message_bits) <= nsamples, "message is too long" output.setnchannels(nchannels) output.setsampwidth(sampwidth) output.setframerate(framerate) - # TODO encode message length + message - frames = input.readframes(nframes * nchannels) + frames = bytearray(input.readframes(nsamples)) + for i in range(nsamples): + if i < len(message_bits): + if message_bits[i] == "0": + frames[i] = frames[i] & ~1 + else: + frames[i] = frames[i] | 1 + output.writeframes(frames) -def reveal(input_file: Union[str, IO[bytes]]): +def reveal(input_file: Union[str, IO[bytes]], encoding: str = "UTF-8"): """ Find a message in an image. @@ -65,7 +76,26 @@ def reveal(input_file: Union[str, IO[bytes]]): The first eight bits are used for message_length of the string. """ message = "" + encoding_len = tools.ENCODINGS[encoding] with wave.open(input_file, "rb") as input: - pass - # TODO + nchannels, _, _, nframes, comptype, _ = input.getparams() + assert comptype == "NONE", "only uncompressed files are supported" + + nsamples = nframes * nchannels + frames = bytearray(input.readframes(nsamples)) + + # Read first 8 bits for message length + length_bits = "" + for i in range(8): + length_bits += str(frames[i] & 1) + message_length = int(length_bits, 2) + + # Read message bits + message_bits = "" + for i in range(8, 8 + message_length * encoding_len): + message_bits += str(frames[i] & 1) + + # Convert bits to string + chars = [chr(int(message_bits[i:i+encoding_len], 2)) for i in range(0, len(message_bits), encoding_len)] + message = "".join(chars) return message