#!/usr/local/bin/python3.8

import numpy as np
import h5py
import sys
import os
from scipy import stats
import argparse

epsilon = 1.0e-8

def collect_data(gamma, weights, frequencies, cutoff, max_freq):
    freqs = []
    mode_prop = []
    mode_weights = []
    for w, freq, g in zip(weights, frequencies, gamma):

        tau = 1.0 / np.where(g > 0, g, -1) / (2 * 2 * np.pi)
        if cutoff:
            tau = np.where(tau < cutoff, tau, -1)

        condition = tau > 0
        _tau = np.extract(condition, tau)
        _freq = np.extract(condition, freq)

        if max_freq is None:
            freqs += list(_freq) * w
            mode_prop += list(_tau) * w
        else:
            freqs += list(np.extract(freq < max_freq, freq))
            mode_prop += list(np.extract(freq < max_freq, tau))

    x = np.array(freqs)
    y = np.array(mode_prop)

    return x, y

def run_KDE(x, y, nbins, x_max=None, y_max=None, density_ratio=0.1):
    """Running Gaussian-KDE by scipy
    """

    x_min = 0
    if x_max is None:
        _x_max = np.rint(x.max() * 1.1)
    else:
        _x_max = x_max
    y_min = 0
    if y_max is None:
        _y_max = np.rint(y.max())
    else:
        _y_max = y_max
    values = np.vstack([x.ravel(), y.ravel()])
    kernel = stats.gaussian_kde(values)

    xi, yi = np.mgrid[x_min:_x_max:nbins*1j, y_min:_y_max:nbins*1j]
    positions = np.vstack([xi.ravel(), yi.ravel()])
    zi = np.reshape(kernel(positions).T, xi.shape)

    if y_max is None:
        zi_max = np.max(zi)
        indices = []
        for i, r_zi in enumerate((zi.T)[::-1]):
            if indices:
                indices.append(nbins - i - 1)
            elif np.max(r_zi) > zi_max * density_ratio:
                indices = [nbins - i - 1]
        short_nbinds = len(indices)

        ynbins = nbins ** 2 // short_nbinds
        xi, yi = np.mgrid[x_min:_x_max:nbins*1j, y_min:_y_max:ynbins*1j]
        positions = np.vstack([xi.ravel(), yi.ravel()])
        zi = np.reshape(kernel(positions).T, xi.shape)
    else:
        short_nbinds = nbins

    return xi, yi, zi, short_nbinds

def plot(plt, xi, yi, zi, x, y, short_nbinds, nbins,
         y_max=None, z_max=None, cmap=None, aspect=None,
         flip=False, no_points=False, show_colorbar=True,
         point_size=5, title=None):
    xmax = np.max(x)
    ymax = np.max(y)
    x_cut = []
    y_cut = []
    threshold = ymax / nbins * short_nbinds / nbins * (nbins - 1)
    for _x, _y in zip(x, y):
        if (epsilon < _y and _y < threshold and
            epsilon < _x and _x < xmax - epsilon):
            x_cut.append(_x)
            y_cut.append(_y)

    fig, ax = plt.subplots()
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_tick_params(which='both', direction='out')
    ax.yaxis.set_tick_params(which='both', direction='out')

    if flip:
        # Start Flip
        plt.pcolormesh(yi[:,:nbins], xi[:,:nbins], zi[:,:nbins],
                       vmax=z_max, cmap=cmap)
        if show_colorbar:
            plt.colorbar()
        if not no_points:
            plt.scatter(y_cut, x_cut, s=point_size,
                        c='k', marker='.', linewidth=0)
        plt.ylim(ymin=0, ymax=xi.max())
        if y_max is None:
            plt.xlim(xmin=0, xmax=(np.max(y_cut) + epsilon))
        else:
            plt.xlim(xmin=0, xmax=(y_max + epsilon))
        plt.xlabel('Lifetime (ps)', fontsize=18)
        plt.ylabel('Phonon frequency (THz)', fontsize=18)
        # End Flip
    else:
        plt.pcolormesh(xi[:,:nbins], yi[:,:nbins], zi[:,:nbins],
                       vmax=z_max, cmap=cmap)
        if show_colorbar:
            plt.colorbar()
        if not no_points:
            plt.scatter(x_cut, y_cut, s=point_size,
                        c='k', marker='.', linewidth=0)
        plt.xlim(xmin=0, xmax=xi.max())
        if y_max is None:
            plt.ylim(ymin=0, ymax=(np.max(y_cut) + epsilon))
        else:
            plt.ylim(ymin=0, ymax=(y_max + epsilon))
        if title:
            plt.title(title, fontsize=20)
        plt.xlabel('Phonon frequency (THz)', fontsize=18)
        plt.ylabel('Lifetime (ps)', fontsize=18)

    if aspect is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax_aspect = (xlim[1] - xlim[0]) / (ylim[1] - ylim[0]) * aspect
        ax.set_aspect(ax_aspect)

    if title:
        plt.title(title, fontsize=20)

    return fig

def get_options():
    # Arg-parser
    parser = argparse.ArgumentParser(
        description="Plot property density with gaussian KDE")
    parser.add_argument(
        "--aspect", type=float, default=None,
        help="The ration, height/width of canvas")
    parser.add_argument(
        "--cmap", dest="cmap", default=None,
        help="Matplotlib cmap")
    parser.add_argument(
        "--cutoff", type=float, default=None,
        help=("Property (y-axis) below this value is included in "
              "data before running Gaussian-KDE"))
    parser.add_argument(
        "--dr", "--density-ratio", dest="density_ratio", type=float,
        default=0.1,
        help=("Minimum density ratio with respect to maximum "
              "density used to determine drawing region"))
    parser.add_argument(
        "--flip", action='store_true',
        help='Flip x and y.')
    parser.add_argument(
        "--fmax", type=float, default=None,
        help="Max frequency to plot")
    parser.add_argument(
        "--hc", "--hide-colorbar", dest="hide_colorbar", action='store_true',
        help='Do not show colorbar')
    parser.add_argument(
        "--nbins", type=int, default=100,
        help=("Number of bins in which data are assigned, "
              "i.e., determining resolution of plot")),
    parser.add_argument(
        "--no-points", dest="no_points", action='store_true',
        help='Do not show points')
    parser.add_argument(
        "--nu", action='store_true',
        help="Plot N and U.")
    parser.add_argument(
        "-o", "--output", dest="output_filename", default=None,
        help="Output filename, e.g., to PDF")
    parser.add_argument(
        '--point-size', dest="point_size", type=float, default=5,
        help='Point size (default=5)')
    parser.add_argument(
        '--temperature', type=float, default=300.0, dest='temperature',
        help='Temperature to output data at')
    parser.add_argument(
        "--title", dest="title", default=None,
        help="Plot title")
    parser.add_argument(
        "--xmax", type=float, default=None,
        help="Set maximum x of draw area")
    parser.add_argument(
        "--ymax", type=float, default=None,
        help="Set maximum y of draw area")
    parser.add_argument(
        "--zmax", type=float, default=None,
        help="Set maximum indisity")
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args()
    return args

def main(args):
    #
    # Matplotlib setting
    #
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib import rc
    rc('text', usetex=True)
    rc('font', family='serif')
    # rc('font', serif='Times New Roman')
    rc('font', serif='Liberation Serif')
    # plt.rcParams['pdf.fonttype'] = 42

    #
    # Initial setting
    #
    if os.path.isfile(args.filenames[0]):
        f = h5py.File(args.filenames[0], 'r')
    else:
        print("File %s doens't exist." % args.filenames[0])
        sys.exit(1)

    if args.title:
        title = args.title
    else:
        title = None

    #
    # Set temperature
    #
    temperatures = f['temperature'][:]
    if len(temperatures) > 29:
        t_index = 30
    else:
        t_index = 0
    for i, t in enumerate(temperatures):
        if np.abs(t - args.temperature) < epsilon:
            t_index = i
            break

    print("Temperature at which lifetime density is drawn: %7.3f"
          % temperatures[t_index])

    #
    # Set data
    #
    weights = f['weight'][:]
    frequencies = f['frequency'][:]
    symbols = ['',]

    gammas = [f['gamma'][t_index],]
    if args.nu:
        if 'gamma_N' in f:
            gammas.append(f['gamma_N'][t_index])
            symbols.append('N')
        if 'gamma_U' in f:
            gammas.append(f['gamma_U'][t_index])
            symbols.append('U')

    #
    # Run
    #
    for gamma, s in zip(gammas, symbols):
        x, y = collect_data(gamma, weights, frequencies, args.cutoff, args.fmax)
        xi, yi, zi, short_nbinds = run_KDE(x, y, args.nbins,
                                           x_max=args.xmax,
                                           y_max=args.ymax,
                                           density_ratio=args.density_ratio)
        fig = plot(plt, xi, yi, zi, x, y, short_nbinds, args.nbins,
                   y_max=args.ymax, z_max=args.zmax, cmap=args.cmap,
                   aspect=args.aspect, flip=args.flip,
                   no_points=args.no_points,
                   show_colorbar=(not args.hide_colorbar),
                   point_size=args.point_size, title=title)

        if args.output_filename:
            fig.savefig(args.output_filename)
        else:
            if s:
                fig.savefig("lifetime-%s.png" % s)
            else:
                fig.savefig("lifetime.png")
        plt.close(fig)

if __name__ == "__main__":
    main(get_options())
