Spaces:
Sleeping
Sleeping
File size: 4,680 Bytes
7f33308 1c9183b a950115 7f33308 8b89b10 7f33308 5ec57ad 7f33308 5ec57ad 7f33308 1c9183b 7f33308 a950115 7f33308 3dfe47c 7f33308 4607384 7f33308 5ec57ad 1c9183b 7f33308 1c9183b 7f33308 1c9183b 7f33308 1c9183b 7f33308 1c9183b 7f33308 1c9183b 7f33308 93355f1 7f33308 1c9183b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import time
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import story_generator
from diffusers import DiffusionPipeline
from PIL import Image
import io
import base64
app = FastAPI()
# Set Hugging Face cache directories
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" # Deprecated but still used for now
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
# Enable GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load image generation model
IMAGE_MODEL = "lykon/dreamshaper-8"
pipeline = DiffusionPipeline.from_pretrained(
IMAGE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
# Define request schema
class StoryRequest(BaseModel):
theme: str
reading_level: str
def generate_images(prompts, max_retries=3, delay=2):
"""
Attempts to generate images in batch. If an error related to
"index 16 is out of bounds" occurs, it retries for up to max_retries.
If all attempts fail, it falls back to generating images sequentially.
"""
for attempt in range(max_retries):
try:
print(f"Batched image generation attempt {attempt+1}...")
results = pipeline(
prompt=prompts,
num_inference_steps=15,
height=768,
width=768
).images
return results
except Exception as e:
if "index 16 is out of bounds" in str(e):
print(f"Attempt {attempt+1} failed with error: {e}")
time.sleep(delay)
else:
raise e
# Fallback to sequential generation
print("Falling back to sequential image generation...")
images = []
for i, prompt in enumerate(prompts):
try:
print(f"Sequential generation for prompt {i+1}...")
image = pipeline(
prompt=prompt,
num_inference_steps=15,
height=768,
width=768
).images[0]
images.append(image)
except Exception as e:
print(f"Error in sequential generation for prompt {i+1}: {e}")
raise e
return images
@app.post("/generate_story_questions_images")
def generate_story_questions_images(request: StoryRequest):
"""
Generates a story, dynamic questions, and cartoonish storybook images.
"""
try:
print(f"π Generating story for theme: {request.theme} and level: {request.reading_level}")
# Generate story and questions using the story_generator module
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
story_text = story_result.get("story", "").strip()
questions = story_result.get("questions", "").strip()
if not story_text:
raise HTTPException(status_code=500, detail="Story generation failed.")
# Split the story into up to 6 paragraphs
paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6]
# Build a list of prompts for batched image generation
prompts = [
(
f"Children's storybook illustration of: {p}. "
"Soft pastel colors, hand-drawn style, friendly characters, warm lighting, "
"fantasy setting, watercolor texture, storybook illustration, beautiful composition."
)
for p in paragraphs
]
print(f"Generating images for {len(prompts)} paragraphs concurrently...")
# Use the retry mechanism for image generation
results = generate_images(prompts, max_retries=3, delay=2)
# Convert each generated image to Base64
images = []
for image in results:
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
base64_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
images.append(base64_image)
return JSONResponse(content={
"theme": request.theme,
"reading_level": request.reading_level,
"story": story_text,
"questions": questions,
"images": images
})
except Exception as e:
print(f"β Error generating story/questions/images: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def home():
return {"message": "π Welcome to the Story, Question & Image API!"}
|