import numpy as np


def noise_norm(txSymbols, sufficientStatistics):

    effectiveNoise = sufficientStatistics - txSymbols
    z = np.sum(np.abs(effectiveNoise)**2)

    return z


def mimo_channel(*args):
    """
    Returns the received symbols and the channel matrix.
    If X is a matrix, the same channel matrix is applied to each column of X.
    So, Y will be a matrix with RXANTENANS rows and the same number of columns as X.
    This is useful for simulating a channel that remains constant over more than one symbols period.

    The channel matrix is generated randomly each time the function is called.
    It consists of i.i.d. circularly-symmetric zero-mean unit-variance complex Gaussian entries.

    :param args[0]: is the column vector of TX symbols. Its number of rows determines the number of TX antennas.
    :param args[1]: is the signal to noise ratio (in dB).
    :param args[2]: is the number of RX antennas (default = 2).
    :return: Y: contains the received symbols. It has RXANTENNAS rows. H: the channel matrix.
    """

    if len(args) < 3:
        rxAntennas = 2
    else:
        rxAntennas = args[2]

    x = args[0]
    SNR = args[1]

    if isinstance(rxAntennas, list) | (np.floor(rxAntennas) != rxAntennas) | rxAntennas <= 0:
        raise ValueError('mimoChannel:invalidInput', 'RXANTENNAS must be a positive integer')

    if isinstance(SNR, list):
        raise ValueError('mimoChannel:invalidInput', 'SNR must be a scalar')

    txAntennas = x.shape[0]

    # Channel Matrix
    H = np.random.randn(rxAntennas, txAntennas) / np.sqrt(2) + \
        1j * np.random.randn(rxAntennas, txAntennas) / np.sqrt(2)

    # The Rx signal
    y = np.matmul(H, x)

    # Add the noise
    Es = np.mean(np.abs(x.flatten('F'))**2)  # should be 1 if the constellation has been normalized
    # find the noise variance sigma2 so that 10*log_10(Es/sigma2) = SNR_dB
    sigma2 = Es / (10 ** (SNR / 10))
    sigma = np.sqrt(sigma2)
    # create the sample-level noise vector
    noise = (sigma / np.sqrt(2)) * np.random.randn(rxAntennas, x.shape[1]) + \
            1j * (sigma / np.sqrt(2)) * np.random.randn(rxAntennas, x.shape[1])
    # actual_SNR = 10 * np.log10(np.var(x.flatten('F')) / np.var(noise.flatten('F')))
    # print(f'The obtained SNR is: {actual_SNR} [dB]\n')  # test that we get back SNR
    # create the channel output
    y_noise = y + noise

    return y_noise, H


def lmmse_equalizer(rx_symbols, H, sigma2):
    """
    Implements a linear MMSE equalizer for the MIMO channel. It computes the
    sufficient statistic for decision according to the linear MMSE rule.

    :param rx_symbols: is either a column vector or a matrix, and contains the received symbols.
    If RX_SYMBOLS is a matrix, the function works on each column separately.
    :param H: is the channel matrix. It has to have the same number of rows as the number of rows of RX_SYMBOLS.
    :param sigma2: The noise variance.
    :return: SUFF_STAT: is the equalized signal, used for decision. It has the same size as RX_SYMBOLS.
    """



    return suff_stat


def zeroForcing_equalizer(rx_symbols, H):
    """
    Implements a zero-forcing equalizer for MIMO channel. It computes the
    sufficient statistic for decision by inverting the channel matrix H.

    :param rx_symbols: is either a column vector or a matrix, and contains the received symbols.
    If RX_SYMBOLS is a matrix, the function works on each column separately.
    :param H: is the channel matrix. It has to have the same number of rows as the number of rows of RX_SYMBOLS.
    :return: SUFF_STAT: is the equalized signal, used for decision. It has the same size as RX_SYMBOLS.
    """



    return suff_stat
