import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


infile = "interpolate_fft_test.xlsx"


krange = [-0.5, 0.5]
n_samples = 10
interp_factor = 10
n_interp = n_samples * interp_factor


argv = sys.argv
narg = len(argv)
if narg > 1: infile = argv[1] 
if narg > 2: do_mirror = int(argv[2])


def periodic_function(k):
    return -np.cos(2.0 * np.pi * k) * (1.0 + 5.0 * k**2)
#    return 0.1 * k**4 + k * k
#    return np.sin(k) + 0.5 * np.sin(2*k)

def read_data(infile, do_mstructu2irror = False):
    df = pd.read_excel(infile)
    x = df['k'].values.tolist()
    y = df['E(k)'].values.tolist()
    return x, y


print()
print(f"Input file: {infile}")

if infile is not None and infile != "" and infile != "generate":
    print()
    print(f"Read [{infile}]")
    x, y = read_data(infile, do_mirror)
    xe, ye = None, None
    print("x=", x)
    print("y=", y)
    krange[0] = min(x)
    krange[1] = max(x)
    n = len(x)
    x = x[:n-1]
    y = y[:n-1]
    n_samples = len(x)
    n_interp = n_samples * interp_factor
else:
    print()
    print(f"Generate samples")
    x = np.linspace(krange[0], krange[1], n_samples, endpoint=False)  # sample points
    y = periodic_function(x)  # sampled values
    xe = np.linspace(krange[0], krange[1], n_interp, endpoint=False)
    ye = periodic_function(xe)

y_fft = np.fft.fft(y)

# Pad the FFT datga for interpolation
y_fft_padded = np.zeros(n_interp, dtype=complex)
y_fft_padded[:n_samples//2] = y_fft[:n_samples//2]
y_fft_padded[-n_samples//2:] = y_fft[-n_samples//2:]

x_interp = np.linspace(krange[0], krange[1], n_interp, endpoint=False)
y_interp = np.fft.ifft(y_fft_padded) * interp_factor

plt.figure(figsize=(12, 6))
plt.plot(x, y, 'o', label = 'input data', markersize = 6)
plt.plot(x_interp, y_interp.real, '-', label = 'interpolated', marker = 'x', markersize = 3)
if xe is not None:
    plt.plot(xe, ye, '-', label='exact')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Interpolation of a Periodic Function using FFT')
plt.show()