import numpy as np
import torch
import imageio

from my.utils.tqdm import tqdm
from my.utils.event import EventStorage, read_stats, get_event_storage
from my.utils.heartbeat import HeartBeat, get_heartbeat
from my.utils.debug import EarlyLoopBreak

from .utils import PSNR, Scrambler, every, at
from .data import load_blender
from .render import (
    as_torch_tsrs, scene_box_filter, render_ray_bundle, render_one_view, rays_from_img
)
from .vis import vis, stitch_vis


device_glb = torch.device("cuda")


def all_train_rays(scene):
    imgs, K, poses = load_blender("train", scene)
    num_imgs = len(imgs)
    ro, rd, rgbs = [], [], []
    for i in tqdm(range(num_imgs)):
        img, pose = imgs[i], poses[i]
        H, W = img.shape[:2]
        _ro, _rd = rays_from_img(H, W, K, pose)
        ro.append(_ro)
        rd.append(_rd)
        rgbs.append(img.reshape(-1, 3))

    ro, rd, rgbs = [
        np.concatenate(xs, axis=0) for xs in (ro, rd, rgbs)
    ]
    return ro, rd, rgbs


class OneTestView():
    def __init__(self, scene):
        imgs, K, poses = load_blender("test", scene)
        self.imgs, self.K, self.poses = imgs, K, poses
        self.i = 0

    def render(self, model):
        i = self.i
        img, K, pose = self.imgs[i], self.K, self.poses[i]
        with torch.no_grad():
            aabb = model.aabb.T.cpu().numpy()
            H, W = img.shape[:2]
            rgbs, depth = render_one_view(model, aabb, H, W, K, pose)
            psnr = PSNR.psnr(img, rgbs)

        self.i = (self.i + 1) % len(self.imgs)

        return img, rgbs, depth, psnr


def train(
    model, n_epoch=2, bs=4096, lr=0.02, scene="lego"
):
    fuse = EarlyLoopBreak(500)

    aabb = model.aabb.T.numpy()
    model = model.to(device_glb)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    test_view = OneTestView(scene)
    all_ro, all_rd, all_rgbs = all_train_rays(scene)

    with tqdm(total=(n_epoch * len(all_ro) // bs)) as pbar, \
            HeartBeat(pbar) as hbeat, EventStorage() as metric:

        ro, rd, t_min, t_max, intsct_inds = scene_box_filter(all_ro, all_rd, aabb)
        rgbs = all_rgbs[intsct_inds]

        for epc in range(n_epoch):
            n = len(ro)
            scrambler = Scrambler(n)
            ro, rd, t_min, t_max, rgbs = scrambler.apply(ro, rd, t_min, t_max, rgbs)

            num_batch = int(np.ceil(n / bs))
            for i in range(num_batch):
                if fuse.on_break():
                    break

                s = i * bs
                e = min(n, s + bs)

                optim.zero_grad()
                _ro, _rd, _t_min, _t_max, _rgbs = as_torch_tsrs(
                    model.device, ro[s:e], rd[s:e], t_min[s:e], t_max[s:e], rgbs[s:e]
                )
                pred, _, _ = render_ray_bundle(model, _ro, _rd, _t_min, _t_max)
                loss = ((pred - _rgbs) ** 2).mean()
                loss.backward()
                optim.step()

                pbar.update()

                psnr = PSNR.psnr_from_mse(loss.item())
                metric.put_scalars(psnr=psnr, d_scale=model.d_scale.item())

                if every(pbar, step=50):
                    pbar.set_description(f"TRAIN: psnr {psnr:.2f}")

                if every(pbar, percent=1):
                    gimg, rimg, depth, psnr = test_view.render(model)
                    pane = vis(
                        gimg, rimg, depth,
                        msg=f"psnr: {psnr:.2f}", return_buffer=True
                    )
                    metric.put_artifact(
                        "vis", ".png", lambda fn: imageio.imwrite(fn, pane)
                    )

                if at(pbar, percent=30):
                    model.make_alpha_mask()

                if every(pbar, percent=35):
                    target_xyz = (model.grid_size * 1.328).int().tolist()
                    model.resample(target_xyz)
                    optim = torch.optim.Adam(model.parameters(), lr=lr)
                    print(f"resamp the voxel to {model.grid_size}")

                curr_lr = update_lr(pbar, optim, lr)
                metric.put_scalars(lr=curr_lr)

                metric.step()
                hbeat.beat()

        metric.put_artifact(
            "ckpt", ".pt", lambda fn: torch.save(model.state_dict(), fn)
        )
        # metric.step(flush=True)  # no need to flush since the test routine directly takes the model

        metric.put_artifact(
            "train_seq", ".mp4",
            lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "vis")[1])
        )

        with EventStorage("test"):
            final_psnr = test(model, scene)
        metric.put("test_psnr", final_psnr)

        metric.step()

        hbeat.done()


def update_lr(pbar, optimizer, init_lr):
    i, N = pbar.n, pbar.total
    factor = 0.1 ** (1 / N)
    lr = init_lr * (factor ** i)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def last_ckpt():
    ts, ckpts = read_stats("./", "ckpt")
    if len(ckpts) > 0:
        fname = ckpts[-1]
        last = torch.load(fname, map_location="cpu")
        print(f"loaded ckpt from iter {ts[-1]}")
        return last


def __evaluate_ckpt(model, scene):
    # this is for external script that needs to evaluate an checkpoint
    # currently not used
    metric = get_event_storage()

    state = last_ckpt()
    if state is not None:
        model.load_state_dict(state)
    model.to(device_glb)

    with EventStorage("test"):
        final_psnr = test(model, scene)
    metric.put("test_psnr", final_psnr)


def test(model, scene):
    fuse = EarlyLoopBreak(5)
    metric = get_event_storage()
    hbeat = get_heartbeat()

    aabb = model.aabb.T.cpu().numpy()
    model = model.to(device_glb)

    imgs, K, poses = load_blender("test", scene)
    num_imgs = len(imgs)

    stats = []

    for i in (pbar := tqdm(range(num_imgs))):
        if fuse.on_break():
            break

        img, pose = imgs[i], poses[i]
        H, W = img.shape[:2]
        rgbs, depth = render_one_view(model, aabb, H, W, K, pose)
        psnr = PSNR.psnr(img, rgbs)

        stats.append(psnr)
        metric.put_scalars(psnr=psnr)
        pbar.set_description(f"TEST: mean psnr {np.mean(stats):.2f}")

        plot = vis(img, rgbs, depth, msg=f"PSNR: {psnr:.2f}", return_buffer=True)
        metric.put_artifact("test_vis", ".png", lambda fn: imageio.imwrite(fn, plot))
        metric.step()
        hbeat.beat()

    metric.put_artifact(
        "test_seq", ".mp4",
        lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "test_vis")[1])
    )

    final_psnr = np.mean(stats)
    metric.put("final_psnr", final_psnr)
    metric.step()

    return final_psnr