|
import torch |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
import uvicorn |
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from peft import PeftModel |
|
|
|
|
|
BASE_MODEL_NAME = "google/flan-t5-large" |
|
OUTPUT_DIR = "./lora_t5xl_finetuned_8bit/checkpoint-5745" |
|
MAX_SOURCE_LENGTH = 1024 |
|
MAX_TARGET_LENGTH = 1024 |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_NAME, low_cpu_mem_usage=True) |
|
base_model = T5ForConditionalGeneration.from_pretrained( |
|
BASE_MODEL_NAME, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, OUTPUT_DIR) |
|
model.eval() |
|
|
|
|
|
def generate_text(prompt: str) -> str: |
|
""" |
|
Given an input prompt, generate text using the fine-tuned T5-large LoRA model. |
|
""" |
|
input_text = "Humanize this text to be undetectable: " + prompt |
|
inputs = tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=MAX_SOURCE_LENGTH, |
|
padding="max_length" |
|
) |
|
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=MAX_TARGET_LENGTH, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.9, |
|
num_return_sequences=1 |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.post("/predict") |
|
async def predict(request: Request): |
|
""" |
|
Expects a JSON payload with a "prompt" field. |
|
Returns the generated text. |
|
""" |
|
data = await request.json() |
|
prompt = data.get("prompt", "") |
|
if not prompt: |
|
return JSONResponse(status_code=400, content={"error": "No prompt provided."}) |
|
|
|
output_text = generate_text(prompt) |
|
return {"generated_text": output_text} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|