File size: 2,263 Bytes
4e55ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn

from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import PeftModel

# --- Configurations ---
BASE_MODEL_NAME = "google/flan-t5-large"
OUTPUT_DIR = "./lora_t5xl_finetuned_8bit/checkpoint-5745"  # Path to your fine-tuned LoRA adapter
MAX_SOURCE_LENGTH = 1024
MAX_TARGET_LENGTH = 1024

# --- Load Tokenizer and Base Model ---
tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_NAME, low_cpu_mem_usage=True)
base_model = T5ForConditionalGeneration.from_pretrained(
    BASE_MODEL_NAME,
    device_map="auto",
    low_cpu_mem_usage=True
)

# --- Load Fine-Tuned LoRA Adapter ---
model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
model.eval()  # Set the model to evaluation mode

# --- Inference Function ---
def generate_text(prompt: str) -> str:
    """
    Given an input prompt, generate text using the fine-tuned T5-large LoRA model.
    """
    input_text = "Humanize this text to be undetectable: " + prompt
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_SOURCE_LENGTH,
        padding="max_length"
    )
    # Move inputs to the same device as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate output (adjust generation parameters as needed)
    outputs = model.generate(
        **inputs,
        max_length=MAX_TARGET_LENGTH,
        do_sample=True,
        top_p=0.95,
        temperature=0.9,
        num_return_sequences=1
    )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

# --- FastAPI Setup ---
app = FastAPI()

@app.post("/predict")
async def predict(request: Request):
    """
    Expects a JSON payload with a "prompt" field.
    Returns the generated text.
    """
    data = await request.json()
    prompt = data.get("prompt", "")
    if not prompt:
        return JSONResponse(status_code=400, content={"error": "No prompt provided."})
    
    output_text = generate_text(prompt)
    return {"generated_text": output_text}

# --- For Local Testing ---
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)