|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
from peft import PeftModel
|
|
import torch
|
|
import locale
|
|
import os
|
|
|
|
locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
|
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_8bit=True,
|
|
bnb_8bit_use_double_quant=True,
|
|
bnb_8bit_quant_type="nf8",
|
|
bnb_8bit_compute_dtype=torch.bfloat16
|
|
)
|
|
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", quantization_config=bnb_config)
|
|
|
|
|
|
adapter_config_dir = "adapter_config"
|
|
|
|
model = PeftModel.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", adapter_config=adapter_config_dir)
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
app = FastAPI()
|
|
|
|
class Question(BaseModel):
|
|
question: str
|
|
|
|
class Answer(BaseModel):
|
|
answer: str
|
|
|
|
@app.post("/ask", response_model=Answer)
|
|
async def ask_question(question: Question):
|
|
try:
|
|
inputs = tokenizer(question.question, return_tensors="pt")
|
|
outputs = model.generate(**inputs)
|
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return Answer(answer=answer)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|