#!/usr/local/bin/python3.8

import sys
import numpy as np
from phonopy.phonon.tetrahedron_mesh import TetrahedronMesh
from phonopy.harmonic.force_constants import similarity_transformation
from phonopy.interface.calculator import read_crystal_structure
from phono3py.phonon3.triplets import (get_ir_grid_points,
                                       get_grid_points_by_rotations,
                                       get_grid_address)
from phonopy.phonon.dos import NormalDistribution

epsilon = 1.0e-8


def show_tensor(kdos, temperatures, sampling_points, args):
    for i, kdos_t in enumerate(kdos):
        if not args.gv:
            print("# %d K" % temperatures[i])

        for f, k in zip(sampling_points[i], kdos_t):  # show kappa_xx
            if args.average:
                print(("%13.5f " * 3) %
                      (f, k[0][:3].sum() / 3, k[1][:3].sum() / 3))
            elif args.trace:
                print(("%13.5f " * 3) % (f, k[0][:3].sum(), k[1][:3].sum()))
            else:
                print(("%f " * 13) % ((f,) + tuple(k[0]) + tuple(k[1])))

        print('')
        print('')


def show_scalar(gdos, temperatures, sampling_points, args):
    if args.pqj or args.gruneisen or args.gv_norm:
        for f, g in zip(sampling_points, gdos[0]):
            print("%f %e %e" % (f, g[0], g[1]))
    else:
        for i, gdos_t in enumerate(gdos):
            print("# %d K" % temperatures[i])
            for f, g in zip(sampling_points, gdos_t):
                print("%f %f %f" % (f, g[0], g[1]))
            print('')
            print('')


def set_T_target(temperatures,
                 mode_prop,
                 T_target,
                 mean_freepath=None):
    for i, t in enumerate(temperatures):
        if np.abs(t - args.temperature) < epsilon:
            temperatures = temperatures[i:i+1]
            mode_prop = mode_prop[i:i+1, :, :]
            if mean_freepath is not None:
                mean_freepath = mean_freepath[i:i+1]
                return temperatures, mode_prop, mean_freepath
            else:
                return temperatures, mode_prop


def run_prop_dos(frequencies,
                 mode_prop,
                 primitive,
                 mesh,
                 grid_address,
                 grid_mapping_table,
                 ir_grid_points,
                 num_sampling_points):
    kappa_dos = KappaDOS(mode_prop,
                         primitive,
                         frequencies,
                         mesh,
                         grid_address,
                         grid_mapping_table,
                         ir_grid_points,
                         num_sampling_points=num_sampling_points)
    freq_points, kdos = kappa_dos.get_kdos()
    sampling_points = np.tile(freq_points, (len(kdos), 1))

    return kdos, sampling_points


def get_mfp(g, gv):
    g = np.where(g > 0, g, -1)
    gv_norm = np.sqrt((gv ** 2).sum(axis=2))
    mean_freepath = np.where(g > 0, gv_norm / (2 * 2 * np.pi * g), 0)
    return mean_freepath


def run_mfp_dos(mean_freepath,
                mode_prop,
                primitive,
                mesh,
                grid_address,
                grid_mapping_table,
                ir_grid_points,
                num_sampling_points):
    kdos = []
    sampling_points = []
    for i, mfp in enumerate(mean_freepath):
        kappa_dos = KappaDOS(mode_prop[i:i+1, :, :],
                             primitive,
                             mfp,
                             mesh,
                             grid_address,
                             grid_mapping_table,
                             ir_grid_points,
                             num_sampling_points=num_sampling_points)
        sampling_points_at_T, kdos_at_T = kappa_dos.get_kdos()
        kdos.append(kdos_at_T[0])
        sampling_points.append(sampling_points_at_T)
    kdos = np.array(kdos)
    sampling_points = np.array(sampling_points)

    return kdos, sampling_points


def fracval(frac):
    if frac.find('/') == -1:
        return float(frac)
    else:
        x = frac.split('/')
        return float(x[0]) / float(x[1])


def get_grid_symmetry(primitive, mesh, weights, qpoints):
    symmetry = Symmetry(primitive)
    rotations = symmetry.get_pointgroup_operations()
    (ir_grid_points,
     weights_for_check,
     grid_address,
     grid_mapping_table) = get_ir_grid_points(mesh, rotations)

    try:
        np.testing.assert_array_equal(weights, weights_for_check)
    except:
        print("*******************************")
        print("** Might forget --pa option? **")
        print("*******************************")
        raise

    qpoints_for_check = grid_address[ir_grid_points] / mesh.astype('double')
    diff_q = qpoints - qpoints_for_check
    np.testing.assert_almost_equal(diff_q, np.rint(diff_q))

    return ir_grid_points, grid_address, grid_mapping_table


def get_gv_by_gv(gv,
                 symmetry,
                 primitive,
                 mesh,
                 grid_points,
                 grid_address):
    point_operations = symmetry.get_reciprocal_operations()
    rec_lat = np.linalg.inv(primitive.get_cell())
    rotations_cartesian = np.array(
        [similarity_transformation(rec_lat, r)
         for r in point_operations], dtype='double')

    num_band = gv.shape[1]
    gv_sum2 = np.zeros((gv.shape[0], num_band, 6), dtype='double')
    for i, gp in enumerate(grid_points):
        rotation_map = get_grid_points_by_rotations(
            grid_address[gp],
            point_operations,
            mesh)
        gv_by_gv = np.zeros((num_band, 3, 3), dtype='double')
        for r in rotations_cartesian:
            gvs_rot = np.dot(gv[i], r.T)
            gv_by_gv += [np.outer(r_gv, r_gv) for r_gv in gvs_rot]
        gv_by_gv /= len(rotation_map) // len(np.unique(rotation_map))
        for j, vxv in enumerate(
                ([0, 0], [1, 1], [2, 2], [1, 2], [0, 2], [0, 1])):
            gv_sum2[i, :, j] = gv_by_gv[:, vxv[0], vxv[1]]

    return gv_sum2


class KappaDOS(object):
    def __init__(self,
                 mode_kappa,
                 cell,
                 frequencies,
                 mesh,
                 grid_address,
                 grid_mapping_table,
                 ir_grid_points,
                 grid_order=None,
                 num_sampling_points=100):
        self._mode_kappa = mode_kappa
        self._tetrahedron_mesh = TetrahedronMesh(
            cell,
            frequencies,
            mesh,
            grid_address,
            grid_mapping_table,
            ir_grid_points)

        min_freq = min(frequencies.ravel())
        max_freq = max(frequencies.ravel()) + epsilon
        self._frequency_points = np.linspace(min_freq,
                                             max_freq,
                                             num_sampling_points)
        self._kdos = np.zeros(
            (len(mode_kappa), len(self._frequency_points), 2, 6),
            dtype='double')
        self._run_tetrahedron_method()

    def get_kdos(self):
        return self._frequency_points, self._kdos

    def _run_tetrahedron_method(self):
        thm = self._tetrahedron_mesh
        for j, value in enumerate(('J', 'I')):
            thm.set(value=value, frequency_points=self._frequency_points)
            for i, iw in enumerate(thm):
                # kdos[temp, freq_points, IJ, tensor_elem]
                # iw[freq_points, band]
                # mode_kappa[temp, ir_gp, band, tensor_elem]
                self._kdos[:, :, j] += np.transpose(
                    np.dot(iw, self._mode_kappa[:, i]), axes=(1, 0, 2))


class GammaDOS(object):
    def __init__(self,
                 gamma,
                 frequencies,
                 ir_grid_weights,
                 num_fpoints=200):
        self._gamma = gamma
        self._frequencies = frequencies
        self._ir_grid_weights = ir_grid_weights
        self._num_fpoints = num_fpoints
        self._set_frequency_points()
        self._gdos = np.zeros(
            (len(gamma), len(self._frequency_points), 2), dtype='double')

    def get_gdos(self):
        return self._frequency_points, self._gdos

    def _set_frequency_points(self):
        min_freq = np.min(self._frequencies)
        max_freq = np.max(self._frequencies) + epsilon
        self._frequency_points = np.linspace(min_freq,
                                             max_freq,
                                             self._num_fpoints)


class GammaDOSsmearing(GammaDOS):
    def __init__(self,
                 gamma,
                 frequencies,
                 ir_grid_weights,
                 sigma=None,
                 num_fpoints=200):
        GammaDOS.__init__(self,
                          gamma,
                          frequencies,
                          ir_grid_weights,
                          num_fpoints=num_fpoints)
        if sigma is None:
            self._sigma = (max(self._frequency_points) -
                           min(self._frequency_points)) / 100
        else:
            self._sigma = 0.1
        self._smearing_function = NormalDistribution(self._sigma)
        self._run_smearing_method()

    def _run_smearing_method(self):
        self._dos = []
        num_gp = np.sum(self._ir_grid_weights)
        for i, f in enumerate(self._frequency_points):
            dos = self._smearing_function.calc(self._frequencies - f)
            for j, g_t in enumerate(self._gamma):
                self._gdos[j, i, 1] = np.sum(np.dot(self._ir_grid_weights,
                                                    dos * g_t)) / num_gp


class GammaDOStetrahedron(GammaDOS):
    def __init__(self,
                 gamma,
                 cell,
                 frequencies,
                 mesh,
                 grid_address,
                 grid_mapping_table,
                 ir_grid_points,
                 ir_grid_weights,
                 num_fpoints=200):
        GammaDOS.__init__(self,
                          gamma,
                          frequencies,
                          ir_grid_weights,
                          num_fpoints=num_fpoints)
        self._cell = cell
        self._mesh = mesh
        self._grid_address = grid_address
        self._grid_mapping_table = grid_mapping_table
        self._ir_grid_points = ir_grid_points

        self._set_tetrahedron_method()
        self._run_tetrahedron_method()

    def _set_tetrahedron_method(self):
        self._tetrahedron_mesh = TetrahedronMesh(
            self._cell,
            self._frequencies,
            self._mesh,
            self._grid_address,
            self._grid_mapping_table,
            self._ir_grid_points)

    def _run_tetrahedron_method(self):
        thm = self._tetrahedron_mesh
        for j, value in enumerate(('J', 'I')):
            thm.set(value=value, frequency_points=self._frequency_points)
            for i, iw in enumerate(thm):
                # gdos[temp, freq_points, IJ]
                # iw[freq_points, band]
                # gamma[temp, ir_gp, band]
                self._gdos[:, :, j] += np.dot(
                    self._gamma[:, i] * self._ir_grid_weights[i], iw.T)


if __name__ == '__main__':
    """Incremental kappa with respect to frequency and the derivative"""

    import h5py
    from phonopy.structure.cells import get_primitive
    from phonopy.structure.symmetry import Symmetry
    import argparse

    parser = argparse.ArgumentParser(description="Show unit cell volume")
    parser.add_argument(
        "--pa", dest="primitive_matrix", default="1 0 0 0 1 0 0 0 1",
        help="Primitive matrix")
    parser.add_argument(
        "--mesh", dest="mesh", default="1 1 1",
        help="Mesh numbers")
    parser.add_argument(
        "-c", "--cell", dest="cell_filename",
        help="Unit cell filename")
    parser.add_argument(
        '--gv', action='store_true',
        help='Calculate for gv_x_gv (tensor)')
    parser.add_argument(
        '--pqj', action='store_true',
        help='Calculate for Pqj (scalar)')
    parser.add_argument(
        '--cv', action='store_true',
        help='Calculate for Cv (scalar)')
    parser.add_argument(
        '--tau', action='store_true',
        help='Calculate for lifetimes (scalar)')
    parser.add_argument(
        '--gamma', action='store_true',
        help='Calculate for Gamma (scalar)')
    parser.add_argument(
        '--gruneisen', action='store_true',
        help='Calculate for mode-Gruneisen parameters squared (scalar)')
    parser.add_argument(
        '--gv-norm', action='store_true',
        help='Calculate for |g_v| (scalar)')
    parser.add_argument(
        '--mfp', action='store_true',
        help='Mean free path is used instead of frequency')
    parser.add_argument(
        '--temperature', type=float, dest='temperature',
        help='Temperature to output data at')
    parser.add_argument(
        '--nsp', '--num-sampling-points', type=int, dest='num_sampling_points',
        default=100,
        help="Number of sampling points in frequency or MFP axis")
    parser.add_argument(
        '--average', action='store_true',
        help=("Output the traces of the tensors divided by 3 "
              "rather than the unique elements"))
    parser.add_argument(
        '--trace', action='store_true',
        help=("Output the traces of the tensors "
              "rather than the unique elements"))
    parser.add_argument(
        '--smearing', action='store_true',
        help='Use smearing method (only for scalar density)')
    parser.add_argument(
        '--qe', '--pwscf', dest="qe_mode",
        action="store_true", help="Invoke Pwscf mode")
    parser.add_argument(
        '--crystal', dest="crystal_mode",
        action="store_true", help="Invoke CRYSTAL mode")
    parser.add_argument(
        '--abinit', dest="abinit_mode",
        action="store_true", help="Invoke Abinit mode")
    parser.add_argument(
        '--turbomole', dest="turbomole_mode",
        action="store_true", help="Invoke TURBOMOLE mode")
    parser.add_argument(
        "--noks", "--no-kappa-stars",
        dest="no_kappa_stars", action="store_true",
        help="Deactivate summation of partial kappa at q-stars")
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args()

    interface_mode = None
    if args.qe_mode:
        interface_mode = 'qe'
    elif args.crystal_mode:
        interface_mode = 'crystal'
    elif args.abinit_mode:
        interface_mode = 'abinit'
    elif args.turbomole_mode:
        interface_mode = 'turbomole'
    if len(args.filenames) > 1:
        cell, _ = read_crystal_structure(args.filenames[0],
                                         interface_mode=interface_mode)
        f = h5py.File(args.filenames[1], 'r')
    else:
        cell, _ = read_crystal_structure(args.cell_filename,
                                         interface_mode=interface_mode)
        f = h5py.File(args.filenames[0], 'r')

    primitive_matrix = np.reshape(
        [fracval(x) for x in args.primitive_matrix.split()], (3, 3))
    primitive = get_primitive(cell, primitive_matrix)

    if 'mesh' in f:
        mesh = np.array(f['mesh'][:], dtype='intc')
    else:
        mesh = np.array([int(x) for x in args.mesh.split()], dtype='intc')

    if 'temperature' in f:
        temperatures = f['temperature'][:]
    else:
        temperatures = None
    weights = f['weight'][:]

    if args.no_kappa_stars or (weights == 1).all():
        grid_address = get_grid_address(mesh)
        ir_grid_points = np.arange(np.prod(mesh), dtype='uintp')
        grid_mapping_table = np.arange(np.prod(mesh), dtype='uintp')
    else:
        (ir_grid_points,
         grid_address,
         grid_mapping_table) = get_grid_symmetry(primitive,
                                                 mesh,
                                                 weights,
                                                 f['qpoint'][:])

    ################
    # Set property #
    ################
    if args.gv:
        if 'gv_by_gv' in f:
            gv_sum2 = f['gv_by_gv'][:]
        else:  # For backward compatibility. This will be removed someday.
            gv = f['group_velocity'][:]
            symmetry = Symmetry(primitive)
            gv_sum2 = get_gv_by_gv(gv,
                                   symmetry,
                                   primitive,
                                   mesh,
                                   ir_grid_points,
                                   grid_address)

        # gv x gv is divied by primitive cell volume.
        unit_conversion = primitive.get_volume()
        mode_prop = gv_sum2.reshape((1,) + gv_sum2.shape) / unit_conversion
    else:
        if 'mode_kappa' in f:
            mode_prop = f['mode_kappa'][:]
        else:
            mode_prop = None

    frequencies = f['frequency'][:]
    conditions = frequencies > 0
    if np.logical_not(conditions).sum() > 3:
        sys.stderr.write("# Imaginary frequencies are found. "
                         "They are set to be zero.\n")
        frequencies = np.where(conditions, frequencies, 0)

    #######
    # Run #
    #######
    if (args.gamma or args.gruneisen or args.pqj or
        args.cv or args.tau or args.gv_norm):
        if args.pqj:
            mode_prop = f['ave_pp'][:].reshape((1,) + f['ave_pp'].shape)
        elif args.cv:
            mode_prop = f['heat_capacity'][:]
        elif args.tau:
            g = f['gamma'][:]
            g = np.where(g > 0, g, -1)
            mode_prop = np.where(g > 0, 1.0 / (2 * 2 * np.pi * g), 0)  # tau
        elif args.gv_norm:
            mode_prop = np.sqrt(
                (f['group_velocity'][:, :, :] ** 2).sum(axis=2))
            mode_prop = mode_prop.reshape((1,) + mode_prop.shape)
        elif args.gamma:
            mode_prop = f['gamma'][:]
        elif args.gruneisen:
            mode_prop = f['gruneisen'][:].reshape((1,) + f['gruneisen'].shape)
            mode_prop **= 2
        else:
            raise
        if (args.temperature is not None and
            not (args.gv_norm or args.pqj or args.gruneisen)):
            temperatures, mode_prop = set_T_target(temperatures,
                                                   mode_prop,
                                                   args.temperature)
        if args.smearing:
            mode_prop_dos = GammaDOSsmearing(
                mode_prop,
                frequencies,
                weights,
                num_fpoints=args.num_sampling_points)
            sampling_points, gdos = mode_prop_dos.get_gdos()
        else:
            mode_prop_dos = GammaDOStetrahedron(
                mode_prop,
                primitive,
                frequencies,
                mesh,
                grid_address,
                grid_mapping_table,
                ir_grid_points,
                weights,
                num_fpoints=args.num_sampling_points)
            sampling_points, gdos = mode_prop_dos.get_gdos()
            for i, gdos_t in enumerate(gdos):
                total = np.dot(weights, mode_prop[i]).sum() / weights.sum()
                assert np.isclose(gdos_t[-1][0], total)
        show_scalar(gdos, temperatures, sampling_points, args)
    else:
        if args.mfp:
            if 'mean_free_path' in f:
                mfp = f['mean_free_path'][:]
                mean_freepath = np.sqrt((mfp ** 2).sum(axis=3))
            else:
                mean_freepath = get_mfp(f['gamma'][:], f['group_velocity'][:])
            if args.temperature is not None:
                (temperatures,
                 mode_prop,
                 mean_freepath) = set_T_target(temperatures,
                                               mode_prop,
                                               args.temperature,
                                               mean_freepath=mean_freepath)

            kdos, sampling_points = run_mfp_dos(mean_freepath,
                                                mode_prop,
                                                primitive,
                                                mesh,
                                                grid_address,
                                                grid_mapping_table,
                                                ir_grid_points,
                                                args.num_sampling_points)
            show_tensor(kdos, temperatures, sampling_points, args)
        else:
            if args.temperature is not None and not args.gv:
                temperatures, mode_prop = set_T_target(temperatures,
                                                       mode_prop,
                                                       args.temperature)
            kdos, sampling_points = run_prop_dos(frequencies,
                                                 mode_prop,
                                                 primitive,
                                                 mesh,
                                                 grid_address,
                                                 grid_mapping_table,
                                                 ir_grid_points,
                                                 args.num_sampling_points)
            show_tensor(kdos, temperatures, sampling_points, args)
