import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio

# Problem 3

# Define parameters
N = int(256)  # number of carriers

# Load various files
data = sio.loadmat('data_symbols.mat')
data_symbols = data['data_symbols'].flatten()

ofdm_symb = sio.loadmat('ofdm_rx_symbols.mat')
ofdm_rx_symbols = ofdm_symb['ofdm_rx_symbols'].flatten()

# Plot to see the rotated Rx symbols
plt.scatter(ofdm_rx_symbols.real, ofdm_rx_symbols.imag, marker='*', color='b')
plt.grid()
plt.title('Rotated RX symbols')
plt.xlabel('Real')
plt.ylabel('Imag')
plt.show()

# Reshape into a matrix form
rx_fft = np.reshape(ofdm_rx_symbols, (N, -1), 'F')

# Plot the first OFDM block to check if the rotation is small enough indeed
plt.scatter(rx_fft[:, 0].real, rx_fft[:, 0].imag, marker='*', color='b')
plt.grid()
plt.title('First OFDM block')
plt.xlabel('Real')
plt.ylabel('Imag')
plt.show()

# Implement the decoding mechanism
decoded_symbols = np.zeros(rx_fft.shape, dtype=np.complex)
rot_col = 0
no_columns = rx_fft.shape[1]
rot_all_columns = np.zeros(no_columns)

for k in range(no_columns):
   
    # Correct all the following columns with the previous rotation
    rx_fft = np.exp(-1j*rot_col)*rx_fft
    
    # Get the current column
    col = rx_fft[:, k]
 
    dec_symbols_col = (2*(col.real > 0) - 1) + 1j*(2*(col.imag > 0) - 1)
    # Compute the average rotation over all symbols in the block
    rot_col = np.mean(np.angle(col/dec_symbols_col))
    
    rot_all_columns[k] = rot_col
    decoded_symbols[:, k] = dec_symbols_col

# Plot the rotation just to check if constant over the blocks
plt.plot(rot_all_columns, '*')
plt.grid()
plt.title('Estimated rotation over all the OFDM blocks')
plt.xlabel('OFDM block index')
plt.ylabel('Angle [radians]')
plt.show()

# Serialize and compute SER
ds = decoded_symbols.flatten('F')
SER = np.sum(ds[0:data_symbols.size] != data_symbols)/data_symbols.size
print('SER = ', SER)
