File size: 2,263 Bytes
4e55ab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import PeftModel
# --- Configurations ---
BASE_MODEL_NAME = "google/flan-t5-large"
OUTPUT_DIR = "./lora_t5xl_finetuned_8bit/checkpoint-5745" # Path to your fine-tuned LoRA adapter
MAX_SOURCE_LENGTH = 1024
MAX_TARGET_LENGTH = 1024
# --- Load Tokenizer and Base Model ---
tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_NAME, low_cpu_mem_usage=True)
base_model = T5ForConditionalGeneration.from_pretrained(
BASE_MODEL_NAME,
device_map="auto",
low_cpu_mem_usage=True
)
# --- Load Fine-Tuned LoRA Adapter ---
model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
model.eval() # Set the model to evaluation mode
# --- Inference Function ---
def generate_text(prompt: str) -> str:
"""
Given an input prompt, generate text using the fine-tuned T5-large LoRA model.
"""
input_text = "Humanize this text to be undetectable: " + prompt
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=MAX_SOURCE_LENGTH,
padding="max_length"
)
# Move inputs to the same device as the model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate output (adjust generation parameters as needed)
outputs = model.generate(
**inputs,
max_length=MAX_TARGET_LENGTH,
do_sample=True,
top_p=0.95,
temperature=0.9,
num_return_sequences=1
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# --- FastAPI Setup ---
app = FastAPI()
@app.post("/predict")
async def predict(request: Request):
"""
Expects a JSON payload with a "prompt" field.
Returns the generated text.
"""
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return JSONResponse(status_code=400, content={"error": "No prompt provided."})
output_text = generate_text(prompt)
return {"generated_text": output_text}
# --- For Local Testing ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|