import sys import numpy as np import os import json from pathlib import Path from typing import List, Tuple, Dict, Optional from ml_dtypes import bfloat16 import dataclasses from transformers import AutoTokenizer, AutoConfig import torch from torchvision.transforms.functional import InterpolationMode from axengine import InferenceSession from tqdm import tqdm import torchvision.transforms as T from PIL import Image import argparse from loguru import logger from copy import deepcopy from utils.infer_func import InferManager, KVCacheTools class LlamaChatSession: def __init__(self, builder_instance): self.system_prompt = builder_instance.system_prompt self.builder_instance = builder_instance self.last_reply = "" def encode(self, prompt: str) -> Tuple[List[int], List[int]]: """ keys: "message", "model_inputs", "input_ids", "input_embeds", "input_ids_len" """ return self.builder_instance.encoder_prompt(prompt) def get_kvcache(self) -> Tuple[List[np.ndarray], List[np.ndarray], int]: return self.builder_instance.k_caches, self.builder_instance.v_caches def generate(self, model_inputs ): token_ids = model_inputs["input_ids"] self.builder_instance.decode(token_ids) return None def run(self, model_inputs) -> str: response = self.generate( model_inputs ) return response def reset_context(self, system_prompt: str = None): """ reset 只需要把 kv cache 清空即可 (甚至可以直接复写) 但如果 system_prompt 变化了,则需要重新计算 kv cache """ if system_prompt is not None: self.system_prompt = system_prompt self.builder_instance.precompute_len = self.builder_instance.system_input_ids_len for i in range(len(self.builder_instance.k_caches)): self.builder_instance.k_caches[i][:, self.builder_instance.precompute_len:, :].fill(0) self.builder_instance.v_caches[i][:, self.builder_instance.precompute_len:, :].fill(0) def chat_loop(self, live_print: bool = False): if self.system_prompt: print(f">>> 系统提示: {self.system_prompt}") logger.info("Type 'q' to exit, Ctrl+c to stop current generation\n") while True: try: prompt = input("prompt (输入 q 退出对话) >> ") if prompt.lower() == "q" or prompt.lower() == "exit": print("\nOK, 已退出对话.") return if prompt.lower() == "debug": print(f"\n>>> DEBUG INFO >>>\n precompute_len is {self.builder_instance.precompute_len}\n<<< DEBUG INFO <<<\n") continue if not prompt.strip(): print(f"\n{self.system_prompt}") continue if prompt.strip() == "reset": self.reset_context() print("上下文已重置") continue model_inputs = self.encode(prompt) if self.builder_instance.precompute_len + 128 >= 2559: logger.info("ERROR: 上下文窗口已满! 请使用 `reset` 命令重置上下文") continue response = self.run(model_inputs) except KeyboardInterrupt: # 处理用户按 Ctrl+C 中断生成 print("\n好的, 已成功退出对话.") exit() except Exception as e: print(f"ERROR: {str(e)}") if __name__ == "__main__": hf_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8/' axmodel_model_path = './Qwen2.5-1.5B-Instruct-GPTQ-Int8_axmodel/' builder = InferManager(hf_model_path, axmodel_model_path) # init tokenizer & hf_config & system prompt builder.build_system_prompt() builder.build_kvcache() builder.build_infer_model() cache_manager = KVCacheTools(axmodel_num=28, dtype=bfloat16) if not os.path.exists("./kvcache"): # 预计算 system prompt k,v 缓存 update_kv_cache = builder.prefill( builder.model_inputs, slice_len=128, ) if cache_manager.save_kvcache( target_dir="./kvcache", system_prompt=builder.system_prompt, precompute_len=builder.system_input_ids_len, k_caches=update_kv_cache[0], v_caches=update_kv_cache[1], metadata={"model_version": "v0.1"} ): logger.info(">>> 预计算 system prompt kvcache 保存到 ./kvcache 目录, 下次启动可直接加载缓存 <<<") else: logger.error(">>> kvcache 缓存保存失败, 程序退出! <<<") exit() else: update_kv_cache, prompt, plen, meta = cache_manager.load_kvcache("./kvcache") builder.update_kvcache(update_kv_cache) logger.debug(">>> 创建 LlamaChatSession >>>") session = LlamaChatSession( builder_instance=builder ) session.chat_loop(live_print=False)