import h_transport_materials as htm
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm


def plot_histogram(data):
    # plot histogram
    counts, bins, _ = plt.hist(data, alpha=0.7, edgecolor="tab:blue")

    # fit with Gaussian
    (mu_pre_exp, sigma_pre_exp) = norm.fit(data)
    x_axis = np.linspace(min(data), max(data))
    best_fit = norm.pdf(x_axis, mu_pre_exp, sigma_pre_exp)
    bin_width = np.diff(bins)[0]
    scaling_factor = sum(bin_width * counts)

    # plot best fit
    plt.plot(x_axis, scaling_factor * best_fit)


fig, (axs_top, axs_bot) = plt.subplots(nrows=2, ncols=2, sharey=True, figsize=(6.4, 6))

for i, group in enumerate([htm.diffusivities, htm.solubilities]):
    # filter Steel properties
    props = group.filter(material=htm.Steel)

    all_pre_exp = [np.log10(prop.pre_exp.magnitude) for prop in props]
    all_act_energy = [prop.act_energy.magnitude for prop in props]

    plt.sca(axs_top[i])
    plot_histogram(all_pre_exp)

    plt.sca(axs_bot[i])
    plot_histogram(all_act_energy)


axs_top[0].set_title("Diffusivity")
axs_top[1].set_title("Solubility")

axs_top[0].set_ylabel("Number of properties")
axs_bot[0].set_ylabel("Number of properties")

axs_top[0].set_xlabel(f"log10 ( $D_0$ {htm.diffusivities[0].units:~P} ) ")
axs_top[1].set_xlabel(f"log10 ( $S_0$ {htm.solubilities[0].units:~P} ) ")
axs_bot[0].set_xlabel(f"$E_D$ (eV)")
axs_bot[1].set_xlabel(f"$E_S$ (eV)")

axs_bot[0].set_xlim(0, 0.6)
axs_bot[1].set_xlim(0, 0.6)

plt.tight_layout()
plt.show()