import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio

# Problem 3

# Load the various files
rx = sio.loadmat('received_OFDM_symbols.mat')
received_OFDM_symbols = rx['received_OFDM_symbols']

ts = sio.loadmat('training_symbols.mat')
training_symbols = ts['training_symbols'].flatten()

lmda = sio.loadmat('lambdas.mat')
lambdas = lmda['lambdas'].flatten()


# Using the first column of the received symbols matrix and the training
# sequence, estimate the noise variance. The noise is assumed to have
# independent and identically distributed components.
tmp = received_OFDM_symbols[:, 0].flatten() - training_symbols * lambdas
noise_var = np.mean(np.square(np.abs(tmp)))
print('noise_var = ', noise_var)


# If you do not manage to compute the noise variance, you can use a value
# of 0.01 for it in the following.

# Compute the LS estimate of the transmitted data symbols
D = np.diag(lambdas)
# Since D is diagonal, any of the following will work, they are the same.
data_LS_estim1 = received_OFDM_symbols[:, 1:] / lambdas[:, np.newaxis]
data_LS_estim2 = np.matmul(np.linalg.pinv(D), received_OFDM_symbols[:, 1:])
data_LS_estim3 = np.linalg.lstsq(D, received_OFDM_symbols[:, 1:], rcond=None)[0]


# Plot the estimated data symbols. You should get a noisy 4-QAM
# constellation.
suf_data = data_LS_estim1.flatten('F')
plt.scatter(suf_data.real, suf_data.imag, marker='*', color='b')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('The constellation of the estimated data symbols (LS)')
plt.grid()
plt.show()


# Compute the LMMSE estimate of the transmitted data symbols
Kay = D.conj().T
Ky = np.matmul(D, D.conj().T) + np.diag(noise_var * np.ones(lambdas.size))
tmp_matrix = np.linalg.lstsq(Ky, received_OFDM_symbols[:, 1:], rcond=None)[0]
data_LMMSE_estim = np.matmul(Kay, tmp_matrix)


# Plot the estimated data symbols. You should get a noisy 4-QAM
# constellation.
suf_data = data_LMMSE_estim.flatten('F')
plt.scatter(suf_data.real, suf_data.imag, marker='*', color='b')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('The constellation of the estimated data symbols (LMMSE)')
plt.grid()
plt.show()


# Plot the absolute value squared of the lambdas.
plt.plot(np.abs(lambdas), '-*b')
plt.xlabel('Frequency index')
plt.ylabel('Magnitude')
plt.title('|lambdas|^2')
plt.grid()
plt.show()


# Then, compare the values of the estimated symbols obtained with the two
# methods above. For example, you can compare them by looking at the
# average difference between the two estimates. Explain your results (write
# your text in here).
average_difference_estimated_symbols = np.mean(np.abs(data_LS_estim1.flatten('F') - data_LMMSE_estim.flatten('F')))
print('average_difference_estimated_symbols = ', average_difference_estimated_symbols)

# The two methods produce very similar estimates. This is due to the fact
# that the diagonal elements of D*D' (i.e., abs(lambdas).^2) are much
# larger than the noise variance, so the LMMSE formula simplifies to be the
# same as for LS.
