Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import spaces | |
import gradio as gr | |
from huggingface_hub import login | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import env_config | |
api_token = env_config.HF_API_TOKEN | |
max_new_tokens = env_config.MAX_NEW_TOKENS | |
model_id = env_config.MODEL_ID | |
# 登录 Hugging Face API | |
login(api_token) | |
# 模型加载函数 | |
def get_llm(model_id): | |
# 使用 `device_map="auto"` 自动分配设备 | |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") | |
return model | |
# 问答逻辑 | |
def retriever_qa(file, query): | |
# 加载模型和分词器 | |
# model_id = 'mistralai/Mistral-7B-Instruct-v0.2' | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) | |
# 确保 CUDA 初始化不在主线程 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f'Device: {device}') | |
# 子进程中完成模型加载和推理 | |
def process_inference(file, query): | |
# 加载模型 | |
llm = get_llm(model_id) | |
# 加载文件的第一行内容 | |
with open(file, 'r') as f: | |
first_line = f.readline() | |
# 准备输入 | |
messages = [ | |
{"role": "user", "content": first_line + query} | |
] | |
print(messages) | |
# Tokenize 输入 | |
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
print(f"Model Inputs: {model_inputs}") | |
print('Start Inference') | |
# 推理 | |
generated_ids = llm.generate(model_inputs, max_new_tokens=max_new_tokens, do_sample=True) | |
# generated_ids = llm.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], max_new_tokens=50, do_sample=True) | |
print(f'Generated ids: {generated_ids}') | |
# 解码输出 | |
print('Start detokenize') | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
print(response) | |
return response | |
# 调用推理逻辑 | |
response = process_inference(file, query) | |
return response | |
# Gradio 界面 | |
rag_application = gr.Interface( | |
fn=retriever_qa, | |
allow_flagging="never", | |
inputs=[ | |
gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # 仅支持 TXT 文件 | |
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") # 查询输入框 | |
], | |
outputs=gr.Textbox(label="Output"), # 输出显示框 | |
title="RAG Chatbot", | |
description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document." | |
) | |
# 启动 Gradio 应用 | |
rag_application.launch(share=True) |