#!/usr/bin/env python3

"""
Bare minimum molecular dynamics code for modeling liquid argon in NVE and NVT
ensembles.

- velocity-Verlet algorithm
- sampling g(r) and S(k)
- sampling mean squared displacement and velocity autocorrelation function
- sampling velocity distribution in NVT
"""

__author__  = "Alexey Tal"
__contact__ = "alexey.a.tal@gmail.com"
__license__ = "GPLv3"
__date__    = "2021/05/18"

import numpy as np
import matplotlib.pyplot as plt
from numba import jit
import copy

def crystal(Ncells, lat_par, T=0.78):
    # Create initial position and velocities of the atoms

    # Total number of atoms
    N = 4*Ncells**3

    # Size of simulation box
    L = Ncells*lat_par

    # Volume of the simulation box
    V = L**3

    initPositions = create_init_Positions(Ncells, lat_par, N)
    initVelocities = create_init_Velocities(N, T)

    return initPositions, initVelocities

def create_init_Positions(Ncells,lat_par,N):
    # Create FCC lattice for Argon atoms

    x = np.linspace(0,lat_par*Ncells-lat_par,Ncells)
    xx, yy, zz = np.meshgrid(x, x, x)
    Ncells3 = Ncells**3
    xPos = np.reshape(xx, (Ncells3, 1))
    yPos = np.reshape(yy, (Ncells3, 1))
    zPos = np.reshape(zz, (Ncells3, 1))

    # Set of atoms at 0,0,0 in unit cell
    pos1 = np.hstack((xPos, yPos, zPos))

    # Set of atoms at 1/2,1/2,0 in unit cell
    pos2 = np.zeros((Ncells3, 3))
    pos2[:, 0] = pos1[:, 0] + lat_par/2
    pos2[:, 1] = pos1[:, 1] + lat_par/2
    pos2[:, 2] = pos1[:, 2]

    # Set of atoms at 1/2,0,1/2 in unit cell
    pos3 = np.zeros((Ncells3, 3))
    pos3[:, 0] = pos1[:, 0] + lat_par/2
    pos3[:, 1] = pos1[:, 1]
    pos3[:, 2] = pos1[:, 2] + lat_par/2

    # Set of atoms at 0,1/2,1/2 in unit cell
    pos4 = np.zeros((Ncells3, 3))
    pos4[:, 0] = pos1[:, 0]
    pos4[:, 1] = pos1[:, 1] + lat_par/2
    pos4[:, 2] = pos1[:, 2] + lat_par/2

    # Vectors for all atoms
    pos = np.vstack((pos1, pos2, pos3, pos4))

    # Random displacements
    #pos += np.random.normal(0, 0.09, (N, 3))

    return pos

def create_init_Velocities(N, T):
    # Create Maxwell-Boltzmann distribution of velocities
    # x,y,z projections are generated by Normal distribution

    mu = 0
    sigma = np.sqrt(T)

    # MW veloctties
    vel = np.random.normal(mu, sigma, (N, 3))

    # Remove center of mass velocity
    mean = np.mean(vel, axis = 0)
    vel -= mean

    return vel




@jit(nopython=True, parallel=True)
def calculateForces(pos, L, N, Nbins):
    Forces_x = np.zeros((N,1))
    Forces_y = np.zeros((N,1))
    Forces_z = np.zeros((N,1))
    gofr = np.zeros(Nbins)
    k = np.zeros(Nbins)
    sofk = np.zeros(Nbins)
    EnPot = 0
    r_cutoff = 2.5
    r_cutoff2 = r_cutoff**2
    Lbin = 0.5 * L / Nbins

    for ii in range(Nbins):
        k[ii] = Lbin * (ii + 0.5)*2*np.pi

    # Loop over all pairs
    for ii in range(N):
        for jj in range(ii+1,N):

            # Distance between particles
            d_x = pos[ii, 0] - pos[jj, 0]
            d_y = pos[ii, 1] - pos[jj, 1]
            d_z = pos[ii, 2] - pos[jj, 2]

            # Periodic boundary conditions
            d_x -= L * np.rint(d_x / L)
            d_y -= L * np.rint(d_y / L)
            d_z -= L * np.rint(d_z / L)

            # Absolute distance
            dr2 = (d_x**2 + d_y**2 + d_z**2)

            # Potential and force inside cutoff
            if dr2 < r_cutoff2:
                EnPot += 4*((1/dr2)**6 - (1/dr2)**3)

                # Forces projections are computed as -(1/dr)*dV/dr*dx
                Fx = 4*(12*(1/dr2)**7 - 6*(1/dr2)**4) * d_x
                Fy = 4*(12*(1/dr2)**7 - 6*(1/dr2)**4) * d_y
                Fz = 4*(12*(1/dr2)**7 - 6*(1/dr2)**4) * d_z

            else:
                Fx = 0
                Fy = 0
                Fz = 0

            # Add computed forces
            Forces_x[ii] += Fx
            Forces_x[jj] -= Fx

            Forces_y[ii] += Fy
            Forces_y[jj] -= Fy

            Forces_z[ii] += Fz
            Forces_z[jj] -= Fz

            # Histogram
            if dr2 < (L/2.)**2:
                gofr[int(np.sqrt(dr2) / Lbin)] += 1.0


    # Add tail
    EnPot += 8*np.pi*(1/r_cutoff**9/9.-1/r_cutoff**3/3.)

    # Stack x,y,z components of forces
    forces = np.hstack((Forces_x, Forces_y, Forces_z))

    return forces, EnPot, gofr




def run_NVE(pos_in, vel_in, L, nsteps, N, dt=0.0046, T=None, Nbins=499):

    # Copy by value
    pos = copy.copy(pos_in)
    vel = copy.copy(vel_in)

    forces, EnPot, gofr = calculateForces(pos, L, N, Nbins)

    # Empty arrays
    EnKin = np.zeros(nsteps)
    EnPot = np.zeros(nsteps)
    gofr= np.zeros((nsteps,Nbins))
    msd = np.zeros(nsteps)
    vacf = np.zeros(nsteps)

    # Shifts across the boundary
    shift = np.zeros_like(pos)


    # Output
    print("#{:^5}{:^13}{:^13}".format('Step','Pontetial', 'Kinetic'))
    print("#{:-^5}{:-^13}{:-^13}".format('','', ''))

    # Time evolution
    for ii in range(nsteps):

        # Move particles with Velocity Verlet
        vel += 0.5 * forces * dt
        pos += vel * dt

        # Boundary conditions
        shift +=  pos // L
        pos = pos % L


        if ii == 0:
            pos_init = copy.copy(pos + L * shift)
            vel_init = copy.copy(vel)


        # VACF
        vacf[ii] = np.mean([np.dot(b,a)/3. for a,b in zip(vel_init,vel)])

        # MSD
        msd[ii] = np.mean(((pos_init - (pos + L*shift))**2).sum(axis=1))


        # Calculate new forces
        forces, EnPot[ii], gofr[ii, :] = calculateForces(pos, L, N, Nbins)

        vel += 0.5 * forces * dt

        EnKin[ii] = 0.5 * np.sum(vel * vel)


        # Adjust the temperature
        if (ii % 1 == 0) and T:
            vel = vel * np.sqrt((N)*3*T / (2*EnKin[ii]))

        # Calculte energy per atom
        EnKin[ii] /= N
        EnPot[ii] /= N

        # Output
        print("%5.d %5.8f %5.8f "%(ii, EnKin[ii], EnPot[ii]))


    gofr = calculate_gofr(gofr, L, N)
    sofk = calculate_sofk_FT(gofr, L, N)



    return {'nsteps': range(nsteps),
            'pos': pos,
            'vel': vel,
            'EnPot': EnPot,
            'EnKin': EnKin,
            'gofr': gofr,
            'sofk': sofk,
            'msd': msd,
            'vacf': vacf}



def run_NVT(pos_in, vel_in, L, nsteps, N, dt=0.0046, T=0.78, Q=10, xi=0, lns=0, Nbins=499):

    # Copy by value
    pos = copy.copy(pos_in)
    vel = copy.copy(vel_in)


    forces, EnPot, gofr = calculateForces(pos, L, N, Nbins)

    # Empty arrays
    EnKin = np.zeros(nsteps)
    EnPot = np.zeros(nsteps)
    EnNH = np.zeros(nsteps)
    gofr= np.zeros((nsteps,Nbins))
    pv = np.zeros((nsteps,Nbins))
    pv2 = np.zeros((nsteps,Nbins))
    msd = np.zeros(nsteps)
    vacf = np.zeros(nsteps)

    # Shifts across the boundary
    shift = np.zeros_like(pos)


    # Output
    print("#{:^5}{:^13}{:^13}{:^13}".format('Step','Pontetial', 'Kinetic', 'Total NH'))
    print("#{:-^5}{:-^13}{:-^13}{:-^13}".format('','', '',''))

    # Time evolution
    for ii in range(nsteps):

        # Move particles with Velocity Verlet
        pos += vel * dt + 0.5 * (forces - xi * vel ) * dt**2
        vel += 0.5 * (forces - xi * vel)  * dt


        if ii == 0:
            pos_init = copy.copy(pos + L * shift)
            vel_init = copy.copy(vel)

        # Calculate new forces
        forces, EnPot[ii], gofr[ii, :] = calculateForces(pos, L, N, Nbins)

        # VACF
        vacf[ii] = np.mean([np.dot(b,a)/3. for a,b in zip(vel_init,vel)])

        # MSD
        msd[ii] = np.mean(((pos_init - (pos + L*shift))**2).sum(axis=1))

        # Nose-Hoover thermostat
        EnKin[ii] = 0.5*np.sum(vel*vel)

        dsumv2 =  (2.0*EnKin[ii] - 3.*N*T)/Q

        lns += xi*dt + 0.5*dsumv2*dt**2

        xi += 0.5*dsumv2*dt

        vel += 0.5*(forces - xi*vel)*dt

        EnKin[ii] = 0.5*np.sum(vel*vel)

        dsumv2 =  (2.0*EnKin[ii]-3.*N*T)/Q

        xi += 0.5*dsumv2*dt


        EnNH[ii] = EnKin[ii] + EnPot[ii] + ( (xi**2*Q)/2. + 3.*N*T*lns)

        pv[ii], v = np.histogram(vel[:,0], bins=Nbins, range=(-10, 10), density
                = True)
        pv2[ii], v = np.histogram(np.sqrt(np.sum(vel**2 , axis = 1)),
                bins=Nbins, density = True, range=(0,10))


        # Calculte energy per atom
        EnKin[ii] /= N
        EnPot[ii] /= N
        EnNH[ii] /= N

        # Output
        print("%5.d %5.8f %5.8f %5.8f"%(ii, EnKin[ii], EnPot[ii], EnNH[ii]))




    return {'nsteps': range(nsteps),
            'pos': pos,
            'vel': vel,
            'EnPot': EnPot,
            'EnKin': EnKin,
            'EnNH': EnNH,
            'msd': msd,
            'vacf': vacf,
            'pv': pv,
            'pv2': pv2,
            'v': v,
            'xi': xi,
            'lns': lns,
            }


def calculate_gofr(gofr, L, N, Nbins=499):
    gofr_mean = np.zeros(Nbins)
    r = np.zeros(Nbins)
    Lbin = 0.5 * L  / Nbins
    V = L**3

    for ii in range(Nbins):
        r[ii] = Lbin * (ii + 0.5)
        gofr_mean[ii] = 2*V/(N*(N-1)) *np.mean(gofr[:,ii],axis=0)  / (4*np.pi*r[ii]**2*Lbin)
    return {'g':gofr_mean, 'r': r}


def calculate_sofk_FT(gofr, L, N, Nbins=499):
    sofk = np.zeros(Nbins)
    v = np.zeros(Nbins)
    k = np.zeros(Nbins)
    rho = N / L**3
    dr = gofr['r'][1]-gofr['r'][0]
    k =  gofr['r']*2*np.pi

    for ii in range(len(k)):
        for jj in range(len(gofr['r'])):
            # Integrand r[g(r)-1]sin(q*r)/q
            v[jj] = gofr['r'][jj]*np.sin(k[ii]*gofr['r'][jj])*(gofr['g'][jj]-1)/k[ii]

        sofk[ii] = 1 + 4*np.pi*rho*np.trapz(v,dx=dr)

    return {'s':sofk, 'k': k}



def dump_pos_vel(fname, pos, vel, N, L, xi=0, lns=0):
    with open(fname,'w') as f:
        if xi and lns:
            f.write("%% T %d %E %E %E %E %E\n"%(N, L, L, L, xi, lns))
        else:
            f.write("%% T %d %E %E %E \n"%(N, L, L, L))
        np.savetxt(f, pos, fmt=['    %.15E','%.15E','%.15E'])
        np.savetxt(f, vel, fmt=['%%  % .15E','% .15E','% .15E'])



def read_pos_vel(fname):
    parameters = open(fname).readline().strip().split()
    N = int(parameters[2])
    L = float(parameters[3])

    pos = np.loadtxt(fname, comments="%")
    vel = np.loadtxt(fname,skiprows=int(N)+1, usecols = [1,2,3])

    if len(parameters) > 6:
        xi = float(parameters[6])
        lns = float(parameters[7])
        return N, L, pos, vel, xi, lns
    else:
        return N, L, pos, vel
