import numpy as np
import matplotlib.pyplot as plt
import my_utilOFDM
import utilPDC
import warnings

# TEST_OFDM Script to test the OFDM transmitter / channel / receiver

# System parameters
num_carriers = int(256)     # total number of carriers
num_zeros = int(5)          # number of unsused carriers (at each end of the frequency spectrum)
prefix_length = int(25)     # length of the cyclic prefix
num_ofdm_symbols = int(1e4) # number of OFDM symbols per frame (1 will be used to transmit the preamble and the rest is for data)
M_preamble = int(4)         # we use 4-QAM for the preamble
M_data = int(4)             # we use 4-QAM for the data

SNR_dB = 20  # in dB

# Derived parameters
num_useful_carriers = num_carriers - 2 * num_zeros

# Preamble and data constellation
constel_preamble = utilPDC.sol_qamMap(M_preamble)
constel_data = utilPDC.sol_qamMap(M_data)

# Transmitter
# Generate preamble and data
preamble = np.random.randint(M_preamble, size=num_useful_carriers)
preamble_symbols = utilPDC.sol_encoder(preamble, constel_preamble)

data = np.random.randint(M_data, size=num_useful_carriers * (num_ofdm_symbols-1))
data_symbols = utilPDC.sol_encoder(data, constel_data)

# Create a simple mask that zeros out carriers at the end and the beginning
# of the spectrum
psd_mask = np.concatenate((np.zeros(num_zeros), np.ones(num_useful_carriers), np.zeros(num_zeros)))
psd_mask = np.fft.fftshift(psd_mask)  # the positive frequencies come first!

# Generate OFDM signal to be transmitted
tx_signal = my_utilOFDM.ofdm_tx_frame(num_carriers, prefix_length, preamble_symbols, psd_mask, data_symbols)
E_tx = np.mean(np.square(np.abs(tx_signal)))  # power of transmitted signal (var(tx_signal))

# Channel

# CP should be >= L-1, where L is the channel length
# (L-1) is the length of the channel history
channel_length = prefix_length + 1

# random impulse response for test purposes
# h = (1/np.sqrt(2)) * (np.random.randn(channel_length) + 1j * np.random.randn(channel_length))  # complex channel
# h = np.random.randn(channel_length)  # real channel (again, for testing purposes)

# multipath channel described in class
amplitudes = np.array([1.3290, 0.6754, 0.6516, -0.2432])
delays = np.array([0, 0.5, 1.5, 2])

lengthTailsSinc = 10  # length of the tails for the sinc (reconstruction filter)
h = my_utilOFDM.create_multipath_channel_filter(amplitudes, delays, lengthTailsSinc)

# simple channel
# which just adds WGN (non frequency selective channel, i.e, no ISI).
# to use it, uncomment the following line
# h = np.concatenate(np.ones(1), np.zeros(channel_length-1))

# normalize impulse response
# h = h/np.linalg.norm(h)

if h.size > prefix_length+1:
    warnings.warn(f'test_ofdm:impulseResponseTooLong: '
                  f'The channel impulse response is larger than the cyclic prefix length => it will create ISI.', RuntimeWarning)

# convolve tx_signal with channel impulse response
rx_signal = np.convolve(tx_signal, h)

# add AWGN
sigma2 = E_tx/(10 ** (SNR_dB/10))
sigma = np.sqrt(sigma2)
noise = (sigma/np.sqrt(2)) * (np.random.randn(rx_signal.size) + 1j * np.random.randn(rx_signal.size))

rx_signal_noisy = rx_signal + noise

# Receiver
Rf = my_utilOFDM.ofdm_rx_frame(rx_signal_noisy, num_carriers, psd_mask, prefix_length)
  
# Channel coefficients (lambdas) estimation

# h known
lambda1 = my_utilOFDM.channel_est1(num_carriers, psd_mask, h)

# Channel equalization, demodulation and determining SER
# The channel response is assumed to remain constant during the whole frame

eq_signal1 = Rf * np.tile(1/lambda1, (num_ofdm_symbols, 1)).transpose()
suf_statistics1 = eq_signal1.flatten('F')
estim_data1 = utilPDC.sol_decoder(suf_statistics1[preamble_symbols.size:], constel_data)
SER1 = np.sum(estim_data1 != data) / data.size
print('SER = ', SER1)

# Debugging information
suf_data1 = suf_statistics1[preamble_symbols.size:]

plt.scatter(suf_data1.real, suf_data1.imag, marker='*', color='b', label='Received constellation')
plt.scatter(constel_data.real, constel_data.imag, marker='+', color='r', label='Transmitted constellation')
# l1, = plt.scatter(suf_data1.real, suf_data1.imag, marker='*', color='b')
# l2, = plt.scatter(constel_data.real, constel_data.imag, marker='+', color='r')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('Received constellation after OFDM demodulation and equalization')
plt.grid()
# plot_lines = [l1, l2]
# legend1 = plt.legend(plot_lines, ["Received constellation", "Transmitted constellation"], loc=1)
# plt.gca().add_artist(legend1)
plt.legend(loc=1)
plt.show()
