rag_chatbot / app.py
InspirationYF's picture
feat: add env config
d62ccf6
import os
import torch
import spaces
import gradio as gr
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
import env_config
api_token = env_config.HF_API_TOKEN
max_new_tokens = env_config.MAX_NEW_TOKENS
model_id = env_config.MODEL_ID
# 登录 Hugging Face API
login(api_token)
# 模型加载函数
def get_llm(model_id):
# 使用 `device_map="auto"` 自动分配设备
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
return model
# 问答逻辑
@spaces.GPU(duration=120)
def retriever_qa(file, query):
# 加载模型和分词器
# model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
# 确保 CUDA 初始化不在主线程
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
# 子进程中完成模型加载和推理
def process_inference(file, query):
# 加载模型
llm = get_llm(model_id)
# 加载文件的第一行内容
with open(file, 'r') as f:
first_line = f.readline()
# 准备输入
messages = [
{"role": "user", "content": first_line + query}
]
print(messages)
# Tokenize 输入
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
print(f"Model Inputs: {model_inputs}")
print('Start Inference')
# 推理
generated_ids = llm.generate(model_inputs, max_new_tokens=max_new_tokens, do_sample=True)
# generated_ids = llm.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], max_new_tokens=50, do_sample=True)
print(f'Generated ids: {generated_ids}')
# 解码输出
print('Start detokenize')
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
return response
# 调用推理逻辑
response = process_inference(file, query)
return response
# Gradio 界面
rag_application = gr.Interface(
fn=retriever_qa,
allow_flagging="never",
inputs=[
gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # 仅支持 TXT 文件
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") # 查询输入框
],
outputs=gr.Textbox(label="Output"), # 输出显示框
title="RAG Chatbot",
description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document."
)
# 启动 Gradio 应用
rag_application.launch(share=True)