import numpy as np
from numpy import linalg as LA
from scipy.signal import upfirdn
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
import warnings


def sol_qamMap(M):

    # Verify that M is the square of a power of two
    if np.log2(np.sqrt(M)) != np.fix(np.log2(np.sqrt(M))):
        raise ValueError('sol_qamMap:invalidSize', 'M must be in the form of M = 2^(2K), where K is a positive integer.')

    aux = np.arange(-(np.sqrt(M) - 1), (np.sqrt(M) - 1) + 1, 2)

    x, y = np.meshgrid(aux, np.flip(aux))

    c = x + 1j * y

    # We finally reshape c to be a row vector
    # The columns are stacked on each other as in the homework assignment figures
    c = c.flatten('F')

    return c


def sol_pskMap(M):

    # Verify that M is an integer power of two
    if np.log2(M) != np.fix(np.log2(M)):
        raise ValueError('sol_pskMap:invalidSize', 'M must be in the form of M = 2^K, where K is a positive integer.')

    c = np.exp(1j * 2 * np.pi * np.arange(0, M) / M)

    return c


def sol_encoder(x, mapping):

    M = x.size

    # Verify you do not have a matrix
    if x.ndim > 1:
        raise TypeError('sol_encoder:wrongInputValues', 'X should be either a row or column vector, not a matrix!')

    # Verify that x contains only INTEGER elements between 0 and M - 1
    if ~np.all((x >= 0) & (x < M) & (x == np.fix(x))):
        raise ValueError('sol_encoder:wrongInputValues', 'Elements of input X must be integers in the range [0, %d]' % (M - 1))

    y = mapping[x.astype(int)]

    return y


def sol_decoder(y, mapping):

    # For each element of y, we need to compute the distance to each symbol of the  constellation, and find the one that is closest

    # Make sure y is a vector
    if y.ndim > 1:
        raise TypeError('sol_decoder:inputNotVector', 'Input Y must be a vector')

    # Compute all Euclidean distances
    M = mapping.size  # size of the constellation
    L = y.size
    # distances = np.abs(np.tile(y, (M, 1)) - np.transpose(np.tile(mapping, (L, 1))))
    # or using meshgrid
    y_rep, mapping_rep = np.meshgrid(y, mapping)
    distances = np.abs(y_rep - mapping_rep)

    z = np.argmin(distances, axis=0)

    return z


def sol_rcosdesign(beta, span, sps):
    # returns the root-raised-cosine, with unit norm

    # Verify that span * sps is even
    if (span * sps) % 2 != 0:
        raise ValueError('sol_rcosdesign:wrongInputValues', 'Product of SPS and SPAN must be even.')

    # compute the time line
    M = span * sps
    no_samples = np.round(M/2)
    t = np.arange(-no_samples, no_samples + 1, 1) / sps
    # print(t)

    y = np.zeros(t.shape, dtype=np.float)

    if beta == 0:
        y = np.sinc(t)
    else:
        # implement the formula from Bixio's PDC book
        # first take care of the special cases (when the denominator is zero)

        # mask = np.argwhere(t == 0.25 / beta)  # testing exact equality does not quite work!
        mask = np.isclose(np.abs(t), 0.25 / beta)
        # print(mask)

        y[mask] = (beta / np.sqrt(2)) * ((1 + 2/np.pi)*np.sin(np.pi/(4*beta)) + (1 - 2/np.pi)*np.cos(np.pi/(4*beta)))
        # then the non-special case
        scaling_factor = (4 * beta) / np.pi
        nominator = np.cos((1 + beta) * np.pi * t[~mask]) + (((1 - beta) * np.pi) / (4 * beta)) * np.sinc((1 - beta) * t[~mask])
        denominator = 1 - np.square(4 * beta * t[~mask])
        y[~mask] = scaling_factor * (nominator / denominator)

    # normalize to have unit norm
    y = y / LA.norm(y)

    return y


def sol_symbols2samples(y, h, USF):

    # We first upsample the vector of symbols y by factor USF, inserting USF-1
    # zeros between every two consecutive samples of y
    # y_up = upfirdn([1], y, USF)  # this will not introduce the last USF-1 zeros
    # here with Kronecker it does the job (if we do need those last USF-1 zeros...)
    y_up = np.kron(y, np.concatenate((np.ones(1), np.zeros(USF-1))))

    # Convolve the upsampled symbol vector with the shaping pulse h
    z = np.convolve(y_up, h)
    # z = fftconvolve(y_up, h)  # use this one from scipy, it is faster in principle

    return z


def sol_sufficientStatistics(r, h, USF):
    # For the implementation we use the matched filter, followed by downsampling
    # by a factor USF

    h_matched = np.conj(np.flip(h))
    # in our case h_matched = h, since h is real and symmetric

    # the matched filter output before downsampling is
    y = np.convolve(r, h_matched)
    # y = fftconvolve(r, h_matched)  # use this one from scipy, it is faster in principle

    # Picking the desired matched filter outputs:
    # Comment: the matched filter output (before downsampling) is the correlation
    # between the received signal and h. The first useful sample is the
    # length(h) th correlation result. At the end there are length(h)-1 unused
    # correlation results.

    y = y[np.size(h)-1:-np.size(h) + 1]

    x = y[0:y.size:USF]

    return x, y


def sol_bi2de_old(b):
    # SOL_BI2DE Convert binary vectors to decimal numbers
    #   D = SOL_BI2DE(B) converts a binary vector B to a decimal value D. When
    #   B is a matrix, the conversion is performed row-wise and the output
    #   D is a (column) vector of decimal values. The default orientation of
    #   the binary input is Right-LSB: the first element in B represents
    #   the most significant bit.

    # Test that the matrix contains only ones and zeros
    if not np.logical_or(b == 0, b == 1).all():
        raise ValueError('sol_bi2de:nonBinaryValuesInInput', 'Binary matrix B can only contain values 1 and 0')

    # Determine dimension of input, and prepare vector containing powers of 2
    # Assume b is row by row matrix (each row corresponds to a binary representation of a decimal value)
    ncols = b.shape[1]
    ptwo = np.power(2, np.arange(ncols - 1, -1, -1).astype(float))  # vector of powers

    # Multiply and sum
    d = np.dot(b, ptwo)

    return d


def sol_bi2de(*args):
    """
    Convert binary vectors to decimal numbers.

    D = SOL_BI2DE(B) converts a binary vector B to a decimal value D. When B is a matrix,
    the conversion is performed row-wise and the output D is a column vector of decimal values.
    The default orientation of the binary input is Right-MSB: the first element in B represents
    the least significant bit.

    In addition to the input matrix, an optional parameter MSBFLAG can be given:

    D = SOL_BI2DE(B, MSBFLAG) uses MSBFLAG to determine the input orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the input orientation such that the MSB is on the left.

    :param args: B, MSBFLAG
    :return: D: vector of decimal values
    """

    b = args[0]

    # Set defaults
    if len(args) < 2:
        msbflag = 'right-msb'
    else:
        msbflag = args[1]

    # Test that the matrix contains only ones and zeros
    if not np.logical_or(b == 0, b == 1).all():
        raise ValueError('sol_bi2de:nonBinaryValuesInInput', 'Binary matrix B can only contain values 1 and 0')

    # Determine dimension of input, and prepare vector containing powers of 2
    # Assume b is row by row matrix (each row corresponds to a binary representation of a decimal value)
    ncols = b.shape[1]
    ptwo = np.power(2, np.arange(0, ncols).astype(float))  # vector of powers

    # Flip the input if MSB is left
    if msbflag.lower() == 'left-msb':
        ptwo = np.flip(ptwo)
    elif msbflag.lower() == 'right-msb':
        pass
    else:
        raise ValueError('sol_bi2de:wrongMSBflag', 'Unsupported value for MSBFLAG')

    # Multiply and sum
    d = np.dot(b, ptwo)

    return d


def sol_de2bi_old(d):
    # SOL_DE2BI Convert decimal numbers to binary numbers
    #   B = SOL_DE2BI(D) converts a vector D of non-negative integers from base
    #   10 to a binary matrix B. Each row of the binary matrix B corresponds
    #   to one element of D. The default orientation of the binary output
    #   is Right-LSB, i.e., the first element in a row of B represents the
    #   most significant bit. If D is a matrix rather than a row or column
    #   vector, the matrix is first converted to a vector (column-wise).

    # assume d is a vector
    # Convert d to vector if necessary
    d = d.flatten('F')

    # Make sure input is nonnegative decimal
    if np.any((d < 0) | (d != np.fix(d)) | np.isnan(d) | np.isinf(d)):
        raise ValueError('sol_de2bi:invalidInput', 'Input must contain only finite positive integers.')

    # number of bits to use
    nmax = np.amax([1, np.floor(1 + np.log2((np.amax(d))))])

    # We are ready to get the job done
    # We use two approaches
    # The idea is to use a function that manipulates bits ...
    # and manipulate them in such a way that we get a hold of various bits.
    # One approach uses bitshift; the other approach uses bitand
    # The following is useful in both approaches

    shifts = np.flip(np.arange(0, nmax, 1).astype(int))  # for right LSB

    # bitshift approach
    # ----------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=np.right_shift([d d d d],[0 1 2 3])
    # This returns b=[11 5 2 1]
    # Since 11 is odd, the first bit of d is 1
    # Since 5 is odd, the second bit of d is 1
    # Since 2 is even, the third bit of d is 0 etc.
    # Hence mod(b,2) returns the binary representation 1 1 0 1 (left LSB)

    # c = np.right_shift(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(shifts, (d.size, 1))) % 2
    # c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    # bitand approach
    # ---------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=bitand([d d d d],[1 2 4 8])
    # This returns b=[1 2 0 8]
    # The nonzero positions of b is where the bit is 1
    # Hence b~=0 (or logical(b)) returns the binary representation 1 1 0 1

    powersOf2 = 2 ** shifts
    c = np.bitwise_and(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(powersOf2, (d.size, 1)))
    c = (c > 0)*1  # find non-zero values, then convert to 0/1
    c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    return c


def sol_de2bi(*args):
    """
    Convert decimal numbers to binary numbers.

    B = SOL_DE2BI(D) converts a vector D of non-negative integers from base 10 to a binary matrix B.
    Each row of the binary matrix B corresponds to one element of D. The default orientation of the binary output
    is Right-MSB, i.e., the first element in a row of B represents the least significant bit.
    If D is a matrix rather than a row or column vector, the matrix is first converted to a vector (column-wise).

    In addition to the input vector D, two optional parameters can be given:

    B = SOL_DE2BI(D,MSBFLAG) uses MSBFLAG to determine the output orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the output orientation to display the MSB to the left.

    B = SOL_DE2BI(D,MSBFLAG,N) uses N to define how many binary digits (columns) are output.
    The number of bits must be large enough to represent the largest number in D.

    :param args: D,MSBFLAG,N
    :return: B
    """

    d = args[0]

    # assume d is a vector
    # Convert d to vector if necessary
    d = d.flatten('F')

    # Check number of arguments and assign default values as necessary
    nmax = np.amax([1, np.floor(1 + np.log2((np.amax(d))))])
    if len(args) < 3:
        n = nmax
    elif nmax > args[2]:
        raise ValueError('sol_de2bi:insufficientNbits',
                         'Specified number of bits is too small to represent some of the decimal inputs,'
                         '\n at least %d bits are required' % nmax)
    else:
        n = args[2]

    if len(args) < 2:
        msbflag = 'right-msb'
    else:
        msbflag = args[1]

    # Make sure input is nonnegative decimal
    if np.any((d < 0) | (d != np.fix(d)) | np.isnan(d) | np.isinf(d)):
        raise ValueError('sol_de2bi:invalidInput', 'Input must contain only finite positive integers.')

    # We are ready to get the job done
    # We use two approaches
    # The idea is to use a function that manipulates bits ...
    # and manipulate them in such a way that we get a hold of various bits.
    # One approach uses bitshift; the other approach uses bitand
    # The following is useful in both approaches

    shifts = np.arange(0, n, 1).astype(int)  # for right MSB

    if msbflag.lower() == 'left-msb':
        shifts = np.flip(shifts)
    elif msbflag.lower() == 'right-msb':
        pass
    else:
        raise ValueError('sol_de2bi:wrongMSBflag', 'Unsupported value for MSBFLAG')

    # bitshift approach
    # ----------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=np.right_shift([d d d d],[0 1 2 3])
    # This returns b=[11 5 2 1]
    # Since 11 is odd, the first bit of d is 1
    # Since 5 is odd, the second bit of d is 1
    # Since 2 is even, the third bit of d is 0 etc.
    # Hence mod(b,2) returns the binary representation 1 1 0 1 (left LSB)

    # c = np.right_shift(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(shifts, (d.size, 1))) % 2
    # c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    # bitand approach
    # ---------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=bitand([d d d d],[1 2 4 8])
    # This returns b=[1 2 0 8]
    # The nonzero positions of b is where the bit is 1
    # Hence b~=0 (or logical(b)) returns the binary representation 1 1 0 1

    powersOf2 = 2 ** shifts
    c = np.bitwise_and(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(powersOf2, (d.size, 1)))
    c = (c > 0)*1  # find non-zero values, then convert to 0/1
    c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    return c


def sol_eyediagram(y, Fs, T):
    #MY_EYEDIAGRAM Generates an eye diagram
    #   MY_EYEDIAGRAM(Y, FS, T) plots the eye diagram corresponding to the
    #   sampled matched filter output Y, assuming the sampling frequency is
    #   Fs and the symbol duration is T

    if isinstance(Fs, list) or Fs < 0:
        raise ValueError('Fs must be a positive scalar (sampling frequency)')

    if isinstance(T, list) or T < 0:
        raise ValueError('T must be a positive scalar (symbol duration)')

    # cast y to real, as it should be
    y = y.real

    time = np.arange(0, 2*T, 1/Fs) - T

    # remove extra samples (if any)
    y = y[0: y.size - (y.size % time.size)]
    y = np.reshape(y, (time.size, int(y.size/time.size)), 'F')

    plt.plot(time, y)
    plt.ylabel('Amplitude')
    plt.xlabel('Time [s]')
    plt.grid()
    # plt.show()


def channel(samples, SNR_dB, Es, max_delay):
    # CHANNEL Simulate AWGN channel with random propagation delay

    if samples.ndim > 1:
        warnings.warn('pdc:channel --> TX_SIGNAL is not a vector, reshaping it!')
        samples = samples.flatten('F')

    if (not np.isscalar(SNR_dB)) or (not np.isreal(SNR_dB)):
        raise ValueError('pdc:channel --> SNR must be a real scalar')

    if (not np.isscalar(Es)) or (not np.isreal(Es)) or (Es < 0):
        raise ValueError('pdc:channel --> Symbol energy (Es) must be a real positive scalar')

    # add the delay
    delay = np.random.randint(max_delay + 1)
    samples = np.concatenate((np.zeros(delay), samples), axis=None)
    # add the noise
    # find the noise variance sigma2 so that 10*log_10(Es/sigma2) = SNR_dB
    sigma2 = Es / (10 ** (SNR_dB / 10))
    sigma = np.sqrt(sigma2)
    # create the sample-level noise vector
    noise = (sigma / np.sqrt(2)) * np.random.randn(np.size(samples)) + 1j * (sigma / np.sqrt(2)) * np.random.randn(
        np.size(samples))
    actual_SNR = 10 * np.log10(Es / np.var(noise))
    print('The obtained SNR is: %s [dB]', actual_SNR)  # test that we get back SNR
    # create the channel output
    rx_signal = samples + noise

    return rx_signal, delay


def sol_estimateTau(rx_signal, preamble):
    # TAU_ESTIM = MY_ESTIMATE_TAU(RX_SIGNAL, PREAMBLE)
    #   Estimates the channel delay by computing the inner products between
    #   the received signal (RX_SIGNAL) and the time-shifted copies of the
    #   known-to-receiver signal (PREAMBLE) and finding the maximizing
    #   time-shift.

    if np.isscalar(rx_signal) or np.isscalar(preamble):
        raise ValueError('pdc:estimateTau --> RX_SIGNAL and SYNC_CODE must be vectors')

    # we use the convolution to compute the correlation between rx_signal and preamble
    IP = np.convolve(rx_signal, np.flip(preamble))

    # plot the result to check if there is a clear maximum
    plt.stem(np.abs(IP))
    plt.ylabel('abs(correlation)')
    plt.xlabel('Index')
    plt.title('Correlation of rx_signal with the preamble')
    plt.grid()
    plt.show()

    # remove the ramp - up of the convolution
    IP = IP[preamble.size - 1:]

    tau_estim = np.argmax(np.abs(IP))

    return tau_estim
