import copy
import json
import os
from typing import Optional, Union

import librosa
import numpy as np
import torch
import torch.nn.functional as F
from datasets import Audio
from safetensors.torch import load, load_model
from torch import nn
from .configuring_diva import DiVAConfig
from transformers import (
    AutoProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    PreTrainedModel,
    WhisperModel,
)


class WhisperConnector(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.decoder = None
        self.projection = nn.Linear(1280, 4096)
        self.query_tokens = nn.Parameter(torch.randn(448, 1280))

    def forward(self, x, output_device="cuda:1"):
        bsz = x.shape[0]
        query_tokens = self.query_tokens[None, :, :].expand(bsz, -1, -1)
        virt_whisper_tokens = self.decoder(
            inputs_embeds=query_tokens, encoder_hidden_states=x
        )
        if self.projection.weight.shape[-1] == 5120:
            virtual_tokens = self.projection(virt_whisper_tokens[0].reshape(112, 5120))
        else:
            virtual_tokens = self.projection(virt_whisper_tokens[0])
        return virtual_tokens.to(output_device)


class DiVAModel(PreTrainedModel):
    config_class = DiVAConfig

    def __init__(
        self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None
    ):
        super().__init__(DiVAConfig.from_dict(config_dict))
        if speech_encoder_device is None:
            speech_encoder_device = "cuda:0"
        whisper = WhisperModel.from_pretrained(config_dict["reference_encoder"])
        connector = WhisperConnector()
        connector.decoder = copy.deepcopy(whisper.decoder)
        if via_path is not None:
            with open(via_path, "rb") as f:
                sd = load(f.read())

            with torch.no_grad():
                connector.query_tokens = nn.Parameter(sd["query_tokens"])
                connector.projection.weight = nn.Parameter(sd["projection.weight"].T)
                connector.projection.bias = nn.Parameter(sd["projection.bias"])
                wsd = {
                    key.replace("connector.", ""): sd[key]
                    for key in sd
                    if key.startswith("connector.")
                }
                connector.decoder.load_state_dict(wsd)

        if device_map == None:
            num_layers = 32
            num_gpus = 2
            device_map = dict(
                **{"model.embed_tokens": 1, "model.norm": 1, "lm_head": 2},
                **{
                    "model.layers." + str(i): 1 + (i // (num_layers // num_gpus))
                    for i in range(num_layers)
                },
            )

        self.connector = connector.to(speech_encoder_device)
        self.whisper_encoder = whisper.encoder.to(speech_encoder_device)
        self.llm_decoder = AutoModelForCausalLM.from_pretrained(
            config_dict["reference_decoder"],
            device_map=device_map,
            torch_dtype=torch.float16,
        )
        self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
        self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
        self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
            self.llm_decoder.model.embed_tokens.weight.device
        )

        self.pre_user_suffix = torch.tensor(
            self.tokenizer.encode(
                "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
            )
        ).to(self.llm_decoder.model.embed_tokens.weight.device)
        self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
            self.llm_decoder.model.embed_tokens.weight.device
        )
        self.speech_encoder_device = speech_encoder_device

    def can_generate(cls):
        return False

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *model_args,
        config=None,
        cache_dir=None,
        **kwargs,
    ):
        if os.path.isdir(pretrained_model_name_or_path):
            via_path = (
                pretrained_model_name_or_path + "/model.safetensors"
            )
            config_path = pretrained_model_name_or_path + "/config.json"
        else:
            # Loading from huggingface repo
            from huggingface_hub import hf_hub_download

            hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="model.safetensors",
                token=kwargs.get("token", None),
                local_dir=os.path.dirname(__file__),
            )
            hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                filename="config.json",
                token=kwargs.get("token", None),
                local_dir=os.path.dirname(__file__),
            )
            via_path = os.path.dirname(__file__) + "/model.safetensors"
            config_path = os.path.dirname(__file__) + "/config.json"
        with open(config_path, "r") as f:
            config_dict = json.loads(f.read())
        return cls(
            via_path,
            config_dict,
            kwargs["device_map"] if "device_map" in kwargs else "auto",
            (
                kwargs["speech_encoder_device"]
                if "speech_encoder_device" in kwargs
                else None
            ),
        )

    def forward(self, audio, prefix_text_tokens, suffix_text_tokens):
        inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
        input_features = inputs.input_features.to(self.speech_encoder_device)
        hidden_states = self.whisper_encoder(input_features=input_features)[
            "last_hidden_state"
        ]
        virt_tokens = self.connector(
            hidden_states,
            output_device=self.llm_decoder.model.embed_tokens.weight.device,
        ).squeeze()

        prefix_embed = self.llm_decoder.model.embed_tokens(prefix_text_tokens)
        suffix_embed = self.llm_decoder.model.embed_tokens(suffix_text_tokens)
        inputs_embeds = torch.cat(
            [prefix_embed, virt_tokens, suffix_embed], axis=0
        ).unsqueeze(0)

        outputs = self.llm_decoder(
            inputs_embeds=inputs_embeds.to(
                self.llm_decoder.model.embed_tokens.weight.device
            ).half(),
            return_dict=True,
            output_hidden_states=True,
            past_key_values=past_key_values,
        )

        return outputs

    @torch.no_grad()
    def generate(
        self,
        audio,
        text_prompt=None,
        do_sample=False,
        logits_processor=None,
        max_new_tokens=128,
    ):
        inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
        input_features = inputs.input_features.to(self.speech_encoder_device)
        hidden_states = self.whisper_encoder(input_features=input_features)[
            "last_hidden_state"
        ]
        virt_tokens = self.connector(
            hidden_states,
            output_device=self.llm_decoder.model.embed_tokens.weight.device,
        )
        bsz = virt_tokens.shape[0]

        if text_prompt != None and text_prompt != "":
            user_prompt_text = torch.tensor(
                self.tokenizer(
                    text_prompt,
                    add_special_tokens=False,
                    padding=True,
                    padding_side="right",
                )["input_ids"],
                device=self.pre_user_suffix.device,
            )
            prefix = torch.cat(
                [
                    self.pre_user_suffix.expand(
                        bsz,
                        -1,
                    ),
                    user_prompt_text,
                    self.prefix.expand(
                        bsz,
                        -1,
                    ),
                ],
                axis=1,
            )
        else:
            prefix = self.prefix
        prefix_embed = self.llm_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
        suffix = self.final_header
        suffix_embed = self.llm_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
        inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
        outs = [[] for i in range(bsz)]
        complete = [False] * bsz
        outputs = None
        greedy = 1
        i = 0
        while not all(complete) and len(outs[0]) < max_new_tokens:
            past_key_values = outputs.past_key_values if outputs else None
            outputs = self.llm_decoder(
                inputs_embeds=inputs_embeds.to(
                    self.llm_decoder.model.embed_tokens.weight.device
                ).half(),
                return_dict=True,
                output_hidden_states=True,
                past_key_values=past_key_values,
            )
            next_token_logits = outputs.logits[:, -1, :]

            if logits_processor:
                local_outs = torch.tensor(outs) if outs != [] else suffix
                local_outs = local_outs.reshape(1, -1)
                next_token_logits = logits_processor(
                    local_outs,
                    next_token_logits.reshape(1, -1),
                )
                next_token_logits = next_token_logits.flatten()
            if do_sample:
                logits = next_token_logits / temperature
                probs = F.softmax(logits, dim=-1)
                greedy = torch.multinomial(probs, num_samples=1)[0]
            else:
                greedy = next_token_logits.argmax(dim=-1)
            for token_index, out in enumerate(greedy.flatten().tolist()):
                if not complete[token_index]:
                    outs[token_index].append(out)
                if out == 128009:
                    complete[token_index] = True

            next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(-1, 1))
            inputs_embeds = next_embed
        return self.tokenizer.batch_decode(outs, skip_special_tokens=True)

    def generate_stream(
        self,
        audio,
        text_prompt,
        do_sample=False,
        logits_processor=None,
        max_new_tokens=128,
        return_outputs=False,
        init_outputs=None,
    ):
        inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
        input_features = inputs.input_features.to(self.whisper_encoder.device)
        hidden_states = self.whisper_encoder(input_features=input_features)[
            "last_hidden_state"
        ]
        virt_tokens = self.connector(
            hidden_states,
            output_device=self.llm_decoder.model.embed_tokens.weight.device,
        ).squeeze()

        if text_prompt != None and text_prompt != "":
            user_prompt_text = torch.tensor(
                self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
                device=self.pre_user_suffix.device,
            )
            prefix = torch.cat(
                [self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
            )
        else:
            prefix = self.prefix
        prefix_embed = self.llm_decoder.model.embed_tokens(prefix)
        suffix = self.final_header
        suffix_embed = self.llm_decoder.model.embed_tokens(suffix)
        inputs_embeds = torch.cat(
            [prefix_embed, virt_tokens, suffix_embed], axis=0
        ).unsqueeze(0)
        outs = []
        outputs = init_outputs
        greedy = 1
        i = 0
        while greedy != 128009 and len(outs) < max_new_tokens:
            past_key_values = outputs.past_key_values if outputs else None
            outputs = self.llm_decoder(
                inputs_embeds=inputs_embeds.to(
                    self.llm_decoder.model.embed_tokens.weight.device
                ).half(),
                return_dict=True,
                output_hidden_states=True,
                past_key_values=past_key_values,
            )
            next_token_logits = outputs.logits[-1, -1, :]

            if logits_processor:
                local_outs = torch.tensor(outs) if outs != [] else suffix
                local_outs = local_outs.reshape(1, -1)
                next_token_logits = logits_processor(
                    local_outs,
                    next_token_logits.reshape(1, -1),
                )
                next_token_logits = next_token_logits.flatten()
            if do_sample:
                logits = next_token_logits / temperature
                probs = F.softmax(logits, dim=-1)
                greedy = torch.multinomial(probs, num_samples=1)[0]
            else:
                greedy = next_token_logits.argmax()
            outs.append(greedy)
            next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
            inputs_embeds = next_embed
            if not return_outputs:
                yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
                    "<|eot_id|>", ""
                )
            else:
                yield (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
                    "<|eot_id|>", ""
                ), outputs)
        if not return_outputs:
            return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
                "<|eot_id|>", ""
            )
        else:
            return (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
                "<|eot_id|>", ""
            ), outputs)