LAWSA07 commited on
Commit
f0a5521
·
verified ·
1 Parent(s): 3e7f095

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py CHANGED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ # Load model once at startup
9
+ @app.on_event("startup")
10
+ async def load_model():
11
+ try:
12
+ # Configuration
13
+ model_name = "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
14
+ adapter_name = "LAWSA07/medical_fine_tuned_deepseekR1"
15
+
16
+ # Load base model with 4-bit quantization
17
+ app.state.base_model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ load_in_4bit=True,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ trust_remote_code=True,
23
+ )
24
+
25
+ # Attach PEFT adapter
26
+ app.state.model = PeftModel.from_pretrained(
27
+ app.state.base_model,
28
+ adapter_name,
29
+ adapter_weight_name="adapter_model.safetensors"
30
+ )
31
+
32
+ # Load tokenizer
33
+ app.state.tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+
35
+ except Exception as e:
36
+ raise HTTPException(
37
+ status_code=500,
38
+ detail=f"Model loading failed: {str(e)}"
39
+ )
40
+
41
+ @app.get("/")
42
+ def health_check():
43
+ return {"status": "OK"}
44
+
45
+ @app.post("/generate")
46
+ async def generate_text(prompt: str, max_length: int = 200):
47
+ try:
48
+ inputs = app.state.tokenizer(
49
+ prompt,
50
+ return_tensors="pt",
51
+ padding=True
52
+ ).to("cuda")
53
+
54
+ outputs = app.state.model.generate(
55
+ **inputs,
56
+ max_length=max_length,
57
+ temperature=0.7,
58
+ do_sample=True
59
+ )
60
+
61
+ decoded = app.state.tokenizer.decode(
62
+ outputs[0],
63
+ skip_special_tokens=True
64
+ )
65
+
66
+ return {"response": decoded}
67
+
68
+ except Exception as e:
69
+ raise HTTPException(
70
+ status_code=500,
71
+ detail=f"Generation failed: {str(e)}"
72
+ )