import matplotlib.pyplot as plt
from scipy.stats import maxwell, norm
from MD import *

# Step 1.1 
# Equilibrate the sample in the NVT
#############################################################

N, L, pos, vel = read_pos_vel('sampleT94.4.dat')

Q = 10.             # thermal intertia parameter
T = 0.7807          # requested temperature
nsteps = 900        # number of steps
dt = 0.0046         # integration step

# Run MD simulation in NVT ensemble
output = run_NVT(pos, vel, L, nsteps, N, dt, T, Q)

# Write equilibrated strusture into a file
dump_pos_vel('sampleNVT94.4.dat', output['pos'], output['vel'], N, L, output['xi'], output['lns'])


# Plot E_NH vs step
plt.plot(output['nsteps'],output['EnNH'])
plt.plot(output['nsteps'],output['EnKin']+output['EnPot'])
plt.show()


'''
# Step 1.2
# Study the fluctuations
#############################################################

Q = 10.
T = 0.7807
nsteps = 3000
dt = 0.0046
Nbins=300

N, L, pos, vel, xi, lns = read_pos_vel('sampleNVT94.4.dat')

# Run MD simulation in NVT ensemble
output = run_NVT(pos, vel, L, nsteps, N, dt, T, Q, xi, lns,Nbins=Nbins)

# Plot temperature distribution
plt.hist(output['EnKin']/3*2, bins=np.linspace(0,10,499),fc='none',
        histtype='step', density = True)
plt.show()

# Plot p(v) distribution
v = np.linspace(-10,10,Nbins)
plt.plot(v, np.mean(output['pv'], axis=0))
plt.plot(v,norm.pdf(v,0,np.sqrt(T)))
plt.show()

# Plot p(v**2) distribution
v = np.linspace(0,10,Nbins)
plt.plot(v, np.mean(output['pv2'], axis=0))
plt.plot(v, maxwell.pdf(v, 0, np.sqrt(T)))
plt.show()
'''
