import numpy as np
import matplotlib.pyplot as plt


def tfplot(s, fs, name, plottitle):
    """
    Displays a figure window with two subplots.  Above, the signal S is plotted in time domain;
    below, the signal is plotted in frequency domain.

    :param s: signal to be plotted
    :param fs: sampling frequency
    :param name: NAME is the "name" of the signal, e.g., if NAME is 's', then the labels on the y-axes will be
        's(t)' and '|s_F(f)|', respectively.
    :param plottitle: TITLE is the title that will appear above the two plots.
    :return:
    """

    if isinstance(fs, list):
        raise TypeError('Fs must be scalar')

    # Compute the time and frequency axes
    Ts = 1 / fs
    t = np.linspace(0, (np.size(s) - 1) * Ts, np.size(s))

    NFFT = np.ceil(np.log2(np.size(s)))  # for computational efficiency
    NFFT = np.power(2, NFFT)
    Tp = NFFT * Ts
    f = np.linspace(-1 / (2 * Ts), 1 / (2 * Ts) - 1 / Tp, int(NFFT))

    # Compute the FFT
    s_f = np.fft.fft(s, int(NFFT))
    # Correct the scaling and the frequency axis
    # Use fftshift to move the negative frequencies to the left
    s_f = Ts * np.fft.fftshift(s_f)

    # here we use the subplots
    fig, axs = plt.subplots(2, 1, constrained_layout=True)
    fig.suptitle(plottitle, fontsize=16)
    axs[0].plot(t, s)  # axs[0].plot(t, s, 'r*--')
    # axs[0].grid()
    axs[0].set(xlabel='t [s]', ylabel=name+'(t)')
    axs[1].plot(f, np.abs(s_f))  # axs[1].plot(f, np.abs(s_f), 'r*--')
    # axs[1].grid()
    axs[1].set(xlabel='f [Hz]', ylabel='|'+name+'_F(f)|')
    plt.show()

    return


def tfplotReImPhase(s, fs, name, plottitle):
    """
    Displays a figure window with four subplots.
    In the two plots above, the signal S is plotted in time domain (real and imaginary parts);
    below, the signal is plotted in frequency domain (magnitude and phase).

    :param s: signal to be plotted
    :param fs: sampling frequency
    :param name: NAME is the "name" of the signal, e.g., if NAME is 's', then the labels on the y-axes will be
        's(t)' and '|s_F(f)|', respectively.
    :param plottitle: TITLE is the title that will appear above the two plots.
    :return:
    """

    if isinstance(fs, list):
        raise TypeError('Fs must be scalar')

    # Compute the time and frequency axes
    Ts = 1 / fs
    t = np.linspace(0, (np.size(s) - 1) * Ts, np.size(s))

    NFFT = np.ceil(np.log2(np.size(s)))  # for computational efficiency
    NFFT = np.power(2, NFFT)
    Tp = NFFT * Ts
    f = np.linspace(-1 / (2 * Ts), 1 / (2 * Ts) - 1 / Tp, int(NFFT))

    # Compute the FFT
    s_f = np.fft.fft(s, int(NFFT))
    # Correct the scaling and the frequency axis
    # Use fftshift to move the negative frequencies to the left
    s_f = Ts * np.fft.fftshift(s_f)

    # here we use the subplots
    fig, axs = plt.subplots(2, 2, constrained_layout=True)
    fig.suptitle(plottitle, fontsize=16)

    axs[0, 0].plot(t, s.real)  # axs[0].plot(t, s.real, 'r*--')
    # axs[0].grid()
    axs[0, 0].set(xlabel='t [s]', ylabel='Re['+name+'(t)]')

    axs[0, 1].plot(t, s.imag)  # axs[1].plot(t, s.imag, 'r*--')
    # axs[1].grid()
    axs[0, 1].set(xlabel='t [s]', ylabel='Im[' + name + '(t)]')

    axs[1, 0].plot(f, np.abs(s_f))  # axs[2].plot(f, np.abs(s_f), 'r*--')
    # axs[2].grid()
    axs[1, 0].set(xlabel='f [Hz]', ylabel='|'+name+'_F(f)|')

    axs[1, 1].plot(f, np.angle(s_f))  # axs[3].plot(f, np.angle(s_f), 'r*--')
    # axs[3].grid()
    axs[1, 1].set(xlabel='f [Hz]', ylabel='Angle [' + name + '_F(f)]')

    plt.show()

    return
