import matplotlib.pyplot as plt
from scipy.stats import maxwell, norm
from MD import *

# Step 3.1
# Test N-H at several temperatures
#############################################################


nsteps = 2000
dt = 0.0046
Q = 10


N, L, pos, vel, xi, lns = read_pos_vel('sampleNVT94.4.dat')

for T in 0.78*np.arange(0.4,1.6,0.4):

    # Run NVT in equilibrated system
    output = run_NVT(pos, vel, L, nsteps//2, N, dt, T, Q)

    # Change T 
    output = run_NVT(output['pos'],output['vel'], L, nsteps, N, dt, T, Q, output['xi'],output['lns'])

    # Plot p(v) distribution
    v = np.linspace(-10,10,499)
    plt.plot(v, np.mean(output['pv'], axis=0),label="T=%.2f"%T)

    # Write p(v) into file
    np.savetxt('pv_T%.2f.dat'%T, np.column_stack((v,np.mean(output['pv'], axis=0))))



plt.legend()
plt.show()

'''
# Step 3.2
# TEMPERATURE FLUCTUATIONS NVE vs NVT
#############################################################


nsteps_eq = 300
nsteps = 3000
dt = 0.0046
Q = 5.

N, L, pos, vel, xi, lns = read_pos_vel('sampleNVT94.4.dat')

varT=[]
Ts=[]

#N, L, pos, vel, xi, lns = read_pos_vel('sampleNVT94.4.dat')

for T in 0.78*np.arange(0.8,1.8,0.2):

    # Run NVT in equilibrated system
    output = run_NVT(pos, vel, L, nsteps_eq, N, dt, T, Q)

    # Change T 
    output = run_NVT(output['pos'],output['vel'], L, nsteps, N, dt, T, Q, output['xi'],output['lns'])
    
    # Temperature fluctuations 
    varT.append(np.var(output['EnKin']/3*2))
    Ts.append(np.mean(output['EnKin']/3*2)**2)
    
plt.plot(Ts,varT)
plt.show()
'''
