danibor commited on
Commit
4e55ab7
·
verified ·
1 Parent(s): 682ab73

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +74 -0
handler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import JSONResponse
4
+ import uvicorn
5
+
6
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
7
+ from peft import PeftModel
8
+
9
+ # --- Configurations ---
10
+ BASE_MODEL_NAME = "google/flan-t5-large"
11
+ OUTPUT_DIR = "./lora_t5xl_finetuned_8bit/checkpoint-5745" # Path to your fine-tuned LoRA adapter
12
+ MAX_SOURCE_LENGTH = 1024
13
+ MAX_TARGET_LENGTH = 1024
14
+
15
+ # --- Load Tokenizer and Base Model ---
16
+ tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_NAME, low_cpu_mem_usage=True)
17
+ base_model = T5ForConditionalGeneration.from_pretrained(
18
+ BASE_MODEL_NAME,
19
+ device_map="auto",
20
+ low_cpu_mem_usage=True
21
+ )
22
+
23
+ # --- Load Fine-Tuned LoRA Adapter ---
24
+ model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
25
+ model.eval() # Set the model to evaluation mode
26
+
27
+ # --- Inference Function ---
28
+ def generate_text(prompt: str) -> str:
29
+ """
30
+ Given an input prompt, generate text using the fine-tuned T5-large LoRA model.
31
+ """
32
+ input_text = "Humanize this text to be undetectable: " + prompt
33
+ inputs = tokenizer(
34
+ input_text,
35
+ return_tensors="pt",
36
+ truncation=True,
37
+ max_length=MAX_SOURCE_LENGTH,
38
+ padding="max_length"
39
+ )
40
+ # Move inputs to the same device as the model
41
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
42
+
43
+ # Generate output (adjust generation parameters as needed)
44
+ outputs = model.generate(
45
+ **inputs,
46
+ max_length=MAX_TARGET_LENGTH,
47
+ do_sample=True,
48
+ top_p=0.95,
49
+ temperature=0.9,
50
+ num_return_sequences=1
51
+ )
52
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
53
+ return generated_text
54
+
55
+ # --- FastAPI Setup ---
56
+ app = FastAPI()
57
+
58
+ @app.post("/predict")
59
+ async def predict(request: Request):
60
+ """
61
+ Expects a JSON payload with a "prompt" field.
62
+ Returns the generated text.
63
+ """
64
+ data = await request.json()
65
+ prompt = data.get("prompt", "")
66
+ if not prompt:
67
+ return JSONResponse(status_code=400, content={"error": "No prompt provided."})
68
+
69
+ output_text = generate_text(prompt)
70
+ return {"generated_text": output_text}
71
+
72
+ # --- For Local Testing ---
73
+ if __name__ == "__main__":
74
+ uvicorn.run(app, host="0.0.0.0", port=8000)