File size: 5,157 Bytes
6602891 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import sys
import numpy as np
import os
import json
from pathlib import Path
from typing import List, Tuple, Dict, Optional
from ml_dtypes import bfloat16
import dataclasses
from transformers import AutoTokenizer, AutoConfig
import torch
from torchvision.transforms.functional import InterpolationMode
from axengine import InferenceSession
from tqdm import tqdm
import torchvision.transforms as T
from PIL import Image
import argparse
from loguru import logger
from copy import deepcopy
from utils.infer_func import InferManager, KVCacheTools
class LlamaChatSession:
def __init__(self, builder_instance):
self.system_prompt = builder_instance.system_prompt
self.builder_instance = builder_instance
self.last_reply = ""
def encode(self, prompt: str) -> Tuple[List[int], List[int]]:
"""
keys: "message", "model_inputs", "input_ids", "input_embeds", "input_ids_len"
"""
return self.builder_instance.encoder_prompt(prompt)
def get_kvcache(self) -> Tuple[List[np.ndarray], List[np.ndarray], int]:
return self.builder_instance.k_caches, self.builder_instance.v_caches
def generate(self,
model_inputs
):
token_ids = model_inputs["input_ids"]
self.builder_instance.decode(token_ids)
return None
def run(self, model_inputs) -> str:
response = self.generate(
model_inputs
)
return response
def reset_context(self, system_prompt: str = None):
"""
reset 只需要把 kv cache 清空即可 (甚至可以直接复写)
但如果 system_prompt 变化了,则需要重新计算 kv cache
"""
if system_prompt is not None:
self.system_prompt = system_prompt
self.builder_instance.precompute_len = self.builder_instance.system_input_ids_len
for i in range(len(self.builder_instance.k_caches)):
self.builder_instance.k_caches[i][:, self.builder_instance.precompute_len:, :].fill(0)
self.builder_instance.v_caches[i][:, self.builder_instance.precompute_len:, :].fill(0)
def chat_loop(self, live_print: bool = False):
if self.system_prompt:
print(f">>> 系统提示: {self.system_prompt}")
logger.info("Type 'q' to exit, Ctrl+c to stop current generation\n")
while True:
try:
prompt = input("prompt (输入 q 退出对话) >> ")
if prompt.lower() == "q" or prompt.lower() == "exit":
print("\nOK, 已退出对话.")
return
if prompt.lower() == "debug":
print(f"\n>>> DEBUG INFO >>>\n precompute_len is {self.builder_instance.precompute_len}\n<<< DEBUG INFO <<<\n")
continue
if not prompt.strip():
print(f"\n{self.system_prompt}")
continue
if prompt.strip() == "reset":
self.reset_context()
print("上下文已重置")
continue
model_inputs = self.encode(prompt)
if self.builder_instance.precompute_len + 128 >= 2559:
logger.info("ERROR: 上下文窗口已满! 请使用 `reset` 命令重置上下文")
continue
response = self.run(model_inputs)
except KeyboardInterrupt:
# 处理用户按 Ctrl+C 中断生成
print("\n好的, 已成功退出对话.")
exit()
except Exception as e:
print(f"ERROR: {str(e)}")
if __name__ == "__main__":
hf_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8/'
axmodel_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8_axmodel/'
builder = InferManager(hf_model_path, axmodel_model_path) # init tokenizer & hf_config & system prompt
builder.build_system_prompt()
builder.build_kvcache()
builder.build_infer_model()
cache_manager = KVCacheTools(axmodel_num=28, dtype=bfloat16)
if not os.path.exists("./kvcache"):
# 预计算 system prompt k,v 缓存
update_kv_cache = builder.prefill(
builder.model_inputs,
slice_len=128,
)
if cache_manager.save_kvcache(
target_dir="./kvcache",
system_prompt=builder.system_prompt,
precompute_len=builder.system_input_ids_len,
k_caches=update_kv_cache[0],
v_caches=update_kv_cache[1],
metadata={"model_version": "v0.1"}
):
logger.info(">>> 预计算 system prompt kvcache 保存到 ./kvcache 目录, 下次启动可直接加载缓存 <<<")
else:
logger.error(">>> kvcache 缓存保存失败, 程序退出! <<<")
exit()
else:
update_kv_cache, prompt, plen, meta = cache_manager.load_kvcache("./kvcache")
builder.update_kvcache(update_kv_cache)
logger.debug(">>> 创建 LlamaChatSession >>>")
session = LlamaChatSession(
builder_instance=builder
)
session.chat_loop(live_print=False)
|