import numpy as np
import matplotlib.pyplot as plt
from scipy import fft

# Problem 2

# System parameters
num_carriers = int(512)  # total number of carriers
prefix_length = int(25)  # length of the cyclic prefix
num_ofdm_blocks = int(1000)  # number of OFDM blocks (1 will be used to transmit the training sequence and the rest is for data symbols)
SNR = 20  # in dB


# Generate the training symbols and the data symbols
# Both are drawn from a 4-QAM constellation with unit energy

training_real = 2*np.random.randint(2, size=num_carriers) - 1
training_imag = 2*np.random.randint(2, size=num_carriers) - 1
training_sequence = (1/np.sqrt(2))*(training_real + 1j*training_imag)

data_real = 2*np.random.randint(2, size=num_carriers*(num_ofdm_blocks-1)) - 1
data_imag = 2*np.random.randint(2, size=num_carriers*(num_ofdm_blocks-1)) - 1
data = (1/np.sqrt(2))*(data_real + 1j*data_imag)

# Generate the OFDM trasmitted signal

A = np.zeros((num_carriers, num_ofdm_blocks), dtype=np.complex)
A[:, 0] = training_sequence
A[:, 1:] = np.reshape(data, (num_carriers, -1), 'F')

a0 = fft.ifft(A, num_carriers, axis=0)
a = np.concatenate((a0[num_carriers-prefix_length:num_carriers, :], a0), axis=0)
tx_signal = a.flatten('F')


# Pass the transmitted signal through the channel and add AWGN noise

# The discrete-time symbol-level channel impulse response, including the
# filters at the transmitter and the receiver, is the following:
h = np.array([0.87, 0.62, -0.45, 0.34, -0.12])

# You should add AWGN such that we obtain an SNR = 20 dB with respect to
# the transmitted signal.


# Convolve tx_signal with the channel impulse response
rx_signal = np.convolve(tx_signal, h)

# Add AWGN
E_tx = np.mean(np.square(np.abs(tx_signal)))  # power of the transmitted signal (var(tx_signal))
sigma = np.sqrt(E_tx)/(10 ** (SNR/20))
noise = (sigma/np.sqrt(2)) * (np.random.randn(rx_signal.size) + 1j * np.random.randn(rx_signal.size))
rx_signal_noisy = rx_signal + noise


# Implement the OFDM receiver to obtain the noisy and scaled transmitted symbols

# Remove the tail
num_ofdm_blocks_rx = int(np.floor(rx_signal_noisy.size) / (num_carriers + prefix_length))
rx_signal_noisy = rx_signal_noisy[0:num_ofdm_blocks_rx*(num_carriers+prefix_length)]

# Remove the cyclic prefix
rx_withCP = np.reshape(rx_signal_noisy, (prefix_length+num_carriers, num_ofdm_blocks), 'F')
# remove the rows corresponding to the cyclic prefix
rx_noCP = rx_withCP[prefix_length + np.arange(num_carriers), :]  # remove the rows corresponding to cyclic prefix

# Go to the frequency domain
Rf = fft.fft(rx_noCP, num_carriers, axis=0)


# Estimate the channel coefficients (the lambdas) using the LS (Least
# Squares) estimator applied to the training sequence
lambda_estimated_LS = Rf[:, 0] / training_sequence

# Compute the theoretical lambdas (using h)
lambda_theoretical = fft.fft(h, num_carriers)  # DFT of h

# Plot lambda_estimated_LS and lambda_theoretical to compare them
l1, = plt.plot(np.abs(lambda_theoretical), '-*b')
l2, = plt.plot(np.abs(lambda_estimated_LS), '-+g')
plt.xlabel('Frequency index')
plt.ylabel('Magnitude')
plt.title('The lambdas')
plt.grid()
plot_lines = [l1, l2]
legend1 = plt.legend(plot_lines, ["Theoretical", "LS Estimate"], loc=1)
plt.gca().add_artist(legend1)
plt.show()

# Equalize the received symbols (using lambda_theoretical)
eq_signal = Rf * np.tile(1/lambda_theoretical, (num_ofdm_blocks, 1)).transpose()
suf_statistics = eq_signal.flatten('F')


# Plot the equalized received data symbols constellation
suf_data = suf_statistics[training_sequence.size:]
plt.scatter(suf_data.real, suf_data.imag, marker='*', color='b', label='Received constellation')
plt.scatter(data.real, data.imag, marker='+', color='r', label='Transmitted constellation')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('Received constellation after OFDM demodulation and equalization')
plt.grid()
plt.legend()
plt.show()


# Demodulate the received data symbols to obtain the estimates of the transmitted data symbols
estimated_data = (1/np.sqrt(2))*((2*(suf_data.real > 0) - 1) + 1j*(2*(suf_data.imag > 0) - 1))

# Compute the symbol error rate. If everything was correct, there should be no symbol errors.
SER = np.sum(estimated_data != data) / data.size
print('SER = ', SER)
