import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from scipy.signal import stft, istft
from PIL import Image, ImageFilter


def scale_image(img, scaling_type):
    h,_ = img.shape
    
    y_indices = np.linspace(1, h, h)  # Avoid log(0) issue
    if scaling_type == 'log':
        log_y_indices = np.log(y_indices + 1)
    elif scaling_type == 'MEL':
        log_y_indices = 2595*np.log(y_indices/700 + 1)
    else:
        print('unknown scaling type. please use log or MEL')
    log_y_indices = (log_y_indices - log_y_indices.min()) / (log_y_indices.max() - log_y_indices.min())
    log_y_indices *= (h - 1)
    log_y_indices = log_y_indices.astype(np.int32)
    
    output_img = np.zeros_like(img)
    
    for y in range(h - 1):
        y1, y2 = log_y_indices[y], log_y_indices[y + 1]
        if y1 != y2:
            output_img[y1:y2] = img[y]
        else:
            output_img[y1] = img[y]
    
    for y in range(1, h - 1): # Fill missing pixels via interpolation
        if np.all(output_img[y] == 0):
            output_img[y] = (output_img[y - 1] + output_img[y + 1]) // 2
    
    return output_img


def apply_image_mask_stereo(wav_file, img_file, output_file, time_range=(0, 2), freq_range=(1000, 5000), alpha = 1, blur_radius=0, scaling_type = 'linear'):
    sample_rate, audio = wavfile.read(wav_file)
    stereo = (audio.ndim == 2)

    if stereo:
        left, right = audio[:, 0], audio[:, 1]
    else:
        left = right = audio

    img = Image.open(img_file).convert('L')
    if blur_radius > 0:
        img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))

    def process_channel(channel):
        f, t, Zxx = stft(channel, fs=sample_rate, nperseg=2048, noverlap=1536)
        
        img_resized = img.resize((np.sum((t >= time_range[0]) & (t <= time_range[1])), 
                                    np.sum((f >= freq_range[0]) & (f <= freq_range[1]))))
        
        if scaling_type == 'log':
            img_array = np.flipud(scale_image(np.array(img_resized) / 255.0, 'log'))
        elif scaling_type == 'MEL':
            img_array = np.flipud(scale_image(np.array(img_resized) / 255.0, 'MEL'))
        else:
            img_array = np.flipud(np.array(img_resized) / 255.0)
            
        time_mask = (t >= time_range[0]) & (t <= time_range[1])
        freq_mask = (f >= freq_range[0]) & (f <= freq_range[1])
        
        plt.figure(figsize=(12, 6)) # Show original STFT (only region of interest)
        plt.subplot(2, 1, 1)
        plt.title("Original STFT (Region of Interest)")
        plt.imshow(np.log1p(np.abs(Zxx[np.ix_(freq_mask, time_mask)])), aspect='auto', origin='lower',
                   extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
        plt.colorbar(label='Log Magnitude')
        plt.xlabel('Time [s]')
        plt.ylabel('Frequency [Hz]')

        # Apply image mask to the region of interest
        Zxx[np.ix_(freq_mask, time_mask)] *= (1-alpha*img_array) #I'm inverting the image here, so white color deletes frequencies, and black keeps them.
        #if you want to delete with black and keep with white, use:
        # Zxx[np.ix_(freq_mask, time_mask)] *= alpha*img_array

        plt.subplot(2, 1, 2)
        plt.title("STFT with Image Mask (Region of Interest)")
        plt.imshow(np.log1p(np.abs(Zxx[np.ix_(freq_mask, time_mask)])), aspect='auto', origin='lower',
                   extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
        plt.colorbar(label='Log Magnitude')
        plt.xlabel('Time [s]')
        plt.ylabel('Frequency [Hz]')
        plt.show()

        _, reconstructed_audio = istft(Zxx, fs=sample_rate, nperseg=2048, noverlap=1536)
        return reconstructed_audio

    left_audio = process_channel(left)
    right_audio = process_channel(right)

    min_length = min(len(left_audio), len(right_audio))
    stereo_audio = np.vstack((left_audio[:min_length], right_audio[:min_length])).T
    stereo_audio = np.int16(stereo_audio / np.max(np.abs(stereo_audio)) * 32767)

    wavfile.write(output_file, sample_rate, stereo_audio)
    print(f"Audio with image mask saved to {output_file}")


def apply_image_mask_hidden_mono(wav_file, img_file, output_file, time_range=(0, 2), freq_range=(1000, 5000), alpha= 1, blur_radius=0,scaling_type = 'linear'):
    sample_rate, audio = wavfile.read(wav_file)
    stereo = (audio.ndim == 2)

    if stereo:
        left, right = audio[:, 0], audio[:, 1]
    else:
        left = right = audio

    img = Image.open(img_file).convert('L')
    if blur_radius > 0:
        img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))

    def process_channel(channel):
        f, t, Zxx = stft(channel, fs=sample_rate, nperseg=2048, noverlap=1536)

        img_resized = img.resize((np.sum((t >= time_range[0]) & (t <= time_range[1])), 
                                        np.sum((f >= freq_range[0]) & (f <= freq_range[1]))))
        
        if scaling_type == 'log':
            img_array = np.flipud(scale_image(np.array(img_resized) / 255.0, 'log'))
        elif scaling_type == 'MEL':
            img_array = np.flipud(scale_image(np.array(img_resized) / 255.0, 'MEL'))
        else:
            img_array = np.flipud(np.array(img_resized) / 255.0)

        time_mask = (t >= time_range[0]) & (t <= time_range[1])
        freq_mask = (f >= freq_range[0]) & (f <= freq_range[1])
        mask = np.zeros_like(Zxx, dtype=float)
        mask[np.ix_(freq_mask, time_mask)] = alpha*img_array
        Zxx *= mask

        _, reconstructed_audio = istft(Zxx, fs=sample_rate, nperseg=2048, noverlap=1536)
        return reconstructed_audio

    left__edited = process_channel(left)
    right_edited = process_channel(right)
    

    min_length = min(len(left), len(right), len(right_edited), len(left__edited)) #put some buffer before and after the soundfile, so nothing gets cut off
    combined_left = left[:min_length] - right_edited[:min_length]
    combined_right = right[:min_length] - left__edited[:min_length]


    f, t, Zxx_left = stft(combined_left, fs=sample_rate, nperseg=2048, noverlap=1536)
    time_mask = (t >= time_range[0]) & (t <= time_range[1])
    freq_mask = (f >= freq_range[0]) & (f <= freq_range[1])
    plt.figure(figsize=(12, 8))
    plt.subplot(3, 1, 1)
    plt.title("Left Channel STFT")
    plt.imshow(np.log1p(np.abs(Zxx_left[np.ix_(freq_mask, time_mask)])), aspect='auto', origin='lower',
                   extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
    plt.colorbar(label='Log Magnitude')
    plt.xlabel('Time [s]')
    plt.ylabel('Frequency [Hz]')
    if scaling_type == 'log' or scaling_type == 'MEL':
        plt.gca().set_yscale('log')
    f, t, Zxx_right = stft(combined_right, fs=sample_rate, nperseg=2048, noverlap=1536)
    plt.subplot(3, 1, 2)
    plt.title("Right Channel STFT")
    plt.imshow(np.log1p(np.abs(Zxx_right[np.ix_(freq_mask, time_mask)])), aspect='auto', origin='lower',
                   extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
    plt.colorbar(label='Log Magnitude')
    plt.xlabel('Time [s]')
    plt.ylabel('Frequency [Hz]')
    if scaling_type == 'log' or scaling_type == 'MEL':
        plt.gca().set_yscale('log')
    plt.subplot(3, 1, 3)
    plt.title("Mono (Combined) STFT")
    f, t, Zxx_mono = stft((combined_right+combined_left)/2, fs=sample_rate, nperseg=2048, noverlap=1536)
    plt.imshow(np.log1p(np.abs(Zxx_mono[np.ix_(freq_mask, time_mask)])), aspect='auto', origin='lower',
                   extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
    plt.colorbar(label='Log Magnitude')
    plt.xlabel('Time [s]')
    plt.ylabel('Frequency [Hz]')
    if scaling_type == 'log' or scaling_type == 'MEL':
        plt.gca().set_yscale('log')
    plt.tight_layout()
    plt.show()

    stereo_audio = np.vstack((combined_left, combined_right)).T
    stereo_audio = np.int16(stereo_audio / np.max(np.abs(stereo_audio)) * 32767)

    wavfile.write(output_file, sample_rate, stereo_audio)
    print(f"Audio with image mask saved to {output_file}")


def analyze_file(wav_file, time_range, freq_range, log_scale=False):
    sample_rate, audio = wavfile.read(wav_file)
    stereo = (audio.ndim == 2)

    if stereo:
        left, right = audio[:, 0], audio[:, 1]
    else:
        left = right = audio

    # Process left channel
    f, t, Zxx_left = stft(left, fs=sample_rate, nperseg=2048, noverlap=1536)
    time_mask = (t >= time_range[0]) & (t <= time_range[1])
    freq_mask = (f >= freq_range[0]) & (f <= freq_range[1])
    plt.figure(figsize=(12, 6))
    plt.subplot(3, 1, 1)
    plt.title("Left Channel STFT")
    plt.imshow(np.log1p(np.abs(Zxx_left[np.ix_(freq_mask, time_mask)])),
               aspect='auto', origin='lower',
               extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
    plt.colorbar(label='Log Magnitude')
    plt.xlabel('Time [s]')
    plt.ylabel('Frequency [Hz]')
    if log_scale:
        plt.gca().set_yscale('log')

    # Process right channel
    f, t, Zxx_right = stft(right, fs=sample_rate, nperseg=2048, noverlap=1536)
    plt.subplot(3, 1, 2)
    plt.title("Right Channel STFT")
    plt.imshow(np.log1p(np.abs(Zxx_right[np.ix_(freq_mask, time_mask)])),
               aspect='auto', origin='lower',
               extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
    plt.colorbar(label='Log Magnitude')
    plt.xlabel('Time [s]')
    plt.ylabel('Frequency [Hz]')
    if log_scale:
        plt.gca().set_yscale('log')

    # Process mono channel (combined)
    f, t, Zxx_mono = stft((right + left) / 2, fs=sample_rate, nperseg=2048, noverlap=1536)
    plt.subplot(3, 1, 3)
    plt.title("Mono (Combined) STFT")
    plt.imshow(np.log1p(np.abs(Zxx_mono[np.ix_(freq_mask, time_mask)])),
               aspect='auto', origin='lower',
               extent=[time_range[0], time_range[1], freq_range[0], freq_range[1]], cmap='turbo')
    plt.colorbar(label='Log Magnitude')
    plt.xlabel('Time [s]')
    plt.ylabel('Frequency [Hz]')
    if log_scale:
        plt.gca().set_yscale('log')

    plt.tight_layout()
    plt.show()



# the hidden mono function subtracts the image created on the right channel from the let channel, and the left channel image from the righ channel. 
# If the data wasnt mono before the changes to the stereo file won't be so easily audible, but also not visible, except if you look at it in mono. 
# .png and .jpg are tested and work well. But use black and white images for the best results
#most simple use case:
apply_image_mask_hidden_mono('YourSong.wav', 'YourImage.jpg', 'output_hidden_mono.wav')
analyze_file('output_hidden_mono.wav', time_range=(0, 12), freq_range=(10, 10000))


#more advanced usecase, where blur is added to reduce the noise (but makes the image less sharp) and the frequency is scaled logarithmically
apply_image_mask_hidden_mono('YourSong.wav', 'YourImage.jpg', 'output_hidden_mono_advanced.wav', time_range=(4, 8), freq_range=(1000, 8000), alpha = 0.9, blur_radius=5, scaling_type='log') #time_range in s, frequency in Hz
analyze_file('output_hidden_mono.wav', time_range=(0, 12), freq_range=(10, 10000), log_scale=True)



# the stereo function simply carves the image into the spectrum of each channel, the artifacts are easily audible here.
apply_image_mask_stereo('YourSong.wav', 'YourImage.jpg', 'output_stereo.wav') #time_range in s, frequency in Hz
analyze_file('output_stereo.wav', time_range=(0, 20), freq_range=(20, 10000))


#currently only working with .wav files, more support may ot may not be added in the forseeable or unforseeable future, mainly depending on your interest

Bugs and improvements can be posted here or under the video.

Leave a Reply

Your email address will not be published. Required fields are marked *