import numpy as np
import matplotlib.pyplot as plt
from scipy import fft
import scipy.io as sio

# Problem 2

# Define parameters
N = int(256)  # number of carriers
L = int(20)  # length of the cyclic prefix
Noise_variance = 0.3  # noise variance after the DFT at the receiver
T = 1e-6  # [seconds]

# Load various files
pnSeq = sio.loadmat('pn_sequence.mat')
pn_sequence = pnSeq['pn_sequence'].flatten()

dataSymb = sio.loadmat('data_symbols.mat')
data_symbols = dataSymb['data_symbols'].flatten()

# Get the indices of the carriers where we transmit data symbols.
frequencies = np.array([30, 80, 110, -20, -70, -100])*1e6
index_carriers = (frequencies*T % N).astype(int)
# Or in a fancier way
# index_carriers = np.array([np.where(frequencies[i] == np.fft.fftshift(np.arange(-N/2, N/2)/T))[0][0] for i in range(frequencies.size)])
print('index_carriers = ', index_carriers)

no_useful_carriers = index_carriers.size

# Generate the training sequence (same symbol [1+1j] on all carriers)
training_sequence = np.ones(N) + 1j*np.ones(N)

# See if we need to do some zero padding to fit the Tx data symbols
no_ofdm_blocks = int(np.ceil(data_symbols.size/no_useful_carriers))
zero_padding = no_ofdm_blocks*no_useful_carriers - data_symbols.size
print('zero_padding = ', zero_padding)
data_padded = np.concatenate((data_symbols, np.zeros(zero_padding, dtype=np.complex)))

# Create the matrix with the data symbols
B = np.reshape(data_padded, (no_useful_carriers, no_ofdm_blocks), 'F')

# We need one extra OFDM block for the training sequence
A = np.zeros((N, no_ofdm_blocks + 1), dtype=np.complex)
# Insert the training sequence
A[:, 0] = training_sequence
# Insert the data symbols
A[index_carriers, 1:] = B

# Do the IFFT
a0 = fft.ifft(A, N, axis=0)
# Add the CP
a = np.concatenate((a0[N-L:N, :], a0), axis=0)

# Serialize the Tx samples
tx_samples = a.flatten('F')

# Add the P/N sequence
tx_samples = np.concatenate((pn_sequence, tx_samples))

# The ISI channel (sample-level)
h = np.array([3.2, 1.7, 1.3, 1, 0.8, 0.45, 0.3, 0.15, 0.05])

# Convolve with the channel
samples = np.convolve(tx_samples, h)

# Generate the noise samples
sigma2 = Noise_variance/N  # Noise variance is increased since the DFT is not an unitary transformation
sigma = np.sqrt(sigma2)
# Create the sample-level noise vector
noise = (sigma/np.sqrt(2)) * (np.random.randn(samples.size) + 1j * np.random.randn(samples.size))

# Add the noise, create the Rx samples
rx_samples = samples + noise

# Find the P/N sequence
R = np.convolve(rx_samples, np.conj(np.flip(pn_sequence)), mode='valid')
index_start_pn = np.argmax(np.abs(R))
print('index_start_pn = ', index_start_pn)

# Remove the P/N sequence
rx_samples = rx_samples[index_start_pn + pn_sequence.size:]

# Remove eventual garbage at the end
no_rx_ofdm_blocks = int(np.floor(rx_samples.size/(N+L)))
print('no_rx_ofdm_blocks = ', no_rx_ofdm_blocks)
rx_samples = rx_samples[0:no_rx_ofdm_blocks*(N+L)]

# Prepare the samples for the OFDM receiver
rx_samples = np.reshape(rx_samples, (N+L, no_rx_ofdm_blocks), 'F')

# Remove the CP
rx_samples = rx_samples[L + np.arange(N), :]
# Do the FFT
rx_fft = fft.fft(rx_samples, N, axis=0)

# Estimate and plot the channel in frequency domain (lambdas)
lambdas = rx_fft[:, 0] / training_sequence

# Compute the theoretical lambdas
th_lamdas = fft.fft(h, N)

# Plot the lambdas and check if they agree
l1, = plt.plot(np.abs(lambdas), '-*b')
l2, = plt.plot(np.abs(th_lamdas), '-+g')
plt.title('The lambdas')
plt.xlabel('Frequency index')
plt.ylabel('abs(lambdas)')
plt.grid()
plot_lines = [l1, l2]
legend = plt.legend(plot_lines, ["Estimated (LS)", "Theoretical"], loc=1)
plt.gca().add_artist(legend)
plt.show()

# Correct for the lambdas
lambdas = lambdas[index_carriers]

eq_symbols = rx_fft[index_carriers, 1:] / lambdas[:, None]

# Serialize
eq_symbols = eq_symbols.flatten('F')

# Remove eventual tails
eq_symbols = eq_symbols[0:data_symbols.size]

# Plot the received symbols
plt.scatter(eq_symbols.real, eq_symbols.imag, marker='*', color='b')
plt.grid()
plt.title('Equalized symbols at Rx')
plt.xlabel('Real')
plt.ylabel('Imag')
plt.show()

# Take decisions and compute SER
dec_symbols = (2*(eq_symbols.real > 0) - 1) + 1j*(2*(eq_symbols.imag > 0) - 1)
SER = np.sum(dec_symbols != data_symbols)/data_symbols.size
print('SER = ', SER)
