import cProfile
import pstats
import time
from pathlib import Path
from typing import Tuple

import h5py
import numpy as np


def append_suffix_to_file(file_path: Path, suffix: str = '_INF', ext: str = None):
    """
    Adds a suffix to the given file path.
    :param file_path: `Path` object to the original file.
    :param suffix: `str` suffix to add to the end of the original file name.
    :param ext: `str` potential new file extension.
    :return: Updated `Path`.
    """
    if ext:
        file_path = file_path.with_suffix(ext)
    new_file_name = file_path.stem + suffix + file_path.suffix
    return file_path.with_name(new_file_name)


def array4d_to_h5(array_4ds: Tuple, output_file: Path, group: str = None, datasets: Tuple = 'array_data'):
    if len(array_4ds) != len(datasets):
        raise ValueError(f'Amount of arrays {len(array_4ds)} must match amount of dataset names {len(datasets)}.')
    with h5py.File(output_file, 'a') as h5f:
        if group is not None:
            grp = h5f.create_group(group)
            for i in range(len(array_4ds)):
                grp.create_dataset(name=datasets[i], data=array_4ds[i], compression='gzip', compression_opts=9)
        else:
            for i in range(len(array_4ds)):
                h5f.create_dataset(name=datasets[i], data=array_4ds[i], compression='gzip', compression_opts=9)


def h5_to_array4d(input_file: Path) -> np.array:
    with h5py.File(input_file, 'r') as h5f:
        return np.vstack([np.array(h5f[key]) for key in h5f.keys()])


def combined_test_h5_to_array4d(input_file: Path, pc_size: int = 1024, merged: bool = True) -> np.array:
    with h5py.File(input_file, 'r') as h5f:
        data = []
        for grp_name in list(h5f.keys()):
            grp = h5f[grp_name]
            labeled = np.array(grp['labeled'])
            unlabeled = np.array(grp['unlabeled'])
            data.append(merge_labeled_and_unlabeled_data(labeled, unlabeled, pc_size=pc_size))

        return np.vstack(data)


def merge_labeled_and_unlabeled_data(labeled: np.array, unlabeled: np.array, pc_size: int,
                                     augment: str = None) -> np.array:
    missing = pc_size - (labeled.shape[2] + unlabeled.shape[2])
    if missing <= 0:
        # Returns shape (n_frames, 15, self.pc_size).
        return np.concatenate((unlabeled, labeled), axis=2)[:, :, -pc_size:]

    # This is similar to the way that TrainDataset.fill_point_cloud() fills values.
    if augment is None:
        missing_markers = np.ones((labeled.shape[0], labeled.shape[1], missing))
    elif augment == 'normal':
        missing_markers = np.random.rand(labeled.shape[0], labeled.shape[1], missing)
    else:
        missing_markers = np.zeros((labeled.shape[0], labeled.shape[1], missing))

    missing_markers[:, 0] = 0.
    missing_markers[:, 1] = 0.

    # Returns shape (n_frames, 15, self.pc_size).
    return np.concatenate((missing_markers,
                           unlabeled,
                           labeled), axis=2)


class Timer:
    def __init__(self, txt: str = 'Execution time: ', profiler: bool = False):
        self.txt = txt
        self.profiler = profiler

    def __enter__(self):
        self.start_time = time.time()
        if self.profiler:
            self.p = cProfile.Profile()
            self.p.enable()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end_time = time.time()
        dif = self.end_time - self.start_time
        print(f"{self.txt}: {dif:.4f} seconds")

        if self.profiler:
            self.p.disable()
            stats = pstats.Stats(self.p).sort_stats('time')
            stats.print_stats()