import numpy as np
from numpy import linalg as LA
from scipy.signal import upfirdn
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
import warnings


def my_qamMap(M):
    """
    MY_QAMMAP Creates constellation for square QAM modulations
    C = MY_QAMMAP(M) outputs a 1xM vector with the constellation for the
    quadrature amplitude modulation of alphabet size M, where M is of
    the form 2^(2m) for some positive integer m. The signal
    constellation is a square constellation.

    :param M: M is of the form 2^(2m) for some positive integer m.
    :return: 1xM vector with the constellation for the quadrature amplitude modulation of alphabet size M.

    """


    return c


def my_pskMap(M):
    """
    MY_PSKMAP Creates constellation for Phase Shift Keying modulation
    C = MY_PSKMAP(M) outputs a 1xM vector with the complex symbols of the
    PSK constellation of alphabet size M, where M is of the form 2^m
    for some positive integer m.

    :param M: M is of the form 2^m for some positive integer m.
    :return: 1xM vector with the complex symbols of the PSK constellation of alphabet size M.

    """


    return c


def my_encoder(x, mapping):
    """
    MY_ENCODER Maps a vector of M-ary integers to constellation points
    Y = MY_ENCODER(X, MAPPING) outputs a vector of (possibly complex)
    symbols from the constellation specified as second parameter,
    corresponding to the integer valued symbols of X.  Input X can be
    a row or column vector, and output Y has the same dimensions as X.
    If the length of vector MAPPING is M, then the message symbols of
    X must be integers between 0 and M-1.


    :param x: Input X can be a row or column vector containing the symbol values.
    :param mapping: The constellation.
    :return: a vector of (possibly complex) symbols from the constellation specified as second parameter.
    """


    return y


def my_decoder(y, mapping):
    """
    MY_DECODER Minimum distance slicer
    Z = MY_DECODER(Y, MAPPING) demodulates vector Y by finding
    the element of the specified constellation that is closest to each
    element of input Y. Y contains the outputs of the matched filter of
    the receiver, and it can be a row or column vector. MAPPING specifies
    the constellation used, it can also be a row or column vector.
    Output Z has the same dimensions as input Y. The elements of Z are
    M-ary symbols, i.e., integers between 0 and [M=length(MAPPING)]-1.

    :param y: Y contains the outputs of the matched filter of the receiver, and it can be a row or column vector.
    :param mapping: The constellation.
    :return: The elements of Z are M-ary symbols, i.e., integers between 0 and [M=length(MAPPING)]-1.
    """


    return z


def sol_rcosdesign(beta, span, sps):
    """
    SOL_RCOSDESIGN returns the coefficients of a root-raised-cosine pulse.

    :param beta: The roll-off factor. Should be real valued and in the interval [0, 1].
    :param span: The number of symbols to which the pulse is truncated.
    :param sps: The number of samples per symbol.
    :return: The coefficients of the root-raised-cosine pulse. The pulse is normalized such that it has unit norm.
    """

    # Verify that span * sps is even
    if (span * sps) % 2 != 0:
        raise ValueError('sol_rcosdesign:wrongInputValues', 'Product of SPS and SPAN must be even.')

    # compute the time line
    M = span * sps
    no_samples = np.round(M/2)
    t = np.arange(-no_samples, no_samples + 1, 1) / sps
    # print(t)

    y = np.zeros(t.shape, dtype=np.float)

    if beta == 0:
        y = np.sinc(t)
    else:
        # implement the formula from Bixio's PDC book
        # first take care of the special cases (when the denominator is zero)

        # mask = np.argwhere(t == 0.25 / beta)  # testing exact equality does not quite work!
        mask = np.isclose(np.abs(t), 0.25 / beta)
        # print(mask)

        y[mask] = (beta / np.sqrt(2)) * ((1 + 2/np.pi)*np.sin(np.pi/(4*beta)) + (1 - 2/np.pi)*np.cos(np.pi/(4*beta)))
        # then the non-special case
        scaling_factor = (4 * beta) / np.pi
        nominator = np.cos((1 + beta) * np.pi * t[~mask]) + (((1 - beta) * np.pi) / (4 * beta)) * np.sinc((1 - beta) * t[~mask])
        denominator = 1 - np.square(4 * beta * t[~mask])
        y[~mask] = scaling_factor * (nominator / denominator)

    # normalize to have unit norm
    y = y / LA.norm(y)

    return y


def my_symbols2samples(y, h, USF):
    """
    MY_SYMBOLS2SAMPLES Produces the samples of the modulated signal
    Z = MY_SYMBOLS2SAMPLES(Y, H, USF) produces the samples of a
    modulated pulse train. The sampled pulse is given in vector H, and
    the symbols modulating the pulse are contained in vector Y. USF is
    the upsampling factor, i.e., the number of samples per symbol.

    :param y: the symbols modulating the pulse are contained in vector Y.
    :param h: The sampled pulse is given in vector H.
    :param USF: The upsampling factor, i.e., the number of samples per symbol.
    :return: The samples of a modulated pulse train.
    """


    return z


def my_sufficientStatistics(r, h, USF):
    """
    MY_SUFFICIENTSTATISTICS Processes the output of the channel to generate
    sufficient statistics about the transmitted symbols.

    [X,Y] = MY_SUFICIENTSTATISTICS(R, H, USF) produces sufficient statistics
    about the transmitted symbols, given the signal received in vector
    R, the impulse response H of the basic pulse (transmitting filter),
    and the integer USF (upsampling factor), which is the number of
    samples per symbol. It also produces also Y, which
    is the matched filter output before downsampling. Y is truncated
    in such a way that the first components of X and Y are identical.
    Y can be used to plot the eye diagram.

    :param r: The signal received.
    :param h: The impulse response H of the basic pulse (transmitting filter).
    :param USF: The upsampling factor, which is the number of samples per symbol.
    :return: The sufficient statistics about the transmitted symbols and the matched filter output before downsampling.

    """


    return x, y


def my_bi2de(*args):
    """
    Convert binary vectors to decimal numbers.

    D = MY_BI2DE(B) converts a binary vector B to a decimal value D. When B is a matrix,
    the conversion is performed row-wise and the output D is a column vector of decimal values.
    The default orientation of the binary input is Right-MSB: the first element in B represents
    the least significant bit.

    In addition to the input matrix, an optional parameter MSBFLAG can be given:

    D = MY_BI2DE(B, MSBFLAG) uses MSBFLAG to determine the input orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the input orientation such that the MSB is on the left.

    :param args: B, MSBFLAG
    :return: D: vector of decimal values
    """


    return d


def my_de2bi(*args):
    """
    Convert decimal numbers to binary numbers.

    B = MY_DE2BI(D) converts a vector D of non-negative integers from base 10 to a binary matrix B.
    Each row of the binary matrix B corresponds to one element of D. The default orientation of the binary output
    is Right-MSB, i.e., the first element in a row of B represents the least significant bit.
    If D is a matrix rather than a row or column vector, the matrix is first converted to a vector (column-wise).

    In addition to the input vector D, two optional parameters can be given:

    B = MY_DE2BI(D,MSBFLAG) uses MSBFLAG to determine the output orientation.
    MSBFLAG has two possible values, 'right-msb' and 'left-msb'.
    Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
    Giving a 'left-msb' MSBFLAG flips the output orientation to display the MSB to the left.

    B = MY_DE2BI(D,MSBFLAG,N) uses N to define how many binary digits (columns) are output.
    The number of bits must be large enough to represent the largest number in D.

    :param args: D,MSBFLAG,N
    :return: B
    """


    return c
