import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.signal import correlate as corr

# Problem 2

# Load the various vectors
prs = sio.loadmat('preamble.mat')
preamble = prs['preamble'].flatten()

rxs = sio.loadmat('rx_samples.mat')
rx_samples = rxs['rx_samples'].flatten()

# Define parameters
SPS = int(5)  # Samples per symbol

# Find the start of the preamble within the received samples. If your code
# works properly, you should find that the preamble starts at the sample
# number 211.

# Create the pulse-shaped preamble
preamble_ps = np.kron(preamble, np.ones(SPS))
# Do the correlation
R = corr(rx_samples, preamble_ps, 'valid')

# Plot to see what we have got
plt.figure()
plt.plot(abs(R), marker='o')
plt.grid()
plt.title('The result of the correlation')
plt.show()

# Find the sample number where the preamble starts
ind = np.argmax(abs(R))
print('ind = ', ind)

# Extract from the received samples the portion corresponding to the 1000
# transmitted BPSK data symbols and process it accordingly in order to
# generate the sufficient statistics about the transmitted data symbols.
# Plot the result. If your code works properly, you should obtain a rotated
# (and noisy) BPSK constellation.

N = preamble.size
txds = sio.loadmat('tx_data_symbols.mat')
tx_data_symbols = txds['tx_data_symbols'].flatten()
L = tx_data_symbols.size
# Extract the useful samples corresponding to the transmitted data
rx_data_samples = rx_samples[ind + SPS*N:ind + SPS*(N+L)]

# Implement the matched filter (MF)
rx_data_samples_mf = np.convolve(rx_data_samples, np.ones(SPS))
# Remove the tails
rx_data_samples_mf = rx_data_samples_mf[SPS-1:-SPS+1]
# Sample the MF output
suffStat = rx_data_samples_mf[0:rx_data_samples_mf.size:SPS]
# Check the length
print('length_suffStat = ', suffStat.size)

# Plot the constellation
plt.scatter(suffStat.real, suffStat.imag, marker='*')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('The received constellation at the output of the matched filter')
plt.grid()
plt.show()

# Perform differential decoding. Assume that the first symbol is either a
# +1 or a -1 and then decide on the following symbols depending on the
# phase difference between two consecutive values of the sampled output of
# the matched filter.

# Plot the phase difference to see what we have got
phase = np.angle(suffStat[:-1] * np.conj(suffStat[1:]))

plt.plot(phase, '*')
plt.title('Phase difference')
plt.grid()
plt.show()

# Assume that the first symbol was an +1 and then perform differential
# decoding.
decodedSymbols = np.zeros(L)
decodedSymbols[0] = 1
for k in range(1, decodedSymbols.size):
    decodedSymbols[k] = (1 - 2*(abs(phase[k-1]) > np.pi/2))*decodedSymbols[k-1]

# At this point, using the transmitted data symbols given to you, one can
# compute the symbol error rate (SER). If your code works properly, this
# should be either 0, if your assumption about the first symbol was
# correct, or 1, if your assumption about the first symbol was wrong.

print('first_transmitted_symbol = ', tx_data_symbols[0])
print('first_decoded_symbol = ', decodedSymbols[0])
SER = np.sum(decodedSymbols != tx_data_symbols)/tx_data_symbols.size
print('SER = ', SER)
