File size: 4,680 Bytes
7f33308
1c9183b
a950115
7f33308
 
 
8b89b10
7f33308
 
 
 
5ec57ad
7f33308
5ec57ad
7f33308
 
1c9183b
7f33308
a950115
7f33308
 
3dfe47c
7f33308
 
 
 
 
 
4607384
7f33308
 
 
 
5ec57ad
1c9183b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f33308
 
1c9183b
 
 
7f33308
 
1c9183b
7f33308
 
 
 
 
 
 
 
 
1c9183b
7f33308
 
 
 
 
 
 
 
 
 
1c9183b
 
7f33308
1c9183b
7f33308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93355f1
7f33308
 
1c9183b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import time
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import story_generator
from diffusers import DiffusionPipeline
from PIL import Image
import io
import base64

app = FastAPI()

# Set Hugging Face cache directories
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"  # Deprecated but still used for now
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"

# Enable GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load image generation model
IMAGE_MODEL = "lykon/dreamshaper-8"
pipeline = DiffusionPipeline.from_pretrained(
    IMAGE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)

# Define request schema
class StoryRequest(BaseModel):
    theme: str
    reading_level: str

def generate_images(prompts, max_retries=3, delay=2):
    """
    Attempts to generate images in batch. If an error related to 
    "index 16 is out of bounds" occurs, it retries for up to max_retries.
    If all attempts fail, it falls back to generating images sequentially.
    """
    for attempt in range(max_retries):
        try:
            print(f"Batched image generation attempt {attempt+1}...")
            results = pipeline(
                prompt=prompts,
                num_inference_steps=15,
                height=768,
                width=768
            ).images
            return results
        except Exception as e:
            if "index 16 is out of bounds" in str(e):
                print(f"Attempt {attempt+1} failed with error: {e}")
                time.sleep(delay)
            else:
                raise e
    # Fallback to sequential generation
    print("Falling back to sequential image generation...")
    images = []
    for i, prompt in enumerate(prompts):
        try:
            print(f"Sequential generation for prompt {i+1}...")
            image = pipeline(
                prompt=prompt,
                num_inference_steps=15,
                height=768,
                width=768
            ).images[0]
            images.append(image)
        except Exception as e:
            print(f"Error in sequential generation for prompt {i+1}: {e}")
            raise e
    return images

@app.post("/generate_story_questions_images")
def generate_story_questions_images(request: StoryRequest):
    """
    Generates a story, dynamic questions, and cartoonish storybook images.
    """
    try:
        print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
        # Generate story and questions using the story_generator module
        story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
        story_text = story_result.get("story", "").strip()
        questions = story_result.get("questions", "").strip()
        if not story_text:
            raise HTTPException(status_code=500, detail="Story generation failed.")
        
        # Split the story into up to 6 paragraphs
        paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6]
        
        # Build a list of prompts for batched image generation
        prompts = [
            (
                f"Children's storybook illustration of: {p}. "
                "Soft pastel colors, hand-drawn style, friendly characters, warm lighting, "
                "fantasy setting, watercolor texture, storybook illustration, beautiful composition."
            )
            for p in paragraphs
        ]
        print(f"Generating images for {len(prompts)} paragraphs concurrently...")
        
        # Use the retry mechanism for image generation
        results = generate_images(prompts, max_retries=3, delay=2)
        
        # Convert each generated image to Base64
        images = []
        for image in results:
            img_byte_arr = io.BytesIO()
            image.save(img_byte_arr, format="PNG")
            img_byte_arr.seek(0)
            base64_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
            images.append(base64_image)
        
        return JSONResponse(content={
            "theme": request.theme,
            "reading_level": request.reading_level,
            "story": story_text,
            "questions": questions,
            "images": images
        })
    except Exception as e:
        print(f"❌ Error generating story/questions/images: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def home():
    return {"message": "πŸŽ‰ Welcome to the Story, Question & Image API!"}