File size: 5,157 Bytes
6602891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import sys
import numpy as np
import os
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from ml_dtypes import bfloat16
import dataclasses
from transformers import AutoTokenizer, AutoConfig
import torch
from torchvision.transforms.functional import InterpolationMode
from axengine import InferenceSession
from tqdm import tqdm
import torchvision.transforms as T
from PIL import Image
import argparse
from loguru import logger
from copy import deepcopy
from utils.infer_func import InferManager, KVCacheTools


class LlamaChatSession:
    def __init__(self, builder_instance):
        self.system_prompt = builder_instance.system_prompt
        self.builder_instance = builder_instance
        self.last_reply = ""

    def encode(self, prompt: str) -> Tuple[List[int], List[int]]:
        """
        keys: "message", "model_inputs", "input_ids", "input_embeds", "input_ids_len"
        """
        return self.builder_instance.encoder_prompt(prompt)

    def get_kvcache(self) -> Tuple[List[np.ndarray], List[np.ndarray], int]:
        return self.builder_instance.k_caches, self.builder_instance.v_caches

    def generate(self,
        model_inputs
    ):
        token_ids = model_inputs["input_ids"]
        self.builder_instance.decode(token_ids)
        return None

    def run(self, model_inputs) -> str:
        response = self.generate(
            model_inputs
        )
        return response

    def reset_context(self, system_prompt: str = None):
        """
        reset 只需要把 kv cache 清空即可 (甚至可以直接复写)
        但如果 system_prompt 变化了,则需要重新计算 kv cache
        """
        if system_prompt is not None:
            self.system_prompt = system_prompt

        self.builder_instance.precompute_len = self.builder_instance.system_input_ids_len

        for i in range(len(self.builder_instance.k_caches)):
            self.builder_instance.k_caches[i][:, self.builder_instance.precompute_len:, :].fill(0)
            self.builder_instance.v_caches[i][:, self.builder_instance.precompute_len:, :].fill(0)

    def chat_loop(self, live_print: bool = False):

        if self.system_prompt:
            print(f">>> 系统提示: {self.system_prompt}")

        logger.info("Type 'q' to exit, Ctrl+c to stop current generation\n")

        while True:
            try:
                prompt = input("prompt (输入 q 退出对话) >> ")

                if prompt.lower() == "q" or prompt.lower() == "exit":
                    print("\nOK, 已退出对话.")
                    return

                if prompt.lower() == "debug":
                    print(f"\n>>> DEBUG INFO >>>\n precompute_len is {self.builder_instance.precompute_len}\n<<< DEBUG INFO <<<\n")
                    continue

                if not prompt.strip():
                    print(f"\n{self.system_prompt}")
                    continue

                if prompt.strip() == "reset":
                    self.reset_context()
                    print("上下文已重置")
                    continue

                model_inputs = self.encode(prompt)

                if self.builder_instance.precompute_len + 128 >= 2559:
                    logger.info("ERROR: 上下文窗口已满! 请使用 `reset` 命令重置上下文")
                    continue

                response = self.run(model_inputs)

            except KeyboardInterrupt:
                # 处理用户按 Ctrl+C 中断生成
                print("\n好的, 已成功退出对话.")
                exit()
                
            except Exception as e:
                print(f"ERROR: {str(e)}")


if __name__ == "__main__":

    hf_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8/'
    axmodel_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8_axmodel/'

    builder = InferManager(hf_model_path, axmodel_model_path) # init tokenizer & hf_config & system prompt
    builder.build_system_prompt()
    builder.build_kvcache()
    builder.build_infer_model()

    cache_manager = KVCacheTools(axmodel_num=28, dtype=bfloat16)

    if not os.path.exists("./kvcache"):
        # 预计算 system prompt k,v 缓存
        update_kv_cache = builder.prefill(
            builder.model_inputs,
            slice_len=128,
        )
        if cache_manager.save_kvcache(
            target_dir="./kvcache",
            system_prompt=builder.system_prompt,
            precompute_len=builder.system_input_ids_len,
            k_caches=update_kv_cache[0],
            v_caches=update_kv_cache[1],
            metadata={"model_version": "v0.1"}
        ):
            logger.info(">>> 预计算 system prompt kvcache 保存到 ./kvcache 目录, 下次启动可直接加载缓存 <<<")
        else:
            logger.error(">>> kvcache 缓存保存失败, 程序退出! <<<")
            exit()
    else:
        update_kv_cache, prompt, plen, meta = cache_manager.load_kvcache("./kvcache")
        builder.update_kvcache(update_kv_cache)

    logger.debug(">>> 创建 LlamaChatSession >>>")

    session = LlamaChatSession(
        builder_instance=builder
    )
    session.chat_loop(live_print=False)