Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
import story_generator | |
from diffusers import DiffusionPipeline | |
from PIL import Image | |
import io | |
import base64 | |
app = FastAPI() | |
# Set Hugging Face cache directories | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" # Deprecated but still used for now | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface" | |
# Enable GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load image generation model | |
IMAGE_MODEL = "lykon/dreamshaper-8" | |
pipeline = DiffusionPipeline.from_pretrained( | |
IMAGE_MODEL, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
# Define request schema | |
class StoryRequest(BaseModel): | |
theme: str | |
reading_level: str | |
def generate_images(prompts, max_retries=3, delay=2): | |
""" | |
Attempts to generate images in batch. If an error related to | |
"index 16 is out of bounds" occurs, it retries for up to max_retries. | |
If all attempts fail, it falls back to generating images sequentially. | |
""" | |
for attempt in range(max_retries): | |
try: | |
print(f"Batched image generation attempt {attempt+1}...") | |
results = pipeline( | |
prompt=prompts, | |
num_inference_steps=15, | |
height=768, | |
width=768 | |
).images | |
return results | |
except Exception as e: | |
if "index 16 is out of bounds" in str(e): | |
print(f"Attempt {attempt+1} failed with error: {e}") | |
time.sleep(delay) | |
else: | |
raise e | |
# Fallback to sequential generation | |
print("Falling back to sequential image generation...") | |
images = [] | |
for i, prompt in enumerate(prompts): | |
try: | |
print(f"Sequential generation for prompt {i+1}...") | |
image = pipeline( | |
prompt=prompt, | |
num_inference_steps=15, | |
height=768, | |
width=768 | |
).images[0] | |
images.append(image) | |
except Exception as e: | |
print(f"Error in sequential generation for prompt {i+1}: {e}") | |
raise e | |
return images | |
def generate_story_questions_images(request: StoryRequest): | |
""" | |
Generates a story, dynamic questions, and cartoonish storybook images. | |
""" | |
try: | |
print(f"π Generating story for theme: {request.theme} and level: {request.reading_level}") | |
# Generate story and questions using the story_generator module | |
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level) | |
story_text = story_result.get("story", "").strip() | |
questions = story_result.get("questions", "").strip() | |
if not story_text: | |
raise HTTPException(status_code=500, detail="Story generation failed.") | |
# Split the story into up to 6 paragraphs | |
paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6] | |
# Build a list of prompts for batched image generation | |
prompts = [ | |
( | |
f"Children's storybook illustration of: {p}. " | |
"Soft pastel colors, hand-drawn style, friendly characters, warm lighting, " | |
"fantasy setting, watercolor texture, storybook illustration, beautiful composition." | |
) | |
for p in paragraphs | |
] | |
print(f"Generating images for {len(prompts)} paragraphs concurrently...") | |
# Use the retry mechanism for image generation | |
results = generate_images(prompts, max_retries=3, delay=2) | |
# Convert each generated image to Base64 | |
images = [] | |
for image in results: | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format="PNG") | |
img_byte_arr.seek(0) | |
base64_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") | |
images.append(base64_image) | |
return JSONResponse(content={ | |
"theme": request.theme, | |
"reading_level": request.reading_level, | |
"story": story_text, | |
"questions": questions, | |
"images": images | |
}) | |
except Exception as e: | |
print(f"β Error generating story/questions/images: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
def home(): | |
return {"message": "π Welcome to the Story, Question & Image API!"} | |