GUIで条件設定をして最小二乗を行う

10-leastsq-GUI-inifile.py

import sys
import os
import configparser
import csv
from pprint import pprint
from math import sqrt
import numpy as np
from scipy import optimize
from matplotlib import pyplot as plt
from tkinter import *
from tkinter import ttk
from tkinter import filedialog, messagebox


"""
GUIで条件設定をして最小二乗を行う
"""



#=============================
# 大域変数の定義
#=============================
# フィッティングパラメータ初期値。線形最小二乗の場合は適当
ai0 = [0, 0, 0]
xrange = [0, 10]
# グラフのフォントサイズ
font_size = 24

window_size = "500x200"
ini_dir = os.path.abspath(os.path.dirname(__file__))
file_type = [('CSV file', '*.csv')]


# 起動時引数でデータ/設定ファイル名を受け取る
datafile = None
inifile = None

argv = sys.argv
if len(argv) >= 2:
    path = argv[1]
    datafile = os.path.splitext(path)[0] + '.csv'
    inifile = os.path.splitext(path)[0] + '.prm'

fit_order = 2


#=============================
# functions
#=============================
def eprint(textbox, *args):
    print(*args, end = '')
    textbox.insert('end', *args)

#=============================
# csvファイルの読み込み
#=============================
def read_ini(path, obj):
    config = configparser.ConfigParser()
    if config.read(path):
        obj['x0'] = config.get('data', 'x0')
        obj['x1'] = config.get('data', 'x1')
        obj['inipath'] = path
        return config

    return None


#=============================
# csvファイルの読み込み
#=============================
def read_csv(path, obj):
    i = 0
    x = []
    y = []
    with open(path, "r") as f:
        reader = csv.reader(f)

        for row in reader:
            if i == 0:
                header = row
            else:
                xi = float(row[0])
                x.append(xi)
                y.append(float(row[1]))
            i += 1
        obj['datapath'].set(path)
        obj['savepath'].set(os.path.splitext(path)[0] + '-fit.csv')
        obj['header'] = header
        obj['x'] = x
        obj['y'] = y
        obj['title'].set(path)
        obj['xlabel'].set(header[0])
        obj['ylabel'].set(header[1])

    textbox = obj['outputtext']
    eprint(textbox, "\n")
    eprint(textbox, "CSV data:\n")
    eprint(textbox, " header: {}\n".format(header))
    eprint(textbox, " x: {}\n".format(x))
    eprint(textbox, " y: {}\n".format(y))
    return header, x, y


#=============================
# File=>openメニュー、pathボタン
#=============================
def path_button_click(obj):
    selpath = filedialog.askopenfilename(filetypes = file_type, initialdir = ini_dir)
    obj['datapath'].set(selpath)
    header, x, y = read_csv(selpath, obj)
    obj['x0'].set(min(x))
    obj['x1'].set(max(x))



#=============================
# scipy.optimize()による最小化
#=============================
def lsq_button_click(obj):
    eprint(obj['outputtext'], "\n")
    eprint(obj['outputtext'], "polynomial fit by scipy.optimize() start:\n")

    x = obj['x']
    y = obj['y']
    x0 = obj['x0'].get()
    x1 = obj['x1'].get()
    xf = []
    yf = []
    for i in range(len(x)):
        if x0 <= x[i] <= x1:
            xf.append(x[i])
            yf.append(y[i])

# leastsqの戻り値は、最適化したパラメータのリストと、最適化の結果
    ret = np.polyfit(x, y, deg = obj['fit_order'].get(), full = True)
    ai = ret[0]
    res = sqrt(ret[1][0] / len(x))

    eprint(obj['outputtext'], " lsq result: ai={}\n".format(ai))
    eprint(obj['outputtext'], " residual={}\n".format(res))


#=============================
# グラフの表示
#=============================
#表示データの作成
    ncal = 100
    xmin = min(x)
    xmax = max(x)
    xstep = (xmax - xmin) / (ncal - 1)
    xc = []
    yc = []
    for i in range(ncal):
        xi = xmin + i * xstep
        yi = np.poly1d(ai)(xi)
        xc.append(xi)
        yc.append(yi)

#グラフの作成、表示
    fontsize = obj['font_size'].get()
    plt.clf()
    plt.plot(x, y, label = 'raw data', marker = 'o', linestyle = 'None')
    plt.plot(xc, yc, label = 'fitted', linestyle = 'dashed')
    plt.title(obj['title'].get(), fontsize = fontsize)
    plt.xlabel(obj['xlabel'].get(), fontsize = fontsize)
    plt.ylabel(obj['ylabel'].get(), fontsize = fontsize)
    plt.legend(fontsize = fontsize)
    plt.tick_params(labelsize = fontsize)
    plt.tight_layout()

    plt.pause(0.001)
# plt.show()


def main():
    print('Second-order polynomial lsq')

# Base variables
    obj = {}
    root = Tk()

# tkinter variabls
    obj['inipath'] = StringVar()
    obj['datapath'] = StringVar()
    obj['savepath'] = StringVar()
    obj['x0'] = DoubleVar(value = xrange[0])
    obj['x1'] = DoubleVar(value = xrange[1])
    obj['title'] = StringVar()
    obj['xlabel'] = StringVar()
    obj['ylabel'] = StringVar()
    obj['font_size'] = IntVar(value = font_size)
    obj['fit_order'] = IntVar(value = fit_order)
    obj['output'] = StringVar()

    root.title('Second-order polynomial lsq')
# root.resizable(False, False)
# root.geometry(window_size)
    root.minsize(200, 200)

# Menu
    menu_bar = Menu(root)
    menu_file = Menu(menu_bar, tearoff = 0)
    menu_file.add_command(label='Open', accelerator='Ctrl+O',
                    command = lambda: path_button_click(obj))
    menu_file.add_command(label='exit', accelerator='Alt+E',
                    command = lambda: exit())
    menu_bar.add_cascade(label = 'File', menu = menu_file)
    root.config(menu = menu_bar)
    root.grid()

# Root frame
    root_frame = ttk.Frame(root, padding=10)
    root_frame.pack(side = 'top', expand = True)

# Ini path frame
    inipath_frame = ttk.Frame(root_frame)
    inipath_label = ttk.Label(inipath_frame, text = 'Ini file path:', padding = (5,2))
    inipath_label.pack(side = 'left')
    inipath_entry = ttk.Entry(
        inipath_frame,
        textvariable = obj['inipath'],
        width = 50
        )
    inipath_entry.pack(side = 'left', expand = True)

    inipath_button = ttk.Button(inipath_frame, text = 'path',
                    command = lambda: inipath_button_click(obj))
    inipath_button.pack(side = 'left')
    inipath_frame.pack(side = 'top', anchor = 'w')

# Data path frame
    datapath_frame = ttk.Frame(root_frame)
    datapath_label = ttk.Label(datapath_frame, text = 'Data file path:', padding = (5,2))
    datapath_label.pack(side = 'left')
    datapath_entry = ttk.Entry(
        datapath_frame,
        textvariable = obj['datapath'],
        width = 50
        )
    datapath_entry.pack(side = 'left', expand = True)

    datapath_button = ttk.Button(datapath_frame, text = 'path',
                    command = lambda: path_button_click(obj))
    datapath_button.pack(side = 'left')
    datapath_frame.pack(side = 'top', anchor = 'w')

# Save path frame
    savepath_frame = ttk.Frame(root_frame)
    savepath_label = ttk.Label(savepath_frame, text = 'Save path:', padding = (5,2))
    savepath_label.pack(side = 'left')
    savepath_entry = ttk.Entry(
        savepath_frame,
        textvariable = obj['savepath'],
        width = 50
        )
    savepath_entry.pack(side = 'left', expand = True)

    savepath_button = ttk.Button(savepath_frame, text = 'save',
                    command = lambda: savepath_button_click(obj))
    savepath_button.pack(side = 'left')
    savepath_frame.pack(side = 'top', anchor = 'w')

# Range frame
    range_frame = ttk.Frame(root_frame)
    range_label = ttk.Label(range_frame, text = 'x range:', padding = (5,2))
    range_label.grid(row = 1, column = 0, sticky = 'w')
    x0_entry = ttk.Entry(
        range_frame,
        textvariable = obj['x0'],
        width = 10)
    x0_entry.grid(row = 1, column = 1)

    range_label2 = ttk.Label(range_frame, text = '-', padding = (5,2))
    range_label2.grid(row = 1, column = 2, sticky = E)
    x1_entry = ttk.Entry(
        range_frame,
        textvariable = obj['x1'],
        width = 10)
    x1_entry.grid(row = 1, column = 3)
    range_frame.pack(side = 'top', anchor = 'w')

# Title / XY label frame
    xylabel_frame = ttk.Frame(root_frame)
    title_label = ttk.Label(xylabel_frame, text = 'Title:')
    title_label.grid(row = 0, column = 0, sticky = 'w')
    title_entry = ttk.Entry(
        xylabel_frame,
        textvariable = obj['title'],
        width = 50)
    title_entry.grid(row = 0, column = 1)

    xlabel_label = ttk.Label(xylabel_frame, text = 'x label:')
    xlabel_label.grid(row = 1, column = 0, sticky = 'w')
    xlabel_entry = ttk.Entry(
        xylabel_frame,
        textvariable = obj['xlabel'],
        width = 50)
    xlabel_entry.grid(row = 1, column = 1)

    ylabel_label = ttk.Label(xylabel_frame, text = 'y label:')
    ylabel_label.grid(row = 2, column = 0, sticky = 'w')
    ylabel_entry = ttk.Entry(
        xylabel_frame,
        textvariable = obj['ylabel'],
        width = 50)
    ylabel_entry.grid(row = 2, column = 1)

    xylabel_frame.pack(side = 'top', anchor = 'w')

# Order frame
    order_frame = ttk.Frame(root_frame)
    order_label = ttk.Label(order_frame, text = 'fit order:', padding = (5,2))
    order_label.pack(side = 'left')
    order_entry = ttk.Entry(
        order_frame,
        textvariable = obj['fit_order'],
        width = 10)
    order_entry.pack(side = 'top')
    order_frame.pack(side = 'top', anchor = 'w')

# Button frame
    button_frame = ttk.Frame(root_frame)
    lsq_button = ttk.Button(button_frame, text = 'lsq',
        command = lambda: lsq_button_click(obj))
    lsq_button.pack(side = 'left')
    exit_button = ttk.Button(button_frame, text = 'exit', command = lambda: exit())
    exit_button.pack(side = 'left')
    button_frame.pack(side = 'top', anchor = 'w')

# Font size frame
    fontsize_frame = ttk.Frame(root_frame)
    fontsize_label = ttk.Label(fontsize_frame, text = 'Font size:')
    fontsize_label.pack(side = 'left')
    fontsize_entry = ttk.Entry(
        fontsize_frame,
        textvariable = obj['font_size'],
        width = 10)
    fontsize_entry.pack(side = 'left')
    fontsize_frame.pack(side = 'top', anchor = 'w')

# Output text frame
    output_frame = ttk.Frame(root_frame)
    output_text = Text(
        output_frame,
# textvariable = obj['output'],
# width = 50,
        height = 10)
    output_text.pack(side = 'left')
    output_frame.pack(side = 'top', anchor = 'w', expand = True)
    obj['outputtext'] = output_text


    if inifile is not None:
        if read_ini(inifile, obj):
# read_csv(datafile, obj)
            pass
    if datafile is not None:
        read_csv(datafile, obj)


    root.mainloop()


if __name__ == '__main__':
    main()