from intel_extension_for_transformers.backends.neural_engine.compile import compile
# graph = compile("/home/ubuntu/mengfeil/IR/llama-7b-bf16/")
# graph = compile("/home/ubuntu/mengfeil/IR/llama-7b-int8/")
# graph = compile("./llama-7b-hf-conv-itrex-bf16")
graph = compile("/home/ubuntu/neuralchat_server/frameworks.ai.nlp-toolkit.intel-nlp-toolkit/examples/huggingface/pytorch/text-generation/deployment/bf16ir")
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, AutoModel, AutoConfig
import torch
import numpy as np
import torch.nn.functional as F
import time

model_name = "../mengfeil/llama-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
config = AutoConfig.from_pretrained(model_name)

#prompt = "Once upon a time, there existed a little girl who liked to have adventures." + \
#         " She wanted to go to places and meet new people, and have fun"
# prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n Human: tell me something about China.\n Assistant:"

print(prompt)
init_input_ids = tokenizer(prompt, return_tensors="pt").input_ids[0]


input_ids = init_input_ids.clone()

attention_mask = torch.ones(len(input_ids)+1)
attention_mask[0] = 0
position_ids = torch.arange(len(input_ids))
past_key_value = tuple([(torch.zeros([1,32,1,128]), torch.zeros([1,32,1,128])) for i in range(32)])

input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
position_ids = position_ids.unsqueeze(0)
all_input_ids = input_ids.clone()


# input_ids_1 = input_ids.cpu().numpy().astype(np.int32)
# attention_mask_1 = attention_mask.cpu().numpy().astype(np.int32)
# past_k_v = [past_key_value[i][j].cpu().numpy() for i in range(32) for j in range(2)]

max_new_tokens = 32
temperature = 0.9


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


# start
total_time = 0.0
num_iter = 10
num_warmup = 3

for i in range(num_iter):

    input_ids_1 = input_ids.cpu().numpy().astype(np.int32)
    attention_mask_1 = attention_mask.cpu().numpy().astype(np.int32)
    past_k_v = [past_key_value[i][j].cpu().numpy() for i in range(32) for j in range(2)]
    output_ids = list(init_input_ids)
    
    tic = time.time()

    for step in range(max_new_tokens):
        a = time.time()
        predictions = graph.inference([input_ids_1, attention_mask_1] + past_k_v)
        # predictions = graph.inference([input_ids_1] + past_k_v + [attention_mask_1])
        print(time.time() - a)

        outs = []
        for key in predictions:
            outs.append(predictions[key])

        logits = outs[0]
        past_k_v = outs[1:]
        logits = torch.from_numpy(logits)

        next_token_logits = logits[:, -1, :]
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        token = int(torch.argmax(probs, dim=-1))

        """
        last_token_logits = logits[0][-1]
        logits = logits[0, -1, :] / temperature
        filtered_logits = top_k_top_p_filtering(logits, top_k=10, top_p=0.8)
        probabilities = F.softmax(filtered_logits, dim=-1)
        token = int(torch.multinomial(probabilities, 1))
        """

        output_ids.append(token)
        input_ids_1 = torch.tensor([[token]])
        attention_mask_1 = torch.cat([torch.from_numpy(attention_mask_1), torch.ones([1, 1])], dim=-1)
        input_ids_1 = input_ids_1.cpu().numpy().astype(np.int32)
        attention_mask_1 = attention_mask_1.cpu().numpy().astype(np.int32)
    
    toc = time.time()
    if i >= num_warmup:
        total_time += (toc - tic)
    # print(output_ids)
    print(tokenizer.decode(output_ids, skip_special_tokens=True))

print("Inference latency: %.3f ms." % (total_time / (num_iter - num_warmup) * 1000))