Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from pydantic import BaseModel
|
6 |
-
import story_generator
|
7 |
from diffusers import DiffusionPipeline
|
8 |
from PIL import Image
|
9 |
import io
|
@@ -11,7 +11,7 @@ import base64
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
-
# β
Hugging Face cache directory
|
15 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
16 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
|
17 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
|
@@ -21,47 +21,47 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
21 |
|
22 |
# β
Load image generation model
|
23 |
IMAGE_MODEL = "lykon/dreamshaper-8"
|
24 |
-
pipeline = DiffusionPipeline.from_pretrained(
|
|
|
|
|
|
|
25 |
|
26 |
# β
Define request format
|
27 |
class StoryRequest(BaseModel):
|
28 |
theme: str
|
29 |
reading_level: str
|
30 |
|
|
|
31 |
@app.post("/generate_story_questions_images")
|
32 |
def generate_story_questions_images(request: StoryRequest):
|
33 |
-
"""Generates a story, questions, and cartoonish images."""
|
34 |
try:
|
35 |
print(f"π Generating story for theme: {request.theme} and level: {request.reading_level}")
|
36 |
|
37 |
-
# β
Generate the story
|
38 |
if not hasattr(story_generator, "generate_story_and_questions"):
|
39 |
-
raise HTTPException(status_code=500, detail="Story
|
40 |
|
41 |
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
|
42 |
-
story_text = story_result
|
43 |
-
questions = story_result
|
44 |
-
|
45 |
-
# β
Process Story (Ensure itβs correctly structured)
|
46 |
-
if not isinstance(story_text, str) or len(story_text) < 50:
|
47 |
-
raise HTTPException(status_code=500, detail="Generated story is too short or invalid.")
|
48 |
|
49 |
-
#
|
50 |
-
if not
|
|
|
|
|
51 |
raise HTTPException(status_code=500, detail="Question generation failed.")
|
52 |
|
53 |
-
# β
|
54 |
-
paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()]
|
|
|
|
|
55 |
images = []
|
56 |
-
|
57 |
for i, paragraph in enumerate(paragraphs):
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
prompt = f"Children's storybook illustration of: {paragraph}. Soft pastel colors, hand-drawn style, warm lighting, fantasy setting, watercolor texture."
|
62 |
-
print(f"πΌοΈ Generating better cartoon-style image for paragraph {i+1}...")
|
63 |
|
64 |
-
# β
Reduce inference steps (from 40 β 15
|
65 |
image = pipeline(prompt=prompt, num_inference_steps=15, height=768, width=768).images[0]
|
66 |
|
67 |
# β
Convert image to Base64
|
@@ -81,10 +81,11 @@ def generate_story_questions_images(request: StoryRequest):
|
|
81 |
})
|
82 |
|
83 |
except Exception as e:
|
84 |
-
print(f"β Error generating
|
85 |
raise HTTPException(status_code=500, detail=str(e))
|
86 |
|
|
|
87 |
# β
Welcome message
|
88 |
@app.get("/")
|
89 |
def home():
|
90 |
-
return {"message": "π Welcome to the Story, Question & Image API!"}
|
|
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from pydantic import BaseModel
|
6 |
+
import story_generator
|
7 |
from diffusers import DiffusionPipeline
|
8 |
from PIL import Image
|
9 |
import io
|
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
+
# β
Set Hugging Face cache directory
|
15 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
16 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
|
17 |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
|
|
|
21 |
|
22 |
# β
Load image generation model
|
23 |
IMAGE_MODEL = "lykon/dreamshaper-8"
|
24 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
25 |
+
IMAGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
26 |
+
).to(device)
|
27 |
+
|
28 |
|
29 |
# β
Define request format
|
30 |
class StoryRequest(BaseModel):
|
31 |
theme: str
|
32 |
reading_level: str
|
33 |
|
34 |
+
|
35 |
@app.post("/generate_story_questions_images")
|
36 |
def generate_story_questions_images(request: StoryRequest):
|
37 |
+
"""Generates a story, questions, and cartoonish storybook images."""
|
38 |
try:
|
39 |
print(f"π Generating story for theme: {request.theme} and level: {request.reading_level}")
|
40 |
|
41 |
+
# β
Generate the story
|
42 |
if not hasattr(story_generator, "generate_story_and_questions"):
|
43 |
+
raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py")
|
44 |
|
45 |
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
|
46 |
+
story_text = story_result.get("story", "").strip()
|
47 |
+
questions = story_result.get("questions", "").strip()
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
# π¨ Error Handling: Ensure valid story & questions are generated
|
50 |
+
if not story_text:
|
51 |
+
raise HTTPException(status_code=500, detail="Story generation failed.")
|
52 |
+
if not questions or "?" not in questions:
|
53 |
raise HTTPException(status_code=500, detail="Question generation failed.")
|
54 |
|
55 |
+
# β
Split the story into paragraphs (max 6)
|
56 |
+
paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6] # Only take up to 6 paragraphs
|
57 |
+
|
58 |
+
# β
Generate high-quality cartoonish images
|
59 |
images = []
|
|
|
60 |
for i, paragraph in enumerate(paragraphs):
|
61 |
+
prompt = f"Children's storybook illustration of: {paragraph}. Soft pastel colors, hand-drawn style, friendly characters, warm lighting, fantasy setting, watercolor texture, storybook illustration, beautiful composition."
|
62 |
+
print(f"πΌ Generating cartoon-style image for paragraph {i+1}...")
|
|
|
|
|
|
|
63 |
|
64 |
+
# β
Reduce inference steps to speed it up (from 40 β 15)
|
65 |
image = pipeline(prompt=prompt, num_inference_steps=15, height=768, width=768).images[0]
|
66 |
|
67 |
# β
Convert image to Base64
|
|
|
81 |
})
|
82 |
|
83 |
except Exception as e:
|
84 |
+
print(f"β Error generating story/questions/images: {e}")
|
85 |
raise HTTPException(status_code=500, detail=str(e))
|
86 |
|
87 |
+
|
88 |
# β
Welcome message
|
89 |
@app.get("/")
|
90 |
def home():
|
91 |
+
return {"message": "π Welcome to the Story, Question & Image API!"}
|