story-image-api / app.py
abdalraheemdmd's picture
Update app.py
a950115 verified
raw
history blame
3.23 kB
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import story_generator # βœ… Import Story Generator
from diffusers import DiffusionPipeline
import io
import base64
from PIL import Image
from huggingface_hub import login
app = FastAPI()
# βœ… Set Hugging Face cache directory to /tmp
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
# βœ… Hugging Face Authentication (Only needed if model is private)
HF_TOKEN = "your_huggingface_token_here" # Replace this with your actual token
login(token=HF_TOKEN)
# βœ… Load Image Generation Model (Use a fast, public model)
IMAGE_MODEL = "stabilityai/sdxl-turbo" # Replace with "stabilityai/sdxl-lightning" if needed
pipeline = DiffusionPipeline.from_pretrained(
IMAGE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=HF_TOKEN # Required for private models
).to("cuda" if torch.cuda.is_available() else "cpu")
# βœ… Define the input request format
class StoryRequest(BaseModel):
theme: str
reading_level: str
@app.post("/generate_story_questions_images")
def generate_story_questions_images(request: StoryRequest):
"""Generates a story, questions, and corresponding images."""
try:
print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
# βœ… Generate the story and questions
if not hasattr(story_generator, "generate_story_and_questions"):
raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py")
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
story_text = story_result["story"]
questions = story_result["questions"]
# βœ… Split the story into sentences for image generation
story_sentences = story_text.strip().split(". ")
# βœ… Generate an image for each sentence
images = []
for sentence in story_sentences:
if len(sentence) > 5: # Avoid empty sentences
print(f"πŸ–ΌοΈ Generating image for: {sentence}")
image = pipeline(prompt=sentence, num_inference_steps=5).images[0]
# Convert Image to Base64
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
images.append({"sentence": sentence, "image": img_base64})
# βœ… Return the full response
return {
"theme": request.theme,
"reading_level": request.reading_level,
"story": story_text,
"questions": questions,
"images": images,
}
except Exception as e:
print(f"❌ Error generating story, questions, or images: {e}")
raise HTTPException(status_code=500, detail=str(e))
# βœ… Welcome message at root
@app.get("/")
def home():
return {"message": "πŸŽ‰ Welcome to the Story, Question & Image Generation API! Use /generate_story_questions_images"}