import dataclasses
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
import transformers

from .ultravox_config import UltravoxConfig


@dataclasses.dataclass
class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
    # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel
    include_alt_fields: bool = False

    def __call__(self, features, *args, **kwargs):
        audio_values = [x for f in features for x in f.pop("audio_values", [])]
        audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
        audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
        audio_token_start_idx = [
            x for f in features for x in f.pop("audio_token_start_idx", [])
        ]

        if self.include_alt_fields:
            # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
            alt_features = [
                {
                    "input_ids": f.pop("alt_input_ids"),
                    "attention_mask": f.pop("alt_attention_mask"),
                    "labels": f.pop("alt_labels"),
                }
                for f in features
            ]

        batch = super().__call__(features, *args, **kwargs)
        if self.include_alt_fields:
            alt_batch = super().__call__(alt_features, *args, **kwargs)
            batch["alt_input_ids"] = alt_batch["input_ids"]
            batch["alt_attention_mask"] = alt_batch["attention_mask"]
            batch["alt_labels"] = alt_batch["labels"]

        batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
        batch["audio_lens"] = torch.stack(audio_lens)
        batch["audio_token_len"] = torch.stack(audio_token_len)

        # Pad the last dimension of all audio_values to the same length, with 0s on the right.
        if audio_values:
            max_len = max([x.shape[-1] for x in audio_values])
            batch["audio_values"] = torch.stack(
                [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
            )
            if self.tokenizer.padding_side == "left":
                input_ids_lens = torch.LongTensor(
                    [f["input_ids"].shape[-1] for f in features]
                )
                displacement = batch["input_ids"].shape[-1] - input_ids_lens
                displacement = displacement.repeat_interleave(
                    batch["audio_batch_size"].squeeze(-1)
                )
                batch["audio_token_start_idx"] += displacement.to(
                    batch["audio_token_start_idx"].device
                )
        return batch


class UltravoxProcessor(transformers.ProcessorMixin):
    """
    Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.

    Args:
        audio_processor: The audio processor for the audio encoder.
        tokenizer: The tokenizer for the language model.
    """

    attributes = ["audio_processor", "tokenizer"]
    audio_processor_class = ("WhisperProcessor",)
    tokenizer_class = (
        "PreTrainedTokenizer",
        "PreTrainedTokenizerFast",
    )

    tokenizer: transformers.PreTrainedTokenizerBase
    audio_processor: transformers.ProcessorMixin

    def __init__(
        self,
        audio_processor=None,
        tokenizer=None,
        audio_padding: str = "longest",
        encoder_ds_factor: int = 2,
        stack_factor: int = 8,
        audio_placeholder: str = "<|audio|>",
        # Defaults to whisper encoder context size
        audio_context_size: Optional[int] = 3000,
    ):
        """
        Args:
            audio_processor: The audio processor for the audio encoder.
            tokenizer: The tokenizer for the language model.
            audio_padding: The padding strategy for the audio encoder.
            stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
            encoder_ds_factor: The downsampling factor of the audio encoder.
            audio_placeholder: The placeholder for the audio in the text.
            audio_context_size: The maximum number of frames that the audio encoder can handle.
        """
        self.audio_padding = audio_padding
        self.encoder_ds_factor = encoder_ds_factor
        self.stack_factor = stack_factor
        self.audio_placeholder = audio_placeholder
        self.audio_token_replacement = tokenizer.eos_token
        self.audio_context_size = audio_context_size
        assert (
            self.audio_token_replacement is not None
        ), "The tokenizer has no EOS token. Cannot recover."
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id

        super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
            pretrained_model_name_or_path, **kwargs
        )
        audio_processor = transformers.AutoProcessor.from_pretrained(
            config.audio_model_id
            or config.audio_config._name_or_path
            or "openai/whisper-tiny"
        )

        tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, **kwargs
        )
        tokenizer.padding_side = "left"
        tokenizer.pad_token = tokenizer.eos_token

        return cls(
            audio_processor=audio_processor,
            tokenizer=tokenizer,
            stack_factor=config.stack_factor,
        )

    def _chunk_and_pad_audio(
        self, audio_values: torch.Tensor, audio_lens: torch.Tensor
    ) -> Dict[str, Any]:
        """
        Processes the audio batch by chunking any items in the batch according to the audio_context_size,
        padding the last chunk if needed, and returns a dictionary with updated audio data.

        Args:
            audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
            audio_lens (torch.Tensor): A tensor of audio lengths.

        Returns:
            Dict[str, Any]: Dictionary with the following keys:
                - "audio_values": The concatenated audio tensor after chunking and padding.
                - "audio_lens": Tensor of lengths for each chunk.
                - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
                - "audio_batch_size": A Tensor with one integer representing the number of chunks.

        """
        chunked_audio_values: List[torch.Tensor] = []
        chunked_audio_lens: List[int] = []
        is_continuation_list: List[bool] = []
        context_size = self.audio_context_size or audio_values.shape[-1]

        for audio, audio_len in zip(audio_values, audio_lens):
            for offset in range(0, audio_len, context_size):
                is_continuation = offset > 0
                chunk = audio[..., offset : offset + context_size]
                if is_continuation and chunk.shape[-1] < context_size:
                    # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
                    # batch might not (need to) be padded all the way to the audio_context_size, in which case
                    # we've already included the padding above. On the other hand, if we have any continuation
                    # chunks we know that the batch needs to be padded to audio_context_size because that's what
                    # we're slicing to.
                    chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
                chunked_audio_values.append(torch.as_tensor(chunk))
                chunked_audio_lens.append(min(audio_len - offset, context_size))
                is_continuation_list.append(is_continuation)

        return {
            "audio_values": torch.stack(chunked_audio_values),
            "audio_lens": torch.tensor(chunked_audio_lens),
            "audio_is_continuation": torch.tensor(is_continuation_list),
            "audio_batch_size": torch.tensor([len(chunked_audio_values)]),
        }

    def __call__(
        self,
        text: Optional[str] = None,
        audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
        audios: Optional[
            Union[
                List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
            ]
        ] = None,
        sampling_rate: Optional[int] = None,
        return_tensors: Optional[
            Union[str, transformers.TensorType]
        ] = transformers.TensorType.PYTORCH,
        **kwargs,
    ) -> transformers.BatchFeature:
        """
        Main method to prepare for the model one text sequence and audio. This method forwards the `text`
        and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
        the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
        audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
        of the above two methods for more information.

        Args:
            text (`str`, `List[str]`):
                The sequence to be encoded. Sequence can be a string or (pretokenized string).
            audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
                The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
            audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
                A list or two dimensional array of audio to be prepared.
            sampling_rate (`int`, *optional*, defaults to 16000):
                Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
                you are doing.
            return_tensors (`str` or [`~utils.TensorType`], *optional*):
                If set, will return tensors of a particular framework. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return NumPy `np.ndarray` objects.
                - `'jax'`: Return JAX `jnp.ndarray` objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
              `None`).
            - **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
            - **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
              Returned when `audio` is not `None`.
            - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
        """
        # TODO: Add support for multiple text inputs.
        if audio is not None and audios is not None:
            raise ValueError("Only one of `audio` or `audios` should be provided.")
        elif audio is not None:
            audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
        elif audios is None:
            audios = []

        data = {}
        audio_is_continuation = []
        if len(audios) > 0:
            audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]

            # Pad out each audio to at least 2 hops (the minimum required by the processor).
            hop_length = self.audio_processor.feature_extractor.hop_length
            audios = [
                (
                    np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
                    if len(x) < 2 * hop_length
                    else x
                )
                for x in audios
            ]

            # Main audio processing. The processor is model-specific.
            x: transformers.BatchFeature = self.audio_processor(
                audios,
                sampling_rate=sampling_rate,
                padding="longest",
                pad_to_multiple_of=hop_length,  # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
                truncation=False,
                return_attention_mask=True,
                **kwargs,
            )

            data.update(
                self._chunk_and_pad_audio(
                    audio_values=torch.as_tensor(
                        x.input_features if "input_features" in x else x.input_values
                    ),
                    audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
                )
            )

            audio_is_continuation = data.pop("audio_is_continuation")
            data["audio_token_len"] = torch.ceil(
                data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
            ).to(dtype=torch.int)

        if text is not None:
            if not isinstance(text, str):
                raise ValueError("Text must be a string. Batch mode not supported yet.")

            # Special tokens like BOS should already have been added by the caller.
            tokenized_parts = self.tokenizer(
                text.split(
                    "<|audio|>"  # The placeholder isn't part of the vocabulary, so split the text around it.
                ),
                add_special_tokens=False,
                **kwargs,
            )

            audio_token_start_idx = []
            replacement_token_id = self.tokenizer.get_vocab()[
                self.audio_token_replacement
            ]
            placeholder_index = -1
            split_input_ids = tokenized_parts["input_ids"]
            input_ids: List[int] = []

            for i, token_len in enumerate(data.get("audio_token_len", [])):
                if not audio_is_continuation[i]:
                    placeholder_index += 1
                    if placeholder_index >= len(split_input_ids):
                        raise ValueError(
                            f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
                        )

                    input_ids.extend(split_input_ids[placeholder_index])

                audio_token_start_idx.append(len(input_ids))

                input_ids.extend([replacement_token_id] * token_len)

            # Include any tokens after the last audio.
            placeholder_index += 1
            if placeholder_index != len(split_input_ids) - 1:
                raise ValueError(
                    f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
                )
            input_ids.extend(split_input_ids[placeholder_index])

            if "audio_token_len" in data:
                data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)

            data["input_ids"] = [input_ids]
            data["attention_mask"] = [[1] * len(input_ids)]

            # Ensure that there are no audio placeholders after the last audio.

        return transformers.BatchFeature(data=data, tensor_type=return_tensors)

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        audio_processor_input_names = self.audio_processor.model_input_names
        return list(set(tokenizer_input_names + audio_processor_input_names))


UltravoxProcessor.register_for_auto_class()

transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)