import os import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel import story_generator from diffusers import DiffusionPipeline import io import base64 from PIL import Image app = FastAPI() # ✅ Set Hugging Face cache directory to /tmp (Fixes cache write errors) os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface" # ✅ Enable GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" # ✅ Load Public Image Generation Model IMAGE_MODEL = "runwayml/stable-diffusion-v1-5" # ✅ Optimized for GPU pipeline = DiffusionPipeline.from_pretrained( IMAGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) # ✅ Define the input request format class StoryRequest(BaseModel): theme: str reading_level: str @app.post("/generate_story_questions_images") def generate_story_questions_images(request: StoryRequest): """Generates a story, questions, and corresponding images.""" try: print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}") # ✅ Generate the story and questions if not hasattr(story_generator, "generate_story_and_questions"): raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py") story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level) story_text = story_result["story"] questions = story_result["questions"] # ✅ Generate an image for the story theme print(f"🖼️ Generating image for: {request.theme}") image = pipeline(prompt=request.theme, num_inference_steps=5).images[0] # Convert Image to Base64 img_byte_arr = io.BytesIO() image.save(img_byte_arr, format="PNG") img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") return { "theme": request.theme, "reading_level": request.reading_level, "story": story_text, "questions": questions, "image": img_base64 } except Exception as e: print(f"❌ Error generating story, questions, or images: {e}") raise HTTPException(status_code=500, detail=str(e)) # ✅ Welcome message at root @app.get("/") def home(): return {"message": "🎉 Welcome to the Story, Question & Image Generation API! Use /generate_story_questions_images"}