import torch from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn from transformers import T5ForConditionalGeneration, T5Tokenizer from peft import PeftModel # --- Configurations --- BASE_MODEL_NAME = "google/flan-t5-large" OUTPUT_DIR = "./lora_t5xl_finetuned_8bit/checkpoint-5745" # Path to your fine-tuned LoRA adapter MAX_SOURCE_LENGTH = 1024 MAX_TARGET_LENGTH = 1024 # --- Load Tokenizer and Base Model --- tokenizer = T5Tokenizer.from_pretrained(BASE_MODEL_NAME, low_cpu_mem_usage=True) base_model = T5ForConditionalGeneration.from_pretrained( BASE_MODEL_NAME, device_map="auto", low_cpu_mem_usage=True ) # --- Load Fine-Tuned LoRA Adapter --- model = PeftModel.from_pretrained(base_model, OUTPUT_DIR) model.eval() # Set the model to evaluation mode # --- Inference Function --- def generate_text(prompt: str) -> str: """ Given an input prompt, generate text using the fine-tuned T5-large LoRA model. """ input_text = "Humanize this text to be undetectable: " + prompt inputs = tokenizer( input_text, return_tensors="pt", truncation=True, max_length=MAX_SOURCE_LENGTH, padding="max_length" ) # Move inputs to the same device as the model inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate output (adjust generation parameters as needed) outputs = model.generate( **inputs, max_length=MAX_TARGET_LENGTH, do_sample=True, top_p=0.95, temperature=0.9, num_return_sequences=1 ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text # --- FastAPI Setup --- app = FastAPI() @app.post("/predict") async def predict(request: Request): """ Expects a JSON payload with a "prompt" field. Returns the generated text. """ data = await request.json() prompt = data.get("prompt", "") if not prompt: return JSONResponse(status_code=400, content={"error": "No prompt provided."}) output_text = generate_text(prompt) return {"generated_text": output_text} # --- For Local Testing --- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)