import numpy as np
import ofdmc as ofdmc
import warnings
from scipy import signal
import matplotlib.pyplot as plt
from fractions import Fraction
import utilPDC
import utilMDC
import utilOFDM
from scipy import fft
from scipy.signal import correlate as corr
import usrpc


def channelSimulator(tx_symbols, snr, cfo, clockOffset):
    """
    Simulates the effective channel seen between two USRP boards.

    To simulate the channel, we repeat dataIn 10 times, we add impairments
    (noise, CFO, clock offset, multipath), we cut the sequence at a random
    location and we output a sequence of length 4*length(TX_SIGNAL).

    :param tx_symbols: symbols that form the TX signal (symbol rate OFDMC.T)
    :param snr: signal to noise ratio [dB] (default 30dB).
    :param cfo: carrier frequency offset [Hz] (default 284 Hz, can be any value between -1kHz to 1kHz).
    :param clockOffset: sampling clock offset (in samples) between the transmitter and receiver boards
        (default 3 samples, can be any value from 0 to 9 samples, where we assume an upsampling factor
        of 10 with respect to OFDMC.T).
    :return: symbols that form the RX signal (symbol rate OFDMC.T).
    """

    lengthDataOut = 4 * tx_symbols.size

    # TO DO: add default values for snr, cfo, clockOffset

    if tx_symbols.ndim > 1:
        warnings.warn(f'ofdm:channelSimulator: TX_SIGNAL is not a vector, reshaping it!\n', RuntimeWarning)
        tx_symbols = tx_symbols.flatten('F')

    if isinstance(snr, list):
        raise ValueError('ofdm:channelSimulator', 'SNR is not scalar.')

    if isinstance(cfo, list) | np.iscomplex(cfo) | (cfo < -1e3) | cfo > 1e3:
        raise ValueError('ofdm:channelSimulator', 'CFO must be a real scalar between -1kHz to 1kHz')

    resampleFactor = 10  # for introducing the fixed clock offset with granularity of 1/resampleFactor
    jitterValue = 0  # clock jitter. Keep to 0 for now.
    clockMismatchPPM = 0  # [ppm]: clock drift between Tx/Rx (min value 22, for implementation issues)
    droppedSamples = 1234  # we drop some samples from the beginning of the sequence

    if isinstance(clockOffset, list) | clockOffset % 1 != 0 | clockOffset < 0 | clockOffset >= resampleFactor:
        raise ValueError('ofdm:channelSimulator',
                         'CLOCKOFFSET must be an integer scalar between 0 and {resampleFactor}')

    # repeat the data several times
    noRepData = 10
    dataTx = np.tile(tx_symbols, noRepData)

    # add CFO
    dataTx = dataTx * np.exp(1j * 2 * np.pi * cfo * np.arange(dataTx.size) * ofdmc.T)

    # multipath channel
    h = np.array([0.646648866062013, 0.413731206902691, 0.192749918053592])

    if h.size > ofdmc.cpLength + 1:
        warnings.warn(f'ofdm:impulseResponseTooLong: '
                      f'The channel impulse response is larger than the cyclic prefix length => '
                      f'it will create ISI\n', RuntimeWarning)

    # convolve the signal with channel impulse response
    dataTx = np.convolve(dataTx, h)

    # Add noise
    if snr < float('inf'):
        Es = np.var(dataTx)
        # find the noise variance sigma2 so that 10*log_10(Es/sigma2) = SNR_dB
        sigma2 = Es / (10 ** (snr / 10))
        sigma = np.sqrt(sigma2)
        # create the sample-level noise vector
        noise = (sigma / np.sqrt(2)) * np.random.randn(np.size(dataTx)) + \
                1j * (sigma / np.sqrt(2)) * np.random.randn(np.size(dataTx))
        actual_SNR = 10 * np.log10(np.var(dataTx) / np.var(noise))
        print(f'The obtained SNR is: {actual_SNR} [dB]\n')  # test that we get back SNR
        # create the channel output
        dataTx = dataTx + noise

    # drop some samples
    dataTx = dataTx[droppedSamples:]

    # resample and introduce some (fixed) clock offset / jitter
    if clockOffset > 0:
        puResampled = signal.resample(dataTx, resampleFactor * len(dataTx))

        print('channelSimulator: we are offsetting the sampling clock...\n')

        # introduce some jitter
        samplingMoments = np.arange(clockOffset, puResampled.size - resampleFactor, resampleFactor)
        jitterV = np.random.randint(-jitterValue, jitterValue + 1, size=samplingMoments.size)
    
        puReRe = puResampled[samplingMoments + jitterV]
    else:
        puReRe = dataTx

    if clockMismatchPPM > 0:
        tmpRat = Fraction.from_float(1+clockMismatchPPM*1e-6).limit_denominator(50000)
        N = tmpRat.numerator
        D = tmpRat.denominator
        print('The nominator/denominator for the clock mismatch:', N, D, '\n')
        puReRe = signal.resample_poly(puReRe, N, D)
        # [N,D] = rat((1+clockMismatchPPM*1e-6), 1e-7);
        # puReRe = resample(puReRe, N, D);
        # %puReRe = resample(puReRe, N, 1); puReRe = downsample(puReRe, D)

    dataRx = puReRe

    # select the data we output
    stopData = dataRx.size - int(np.floor(tx_symbols.size/2))
    startData = stopData - lengthDataOut
    dataRx1 = dataRx[startData:stopData]

    # print(dataRx.size, tx_symbols.size, startData, stopData) # for debugging

    if ofdmc.verbose:
        # plot the magnitude of the Rx data
        plt.grid()
        plt.plot(np.arange(dataRx.size), abs(dataRx))
        plt.title('Rx data: vertical bars show selected data')
        plt.xlabel('Sample Index')
        plt.ylabel('Magnitude')
        plt.axvline(x=startData, color='r')
        plt.axvline(x=stopData, color='r')
        plt.show()
    
        # plot the spectrum of the Rx data
        fso = 1/ofdmc.T
        N = dataRx.size
        freqLineo = np.arange(-fso/2, fso/2, fso/N)
        plt.plot(freqLineo, np.fft.fftshift(np.abs(np.fft.fft(dataRx, N))))
        plt.grid()
        plt.xlabel('f [Hz]')
        plt.ylabel('Magnitude')
        plt.title('OFDM Rx Spectrum')
        plt.show()

    rx_symbols = dataRx1

    return rx_symbols


def transmitter(data_bits):
    """
    Converts data bits to TX signal, using the parameters defined in OFDMC.
    To this end, the following steps are taken:

    1. Data bits are mapped to constellation symbols.

    2. Using the function OFDM_TX_FRAME_WITH_PILOTS data symbols are
    converted to time-domain OFDM symbols with pilot symbols on the
    selected carriers (and desired PSD mask).
    The resulting time-domain signal is scaled to be normalized
    between -1 and +1, so as to avoid clipping by the USRP board.

    3. The samples of time-domain data signal are prepended with a
    PREAMBLE that helps the receiver to detect the start of
    transmission and correct the CFO. (The PREAMBLE is normalized between -1 and +1.)

    We also insert some "zero"s before data signal. The receiver
    uses these zeros to estimate the noise variance (needed for MMSE
    estimation of channel coefficients).

    :param data_bits:
    :return:
    """

    if data_bits.ndim > 1:
        warnings.warn(f'ofdm:reshaping: DATA_BITS is not a vector, reshaping it', RuntimeWarning)
        data_bits = data_bits.flatten()

    if not np.logical_or(data_bits == 0, data_bits == 1).all():
        raise ValueError('ofdm:invalidInput','DATA_BITS does not contain bits')

    if data_bits.size < ofdmc.nDataBits:
        # data_bits = np.concatenate((data_bits, np.zeros(ofdmc.nDataBits - data_bits.size)))
        # Nicolae: add random bits instead of zeros
        randomPaddingBits = np.random.randint(2, size=ofdmc.nDataBits - data_bits.size)
        data_bits = np.concatenate((data_bits, randomPaddingBits))
    elif data_bits.size > ofdmc.nDataBits:
        warnings.warn(f'ofdmc:inputTooLong: DATA_BITS contains more than {ofdmc.nDataBits} bits. '
                      f'Truncating it!', RuntimeWarning)
        data_bits = data_bits[:ofdmc.nDataBits]

    # 1. Modulate the data bits to constellation symbols
    tx_bits = np.reshape(data_bits, (-1, int(np.log2(ofdmc.M))))
    decimal_symbols = utilPDC.sol_bi2de(tx_bits)

    # get the mapping
    if ofdmc.constellationType.lower() == 'qam':
        mapping = utilPDC.sol_qamMap(ofdmc.M)
    elif ofdmc.constellationType.lower() == 'psk':
        mapping = utilPDC.sol_pskMap(ofdmc.M)
    else:
        raise ValueError('ofdm:invalidModulation', 'Unsupported modulation type %s', ofdmc.constellationType)

    # normalize constellation to unit average energy - not really necessary
    mapping = mapping / np.sqrt(np.mean(np.abs(mapping)**2))

    # modulator
    data_symbols = utilPDC.sol_encoder(decimal_symbols, mapping)

    # 2. Convert data symbols to time-domain OFDM symbols
    tx_signal = ofdm_tx_frame_with_pilots(ofdmc.nCarriers, ofdmc.psdMask, ofdmc.cpLength, ofdmc.trainingSymbols,
                                          ofdmc.pilotIndices, ofdmc.pilotSymbols, data_symbols)

    # make sure the signal is scaled between +1 and -1
    tx_signal = tx_signal / np.max(np.abs(tx_signal))

    # 3. Prepend the preamble for symbol synchronization and CFO estimation.
    # Insert zeros for noise estimation.
    tx_signal = np.concatenate((ofdmc.preamble, np.zeros(ofdmc.nZeros, dtype=complex), tx_signal))

    # NOTE. ofdmc.preamble is normalized such that to fit between -1 and +1

    if ofdmc.verbose:
        # plot the frequency and time domain TX signals
        utilMDC.tfplotReImPhase(tx_signal, 1/ofdmc.T, 's', 'TX Signal')

    return tx_signal


def receiver(rx_signal):
    """
    Estimates the sent bits from RX signal.
    The receiver proceeds through the following steps:
    1. The first step is to find the start of data symbols and carrier frequency offset (CFO),
    by using the synchronization sequence (PREAMBLE defined in OFDMC); and correct the CFO.

    2. Next, the time-domain OFDM symbols are converted to frequency domain using OFDM_RX_FRAME function.

    3. Channel coefficients are estimated (in frequency domain) using CHANNEL_EST LS estimator.
    Then, the symbols are equalized.

    4. The residual rotation in OFDM symbols (due to CFO) is corrected with the help of pilot carriers.

    5. After removing the training sequence and the pilot carriers,
    the data bits are estimated using a minimum-distance decoder.

    :param rx_signal: The samples of the received signal.
    :return:
    """

    # print(rx_signal.size)  # for debugging

    # 1. Find the starting time of data, and estimate the CFO
    tau_estim, cfo_estim = estimateTauAndCFO(rx_signal, ofdmc.preamble, ofdmc.T, ofdmc.rangeCFO)

    # 1.1. Correct for CFO
    rx_signal = rx_signal * np.exp(-1j * 2 * np.pi * cfo_estim * np.arange(rx_signal.size) * ofdmc.T)

    # 1.2. Skip the synchro and zero sequence
    tau_estim = tau_estim + ofdmc.preamble.size + ofdmc.nZeros

    # 1.3. Compute the OFDM data length (including the training OFDM symbols)
    ofdmDataLength = ofdmc.nOFDMSymbols * (ofdmc.nCarriers + ofdmc.cpLength)

    # We might have several peaks in xcorr() because we repeat the data.
    # To be sure we do not get out of the Rx data, we back up
    if tau_estim + ofdmDataLength - 1 > rx_signal.size:
        tau_estim = tau_estim - ofdmc.txSignalLength
        if ofdmc.verbose:
            print('We have backed up to the previous peak of xcorr()...\n')

    # 1.4. Extract OFDM symbols
    rx_data = rx_signal[tau_estim:tau_estim + ofdmDataLength]

    if ofdmc.verbose:
        utilMDC.tfplotReImPhase(rx_data, 1/ofdmc.T, 'r', 'RX Signal (Data Part)')

    # 1.5. Extract the noise and estimate the noise variance
    noise_samples = rx_signal[tau_estim - ofdmc.nZeros:tau_estim]
    # estimate the noise only if ofdmc.nZeros > 0
    if noise_samples.size > 0:
        noiseVar = np.var(noise_samples)
        sigma2 = ofdmc.nCarriers * noiseVar  # We have N = ofdmc.noCarriers carriers.

    # 2. Convert OFDM symbols to frequency domain
    Y = utilOFDM.ofdm_rx_frame(rx_data, ofdmc.nCarriers, ofdmc.psdMask, ofdmc.cpLength)

    # 3. Estimate the channel and equalize the symbols
    # 3.1. Channel estimation
    lambda_channel = utilOFDM.channel_estLS(Y, ofdmc.trainingSymbols)

    if ofdmc.verbose:
        # plot the estimated channel response
        lambda_aug = np.zeros(ofdmc.nCarriers, dtype=complex)
        lambda_aug[ofdmc.psdMask > 0] = lambda_channel
        carrierLine = np.arange(-ofdmc.nCarriers/2, ofdmc.nCarriers/2)
        plt.plot(carrierLine, np.abs(np.fft.fftshift(lambda_aug)), 's-', label='Magnitude')
        plt.grid()
        plt.plot(carrierLine, np.angle(np.fft.fftshift(lambda_aug)), '*-', label='Phase')
        plt.legend()
        plt.title('Estimated channel (LS)')
        plt.xlabel('Carrier index')
        plt.ylabel('Magnitude and Phase')
        plt.show()

    # 3.2. Equalization
    Y_eq = Y * np.tile(1/lambda_channel, (ofdmc.nOFDMSymbols, 1)).transpose()

    # 4. Correct the rotation due to the residual CFO
    Y_aug = np.zeros((ofdmc.nCarriers, ofdmc.nOFDMSymbols), dtype=complex)
    Y_aug[ofdmc.psdMask > 0, :] = Y_eq

    Y_aug = correctOFDMSymbolRotation(Y_aug, ofdmc.pilotIndices, ofdmc.pilotSymbols)

    # 5. Estimate the data bits
    # 5.1. Remove the training sequence
    Y_aug  = Y_aug[:, 1:]
    # 5.2. Remove the pilot symbols and masked carriers
    Y_data = Y_aug[ofdmc.dataMask, :]
    # 5.3. Serialize the Rx data symbols
    rx_data_symbols = Y_data.flatten('F')

    if ofdmc.verbose:
        plt.scatter(rx_data_symbols.real, rx_data_symbols.imag, marker='*', color='b')
        plt.grid()
        plt.title('RX Constellation')
        plt.show()

    # Get the mapping
    if ofdmc.constellationType.lower() == 'qam':
        mapping = utilPDC.sol_qamMap(ofdmc.M)
    elif ofdmc.constellationType.lower() == 'psk':
        mapping = utilPDC.sol_pskMap(ofdmc.M)
    else:
        raise ValueError('ofdm:invalidModulation', 'Unsupported modulation type %s', ofdmc.constellationType)

    # Normalize constellation to unit average energy - not really necessary
    mapping = mapping / np.sqrt(np.mean(np.abs(mapping)**2))

    est_symbols = utilPDC.sol_decoder(rx_data_symbols, mapping)
    est_bits = utilPDC.sol_de2bi(est_symbols)

    return est_bits


def ofdm_tx_frame_with_pilots(num_carriers, psd_mask, prefix_length,
                                  training_symbols, pilot_indices, pilot_symbols, data_symbols):
    """
    Generates an OFDM frame.

    :param num_carriers: number of carriers per OFDM block (power of 2)
    :param psd_mask: A {0,1}-valued vector of length NUM_CARRIERS, used to turn off individual carriers.
    :param prefix_length: cyclic-prefix length (in number of samples).
    :param training_symbols: vector of symbols known to the receiver, used to estimate the channel.
        Its length is the number of ones in the PSD_MASK (one training symbol per non-off carrier).
    :param pilot_indices: indices of pilot symbols.
    :param pilot_symbols: vector of pilot symbols.
    :param data_symbols: vector of symbols to transmit (it will be padded with zeros if the number of data
        symbols is not a multiple of the number of useful carriers, where the number of useful carriers
         is the number of ones in the PSD_MASK minus the number of pilot symbols).
    :return: tx_symbols: A column vector containing the generated OFDM symbols, corresponding to one OFDM
        frame with the training symbols transmitted during the first OFDM block and the data transmitted
        in the subsequent OFDM blocks.
    """



    return tx_symbols


def estimateTauAndCFO(rx_symbols, preamble, T, cfo_range):
    """
    Estimates the starting time of the data symbols and the carrier frequency offset,
    by correlating the received signal with the preamble.

    :param rx_symbols: the vector of received symbols.
    :param preamble: the preamble sequence.
    :param T: the sampling period of data.
    :param cfo_range: vector of candidates CFO values, used to search for a coarse CFO estimate.
    The coarse estimate is then refined using a second order approximation.
    :return: (tau_estim, cfo_estim): the estimate of the starting time of the data symbols and
        the estimate of the carrier frequency offset.
    """



    return tau_estim, cfo_estim


def correctOFDMSymbolRotation(Y, pilot_indices, pilot_symbols):
    """
    Uses the pilot symbols to estimate and remove the block-dependent phase rotation
    in the OFDM symbols (after the DFT).

    :param Y: a matrix in which, except for the rotation, the columns are noisy OFDM blocks (after DFT).
        Each element of a column is expected to be rotated by the same phase.
    :param pilot_indices: indices of pilot symbols.
    :param pilot_symbols: vector of pilot symbols.
    :return: Y_corr: the matrix Y corrected for the rotations.
    """



    return Y_corr
