import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer, PreTrainedModel, PretrainedConfig

# Custom Configuration
from transformers import GPT2Config
from transformers.models.auto.configuration_auto import CONFIG_MAPPING


class CustomGPTConfig(GPT2Config):
    model_type = "custom_gpt"

    def __init__(self, vocab_size=50304, n_layer=24, n_head=16, hidden_size=1024, block_size=1024, **kwargs):
        super().__init__(
            vocab_size=vocab_size,
            n_positions=block_size,
            n_ctx=block_size,
            n_embd=hidden_size,
            n_layer=n_layer,
            n_head=n_head,
            **kwargs,
        )
        self.block_size = block_size  # Ensure block_size is properly set


# Register the custom configuration
CONFIG_MAPPING.register("custom_gpt", CustomGPTConfig)


# Wrapper for GPT to make it compatible with Hugging Face
class HuggingFaceGPT(PreTrainedModel):
    config_class = CustomGPTConfig

    def __init__(self, config):
        super().__init__(config)
        from nova_model import GPT  # Replace with your actual model import
        self.transformer = GPT(config)

    def forward(self, input_ids, **kwargs):
        targets = kwargs.get("labels", None)
        logits, loss = self.transformer(input_ids, targets=targets)
        return {"logits": logits, "loss": loss}


class EndpointHandler:
    def __init__(self, model_dir, device="cuda"):
        print(f"Initializing model from directory: {model_dir}")
        # Load custom configuration and model
        self.config = CustomGPTConfig.from_pretrained(model_dir)
        self.model = HuggingFaceGPT(self.config)
        state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=torch.device(device))
        self.model.load_state_dict(state_dict)
        self.model.to(device)
        self.model.eval()
        print("Model initialized successfully.")

        # Load tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.device = device
        print("Tokenizer loaded successfully.")

    def __call__(self, inputs):
        print("Processing inputs...")
        # Extract inputs
        prompt = inputs.get("inputs", "")
        parameters = inputs.get("parameters", {})
        max_length = parameters.get("max_length", 32)
        num_return_sequences = parameters.get("num_return_sequences", 4)
        temperature = parameters.get("temperature", 1.0)
        top_k = parameters.get("top_k", 50)

        if not prompt:
            print("Error: Input prompt is missing.")
            return [{"error": "Input prompt is missing"}]

        print(f"Prompt: {prompt}")
        print(f"Parameters: {parameters}")

        # Encode input prompt
        tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        tokens = tokens.repeat(num_return_sequences, 1)

        # Prepare RNG for reproducibility
        sample_rng = torch.Generator(device=self.device)
        sample_rng.manual_seed(42)

        # Initialize generation
        generated_tokens = tokens
        while generated_tokens.size(1) < max_length:
            with torch.no_grad():
                # Forward pass to get logits
                output = self.model(input_ids=generated_tokens)
                logits = output["logits"][:, -1, :]  # Get the last token logits

                # Apply softmax to get probabilities
                probs = F.softmax(logits / temperature, dim=-1)

                # Top-k sampling
                topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
                next_token = torch.multinomial(topk_probs, 1, generator=sample_rng)
                selected_token = torch.gather(topk_indices, -1, next_token)

                # Append the generated token
                generated_tokens = torch.cat((generated_tokens, selected_token), dim=1)

                # Debug log for generation progress
                print(f"Generated tokens so far: {generated_tokens.size(1)}/{max_length}")

        # Decode and return generated text
        results = []
        for i in range(num_return_sequences):
            tokens_list = generated_tokens[i, :max_length].tolist()
            decoded_text = self.tokenizer.decode(tokens_list, skip_special_tokens=True)
            results.append({"generated_text": decoded_text})

        print("Generation completed.")
        return results


if __name__ == "__main__":
    # Example usage
    model_directory = "./"
    handler = EndpointHandler(model_directory)

    prompt_text = "Hello, I'm a language model,"
    inputs = {"inputs": prompt_text, "parameters": {"max_length": 32, "num_return_sequences": 4, "temperature": 0.7, "top_k": 50}}

    print("Starting inference...")
    outputs = handler(inputs)
    for idx, result in enumerate(outputs):
        print(f"Sample {idx}: {result['generated_text']}")