medical_model / app.py
LAWSA07's picture
Update app.py
f0a5521 verified
raw
history blame
1.99 kB
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
app = FastAPI()
# Load model once at startup
@app.on_event("startup")
async def load_model():
try:
# Configuration
model_name = "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
adapter_name = "LAWSA07/medical_fine_tuned_deepseekR1"
# Load base model with 4-bit quantization
app.state.base_model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_4bit=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
# Attach PEFT adapter
app.state.model = PeftModel.from_pretrained(
app.state.base_model,
adapter_name,
adapter_weight_name="adapter_model.safetensors"
)
# Load tokenizer
app.state.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Model loading failed: {str(e)}"
)
@app.get("/")
def health_check():
return {"status": "OK"}
@app.post("/generate")
async def generate_text(prompt: str, max_length: int = 200):
try:
inputs = app.state.tokenizer(
prompt,
return_tensors="pt",
padding=True
).to("cuda")
outputs = app.state.model.generate(
**inputs,
max_length=max_length,
temperature=0.7,
do_sample=True
)
decoded = app.state.tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
return {"response": decoded}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Generation failed: {str(e)}"
)