import numpy as np
import ofdmc as ofdmc
import warnings
from scipy import signal
import matplotlib.pyplot as plt
from fractions import Fraction
import utilPDC
import utilMDC
import utilOFDM
from scipy import fft
from scipy.signal import correlate as corr
import usrpc
import threading

import utilSoapySDR

import SoapySDR
from SoapySDR import *  # SOAPY_SDR_ constants


def runUSRPSinglePC(tx_signal, mode_of_operation):
    """
    USRP transmission and reception
    """
    
    #
    # 1) Configuration of the driver for the SDR
    #
    
    # Disable logging for operation
    SoapySDR.setLogLevel(SOAPY_SDR_FATAL)

    # Get available devices
    devices = utilSoapySDR.get_available_devices(usrpc.driver)
    available_serials = np.array([device['serial'] for device in devices])
    print(f"Available devices found: {available_serials}")

    #
    # 2) Select mode of operation
    #

    # Use 4 copies of the tx_signal to extract the data
    lengthDataOut = 4 * tx_signal.size
    
    # Define the output
    rx_signal = np.array([], dtype=np.complex64)  # very important to use complex64. See examples with SoapySDR.
    
    # No device found -> Simulation
    if available_serials.size == 0:
        raise ValueError('No SDR boards detected. Please check if the boards are properly connected and try again.')
        # print(f"No {usrpc.driver} devices found. Continuing with simulation.")
        # rx_signal = channelSimulator(tx_signal, 50, 543, 3)

    # One device found -> Use specified mode of operation
    # If the mode is TX or RX just use one!
    elif available_serials.size == 1 or mode_of_operation == 'TX' or mode_of_operation == 'RX':
    # elif available_serials.size == 1:
        print(f"Using {usrpc.driver} device with serial {available_serials[0]}.")

        # Initialize the USRP device
        sdr = utilSoapySDR.sdr_create_device(usrpc.driver, available_serials[0])  #'3273A32')

        print("[SDR] Device initialized.")

        # Configure the USRP device
        utilSoapySDR.sdr_setup_tx(sdr, 0, usrpc)
        utilSoapySDR.sdr_setup_rx(sdr, 0, usrpc)

        print("[SDR] Configuration complete.")
        print("[SDR] Mode:", mode_of_operation)

        # TX only -> another device receives the signal
        if mode_of_operation == 'TX':
            # Exit after transmission
            usrpTx(sdr, tx_signal, usrpc.txDuration)
            exit(0)
        # RX only -> Signal is received and processed
        elif mode_of_operation == 'RX':
            
            # Create the buffer and store the result in there
            noSamples = int(usrpc.rxDuration / usrpc.Ts)
            rx_samples_buff = np.zeros((noSamples,), dtype=np.complex64)
            
            rx_signal = usrpRx(sdr, rx_samples_buff)
        # Use the same device for TX and RX
        # in loopback mode (by default)
        else:
            rx_signal = usrpTxRx(sdr, tx_signal)

    # Two or more devices -> Use one for TX and one for RX
    else:
        print(f"Multiple {usrpc.driver} devices found.")
        print(f"First device (with serial no.) {available_serials[0]} will be used for TX and {available_serials[1]} for RX.")

        # Initialize the USRP devices
        sdr_tx = utilSoapySDR.sdr_create_device(usrpc.driver, available_serials[0])
        sdr_rx = utilSoapySDR.sdr_create_device(usrpc.driver, available_serials[1])

        print("[SDR] Devices initialized.")
        
        # Configure the USRP devices
        utilSoapySDR.sdr_setup_tx(sdr_tx, 0, usrpc)
        utilSoapySDR.sdr_setup_rx(sdr_rx, 0, usrpc)
        
        print("[SDR] Configuration complete.")

        # Do the setup for both devices
        rx_signal = dualUsrpTxRx(sdr_tx, sdr_rx, tx_signal)

    # Plot the received samples for a visual sanity check
    rx_symbols = analyze_received_data(rx_signal, lengthDataOut, tx_signal.size)

    return rx_symbols


def dualUsrpTxRx(sdr_tx, sdr_rx, samplesIn):
    """
    USRP transmitter and receiver on two separate devices
    
    - Starts one thread for transmission and one for reception.
    - The transmitter transmits the samplesIn.
    - The receiver receives the samples and returns them.
    """

    # RX Setup: Create a buffer that is passed to the thread
    noSamples = int(usrpc.txRxDuration / usrpc.Ts)
    rx_samples_buff = np.zeros((noSamples,), dtype=np.complex64)

    # Start the receiver and get the samples
    # TX sends for a bit more to make sure we get the samples
    usrpRx_thread = threading.Thread(target=usrpRx, args=(sdr_rx, rx_samples_buff))
    usrpTx_thread = threading.Thread(target=usrpTx, args=(sdr_tx, samplesIn, 1.2 * usrpc.txRxDuration))
    
    # Start the threads
    usrpRx_thread.start()
    usrpTx_thread.start()
    
    # Join the threads
    usrpTx_thread.join()
    usrpRx_thread.join()
    
    # Return the samples from the second USRP
    return rx_samples_buff


def usrpTx(sdr, samplesIn, duration):
    """
    USRP transmitter

    :param samplesIn: the input samples
    :param duration: the duration to do TX in seconds
    :return: None
    """

    print(f'\nTx over USRP, duration of {duration} [seconds].\n')
    # repeat the data such that we have a Tx time of usrpc.txRxDuration (min 10 times)
    noRepData = np.ceil(duration / (usrpc.Ts * samplesIn.size))
    noRepData = max(10, noRepData)  
    print(f'Repeating the data {noRepData} times.\n')

    dataToTx = np.tile(samplesIn, int(noRepData)).astype(np.complex64)

    # Get timestamp for transmit (+ 0.5 seconds scheduled ahead)
    time = int(sdr.getHardwareTime() + 0.5e9)

    # Execute only the transmission
    utilSoapySDR.sdr_transmit(sdr, time, dataToTx)


def usrpRx(sdr, rx_samples_buff):
    """
    USRP receiver
    
    :param sdr: the sdr
    :param rx_samples_buff: the buffer that contains the samples (ref)
    """

    # Start in 0.5 ms (approximately the same as for the second one in dual)
    rx_ts_start = int(sdr.getHardwareTime() + 0.5e9)

    # Receive the samples
    utilSoapySDR.sdr_receive(sdr, rx_ts_start, rx_samples_buff, int(rx_samples_buff.size))
    
    print("rx_buffs size: ", rx_samples_buff.size)

    return rx_samples_buff


def usrpTxRx(sdr, samplesIn):
    """
    USRP transmitter and receiver (loopback)
    The difference to the dualUsrpTxRx is the fact that the timestamps are synchronized in hardware
    This is important, in order to not miss samples

    :param samplesIn: the input samples
    :return: dataOut: vector of the output samples
    """

    print(f'\nTx/Rx over USRP, duration of {usrpc.txRxDuration} [seconds].\n')
    # repeat the data such that we have a Tx time of usrpc.txRxDuration (min 10 times)
    noRepData = np.ceil(usrpc.txRxDuration / (usrpc.Ts * samplesIn.size))
    noRepData = max(10, noRepData)  
    print(f'Repeating the data {noRepData} times.\n')

    dataToTx = np.tile(samplesIn, int(noRepData)).astype(np.complex64)
    print("data to tx has size:", dataToTx.size)

    ##### SDR Instrumentation #####

    # This will contain the received samples (used for analysis)
    noSamplesRx = int(usrpc.txRxDuration / usrpc.Ts)
    rx_samples_buff = np.zeros((noSamplesRx,), dtype=np.complex64)
    
    # Get timestamp for transmit (+ 0.5 seconds scheduled ahead)
    # The timestamp is used for *both* streams* for sync
    time_tx = int(sdr.getHardwareTime() + 0.5e9)
    time_rx = int(sdr.getHardwareTime() + 0.4e9)
    
    # Start both threads
    tx_thread = threading.Thread(target=utilSoapySDR.sdr_transmit, args=(sdr, time_tx, dataToTx))
    rx_thread = threading.Thread(target=utilSoapySDR.sdr_receive, args=(sdr, time_rx, rx_samples_buff, noSamplesRx))

    tx_thread.start()
    rx_thread.start()
    
    print("[TRX] Waiting for threads to finish...")
    
    # Wait for both threads to finish
    tx_thread.join()
    rx_thread.join()
    
    print("[TRX] complete!")
    print("[TRX] Samples: ", rx_samples_buff.size)

    # return the samples buffer
    return rx_samples_buff


def analyze_received_data(dataRx, lengthDataOut, tx_signal_size):
    """
    Analyze the received data
    
    :param dataRx: the received data
    """

    # print("dataRX.size:", dataRx.size, "lengthDataOut:", lengthDataOut, "samplesInSize:", tx_signal_size)
    # print("First ten and last ten samples:", dataRx[:10], dataRx[-10:])

    # scatter plot to check if no ADC saturation happened
    if ofdmc.verbose:
        plt.scatter(dataRx.real, dataRx.imag, marker='*', color='b')
        plt.grid()
        plt.title('RX Constellation')
        plt.show()

    # drop some data to avoid the transients at the end
    stopData = dataRx.size - int(np.floor(tx_signal_size / 2))
    startData = max(0, stopData - lengthDataOut)
    # max() to avoid problems, as dataRx is already truncated when using the channel simulator.

    # this is the truncated version of the received samples 
    # Getting the last 4 copies from the end
    dataRx1 = dataRx[startData:stopData]

    if ofdmc.verbose:
        # plot the magnitude of the Rx data
        plt.grid()
        plt.plot(np.arange(dataRx.size), abs(dataRx))
        plt.title('Rx data: vertical bars show selected data')
        plt.xlabel('Sample Index')
        plt.ylabel('Magnitude')
        plt.axvline(x=startData, color='r')
        plt.axvline(x=stopData, color='r')
        plt.show()

        # plot the spectrum of the Rx data
        fso = 1 / usrpc.Ts
        N = dataRx.size
        freqLineo = np.arange(-fso / 2, fso / 2, fso / N)
        plt.plot(freqLineo, np.fft.fftshift(np.abs(np.fft.fft(dataRx, N))))
        plt.grid()
        plt.xlabel('f [Hz]')
        plt.ylabel('Magnitude')
        plt.title('OFDM Rx Spectrum')
        plt.show()

    rx_symbols = dataRx1

    print(f"[TRX] Processing {rx_symbols.size} the samples...")

    return rx_symbols
