import numpy as np
from numpy import linalg as LA
from scipy import fft
from scipy.signal import upfirdn
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
import warnings


def ofdm_tx_frame(num_carriers, prefix_length, training_symbols, psd_mask, data_symbols):
    """
    Generates an OFDM frame.

    :param num_carriers: number of carriers per OFDM block (power of 2)
    :param prefix_length: cyclic-prefix length (in number of samples)
    :param training_symbols: vector of symbols known to the receiver,
        used to estimate the channel. Its length is the number of ones
        in the PSD_MASK (one training symbol per non-off carrier).
    :param psd_mask: A {0,1}-valued vector of length NUM_CARRIERS, used
        to turn off individual carriers.
    :param data_symbols: vector of symbols to transmit (it will be padded
        with zeros if the number of data symbols is not a multiple of the
        number of useful carriers)

    :return: tx_symbols: A column vector containing the generated OFDM symbols,
        corresponding to one OFDM frame with the training symbols transmitted
        during the first OFDM block and the data transmitted in the subsequent
        OFDM blocks.
    """

    if isinstance(num_carriers, list) | num_carriers < 0  | num_carriers % 1 != 0:
        raise ValueError('ofdm_tx_frame:dimensionMismatch', 'NUM_CARRIERS must be a positive scalar integer')

    if isinstance(prefix_length, list) | prefix_length < 0  | prefix_length % 1 != 0:
        raise ValueError('ofdm_tx_frame:dimensionMismatch', 'PREFIX_LENGTH must be a positive scalar integer')

    if psd_mask.ndim > 1 | psd_mask.size != num_carriers:
        raise ValueError('ofdm_tx_frame:dimensionMismatch','PSD_MASK must be a vector of length %d', num_carriers)

    if not np.logical_or(psd_mask == 0, psd_mask == 1).all():
        raise ValueError('ofdm_tx_frame:invalidMask', 'PSD_MASK must be {0,1}-valued')

    psd_mask = psd_mask > 0  # convert to logical

    # Determine the number of useful carriers
    num_useful_carriers = np.sum(psd_mask == 1)

    if training_symbols.size != num_useful_carriers:
        raise ValueError('ofdm_tx_frame:dimensionsMismatch', 'PREAMBLE must contain exactly %d symbols', num_useful_carriers)

    num_data_symbols = data_symbols.size
    num_ofdm_symbols = int(np.ceil(num_data_symbols/ num_useful_carriers))
    # number of zero data symbols to add to obtain an integer number of OFDM symbols
    num_symbols_padding = num_ofdm_symbols * num_useful_carriers - num_data_symbols
    data_symbols = np.concatenate((data_symbols, np.zeros(num_symbols_padding, dtype=complex)))

    num_ofdm_symbols = num_ofdm_symbols + 1  # we use one OFDM symbol for the preamble (to aid channel estimation)

    # Build matrix: the first column is for the preamble, and the following for the data
    B = np.zeros((num_useful_carriers, num_ofdm_symbols), dtype=complex)
    B[:, 0] = training_symbols
    B[:, 1:] = np.reshape(data_symbols, (num_useful_carriers, -1), 'F')

    # map B to A which includes zeros at the masked positions and elements of A at non-masked positions
    A = np.zeros((num_carriers, num_ofdm_symbols), dtype=complex)
    A[psd_mask, :] = B

    # Do the IFFT (IFFT applied on a matrix with axis = 0, returns a matrix with the IFFT of every column)
    # a0 = np.fft.ifft(A, num_carriers, axis=0)
    a0 = fft.ifft(A, num_carriers, axis=0)

    # Before sending through the channel we have to add the cyclic prefix
    a = np.concatenate((a0[num_carriers-prefix_length:num_carriers, :], a0), axis=0)

    # Serialize, so that the output is a vector
    tx_symbols = a.flatten('F')

    return tx_symbols


def create_multipath_channel_filter(amplitudes, delays, L):
    """
    Creates the sampled response of multipath channel.

    We assume that the shaping pulse is a sinc. This is the reconstruction filter.
    DELAYS and AMPLITUDES are vectors of the same length, specifying the
    strength and delay of each path. The DELAYS must be specified relative
    to the sampling period.

    :param amplitudes: vector specifying the strengths of the paths
    :param delays: vector specifying the delays of the paths
    :param L: The length of the tails of the sinc

    :return: h: contains the samples of filtered impulse response.
    """

    if amplitudes.size != delays.size:
        raise ValueError('create_multipath_channel_filter:wrongInputDimensions', 'AMPLITUDES and DELAYS must be vectors of the same length')

    # Here we implement the formula of h_overall, according to the lecture
    # notes (as well as in Hw exercise "Symbol-level channel" - for now we have
    # 1 sample per symbol, so this is the sample-level channel as well).
    # h_overall = p*h*q, with p,q being sinc functions.
    timeLine = np.arange(-L, L+1, 1)
    F = np.tile(timeLine, (delays.size, 1)).transpose()
    F = F - np.tile(delays, (timeLine.size, 1))
    h = np.dot(np.sinc(F), amplitudes)

    return h


def ofdm_rx_frame(rx_symbols, num_carriers, psd_mask, prefix_length):
    """
    Receiver for ODFM signals (without channel equalization).

    :param rx_symbols: vector of channel outputs. It is a filtered and noisy
        version of the transmitted symbols.
    :param num_carriers: number of carriers per OFDM block (power of 2). The
        same as in OFDM_TX_FRAME.
    :param psd_mask: A {0,1}-valued vector of length NUM_CARRIERS. The same
        as in OFDM_TX_FRAME.
    :param prefix_length: cyclic-prefix length (in number of samples). The
        same as in OFDM_TX_FRAME.

    :return: RF: Matrix of TRAINING_SYMBOLS and DATA_SYMBOLS (see OFDM_TX_FRAME)
        as seen at the output of the equivalent "parallel channels" created by
        OFDM. To obtain the RF matrix, we start by rearranging the received
        symbols columnwise in a matrix that has (NUM_CARRIERS + PREFIX_LENGTH)
        rows. In doing so, we remove the tail of RX_SYMBOLS that don't fill an
        entire column. Then we remove the cyclic prefix and take the FFT.
        Finally, the symbols that correspond to unused carriers are removed.
    """

    if isinstance(num_carriers, list) | num_carriers < 0  | num_carriers % 1 != 0:
        raise ValueError('ofdm_rx_frame:dimensionMismatch', 'NUM_CARRIERS must be a positive scalar integer')

    if isinstance(prefix_length, list) | prefix_length < 0  | prefix_length % 1 != 0:
        raise ValueError('ofdm_rx_frame:dimensionMismatch', 'PREFIX_LENGTH must be a positive scalar integer')

    if psd_mask.ndim > 1 | psd_mask.size != num_carriers:
        raise ValueError('ofdm_rx_frame:dimensionMismatch','PSD_MASK must be a vector of length %d', num_carriers)

    if not np.logical_or(psd_mask == 0, psd_mask == 1).all():
        raise ValueError('ofdm_rx_frame:invalidMask', 'PSD_MASK must be {0,1}-valued')

    psd_mask = psd_mask > 0  # convert to logical

    # trim the received data such that we have an integer number of OFDM symbols
    num_ofdm_symbols = int(np.floor(rx_symbols.size) / (num_carriers + prefix_length))
    rx_symbols = rx_symbols[0:num_ofdm_symbols*(num_carriers+prefix_length)]

    # "remove" cyclic prefix
    rx_withCP = np.reshape(rx_symbols, (prefix_length+num_carriers, num_ofdm_symbols), 'F')
    # remove the rows corresponding to the cyclic prefix
    rx_noCP = rx_withCP[prefix_length + np.arange(num_carriers), :]

    # go to the frequency domain
    # Rf = np.fft.fft(rx_noCP, num_carriers, axis=0)
    Rf = fft.fft(rx_noCP, num_carriers, axis=0)

    # remove the off-carrier symbols
    Rf = Rf[psd_mask, :]

    return Rf


def channel_est1(num_carriers, psd_mask, h):
    """
    Takes the channel impulse response h and returns the
    channel coefficients in the frequency domain.

    :param num_carriers: number of carriers per OFDM block (FFT/IFFT size)
    :param psd_mask: The PSD Mask used by the receiver.
        A {0,1}-valued vector of length NUM_CARRIERS used to
        turn off individual carriers if necessary (e.g., so as to
        avoid interference with other systems)
    :param h: Channel impulse response.

    :return: LAMBDA_CHANNEL: Column vector containing channel coefficients in the
        frequency domain. The number of elements of LAMBDA equals the
        number of ones in PSD_MASK. (We do not care about channel gains
         on the carriers that are turned off.)
    """

    if isinstance(num_carriers, list) | num_carriers < 0  | num_carriers % 1 != 0:
        raise ValueError('channel_est1:dimensionMismatch', 'NUM_CARRIERS must be a positive scalar integer')

    if psd_mask.ndim > 1 | psd_mask.size != num_carriers:
        raise ValueError('channel_est1:dimensionMismatch','PSD_MASK must be a vector of length %d', num_carriers)

    if not np.logical_or(psd_mask == 0, psd_mask == 1).all():
        raise ValueError('channel_est1:invalidMask', 'PSD_MASK must be {0,1}-valued')

    psd_mask = psd_mask > 0  # convert to logical

    # lambda_channel = np.fft.fft(h, num_carriers)  # DFT of h
    lambda_channel = fft.fft(h, num_carriers)
    lambda_channel = lambda_channel[psd_mask]  # We need only the non-zero subcarriers

    return lambda_channel


def channel_estLS(Rf, training_symbols):
    """
    Estimate the channel coefficients in the frequency domain.
    There is no channel information available. The channel coefficients
    are computed as the ratio between the first received OFDM symbol and the preamble.

    :param Rf: Matrix of TRAINING_SYMBOLS and DATA_SYMBOLS, as returned by OFDM_RX_FRAME.
    :param training_symbols: vector of symbols known to the receiver, used to estimate the channel.

    :return: LAMBDA_CHANNEL: Vector containing the channel coefficients in the frequency domain.
    """

    # Select the received symbol corresponding to the preamble
    Y = Rf[:,0]

    # Estimate the channel coefficients
    lambda_channel = Y / training_symbols

    return lambda_channel


def channel_estMMSE(Rf, N, psd_mask, training_symbols, Ka, delays, sigma2):
    """
    Estimate the channel coefficients in the frequency domain.
    The channel delays are known. The covariance matrix of channel
    amplitudes is also assumed to be known.

    :param Rf: Matrix of TRAINING_SYMBOLS and DATA_SYMBOLS, as returned by OFDM_RX_FRAME.
    :param N: number of carriers per OFDM block (power of 2).
    :param psd_mask: A {0,1}-valued vector of length NUM_CARRIERS, used to turn off individual carriers.
    :param training_symbols: vector of symbols known to the receiver, used
        to estimate the channel. Its length is the number of ones in the
        PSD_MASK (one training symbol per non-off carrier).
    :param Ka: the covariance matrix of channel amplitudes.
    :param delays: vector containing the delays for each path in the multipath channel.
        The delays are expressed in number of samples, i.e., tau_l/T_s.
    :param sigma2: noise variance.
    :return: LAMBDA_CHANNEL: Column vector containing the channel coefficients in the
        frequency domain. The number of elements in LAMBDA equals the number of ones in PSD_MASK.
    """

    if isinstance(N, list) | N < 0  | N % 1 != 0:
        raise ValueError('channel_estMMSE:dimensionMismatch', 'N must be a positive scalar integer')

    if psd_mask.ndim > 1 | psd_mask.size != N:
        raise ValueError('channel_estMMSE:dimensionMismatch','PSD_MASK must be a vector of length %d', N)

    if not np.logical_or(psd_mask == 0, psd_mask == 1).all():
        raise ValueError('channel_estMMSE:invalidMask', 'PSD_MASK must be {0,1}-valued')

    if delays.ndim > 1:
        raise ValueError('channel_estMMSE:dimensionMismatch', 'DELAYS must be a vector')

    num_useful_carriers = int(np.sum(psd_mask))

    psd_mask = psd_mask > 0  # convert to logical

    if training_symbols.ndim > 1 | training_symbols.size != num_useful_carriers:
        raise ValueError('channel_estMMSE:dimensionMismatch',
                         'PREAMBLE_SYMBOLS must be a vector of length %d', num_useful_carriers)

    y = Rf[:, 0]  # channel output due to training_symbols
    S = np.diag(training_symbols)  # notation as in the lecture notes

    # Form the matrix A according to the lecture notes
    allFreqIndices = np.fft.fftshift(np.arange(-N/2, N/2))  # [0, ..., (N/2 -1), -N/2, ..., -1]
    FreqIndices = allFreqIndices[psd_mask]
    exponentsOfA = np.matmul(FreqIndices.reshape(FreqIndices.size, 1), delays.reshape(1, delays.size))
    A = np.exp(-1j * 2 * (np.pi / N) * exponentsOfA)

    # Form the remaining matrices as in the lecture notes
    B = Ka
    Kd = np.matmul(A, np.matmul(B, A.conj().T))
    Kdy = np.matmul(Kd, S.conj().T)
    Kz = np.diag(sigma2 * np.ones(num_useful_carriers))
    Ky = np.matmul(S, np.matmul(Kd, S.conj().T)) + Kz
    # Estimate lambda
    tmp_matrix = np.linalg.lstsq(Ky, y, rcond=None)[0]
    lambda_channel = np.matmul(Kdy, tmp_matrix)

    return lambda_channel
