import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.signal import correlate as corr

# Problem 3

# Load the various vectors
pas = sio.loadmat('pa_sequence.mat')
pa_sequence = pas['pa_sequence'].flatten()

rxs = sio.loadmat('rx_samples.mat')
rx_samples = rxs['rx_samples'].flatten()

# Define parameters
PAPB = int(10)  # PA sequences per symbol

# Find the beginning of the first full pa_sequence within the received
# samples. If your code works correctly, that should be at sample number
# 208.

# Get the minimum amount of data required to find a full pa_sequence
rx_corr = rx_samples[:2*pa_sequence.size]
R = corr(rx_corr, pa_sequence, 'valid')

plt.figure()
plt.plot(abs(R), marker='o')
plt.grid()
plt.title('The result of the correlation')
plt.show()

# Find the beginning of the first full pa_sequence
ind = np.argmax(abs(R))
print('ind = ', ind)

# Find the beginning of the first full symbol within the received samples.
# To do that, you should look for a phase transition of roughly pi between
# the inner products of two properly chosen consecutive chunks of received
# samples with the pa_sequence. If your code works properly, you should find
# that the first complete symbol starts at sample number 1744.

found = False
lastInnerProduct = np.dot(rx_samples[ind:ind+pa_sequence.size], np.conj(pa_sequence))

while ~found:
    
    ind = ind + pa_sequence.size  # advance by the length of a PA
    y = rx_samples[ind:ind+pa_sequence.size]  # get the corresponding data

    innerProduct = np.dot(y, np.conj(pa_sequence))

    # For a safer estimation, we reduce the decision zone to [3*pi/4, 5*pi/4]
    # Note that angle() returns values in (-pi, pi]
    found = (np.abs(np.angle(innerProduct * np.conj(lastInnerProduct))) > 3*np.pi/4)
    lastInnerProduct = innerProduct

index_change_phase = ind
print('index_change_phase = ', index_change_phase)
index_start_first_symbol = int(ind % (PAPB*pa_sequence.size))
print('index_start_first_symbol = ', index_start_first_symbol)


# Extract from the received samples the portion corresponding to full
# symbols and process it accordingly in order to generate the sufficient
# statistics about the transmitted data symbols. Plot the result. If your
# code works properly, you should obtain a rotated (and noisy) BPSK
# constellation.

# Trim the received samples to have full symbols
rx_samples = rx_samples[index_start_first_symbol:]
rx_samples = rx_samples[:-int(rx_samples.size % (PAPB*pa_sequence.size))]

# Apply the matched filter
code = np.tile(pa_sequence, (1, PAPB)).flatten()
rx_data_mf = np.convolve(rx_samples, np.conj(np.flip(code)))
# Remove the tails
rx_data_mf = rx_data_mf[code.size-1:-code.size+1]
# Sample the MF output
rx_dec = rx_data_mf[0:rx_data_mf.size:code.size]
print('received_number_of_symbols = ', rx_dec.size)

# Plot the constellation at the (sampled) output of the matched filter
plt.scatter(rx_dec.real, rx_dec.imag, marker='*')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('The received constellation at the output of the matched filter')
plt.grid()
plt.show()

# Knowing that the first transmitted symbol was a +1, estimate the phase
# introduced by the channel. If your code works properly, you should get a
# phase of about 2.07 radians.
phase_shift = np.angle(rx_dec[0])
print('phase_shift = ', phase_shift)

# Correct for the phase rotation, take decisions in oder to estimate the
# transmitted data symbols and compute the symbol error rate (SER). If
# everything works correctly, you should get SER = 0.
rx_dec_corrected = np.exp(-1j*phase_shift)*rx_dec

# Plot to see what we have got
plt.scatter(rx_dec_corrected.real, rx_dec_corrected.imag, marker='*')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('The received constellation after the phase correction')
plt.grid()
plt.show()

# Decode the symbols and compute the SER
decodedSymbols = 1 - 2*(np.real(rx_dec_corrected) < 0)

txds = sio.loadmat('tx_data_symbols.mat')
tx_data_symbols = txds['tx_data_symbols'].flatten()

SER = np.sum(decodedSymbols != tx_data_symbols)/tx_data_symbols.size
print('SER = ', SER)
