import numpy as np
import utilPDC
import sys
import matplotlib.pyplot as plt
import my_utilMIMO

"""MIMO_PERFORMANCE simulates the performance of a MIMO communication system at different SNRs"""

# Parameter Selection
antennas = int(2)
modtype = 'qam'  # 'qam', 'psk'
M = int(4)  # constellation size: 4, 16, 64

Es_sigma2 = np.arange(1, 21)  # list of SNRs for simulations

trials_per_snr = int(1e3)  # number of times a new channel is generated per SNR
symbols_per_trial = int(1e4)  # number of symbols over which the channel is constant

test_zf = True  # test the ZF equalizer? (set to true after implementing the equalizer)
test_lmmse = True  # test the LMMSE equalizer? (set to true after implementing the equalizer)

# Main Simulation Loop

if not test_zf and not test_lmmse:
    print(f'Nothing to simulate\n')
    sys.exit()

if test_zf:
    ser_zf = np.zeros(Es_sigma2.size)
    mse_zf = np.zeros(Es_sigma2.size)

if test_lmmse:
    ser_lmmse = np.zeros(Es_sigma2.size)
    mse_lmmse = np.zeros(Es_sigma2.size)

# To get the most random numbers for each run
np.random.seed()

# 1.3.2 - Get the mapping
if modtype.lower() == 'qam':
    mapping = utilPDC.sol_qamMap(M)
elif modtype.lower() == 'psk':
    mapping = utilPDC.sol_pskMap(M)
else:
    raise ValueError('mimo:invalidModulation', 'Unsupported modulation type %s', modtype)

# Normalize constellation to unit average energy
mapping = mapping / np.sqrt(np.mean(np.abs(mapping)**2))

for k in range(Es_sigma2.size):
    SNR = Es_sigma2[k]
    print(f'SNR = {SNR}, ', end='')
    for iteration in range(trials_per_snr):
        tx_digits = np.random.randint(M, size=symbols_per_trial)
        tx_symbols = utilPDC.sol_encoder(tx_digits, mapping)
        
        tx_symbols = np.reshape(tx_symbols, (antennas, -1), 'F')
        
        rx_symbols, H = my_utilMIMO.mimo_channel(tx_symbols, SNR, antennas)

        if test_zf:  # Zero-forcing
            suff_stat = my_utilMIMO.zeroForcing_equalizer(rx_symbols, H)
            
            mse_zf[k] = mse_zf[k] + my_utilMIMO.noise_norm(tx_symbols, suff_stat)
            
            # Reshape the sufficient statistics to a row vector
            suff_stat = suff_stat.flatten('F')
            rx_digits = utilPDC.sol_decoder(suff_stat, mapping)
            
            ser_zf[k] = ser_zf[k] + np.sum(rx_digits != tx_digits)
        
        if test_lmmse:  # LMMSE
            suff_stat = my_utilMIMO.lmmse_equalizer(rx_symbols, H, 10**(-0.1 * SNR))
            
            mse_lmmse[k] = mse_lmmse[k] + my_utilMIMO.noise_norm(tx_symbols, suff_stat)
            
            # Reshape the sufficient statistics to a row vector
            suff_stat = suff_stat.flatten('F')
            rx_digits = utilPDC.sol_decoder(suff_stat, mapping)
            
            ser_lmmse[k] = ser_lmmse[k] + np.sum(rx_digits != tx_digits)
    
    if test_zf:
        mse_zf[k] = mse_zf[k] / (trials_per_snr * symbols_per_trial)
        ser_zf[k] = ser_zf[k] / (trials_per_snr * symbols_per_trial)
        
        print(f'SER (ZF)   = {ser_zf[k]}, MSE (ZF)   = {mse_zf[k]}', end='')
    
    if test_lmmse:
        mse_lmmse[k] = mse_lmmse[k] / (trials_per_snr * symbols_per_trial)
        ser_lmmse[k] = ser_lmmse[k] / (trials_per_snr * symbols_per_trial)

        print(f'\t SER (LMMSE) = {ser_lmmse[k]}, MSE (LMMSE) = {mse_lmmse[k]}')

    # print('\n')

# Plot the results
fig, axs = plt.subplots(1, 2, constrained_layout=True)
if test_zf:
    axs[0].semilogy(Es_sigma2, ser_zf, label='Zero Forcing')

if test_lmmse:
    axs[0].semilogy(Es_sigma2, ser_lmmse, label='MMSE')

axs[0].grid()
axs[0].set_xlabel('Channel SNR [dB]')
axs[0].set_ylabel('Symbol Error Rate')
axs[0].legend()

if test_zf:
    axs[1].plot(Es_sigma2, 10 * np.log10(mse_zf), label='Zero Forcing')

if test_lmmse:
    axs[1].plot(Es_sigma2, 10 * np.log10(mse_lmmse), label='MMSE')

axs[1].grid()
axs[1].set_xlabel('Channel SNR [dB]')
axs[1].set_ylabel('MSE [dB]')
axs[1].legend()

plt.show()
