|
import sys |
|
import numpy as np |
|
import os |
|
import json |
|
from pathlib import Path |
|
from typing import List, Tuple, Dict, Optional |
|
from ml_dtypes import bfloat16 |
|
import dataclasses |
|
from transformers import AutoTokenizer, AutoConfig |
|
import torch |
|
from torchvision.transforms.functional import InterpolationMode |
|
from axengine import InferenceSession |
|
from tqdm import tqdm |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
import argparse |
|
from loguru import logger |
|
from copy import deepcopy |
|
from utils.infer_func import InferManager, KVCacheTools |
|
|
|
|
|
class LlamaChatSession: |
|
def __init__(self, builder_instance): |
|
self.system_prompt = builder_instance.system_prompt |
|
self.builder_instance = builder_instance |
|
self.last_reply = "" |
|
|
|
def encode(self, prompt: str) -> Tuple[List[int], List[int]]: |
|
""" |
|
keys: "message", "model_inputs", "input_ids", "input_embeds", "input_ids_len" |
|
""" |
|
return self.builder_instance.encoder_prompt(prompt) |
|
|
|
def get_kvcache(self) -> Tuple[List[np.ndarray], List[np.ndarray], int]: |
|
return self.builder_instance.k_caches, self.builder_instance.v_caches |
|
|
|
def generate(self, |
|
model_inputs |
|
): |
|
token_ids = model_inputs["input_ids"] |
|
self.builder_instance.decode(token_ids) |
|
return None |
|
|
|
def run(self, model_inputs) -> str: |
|
response = self.generate( |
|
model_inputs |
|
) |
|
return response |
|
|
|
def reset_context(self, system_prompt: str = None): |
|
""" |
|
reset 只需要把 kv cache 清空即可 (甚至可以直接复写) |
|
但如果 system_prompt 变化了,则需要重新计算 kv cache |
|
""" |
|
if system_prompt is not None: |
|
self.system_prompt = system_prompt |
|
|
|
self.builder_instance.precompute_len = self.builder_instance.system_input_ids_len |
|
|
|
for i in range(len(self.builder_instance.k_caches)): |
|
self.builder_instance.k_caches[i][:, self.builder_instance.precompute_len:, :].fill(0) |
|
self.builder_instance.v_caches[i][:, self.builder_instance.precompute_len:, :].fill(0) |
|
|
|
def chat_loop(self, live_print: bool = False): |
|
|
|
if self.system_prompt: |
|
print(f">>> 系统提示: {self.system_prompt}") |
|
|
|
logger.info("Type 'q' to exit, Ctrl+c to stop current generation\n") |
|
|
|
while True: |
|
try: |
|
prompt = input("prompt (输入 q 退出对话) >> ") |
|
|
|
if prompt.lower() == "q" or prompt.lower() == "exit": |
|
print("\nOK, 已退出对话.") |
|
return |
|
|
|
if prompt.lower() == "debug": |
|
print(f"\n>>> DEBUG INFO >>>\n precompute_len is {self.builder_instance.precompute_len}\n<<< DEBUG INFO <<<\n") |
|
continue |
|
|
|
if not prompt.strip(): |
|
print(f"\n{self.system_prompt}") |
|
continue |
|
|
|
if prompt.strip() == "reset": |
|
self.reset_context() |
|
print("上下文已重置") |
|
continue |
|
|
|
model_inputs = self.encode(prompt) |
|
|
|
if self.builder_instance.precompute_len + 128 >= 2559: |
|
logger.info("ERROR: 上下文窗口已满! 请使用 `reset` 命令重置上下文") |
|
continue |
|
|
|
response = self.run(model_inputs) |
|
|
|
except KeyboardInterrupt: |
|
|
|
print("\n好的, 已成功退出对话.") |
|
exit() |
|
|
|
except Exception as e: |
|
print(f"ERROR: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
hf_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8/' |
|
axmodel_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8_axmodel/' |
|
|
|
builder = InferManager(hf_model_path, axmodel_model_path) |
|
builder.build_system_prompt() |
|
builder.build_kvcache() |
|
builder.build_infer_model() |
|
|
|
cache_manager = KVCacheTools(axmodel_num=28, dtype=bfloat16) |
|
|
|
if not os.path.exists("./kvcache"): |
|
|
|
update_kv_cache = builder.prefill( |
|
builder.model_inputs, |
|
slice_len=128, |
|
) |
|
if cache_manager.save_kvcache( |
|
target_dir="./kvcache", |
|
system_prompt=builder.system_prompt, |
|
precompute_len=builder.system_input_ids_len, |
|
k_caches=update_kv_cache[0], |
|
v_caches=update_kv_cache[1], |
|
metadata={"model_version": "v0.1"} |
|
): |
|
logger.info(">>> 预计算 system prompt kvcache 保存到 ./kvcache 目录, 下次启动可直接加载缓存 <<<") |
|
else: |
|
logger.error(">>> kvcache 缓存保存失败, 程序退出! <<<") |
|
exit() |
|
else: |
|
update_kv_cache, prompt, plen, meta = cache_manager.load_kvcache("./kvcache") |
|
builder.update_kvcache(update_kv_cache) |
|
|
|
logger.debug(">>> 创建 LlamaChatSession >>>") |
|
|
|
session = LlamaChatSession( |
|
builder_instance=builder |
|
) |
|
session.chat_loop(live_print=False) |
|
|