import numpy as np
import warnings
import scipy.io
import scipy.io.wavfile as siow
from scipy import signal
import matplotlib.pyplot as plt

import utilMDC

def ammod(m, K, A, fc, fs):
    """
    Performs amplitude modulation.
    S = MY_AMMOD(M, K, A, Fc, Fs) is an amplitude modulated signal of the
    message signal M(t), where Fs is the sampling frequency of the message
    signal and Fc is the desired carrier frequency. The modulation is
    performed as follows:
    S = A * (1 + K*M(t)) * cos(2*pi*Fc*t)

    :param m: message signal
    :param K: as in the formula above
    :param A: as in the formula above
    :param fc: the desired carrier frequency
    :param fs: the sampling frequency of the message signal
    :return: the amplitude modulated signal of the message signal M(t)
    """

    # Issue a warning if the constant K is too large
    if (K * np.amax(np.abs(m)) > 1):
        warnings.warn("K must be such that |K * m(t)| <= 1")

    # Extract time vector from signal and sampling frequency
    t = np.linspace(0, (np.size(m) - 1) / fs, np.size(m))

    # Compute the modulated signal
    s = np.multiply(A * (1 + K * m), np.cos(2 * np.pi * fc * t))

    return s


def amdemod(s, fc, fs):
    """
    Demodulate AM signal (with the moving average method discussed in class).
    MD = MY_AMDEMOD(S, FC, FS) is the demodulation of an AM
    (amplitude modulation) signal at carrier frequency FC, sampled at FS.
    The returned signal is normalized to have values between -1 and 1.

    :param s: the AM signal
    :param fc: the carrier frequency
    :param fs: the sampling frequency
    :return: the demodulated signal, normalized to have values between -1 and 1
    """

    # Take absolute value
    s_abs = np.abs(s)

    k = int(np.round(10*fs/fc)) # we are getting the moving average over 10 carrier periods
    s_env = np.convolve(s_abs, np.ones(k)/k, mode='valid')
    s_zm = s_env - np.mean(s_env) # making it zero-mean

    # Convert to be between -1 and 1
    md = s_zm / np.amax(np.abs(s_zm))

    return md


def amdemod_withFilter(s, fc, fs):
    """
    Demodulate AM signal (with the filtering method discussed in class).
    MD = MY_AMDEMOD(S, FC, FS) is the demodulation of an AM
    (amplitude modulation) signal at carrier frequency FC, sampled at FS.
    The returned signal is normalized to have values between -1 and 1.

    :param s: the AM signal
    :param fc: the carrier frequency
    :param fs: the sampling frequency
    :return: the demodulated signal, normalized to have values between -1 and 1
    """

    N_ORDER = 2  # 8 for a better filtering, the demodulated signal will look better
    CUTOFF = fc/2  # a good choice for AM will be around 5 KHz (AM radio is generally limited to that)

    # Take absolute value
    s_abs = np.abs(s)

    # Filter with butterworth filter at half the center frequency
    w = CUTOFF / (fs / 2)  # Normalize the frequency
    b, a = signal.butter(N_ORDER, w, 'low')   # cf help butter
    s_filt = signal.filtfilt(b, a, s_abs)

    # Uncomment below to visualize the impulse response
    impulseResponseFilter = signal.filtfilt(b, a, np.concatenate((np.ones(1), np.zeros(1000))))
    utilMDC.tfplot(impulseResponseFilter, fs, 'h', 'Butterworth Filter')

    # another way to visualize the impulse response
    w1, h1 = signal.freqs(b, a)
    plt.semilogx(w1, 20 * np.log10(abs(h1)))
    plt.title('Butterworth filter frequency response')
    plt.xlabel('Frequency [radians / second]')
    plt.ylabel('Amplitude [dB]')
    plt.margins(0, 0.1)
    plt.grid(which='both', axis='both')
    plt.axvline(w, color='green')  # cutoff frequency
    plt.show()

    # Remove the mean # (eventually ignore the 25 first/last elements to remove the transients from the filter)
    s_filt = s_filt[25:-25] # remove the elements for good
    s_zm = s_filt - np.mean(s_filt[1:]) # one can also remove the elements only for subtracting the mean

    # Convert to be between -1 and 1 %( eventually ignore the 25 first/last elements)
    md = s_zm / np.amax(np.abs(s_zm[1:]))

    return md


# =======================================
# my_test_am()
# Tests the AM modulator and demodulator with an audible signal.
# Alternative values are given (commented) for playing a more audible signal (440Hz)

# Define signal parameters

Finfo  = 10     # 440;  % Message signal frequency [Hz]
Fc     = 300    # 40e3; % Carrier frequency [Hz]
A      = 1      # Modulation constant
K      = 1      # Modulation constant
Fs     = 4000   # 1000e3; % Sampling frequency [Hz]
d      = 1      # Signal duration [s]
DownSample = 1  # 30

# Time vector
t = np.linspace(0,d,int(d*Fs+1))

# Create the message signal and its modulated signal
m = 0.5 * np.cos(2*np.pi*Finfo*t)
s = ammod(m, K, A, Fc, Fs)

# Plot both the message signal and the modulated signal
utilMDC.tfplot(m, Fs, 'm_{am}', 'Message signal')
utilMDC.tfplot(s, Fs, 's_{am}', 'AM modulated signal')

# Demodulate the AM signal and plot again
m_est = amdemod(s, Fc, Fs)
# m_est = amdemod_withFilter(s, Fc, Fs)
utilMDC.tfplot(m_est, Fs, 'm_{am} (est)', 'Recovered message signal')

# Play the signals and display status messages
print("Playing original signal\n")
siow.write("originalSignal.wav", Fs // DownSample, m[::DownSample])
# playsound.playsound("originalSignal.wav", block=False)
print("Playing recovered signal\n")
siow.write("recoveredSignal.wav", Fs // DownSample, m_est[::DownSample])
# playsound.playsound("recoveredSignal.wav", block=False)

# =======================================
# demodulate and save/play the audio file

# Downsampling factor as in the Hw statement
downsamplingFactor = 10

# Load the samples, fc and fs from file
data = scipy.io.loadmat('am_data.mat')
fs, fc = int(data["fs"]), float(data["fc"])
m_am = data["am_signal"].flatten()  # flatten the array

# Demodulate and play
m = amdemod(m_am, fc, fs)
m = amdemod_withFilter(m_am, fc, fs)

track_file = "track.wav"
siow.write(track_file, fs // downsamplingFactor, m[::downsamplingFactor])
# playsound.playsound(track_file, block=False)
