import signal
import sys
import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import rcParams


#method = 'euler'
method = 'verlet'

save = 0

if len(sys.argv) > 1: method = sys.argv[1]
if len(sys.argv) > 2: save = int(sys.argv[2])

# pos: 部分座標
# x, xeq, xall, xk: 絶対座標

# 単位格子長
a0 = 1.0  
# 部分座標
pos_list0 = [0, 0.5]
# 質量
m_list0 = [1.0, 1.0]
n0 = len(pos_list0)

# 力の定数
k2_list0 = [[0, 1],
            [1, 0]]
k3_list0 = [[0, 0],
            [0, 0]]
k4_list0 = [[0, 1e6],
            [1e6, 0]]
# 平衡原子間距離
l0 = 0.5

# 初期変位最大値
dxmax = 0.02

# supercellの繰り返し数
nrepeat = 25
n = nrepeat * n0

tstep = 1.0e-2  # 時間ステップ
nstep = 2000  # ステップ数
tsleep = 0 # アニメーション時間ステップ [ms]
plotinterval = 5

# supercell内での原子の座標
asl = a0 * nrepeat
pos_list = np.empty(n)
m_list = np.empty(n)
k2_list = np.empty([n, n])
k3_list = np.empty([n, n])
k4_list = np.empty([n, n])
for i in range(n):
    i0 = i % n0
    isl = i // n0
    pos_list[i] = isl * a0 + pos_list0[i0]
    m_list[i] = m_list0[i0]
    for j in range(n):
        j0 = j % n0
        k2_list[i][j] = k2_list0[i0][j0]
        k3_list[i][j] = k3_list0[i0][j0]
        k4_list[i][j] = k4_list0[i0][j0]
pos_list /= nrepeat


#平衡位置
xeq = asl * pos_list
# 固定端
dx01 = xeq[1] - xeq[0]
x_l_fixed = xeq[0] - dx01  
x_r_fixed = xeq[n - 1] + dx01

#初期位置
xini = np.copy(xeq)
xini[0]     += -dxmax
xini[n - 1] += dxmax
# 初期速度
v = np.zeros_like(xini)

def get_x(i, xall):
    if i == -1: return x_l_fixed # 左固定端
    if i == n : return x_r_fixed # 右固定端
#    if i == -1: return x[n-1] - asl # 周期構造
#    if i == n : return x[0] + asl   # 周期構造
    return xall[i]

def dx(i, j, xall):
    if i < j: return get_x(j, xall) - get_x(i, xall) - l0
    if i > j: return get_x(j, xall) - get_x(i, xall) + l0
    return 0.0
   
#原子間力
def fij(i, j, xall):
    if abs(j - i) != 1: return 0.0

    if   i == -1: ik = n - 1
    elif i == n : ik = 0
    else: ik = i

    if   j == -1: jk = n - 1
    elif j == n : jk = 0
    else: jk = j

    _dx = dx(i, j, xall)
    return k2_list[ik][jk] * _dx \
         + 0.5 * k3_list[ik][jk] * _dx**2 \
         + 1/6 * k4_list[ik][jk] * _dx**3

#ポテンシャル[エネルギー
def Uij(i, j, xall):
    if abs(j - i) != 1: return 0.0

    if   i == -1: ik = n - 1
    elif i == n : ik = 0
    else: ik = i

    if   j == -1: jk = n - 1
    elif j == n : jk = 0
    else: jk = j

    _dx = dx(i, j, xall)

    return 0.5 * k2_list[ik][jk] * _dx**2 \
         + 1/6 * k3_list[ik][jk] * _dx**3 \
         + 1/24 * k4_list[ik][jk] * _dx**4

def Uk(xall):
    _Uk = 0.0
#supercell内の結合
    for i in range(n - 1):
        _Uk += Uij(i, i+1, xall)

#左端の結合 (周期構造の場合はコメントアウト)
    _Uk += Uij(-1, 0, xall)
#右端の結合 (周期構造の場合でも残す)
    _Uk += Uij(n-1, n, xall)

    return _Uk

#運動エネルギ-
def UK(v):
    _UK = 0.0
    for i in range(n):
        _UK += 0.5 * m_list[i] * v[i]**2
    return _UK

# エネルギー関数の定義
def Utot(x, v):
    _Uk = Uk(x)
    _UK = UK(v)
    return _Uk + _UK

# 力のリスト
# メモリ取得を繰り返さないため、globalにする
f = np.zeros(n, dtype=float)
def F_list(x):
    global f

    for i in range(n):
        f[i]  = fij(i, i - 1, x)
        f[i] += fij(i, i + 1, x)

    return f

def minimize_func(xk):
    xall = np.empty(n)
#両端は固定
    xall[0] = xini[0]
    xall[n-1] = xini[n-1]
#フィッティングパラメータ
    for i in range(1, n-1):
        xall[i] = xk[i - 1]

    ret = Uk(xall)
#    print("xall=", xall)
#    print("  U=", ret)
    return ret

# 最適化
print()
print("optimize:")
print("initial x:", xini)
#print("  Uini=", Uk(xini))
#print("  Uini=", minimize_func(xini[1:])) #Uk(xini))
#xini[1] += 0.00001
#print("initial x:", xini)
#print("  Uini(x[1]+=0.1)=", minimize_func(xini[1:])) #Uk(xini))

result = minimize(minimize_func, xini[1:n-1], method='BFGS') #'nelder-mead') #BFGS')
xopt = np.empty(n, dtype=float)
xopt[0] = xini[0]
xopt[n-1] = xini[n-1]
for i in range(1, n-1):
    xopt[i] = result.x[i-1]
print("optimized x:", xini[0:1], xopt, xini[n-1:n], result.fun)

print("Equilibrium positions:")
print("xeq=", xeq)
print("x_l_fixed=", x_l_fixed)
print("x_r_fixed=", x_r_fixed)

print("Initial values:")
#print("pos=", pos_list)
print("pos=", xini)
print("dx=", xini - xeq)
print("v  =", v)

print("xini and f0:")
f0 = F_list(xini)
for i in range(n):
    print(f"#{i}: {xini[i]:8.5f} {f0[i]:10.6g}")

# アニメーションの更新関数
x = np.copy(xopt)
UKs = []
Uks = []
Ets = []
x0 = None
x1 = None
v0 = None
v1 = None
def update(frame):
    global x, v, x0, x1, v0, v1, UKs, Uks, Ets

    if frame ==0 or method == 'euler':
        x0 = x
        v0 = v

        f1 = F_list(x)

        v1 = v0 + f1 / m_list * tstep 
        x1 = x0 + v1 * tstep 

        x = x1
        v = v1
    else:
        f1 = F_list(x)
        x2 = 2.0 * x1 - x0 + f1 / m_list * tstep * tstep
        v2 = (x2 - x0) / 2.0 / tstep

        x0 = x1
        v0 = v1
        x1 = x2
        v1 = v2
        x = x2
        v = v2

    _Uk = Uk(x)
    _UK = UK(v)
    UKs.append(_UK)
    Uks.append(_Uk)
    Ets.append(_UK + _Uk)
#    print(f"energy: {_UK:10.6g} + {_Uk:10.6g} = {_UK + _Uk:10.6g}")

    if frame % plotinterval != 0: return
    
    ax1.clear()
    ax1.set_title(f"MD for 1D oscillators: step {frame + 1}", fontdict={'fontname': 'MS Gothic'})
    ax1.plot(xeq, x - xeq, label = "running")
    ax1.plot(xini, xini - xeq, label = "ini")
    ax1.plot(xopt, xopt - xeq, label = "opt")
    ax1.set_xlim([-a0, a0 * (nrepeat + 1)])
    ax1.set_xlabel("x", fontdict={'fontname': 'MS Gothic'})
    ax1.set_ylabel("dx", fontdict={'fontname': 'MS Gothic'})
#    ax1.legend()
    ax1.grid(True)

    ax2.clear()
    t_list = tstep * np.array(range(len(UKs)))
    ax2.plot(t_list, UKs, label='kinetic energy')
    ax2.plot(t_list, Uks, label='potebtial energy')
    ax2.plot(t_list, Ets, label='total energy')
    ax2.set_title("Energy", fontdict={'fontname': 'MS Gothic'})
    ax2.set_xlabel("t", fontdict={'fontname': 'MS Gothic'})
    ax2.set_ylabel("Energy", fontdict={'fontname': 'MS Gothic'})
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()

# キーボード割り込みをキャッチしてアニメーションを停止
def signal_handler(sig, frame):
    input('Press ENTER to stop animation>>')
    ani.event_source.stop()
    plt.close(fig)
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

# プロットの設定
rcParams['font.sans-serif'] = ['MS Gothic']
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))

# アニメーションの設定
ani = FuncAnimation(fig, update, frames=nstep, repeat=False, interval = tsleep)
if save:
    ani.save('md_osscilators.gif', writer='pillow', fps=20)
plt.show()
