from fastapi import FastAPI, HTTPException from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch app = FastAPI() # Load model once at startup @app.on_event("startup") async def load_model(): try: # Configuration model_name = "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit" adapter_name = "LAWSA07/medical_fine_tuned_deepseekR1" # Load base model with 4-bit quantization app.state.base_model = AutoModelForCausalLM.from_pretrained( model_name, load_in_4bit=True, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) # Attach PEFT adapter app.state.model = PeftModel.from_pretrained( app.state.base_model, adapter_name, adapter_weight_name="adapter_model.safetensors" ) # Load tokenizer app.state.tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: raise HTTPException( status_code=500, detail=f"Model loading failed: {str(e)}" ) @app.get("/") def health_check(): return {"status": "OK"} @app.post("/generate") async def generate_text(prompt: str, max_length: int = 200): try: inputs = app.state.tokenizer( prompt, return_tensors="pt", padding=True ).to("cuda") outputs = app.state.model.generate( **inputs, max_length=max_length, temperature=0.7, do_sample=True ) decoded = app.state.tokenizer.decode( outputs[0], skip_special_tokens=True ) return {"response": decoded} except Exception as e: raise HTTPException( status_code=500, detail=f"Generation failed: {str(e)}" )