import numpy as np
from typing import Union, Optional, List
import scipy.io as sio
from scipy.signal import correlate as corr
import matplotlib.pyplot as plt
import math
import gpsc as gpsc


class LCMException(BaseException):
    pass


class NoStartSubframe(LCMException):
    pass


class NotEnoughSubframes(LCMException):
    pass


class ParityCheckFailed(LCMException):
    pass


class NotBinaryInput(LCMException, ValueError):
    pass


class NotIntegralNumberOfSubframes(LCMException, ValueError):
    pass


class WrongMSBFlag(LCMException, ValueError):
    pass


def bi2de(b: np.array, msbflag: str = 'right-msb') -> Union[np.ndarray, int]:
    """
    Converts binary vectors to decimal numbers.

    The default orientation of the binary input is Right-MSB:
    the first element in B represents the least significant bit.

    :param b: binary vector or matrix. If it is a matrix, the conversion is performed row-wise and the output
        D is a vector of decimal values.
    :param msbflag: has two possible values, 'right-msb' and 'left-msb'.
        Giving a 'right-msb' MSBFLAG does not change the function's default behavior.
        Giving a 'left-msb' MSBFLAG flips the input orientation such that the MSB is on the left.

    :return: integer or a vector of decimal values
    """

    # Test that the matrix contains only ones and zeros
    if np.any(b[b != 0] != 1):
        raise NotBinaryInput('bi2de: Binary matrix B can only contain values 1 and 0')

    # Determine dimension of input, and prepare vector containing powers of 2
    # Assume b is row by row matrix (each row corresponds to a binary representation of a decimal value)

    # first, if the input is a vector, we force it into a (1, b.size), so that ncols = b.shape[1] is always correct
    if len(b.shape) == 1:
        b.resize(1, b.size)

    ncols = b.shape[1]
    ptwo = 2 ** np.arange(ncols)

    # Flip the input if MSB is left
    if msbflag.lower() == 'left-msb':
        b = np.fliplr(b)
    elif msbflag.lower() == 'right-msb':
        pass
    else:
        raise WrongMSBFlag(f'bi2de: Unsupported value {msbflag} for msbflag')

    res = b.dot(ptwo)
    # if we have only one row, the result type is int, not np.ndarray
    if res.size == 1:
        res = res[0]

    return res


class Ephemeris:
    """
    The ephemeris data has the following structure:

    Telemetry and HOW (Hand-over word):
        TOW1 - TOW3   Time of the week for subframes 1 to 3 [s]
        wn             Week number
        alert          Alert bit
        antispoof      Antispoof bit

    Clock correction parameters:
        IODC           Issue index of clock data
        t_oc           Clock data reference time
        T_GD           Group delay differential
        a_f0           0th-order correction term
        a_f1           1st-order correction term
        a_f2           2nd-order correction term

    Orbit Parameters:
        IODE           Issue index of ephemeris data
        t_oe           Ephemeris reference time
        M_0            Mean anomaly at reference time
        delta_n        Correction to the computed mean motion
        e              Orbit ellipse eccentricity
        sqrt_a         Square-root of the semi-major axis
        i_0            Inclination angle at reference time
        w              Argument of perigee
        Omega_0        Longitude of the ascending node at reference time

    Orbit correction parameters:
        Omegadot       Rate of change of the right ascension
        idot           Rate of change of the inclination angle
        C_us           Amplitudes of harmonic correction terms for the
        C_uc           computed argument of latitude
        C_rs           Amplitudes of harmonic correction terms for the
        C_rc           computed orbit radius
        C_is           Amplitudes of harmonic correction terms for the
        C_ic           computed inclination angle

    For  your convenience, the following auxiliary parameter is also added:
        t_tr           Transmission time (earliest TOW minus 6)
    """

    def __init__(self):

        self.TOW1: Optional[int] = None
        self.TOW2: Optional[int] = None
        self.TOW3: Optional[int] = None
        self.wn: Optional[int] = None
        self.alert: Optional[int] = None
        self.antispoof: Optional[int] = None

        self.IODC: Optional[int] = None
        self.t_oc: Optional[int] = None
        self.T_GD: Optional[float] = None
        self.a_f0: Optional[float] = None
        self.a_f0: Optional[float] = None
        self.a_f1: Optional[float] = None
        self.a_f2: Optional[float] = None

        self.IODE: Optional[int] = None
        self.t_oe: Optional[int] = None
        self.M_0: Optional[float] = None
        self.delta_n: Optional[float] = None
        self.e: Optional[float] = None
        self.sqrt_a: Optional[float] = None
        self.i_0: Optional[float] = None
        self.w: Optional[float] = None
        self.Omega_0: Optional[float] = None

        self.Omegadot: Optional[float] = None
        self.idot: Optional[float] = None
        self.C_us: Optional[float] = None
        self.C_uc: Optional[float] = None
        self.C_rs: Optional[float] = None
        self.C_rc: Optional[float] = None
        self.C_is: Optional[float] = None
        self.C_ic: Optional[float] = None

        self.t_tr: Optional[int] = None

    def __eq__(self, other):

        equal = True
        for attr in self.__dict__:
            try:
                # compare the corresponding fields with approx. 15-digits precision
                if abs(self.__getattribute__(attr) - other.__getattribute__(attr)) > \
                        1e-15 * abs(self.__getattribute__(attr)):
                    print(f'   ** Differing field: {attr} should be {self.__getattribute__(attr)} '
                          f'and it is {other.__getattribute__(attr)}')
                    equal = False

            except AttributeError:
                # if other does not have this attribute
                print(f'   ** Non-existent field {attr}, that should be {self.__getattribute__(attr)}')
                break
        return equal

    def save(self, filename: str) -> None:
        """
        Converts the ephemeris into a dict and writes it to a .mat file.
        """

        sio.savemat(filename, {'ephemeris': self.__dict__})

    def load(self, filename: str) -> None:
        """
        Reads the ephemeris data from a .mat file
        """

        aux = sio.loadmat(filename)
        dtypes = aux['ephemeris'].dtype
        dtypes = list(dtypes.fields.keys())
        array = aux['ephemeris'].flatten()[0]

        for ind, field in enumerate(dtypes):
            # print(ind, field)
            self.__setattr__(field, array[ind].flatten()[0])


def read_unsigned(p: np.ndarray, sf: int, a: int, n: int, s: int) -> float:
    """
    Extracts data from a page using unsigned format.

    :param p: Page (P is a 300x5 matrix of bits {0,1})
    :param sf: Subframe ID, count from 1
    :param a: Starting bit, count from 1
    :param n: Number of bits
    :param s: Exponent of scaling factor (factor = 2^s)

    :return: m * 2^s, where m is the mantissa represented by the n bits (expressed in base 10)
    """

    bits = np.empty(n)
    bits[:] = p[a - 1: a + n - 1, sf - 1]
    return bi2de(bits, 'left-msb') * 2 ** s


def read_signed(p: np.ndarray, sf: int, a: int, n: int, s: int) -> float:
    """
    Extracts data from a page using two's complement format

    :param p: Page (P is a 300x5 matrix of bits {0,1}, each column is a  subframe)
    :param sf: Subframe ID, count from 1
    :param a: Starting bit, count from 1
    :param n: Number of bits
    :param s: Exponent of scaling factor (factor = 2^s)

    :return: m * 2^s, where m is the mantissa represented by the n bits (expressed in base 10)
    """

    bits = np.empty(n)
    bits[:] = p[a - 1: a + n - 1, sf - 1]

    if bits[0] == 0:
        return bi2de(bits, 'left-msb') * 2 ** s
    else:
        bits = 1 - bits
        return -1 * (bi2de(bits, 'left-msb') + 1) * 2 ** s


def read_2part_unsigned(p: np.ndarray, sf: int, am: int, nm: int, al: int, nl: int, s: int) -> float:
    """
    Extracts data from a page when the data is stored in two different locations

    :param p: Page (P is a 300x5 matrix of bits {0,1})
    :param sf: Subframe ID, count from 1
    :param am: Starting bit (MSB part)
    :param nm: Number of bits (MSB part), count from 1
    :param al: Starting bit (LSB part)
    :param nl: Number of bits (LSB part), count from 1
    :param s: Exponent of scaling factor (factor = 2^s)

    :return: m * 2^s, where m is the mantissa represented by the two chunks of bits (expressed in base 10)
    """

    bits_msb = p[am - 1: am + nm - 1, sf - 1]
    bits_lsb = p[al - 1: al + nl - 1, sf - 1]

    bits = np.empty(bits_msb.size + bits_lsb.size)
    bits[0: bits_msb.size] = bits_msb
    bits[bits_msb.size:] = bits_lsb

    return bi2de(bits, 'left-msb') * 2 ** s


def read_2part_signed(p: np.ndarray, sf: int, am: int, nm: int, al: int, nl: int, s: int) -> float:
    """
    Extracts data from a page when the data is stored in two different locations in two's complement format.

    :param p: page (P is a 300x5 matrix of bits {0,1})
    :param sf: Subframe ID, count from 1
    :param am: Starting bit (MSB part)
    :param nm: Number of bits (MSB part), count from 1
    :param al: Starting bit (LSB part)
    :param nl: Number of bits (LSB part), count from 1
    :param s: Exponent of scaling factor

    :return: m * 2^s, where m is the mantissa represented by the two chunks of bits (expressed in base 10)
    """

    bits_msb = p[am - 1: am + nm - 1, sf - 1]
    bits_lsb = p[al - 1: al + nl - 1, sf - 1]

    bits = np.empty(bits_msb.size + bits_lsb.size)
    bits[0: bits_msb.size] = bits_msb
    bits[bits_msb.size:] = bits_lsb

    if bits[0] == 0:
        # the sign is positive
        return bi2de(bits, 'left-msb') * 2 ** s
    else:
        # the sign is negative, invert the bits
        bits = 1 - bits
        return -1 * (bi2de(bits, 'left-msb') + 1) * 2 ** s


def loadEphemeris(filename: str) -> Ephemeris:
    """
    Reads the ephemeris data from a .mat file.

    :param filename: the file which stores the ephemeris

    :return: the corresponding Ephemeris structure

    """

    # initialize the structure
    ephdata = Ephemeris()

    # load the file and get the list of fields and their corresponding values
    aux = sio.loadmat(filename)
    dtypes = aux['ephemeris'].dtype
    # print(dtypes)
    dtypes = list(dtypes.fields.keys())
    # print(dtypes)
    array = aux['ephemeris'].flatten()[0]
    # print(array)

    for ind, field in enumerate(dtypes):
        # print(ind, field)
        ephdata.__setattr__(field, array[ind].flatten()[0])

    return ephdata


def readEphemeris(page: np.ndarray) -> Ephemeris:
    """
    Extracts the ephemeris data of subframes 1 to 3 from the given page by taking the correct bits and performing
    the appropriate conversion (scaling, 2-complement, etc.).

    :param page: 300x3 matrix of zeros and ones, where columns 0 to 2
        represent the bits of subframes 1 to 3, respectively.

    :return:

    """

    ephdata = Ephemeris()

    # Extract parameters from page according to the GPS standard
    ephdata.sfid = int(read_unsigned(page, 1, 50, 3, 0))
    ephdata.TOW1 = int(read_unsigned(page, 1, 31, 17, 0) * 6)
    ephdata.TOW2 = int(read_unsigned(page, 2, 31, 17, 0) * 6)
    ephdata.TOW3 = int(read_unsigned(page, 3, 31, 17, 0) * 6)
    # ephdata.TOW4 = int(read_unsigned(page, 4, 31, 17, 0) * 6)
    # ephdata.TOW5 = int(read_unsigned(page, 5, 31, 17, 0) * 6)
    ephdata.wn = int(read_unsigned(page, 1, 61, 10, 0))
    ephdata.alert = int(read_unsigned(page, 1, 18, 1, 0))
    ephdata.antispoof = int(read_unsigned(page, 1, 19, 1, 0))

    ephdata.IODC = int(read_2part_unsigned(page, 1, 83, 2, 211, 8, 0))
    ephdata.t_oc = int(read_unsigned(page, 1, 219, 16, 4))
    ephdata.T_GD = read_signed(page, 1, 197, 8, -31)
    ephdata.a_f0 = read_signed(page, 1, 271, 22, -31)
    ephdata.a_f1 = read_signed(page, 1, 249, 16, -43)
    ephdata.a_f2 = read_signed(page, 1, 241, 8, -55)

    ephdata.IODE = int(read_unsigned(page, 2, 61, 8, 0))
    ephdata.t_oe = int(read_unsigned(page, 2, 271, 16, 4))
    ephdata.M_0 = read_2part_signed(page, 2, 107, 8, 121, 24, -31) * gpsc.pi_gps
    ephdata.delta_n = read_signed(page, 2, 91, 16, -43) * gpsc.pi_gps
    ephdata.e = read_2part_unsigned(page, 2, 167, 8, 181, 24, -33)
    ephdata.sqrt_a = read_2part_unsigned(page, 2, 227, 8, 241, 24, -19)
    ephdata.i_0 = read_2part_signed(page, 3, 137, 8, 151, 24, -31) * gpsc.pi_gps
    ephdata.w = read_2part_signed(page, 3, 197, 8, 211, 24, -31) * gpsc.pi_gps
    ephdata.Omega_0 = read_2part_signed(page, 3, 77, 8, 91, 24, -31) * gpsc.pi_gps

    ephdata.Omegadot = read_signed(page, 3, 241, 24, -43) * gpsc.pi_gps
    ephdata.idot = read_signed(page, 3, 279, 14, -43) * gpsc.pi_gps
    ephdata.C_us = read_signed(page, 2, 211, 16, -29)
    ephdata.C_uc = read_signed(page, 2, 151, 16, -29)
    ephdata.C_rs = read_signed(page, 2, 69, 16, -5)
    ephdata.C_rc = read_signed(page, 3, 181, 16, -5)
    ephdata.C_is = read_signed(page, 3, 121, 16, -29)
    ephdata.C_ic = read_signed(page, 3, 61, 16, -29)

    # Compute the minimum of the TOW times. Since the TOW field of a subframe
    # indicates the GPS time of the start of the _next_ subframe, we subtract
    # 6 (seconds) from it, since that is the transmission duration of a subframe.
    mintow = min([
        ephdata.TOW1,
        ephdata.TOW2,
        ephdata.TOW3])

    ephdata.t_tr = int(mintow - 6)

    return ephdata


def getSubframes(bits: np.array) -> (np.array, List[int], List[int]):
    """
    Extracts the subframes from the bit stream transmitted by a satellite.

    :param bits: given bits of a satellite
    :return: (subframes, subframesIDs, idx), where subframes is the matrix with subframes bits in columns,
        subframesIDs is the list of IDs corresponding to each column and
        idx is the index into vector BITS where the first subframe with ID=1 starts
    """

    # Find the first bit of a subframe and remove the bits that precede it and
    # the bits at the end that don't make a full subframe. It also returns the
    # position (in bits) of the first retained bit within the input "bits".
    s_pre, idx = removeExcessBits(bits)

    # Make sure that the parity is fulfilled and flip the bits as needed.
    # The function implements Table 20-XIV.
    s = establishParity(s_pre)

    # Order the bits into a matrix that has as its columns the subframes.
    # Return as well the IDs of the subframes.
    subframes, subframes_IDs = bits2subframes(s)

    # Nicolae: the value 8 should be changed. We can work fine with only 5 subframes. See my notes.
    if subframes.shape[1] < 8:
        raise RuntimeError(f'At least {8} subframes are required '
                           f'(got only {subframes.shape[1]})')

    return subframes, subframes_IDs, idx


def removeExcessBits(bits: np.array) -> (np.array, int):
    """
    Extracts the complete subframes in the bit sequence obtained from a satellite.

    Detects the start of the first subframe in the row vector BITS (values in {-1,+1}) by correlating
    with the GPS preamble (stored in gpsc.preamble as a sequence of 1s and 0s).
    If the negative of the preamble is found all bits are inverted.
    Incomplete subframes at the beginning and at the end of
    BITS are removed and the sequence of bits is converted into a
    sequence of {0,1} values ({1,-1} <-> {0,1}).

    Hints:

    - The preamble is stored in gpsc.preamble as a sequence of 0s and 1s.

    - It is not enough to just find one preamble: the preamble is 8 bits long and it is not unlikely that this 8-bit
        sequence is also present somewhere in the middle of the data sequence.
        You should use the fact that there is one preamble in every subframe (every 300 bits).

    - The result of the correlation operation are real numbers and it can have multiple maxima.
        If you want to compare two values returned by the correlation, you should first round them.

    :raise NoStartSubframe: if fails to find the start of a subframe

    :return: (s, idx), where s is the resulting converted sequence of bits and idx is the index into bits
        where the first subframe starts
    """


    return bits, ind


def establishParity(ws: np.array) -> np.array:
    """
    Checks the parity of the GPS words.
    In particular, it checks the parity of each word in the sequence of subframes WS.

    The last two bits of each subframe are 0 by convention.
    If the parity check fails, a warning is issued; otherwise a message indicating success is displayed.
    Both the input WS and the output S are binary (0 and 1) vectors containing full subframes.

    :return: a possibly modified copy S of ws; the modification consists in flipping the first 24 bits of words
        for which the last bit of the previous word is 1.
    """

    # Parity check matrix (cf. GPS standard)
    H = np.array([
        [1, 0, 1, 0, 1, 0],     # multiplies d1
        [1, 1, 0, 1, 0, 0],     # multiplies d2
        [1, 1, 1, 0, 1, 1],     # ...
        [0, 1, 1, 1, 0, 0],
        [1, 0, 1, 1, 1, 1],
        [1, 1, 0, 1, 1, 1],
        [0, 1, 1, 0, 1, 0],
        [0, 0, 1, 1, 0, 1],
        [0, 0, 0, 1, 1, 1],
        [1, 0, 0, 0, 1, 1],
        [1, 1, 0, 0, 0, 1],
        [1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 0],
        [1, 0, 0, 1, 1, 0],
        [1, 1, 0, 0, 1, 0],
        [0, 1, 1, 0, 0, 1],
        [1, 0, 1, 1, 0, 0],
        [0, 1, 0, 1, 1, 0],
        [0, 0, 1, 0, 1, 1],
        [1, 0, 0, 1, 0, 1],
        [0, 1, 0, 0, 1, 1],     # multiplies d24
        [1, 0, 1, 0, 0, 1],     # multiplies D29*
        [0, 1, 0, 1, 1, 0]      # multiplies D30*
        ])

    # Other needed parameters
    BPW: int = 30    # Bits per word
    DBPW: int = 24   # Databits per word

    # Check parameter validity
    nsf = ws.size / gpsc.bpsf
    if nsf != np.floor(nsf):
        raise RuntimeError('WS must contain a multiple of subframes')
    nsf = int(nsf)

    # Total number of words
    nw = nsf * gpsc.wpsf

    # Flag to indicate whether parity check passed. Set this to false if you
    # detect a parity error somewhere.
    passed = True

    # COMPLETE THE ALGORITHM HERE
    # Note: For the first word of WS, we can assume that D29* and D30* are zero.
    # Reason: the last two bits of each subframe are 0 by convention.



    # Display a warning if there was an error
    if not passed:
        raise ParityCheckFailed('establishParity:unsuccessful. Parity check failed.')
    else:
        print('Parity check passed')

    return ws


def bits2subframes(bits: np.array) -> (np.array, List):
    """
    Returns a matrix containing the subframes in the desired order.

    Provided that BITS is sufficiently long, SUBFRAMES contains the subframes with
    id 1,2,3, in that order. (It might contain additional subframes
    and the first column is not necessarily subframe 1.)
    These are the subframes that we use to obtain the ephemerides.

    :param bits: must be a row vector containing a concatenation of subframes (0 and 1 elements)
        that have already been checked for parity.

    :return: (subframes, subframes_IDs), where
        SUBFRAMES is the matrix having as ist columns the subframes extracted from bits;
        SUBFRAMES_IDS is the list of subframes IDs
    """


    return subframes, subframes_IDs


def computePseudorange(taus: np.ndarray, tau_ref: int, idx_firstBitSubframe: int) -> float:
    """
    Compute pseudorange for one GPS satellite.

    :param taus: vector of indices into the received samples where each decoded
        bit begins. The length of TAUS is equal to the number of decoded bits.
    :param tau_ref: (scalar) index into the vector of received samples where the
        position of the receiver is to be computed.
    :param idx_firstBitSubframe: index into the sequence of bits where the first
        subframe starts for this satellite.

    :return: computed pseudorange for the satellite (in meters).
    """

    # tau_ref should be bigger than 0 (the reference time should be after turning on the receiver)
    # tau_ref is considered to be lower than the start of the first complete subframe


    return

