import os import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel import story_generator from diffusers import DiffusionPipeline import io import base64 from PIL import Image app = FastAPI() # ✅ Set Hugging Face cache directory to /tmp (Fixes cache write errors) os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface" # ✅ Load Public Image Generation Model (No Token Needed) IMAGE_MODEL = "stabilityai/sdxl-turbo" # Fastest model for public access pipeline = DiffusionPipeline.from_pretrained( IMAGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to("cuda" if torch.cuda.is_available() else "cpu") # ✅ Define the input request format class StoryRequest(BaseModel): theme: str reading_level: str @app.post("/generate_story_questions_images") def generate_story_questions_images(request: StoryRequest): """Generates a story, questions, and corresponding images.""" try: print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}") # ✅ Generate the story and questions if not hasattr(story_generator, "generate_story_and_questions"): raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py") story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level) story_text = story_result["story"] questions = story_result["questions"] # ✅ Split the story into sentences for image generation story_sentences = story_text.strip().split(". ") # ✅ Generate an image for each sentence images = [] for sentence in story_sentences: if len(sentence) > 5: # Avoid empty sentences print(f"🖼️ Generating image for: {sentence}") image = pipeline(prompt=sentence, num_inference_steps=5).images[0] # Convert Image to Base64 img_byte_arr = io.BytesIO() image.save(img_byte_arr, format="PNG") img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") images.append({"sentence": sentence, "image": img_base64}) # ✅ Return the full response return { "theme": request.theme, "reading_level": request.reading_level, "story": story_text, "questions": questions, "images": images, } except Exception as e: print(f"❌ Error generating story, questions, or images: {e}") raise HTTPException(status_code=500, detail=str(e)) # ✅ Welcome message at root @app.get("/") def home(): return {"message": "🎉 Welcome to the Story, Question & Image Generation API! Use /generate_story_questions_images"}