import numpy as np
import matplotlib.pyplot as plt
import utilMDC
import utilPDC
from scipy import special

# Basic script to implement a full communication chain and determine the
# error probabilities

# Define parameters
nsymbols = int(1e6)  # use a smaller value while debugging
M = int(4)  # 4-ary qam
bitsPerSymb = int(np.log2(M))
nbits = int(nsymbols * bitsPerSymb)
beta = 0.22
SPS = int(4)
SPAN = int(10)
SNR_dB = 10  # dB

# create the bits (0/1 values)
bits = np.random.randint(2, size=nbits)

# create the symbols (sol_bi2de, sol_qamMap, sol_encoder)
MaryData = utilPDC.sol_bi2de(np.reshape(bits, (nsymbols, bitsPerSymb)))
# encode
map = utilPDC.sol_qamMap(M)
symbols = utilPDC.sol_encoder(MaryData, map)

# map symbols to samples (rcosdesign + sol_symbols2samples)
h = utilPDC.sol_rcosdesign(beta, SPAN, SPS)
# the pulse is already normalized
samples = utilPDC.sol_symbols2samples(symbols, h, SPS)

# add AWGN noise (complex valued! Determine the noise variance from the energy of the symbols and SNR_dB)
# ========= pedestrian method ======
# find the signal energy per symbol
Es = np.var(symbols)
# find the noise variance sigma2 so that 10*log_10(Es/sigma2) = SNR_dB
sigma2 = Es/(10 ** (SNR_dB/10))
sigma = np.sqrt(sigma2)
# create the sample-level noise vector
noise = (sigma/np.sqrt(2))*np.random.randn(np.size(samples)) + 1j*(sigma/np.sqrt(2))*np.random.randn(np.size(samples))
actual_SNR = 10*np.log10(np.var(symbols)/np.var(noise))
print('The obtained SNR is: %s [dB]', actual_SNR)  # test that we get back SNR
# create the channel output
received = samples + noise
# ========= end of pedestrian method ======

# generate the sufficient statistics (sol_sufficientStatistics)
suffStat, mfOutput = utilPDC.sol_sufficientStatistics(received, h, SPS)

# plot the received symbols constellation
plt.scatter(suffStat.real, suffStat.imag, marker='*')
plt.ylabel('Imaginary')
plt.xlabel('Real')
plt.title('The received symbols constellation')
plt.grid()
plt.show()

# decode (sol_decoder)
MaryData_decoded = utilPDC.sol_decoder(suffStat, map)

# compute the symbol-error probability from simulation
symbolErrorCount = np.sum(MaryData != MaryData_decoded)
symbolErrorProbability_simulated = symbolErrorCount/np.size(MaryData)
print('symbolErrorProbability_simulated = %s' % symbolErrorProbability_simulated)

# compute the symbol-error probability from the formula
snr_sqrt = np.sqrt(Es/sigma2)
q = 0.5 - 0.5*special.erf(snr_sqrt/np.sqrt(2))  # Q(f) = 0.5 - 0.5 erf(f/sqrt(2))
symbolErrorProbability_formula = 2*q - q**2
print('symbolErrorProbability_formula = %s' % symbolErrorProbability_formula)


# compute the bit-error probability from simulation (sol_de2bi)
bits_decoded = utilPDC.sol_de2bi(MaryData_decoded)
bitErrorCount = np.sum(bits_decoded != bits)
bitErrorProbability_simulated = bitErrorCount/np.size(bits)
print('bitErrorProbability_simulated = %s' % bitErrorProbability_simulated)

# compute the bit-error probability from the formula
bitErrorProbability_formula = q
print('bitErrorProbability_formula = %s' % bitErrorProbability_formula)
