import sys
import torch
import gradio as gr
import pickle

from easydict import EasyDict as edict
from huggingface_hub import hf_hub_download

sys.path.append("./rome/")
sys.path.append('./DECA')

from rome.infer import Infer
from rome.src.utils.processing import process_black_shape, tensor2image

# loading models ---- create model repo
default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt')
default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth')

# parser configurations
args = edict({
    "save_dir": ".",
    "save_render": True,
    "model_checkpoint": default_model_path,
    "modnet_path": default_modnet_path,
    "random_seed": 0,
    "debug": False,
    "verbose": False,
    "model_image_size": 256,
    "align_source": True,
    "align_target": False,
    "align_scale": 1.25,
    "use_mesh_deformations": False,
    "subdivide_mesh": False,
    "renderer_sigma": 1e-08,
    "renderer_zfar": 100.0,
    "renderer_type": "soft_mesh",
    "renderer_texture_type": "texture_uv",
    "renderer_normalized_alphas": False,
    "deca_path": "DECA",
    "rome_data_dir": "rome/data",
    "autoenc_cat_alphas": False,
    "autoenc_align_inputs": False,
    "autoenc_use_warp": False,
    "autoenc_num_channels": 64,
    "autoenc_max_channels": 512,
    "autoenc_num_groups": 4,
    "autoenc_num_bottleneck_groups": 0,
    "autoenc_num_blocks": 2,
    "autoenc_num_layers": 4,
    "autoenc_block_type": "bottleneck",
    "neural_texture_channels": 8,
    "num_harmonic_encoding_funcs": 6,
    "unet_num_channels": 64,
    "unet_max_channels": 512,
    "unet_num_groups": 4,
    "unet_num_blocks": 1,
    "unet_num_layers": 2,
    "unet_block_type": "conv",
    "unet_skip_connection_type": "cat",
    "unet_use_normals_cond": True,
    "unet_use_vertex_cond": False,
    "unet_use_uvs_cond": False,
    "unet_pred_mask": False,
    "use_separate_seg_unet": True,
    "norm_layer_type": "gn",
    "activation_type": "relu",
    "conv_layer_type": "ws_conv",
    "deform_norm_layer_type": "gn",
    "deform_activation_type": "relu",
    "deform_conv_layer_type": "ws_conv",
    "unet_seg_weight": 0.0,
    "unet_seg_type": "bce_with_logits",
    "deform_face_tightness": 0.0001,
    "use_whole_segmentation": False,
    "mask_hair_for_neck": False,
    "use_hair_from_avatar": False,
    "use_scalp_deforms": True,
    "use_neck_deforms": True,
    "use_basis_deformer": False,
    "use_unet_deformer": True,
    "pretrained_encoder_basis_path": "",
    "pretrained_vertex_basis_path": "",
    "num_basis": 50,
    "basis_init": "pca",
    "num_vertex": 5023,
    "train_basis": True,
    "path_to_deca": "DECA",
    "path_to_linear_hair_model": "data/linear_hair.pth",  # N/A
    "path_to_mobile_model": "data/disp_model.pth",  # N/A
    "n_scalp": 60,
    "use_distill": False,
    "use_mobile_version": False,
    "deformer_path": "data/rome.pth",
    "output_unet_deformer_feats": 32,
    "use_deca_details": False,
    "use_flametex": False,
    "upsample_type": "nearest",
    "num_frequencies": 6,
    "deform_face_scale_coef": 0.0,
    "device": "cpu"
})

# download FLAME and DECA pretrained
generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl')
deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar')

with open(generic_model_path, 'rb') as f:
    ss = pickle.load(f, encoding='latin1')

    with open('./DECA/data/generic_model.pkl', 'wb') as out:
        pickle.dump(ss, out)

with open(deca_model_path, "rb") as input:
    with open('./DECA/data/deca_model.tar', "wb") as out:
        for line in input:
            out.write(line)

# load ROME inference model
infer = Infer(args)

def image_inference(
    source_img: gr.inputs.Image = None,
    driver_img: gr.inputs.Image = None
):
    out = infer.evaluate(source_img, driver_img, crop_center=False)
    res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(),
                                  out['source_information']['data_dict']['target_img'][0].cpu(),
                                  out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
    return res[..., ::-1]

def video_inference():
    pass

with gr.Blocks() as demo:
    gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")

    with gr.Tab("Image Inference"):
        with gr.Row():
            source_img = gr.Image(type="pil", label="source image", show_label=True)
            driver_img =  gr.Image(type="pil", label="driver image", show_label=True)
        image_output = gr.Image()
        image_button = gr.Button("Predict")
    with gr.Tab("Video Inference"):
        with gr.Row():
            source_video = gr.Video(label="source video", )
            driver_image_for_vid = gr.Image(type="pil", label="driver image", show_label=True)
        video_output = gr.Image()
        video_button = gr.Button("Predict")

    gr.Examples(
        examples=[
            ["./examples/lincoln.jpg", "./examples/taras2.jpg"],
            ["./examples/lincoln.jpg", "./examples/taras1.jpg"]
        ],
        inputs=[source_img, driver_img],
        outputs=[image_output],
        fn=image_inference,
        cache_examples=True
    )

    image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output)
    video_button.click(None, inputs=[source_video, driver_image_for_vid], outputs=video_output)

demo.launch()