import numpy as np
from numpy import linalg as LA
from scipy.signal import upfirdn
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
import warnings


def sol_qamMap(M):
    """
    SOL_QAMMAP Creates constellation for square QAM modulations
    C = SOL_QAMMAP(M) outputs a 1xM vector with the constellation for the
    quadrature amplitude modulation of alphabet size M, where M is of
    the form 2^(2m) for some positive integer m. The signal
    constellation is a square constellation.

    :param M: M is of the form 2^(2m) for some positive integer m.
    :return: 1xM vector with the constellation for the quadrature amplitude modulation of alphabet size M.

    """

    # Verify that M is the square of a power of two
    if np.log2(np.sqrt(M)) != np.fix(np.log2(np.sqrt(M))):
        raise ValueError('sol_qamMap:invalidSize',
                         'M must be in the form of M = 2^(2K), where K is a positive integer.')

    aux = np.arange(-(np.sqrt(M) - 1), (np.sqrt(M) - 1) + 1, 2)

    x, y = np.meshgrid(aux, np.flip(aux))

    c = x + 1j * y

    # We finally reshape c to be a row vector
    # The columns are stacked on each other as in the homework assignment figures
    c = c.flatten('F')

    return c


def sol_pskMap(M):
    """
    SOL_PSKMAP Creates constellation for Phase Shift Keying modulation
    C = SOL_PSKMAP(M) outputs a 1xM vector with the complex symbols of the
    PSK constellation of alphabet size M, where M is of the form 2^m
    for some positive integer m.

    :param M: M is of the form 2^m for some positive integer m.
    :return: 1xM vector with the complex symbols of the PSK constellation of alphabet size M.

    """

    # Verify that M is an integer power of two
    if np.log2(M) != np.fix(np.log2(M)):
        raise ValueError('sol_pskMap:invalidSize', 'M must be in the form of M = 2^K, where K is a positive integer.')

    c = np.exp(1j * 2 * np.pi * np.arange(0, M) / M)

    return c


def sol_encoder(x, mapping):
    """
    SOL_ENCODER Maps a vector of M-ary integers to constellation points
    Y = SOL_ENCODER(X, MAPPING) outputs a vector of (possibly complex)
    symbols from the constellation specified as second parameter,
    corresponding to the integer valued symbols of X.  Input X can be
    a row or column vector, and output Y has the same dimensions as X.
    If the length of vector MAPPING is M, then the message symbols of
    X must be integers between 0 and M-1.


    :param x: Input X can be a row or column vector containing the symbol values.
    :param mapping: The constellation.
    :return: a vector of (possibly complex) symbols from the constellation specified as second parameter.
    """

    M = x.size

    # Verify you do not have a matrix
    if x.ndim > 1:
        raise TypeError('sol_encoder:wrongInputValues', 'X should be either a row or column vector, not a matrix!')

    # Verify that x contains only INTEGER elements between 0 and M - 1
    if ~np.all((x >= 0) & (x < M) & (x == np.fix(x))):
        raise ValueError('sol_encoder:wrongInputValues',
                         'Elements of input X must be integers in the range [0, %d]' % (M - 1))

    y = mapping[x.astype(int)]

    return y


def sol_decoder(y, mapping):
    """
    SOL_DECODER Minimum distance slicer
    Z = SOL_DECODER(Y, MAPPING) demodulates vector Y by finding
    the element of the specified constellation that is closest to each
    element of input Y. Y contains the outputs of the matched filter of
    the receiver, and it can be a row or column vector. MAPPING specifies
    the constellation used, it can also be a row or column vector.
    Output Z has the same dimensions as input Y. The elements of Z are
    M-ary symbols, i.e., integers between 0 and [M=length(MAPPING)]-1.

    :param y: Y contains the outputs of the matched filter of the receiver, and it can be a row or column vector.
    :param mapping: The constellation.
    :return: The elements of Z are M-ary symbols, i.e., integers between 0 and [M=length(MAPPING)]-1.
    """

    # For each element of y, we need to compute the distance to each symbol of the  constellation, and find the one that is closest

    # Make sure y is a vector
    if y.ndim > 1:
        raise TypeError('sol_decoder:inputNotVector', 'Input Y must be a vector')

    # Compute all Euclidean distances
    M = mapping.size  # size of the constellation
    L = y.size
    # distances = np.abs(np.tile(y, (M, 1)) - np.transpose(np.tile(mapping, (L, 1))))
    # or using meshgrid
    y_rep, mapping_rep = np.meshgrid(y, mapping)
    distances = np.abs(y_rep - mapping_rep)

    z = np.argmin(distances, axis=0)

    return z


def sol_rcosdesign(beta, span, sps):
    """
    SOL_RCOSDESIGN returns the coefficients of a root-raised-cosine pulse.

    :param beta: The roll-off factor. Should be real valued and in the interval [0, 1].
    :param span: The number of symbols to which the pulse is truncated.
    :param sps: The number of samples per symbol.
    :return: The coefficients of the root-raised-cosine pulse. The pulse is normalized such that it has unit norm.
    """

    # Verify that span * sps is even
    if (span * sps) % 2 != 0:
        raise ValueError('sol_rcosdesign:wrongInputValues', 'Product of SPS and SPAN must be even.')

    # compute the time line
    M = span * sps
    no_samples = np.round(M / 2)
    t = np.arange(-no_samples, no_samples + 1, 1) / sps
    # print(t)

    y = np.zeros(t.shape, dtype=np.float)

    if beta == 0:
        y = np.sinc(t)
    else:
        # implement the formula from Bixio's PDC book
        # first take care of the special cases (when the denominator is zero)

        # mask = np.argwhere(t == 0.25 / beta)  # testing exact equality does not quite work!
        mask = np.isclose(np.abs(t), 0.25 / beta)
        # print(mask)

        y[mask] = (beta / np.sqrt(2)) * (
                    (1 + 2 / np.pi) * np.sin(np.pi / (4 * beta)) + (1 - 2 / np.pi) * np.cos(np.pi / (4 * beta)))
        # then the non-special case
        scaling_factor = (4 * beta) / np.pi
        nominator = np.cos((1 + beta) * np.pi * t[~mask]) + (((1 - beta) * np.pi) / (4 * beta)) * np.sinc(
            (1 - beta) * t[~mask])
        denominator = 1 - np.square(4 * beta * t[~mask])
        y[~mask] = scaling_factor * (nominator / denominator)

    # normalize to have unit norm
    y = y / LA.norm(y)

    return y


def sol_symbols2samples(y, h, USF):
    """
    SOL_SYMBOLS2SAMPLES Produces the samples of the modulated signal
    Z = SOL_SYMBOLS2SAMPLES(Y, H, USF) produces the samples of a
    modulated pulse train. The sampled pulse is given in vector H, and
    the symbols modulating the pulse are contained in vector Y. USF is
    the upsampling factor, i.e., the number of samples per symbol.

    :param y: the symbols modulating the pulse are contained in vector Y.
    :param h: The sampled pulse is given in vector H.
    :param USF: The upsampling factor, i.e., the number of samples per symbol.
    :return: The samples of a modulated pulse train.
    """

    # We first upsample the vector of symbols y by factor USF, inserting USF-1
    # zeros between every two consecutive samples of y
    # y_up = upfirdn([1], y, USF)  # this will not introduce the last USF-1 zeros
    # here with Kronecker it does the job (if we do need those last USF-1 zeros...)
    y_up = np.kron(y, np.concatenate((np.ones(1), np.zeros(USF - 1))))

    # Convolve the upsampled symbol vector with the shaping pulse h
    z = np.convolve(y_up, h)
    # z = fftconvolve(y_up, h)  # use this one from scipy, it is faster in principle

    return z


def sol_sufficientStatistics(r, h, USF):
    """
    SOL_SUFFICIENTSTATISTICS Processes the output of the channel to generate
    sufficient statistics about the transmitted symbols.

    [X,Y] = SOL_SUFICIENTSTATISTICS(R, H, USF) produces sufficient statistics
    about the transmitted symbols, given the signal received in vector
    R, the impulse response H of the basic pulse (transmitting filter),
    and the integer USF (upsampling factor), which is the number of
    samples per symbol. It also produces also Y, which
    is the matched filter output before downsampling. Y is truncated
    in such a way that the first components of X and Y are identical.
    Y can be used to plot the eye diagram.

    :param r: The signal received.
    :param h: The impulse response H of the basic pulse (transmitting filter).
    :param USF: The upsampling factor, which is the number of samples per symbol.
    :return: The sufficient statistics about the transmitted symbols and the matched filter output before downsampling.

    """

    # For the implementation we use the matched filter, followed by downsampling
    # by a factor USF

    h_matched = np.conj(np.flip(h))
    # in our case h_matched = h, since h is real and symmetric

    # the matched filter output before downsampling is
    y = np.convolve(r, h_matched)
    # y = fftconvolve(r, h_matched)  # use this one from scipy, it is faster in principle

    # Picking the desired matched filter outputs:
    # Comment: the matched filter output (before downsampling) is the correlation
    # between the received signal and h. The first useful sample is the
    # length(h) th correlation result. At the end there are length(h)-1 unused
    # correlation results.

    y = y[np.size(h) - 1:-np.size(h) + 1]

    x = y[0:y.size:USF]

    return x, y


def sol_bi2de_old(b):
    """
    Convert binary vectors to decimal numbers.

    D = SOL_BI2DE(B) converts a binary vector B to a decimal value D. When B is a matrix,
    the conversion is performed row-wise and the output D is a column vector of decimal values.
    The default orientation of the binary input is Right-MSB: the first element in B represents
    the least significant bit.

    In addition to the input matrix, an optional parameter MSBFLAG can be given:

    D = SOL_BI2DE(B, MSBFLAG) uses MSBFLAG to determine the input orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the input orientation such that the MSB is on the left.

    :param args: B, MSBFLAG
    :return: D: vector of decimal values
    """

    # Test that the matrix contains only ones and zeros
    if not np.logical_or(b == 0, b == 1).all():
        raise ValueError('sol_bi2de:nonBinaryValuesInInput', 'Binary matrix B can only contain values 1 and 0')

    # Determine dimension of input, and prepare vector containing powers of 2
    # Assume b is row by row matrix (each row corresponds to a binary representation of a decimal value)
    ncols = b.shape[1]
    ptwo = np.power(2, np.arange(ncols - 1, -1, -1).astype(float))  # vector of powers

    # Multiply and sum
    d = np.dot(b, ptwo)

    return d


def sol_bi2de(*args):
    """
    Convert binary vectors to decimal numbers.

    D = SOL_BI2DE(B) converts a binary vector B to a decimal value D. When B is a matrix,
    the conversion is performed row-wise and the output D is a column vector of decimal values.
    The default orientation of the binary input is Right-MSB: the first element in B represents
    the least significant bit.

    In addition to the input matrix, an optional parameter MSBFLAG can be given:

    D = SOL_BI2DE(B, MSBFLAG) uses MSBFLAG to determine the input orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the input orientation such that the MSB is on the left.

    :param args: B, MSBFLAG
    :return: D: vector of decimal values
    """

    b = args[0]

    # Set defaults
    if len(args) < 2:
        msbflag = 'right-msb'
    else:
        msbflag = args[1]

    # Test that the matrix contains only ones and zeros
    if not np.logical_or(b == 0, b == 1).all():
        raise ValueError('sol_bi2de:nonBinaryValuesInInput', 'Binary matrix B can only contain values 1 and 0')

    # Determine dimension of input, and prepare vector containing powers of 2
    # Assume b is row by row matrix (each row corresponds to a binary representation of a decimal value)
    ncols = b.shape[1]
    ptwo = np.power(2, np.arange(0, ncols).astype(float))  # vector of powers

    # Flip the input if MSB is left
    if msbflag.lower() == 'left-msb':
        ptwo = np.flip(ptwo)
    elif msbflag.lower() == 'right-msb':
        pass
    else:
        raise ValueError('sol_bi2de:wrongMSBflag', 'Unsupported value for MSBFLAG')

    # Multiply and sum
    d = np.dot(b, ptwo)

    return d


def sol_de2bi_old(d):
    """
    Convert decimal numbers to binary numbers.

    B = SOL_DE2BI(D) converts a vector D of non-negative integers from base 10 to a binary matrix B.
    Each row of the binary matrix B corresponds to one element of D. The default orientation of the binary output
    is Right-MSB, i.e., the first element in a row of B represents the least significant bit.
    If D is a matrix rather than a row or column vector, the matrix is first converted to a vector (column-wise).

    In addition to the input vector D, two optional parameters can be given:

    B = SOL_DE2BI(D,MSBFLAG) uses MSBFLAG to determine the output orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the output orientation to display the MSB to the left.

    B = SOL_DE2BI(D,MSBFLAG,N) uses N to define how many binary digits (columns) are output.
    The number of bits must be large enough to represent the largest number in D.

    :param args: D,MSBFLAG,N
    :return: B
    """

    # assume d is a vector
    # Convert d to vector if necessary
    d = d.flatten('F')

    # Make sure input is nonnegative decimal
    if np.any((d < 0) | (d != np.fix(d)) | np.isnan(d) | np.isinf(d)):
        raise ValueError('sol_de2bi:invalidInput', 'Input must contain only finite positive integers.')

    # number of bits to use
    nmax = np.amax([1, np.floor(1 + np.log2((np.amax(d))))])

    # We are ready to get the job done
    # We use two approaches
    # The idea is to use a function that manipulates bits ...
    # and manipulate them in such a way that we get a hold of various bits.
    # One approach uses bitshift; the other approach uses bitand
    # The following is useful in both approaches

    shifts = np.flip(np.arange(0, nmax, 1).astype(int))  # for right LSB

    # bitshift approach
    # ----------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=np.right_shift([d d d d],[0 1 2 3])
    # This returns b=[11 5 2 1]
    # Since 11 is odd, the first bit of d is 1
    # Since 5 is odd, the second bit of d is 1
    # Since 2 is even, the third bit of d is 0 etc.
    # Hence mod(b,2) returns the binary representation 1 1 0 1 (left LSB)

    # c = np.right_shift(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(shifts, (d.size, 1))) % 2
    # c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    # bitand approach
    # ---------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=bitand([d d d d],[1 2 4 8])
    # This returns b=[1 2 0 8]
    # The nonzero positions of b is where the bit is 1
    # Hence b~=0 (or logical(b)) returns the binary representation 1 1 0 1

    powersOf2 = 2 ** shifts
    c = np.bitwise_and(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(powersOf2, (d.size, 1)))
    c = (c > 0) * 1  # find non-zero values, then convert to 0/1
    c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    return c


def sol_de2bi(*args):
    """
    Convert decimal numbers to binary numbers.

    B = SOL_DE2BI(D) converts a vector D of non-negative integers from base 10 to a binary matrix B.
    Each row of the binary matrix B corresponds to one element of D. The default orientation of the binary output
    is Right-MSB, i.e., the first element in a row of B represents the least significant bit.
    If D is a matrix rather than a row or column vector, the matrix is first converted to a vector (column-wise).

    In addition to the input vector D, two optional parameters can be given:

    B = SOL_DE2BI(D,MSBFLAG) uses MSBFLAG to determine the output orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the output orientation to display the MSB to the left.

    B = SOL_DE2BI(D,MSBFLAG,N) uses N to define how many binary digits (columns) are output.
    The number of bits must be large enough to represent the largest number in D.

    :param args: D,MSBFLAG,N
    :return: B
    """

    d = args[0]

    # assume d is a vector
    # Convert d to vector if necessary
    d = d.flatten('F')

    # Check number of arguments and assign default values as necessary
    nmax = np.amax([1, np.floor(1 + np.log2((np.amax(d))))])
    if len(args) < 3:
        n = nmax
    elif nmax > args[2]:
        raise ValueError('sol_de2bi:insufficientNbits',
                         'Specified number of bits is too small to represent some of the decimal inputs,'
                         '\n at least %d bits are required' % nmax)
    else:
        n = args[2]

    if len(args) < 2:
        msbflag = 'right-msb'
    else:
        msbflag = args[1]

    # Make sure input is nonnegative decimal
    if np.any((d < 0) | (d != np.fix(d)) | np.isnan(d) | np.isinf(d)):
        raise ValueError('sol_de2bi:invalidInput', 'Input must contain only finite positive integers.')

    # We are ready to get the job done
    # We use two approaches
    # The idea is to use a function that manipulates bits ...
    # and manipulate them in such a way that we get a hold of various bits.
    # One approach uses bitshift; the other approach uses bitand
    # The following is useful in both approaches

    shifts = np.arange(0, n, 1).astype(int)  # for right MSB

    if msbflag.lower() == 'left-msb':
        shifts = np.flip(shifts)
    elif msbflag.lower() == 'right-msb':
        pass
    else:
        raise ValueError('sol_de2bi:wrongMSBflag', 'Unsupported value for MSBFLAG')

    # bitshift approach
    # ----------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=np.right_shift([d d d d],[0 1 2 3])
    # This returns b=[11 5 2 1]
    # Since 11 is odd, the first bit of d is 1
    # Since 5 is odd, the second bit of d is 1
    # Since 2 is even, the third bit of d is 0 etc.
    # Hence mod(b,2) returns the binary representation 1 1 0 1 (left LSB)

    # c = np.right_shift(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(shifts, (d.size, 1))) % 2
    # c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    # bitand approach
    # ---------------
    # Here is the main idea:
    # Suppose that d=11 is the number that we want to convert into binary
    # To get 4 bits of d, we run the command
    # b=bitand([d d d d],[1 2 4 8])
    # This returns b=[1 2 0 8]
    # The nonzero positions of b is where the bit is 1
    # Hence b~=0 (or logical(b)) returns the binary representation 1 1 0 1

    powersOf2 = 2 ** shifts
    c = np.bitwise_and(np.transpose(np.tile(d, (shifts.size, 1))), np.tile(powersOf2, (d.size, 1)))
    c = (c > 0) * 1  # find non-zero values, then convert to 0/1
    c = c.flatten('C')  # flatten row-wise (each row has the binary representation of the corresponding integer)

    return c


def my_eyediagram(y, Fs, T):
    """
    MY_EYEDIAGRAM Generates an eye diagram.
    MY_EYEDIAGRAM(Y, FS, T) plots the eye diagram corresponding to the
    sampled matched filter output Y, assuming the sampling frequency is
    Fs and the symbol duration is T

    :param y: The sampled matched filter output.
    :param Fs: The sampling frequency.
    :param T: The symbol duration.
    :return: Plots the eye diagram.
    """
