abdalraheemdmd commited on
Commit
1c9183b
Β·
verified Β·
1 Parent(s): 06e92c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -10
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import JSONResponse
@@ -13,7 +14,7 @@ app = FastAPI()
13
 
14
  # Set Hugging Face cache directories
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
16
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
17
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
18
 
19
  # Enable GPU if available
@@ -31,24 +32,64 @@ class StoryRequest(BaseModel):
31
  theme: str
32
  reading_level: str
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @app.post("/generate_story_questions_images")
35
  def generate_story_questions_images(request: StoryRequest):
36
- """Generates a story, dynamic questions, and cartoonish storybook images."""
 
 
37
  try:
38
  print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
39
-
40
- # Generate story and questions
41
  story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
42
  story_text = story_result.get("story", "").strip()
43
  questions = story_result.get("questions", "").strip()
44
-
45
  if not story_text:
46
  raise HTTPException(status_code=500, detail="Story generation failed.")
47
 
48
  # Split the story into up to 6 paragraphs
49
  paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6]
50
 
51
- # Batch image generation:
52
  prompts = [
53
  (
54
  f"Children's storybook illustration of: {p}. "
@@ -59,9 +100,10 @@ def generate_story_questions_images(request: StoryRequest):
59
  ]
60
  print(f"Generating images for {len(prompts)} paragraphs concurrently...")
61
 
62
- # Single batched call to generate images concurrently
63
- results = pipeline(prompt=prompts, num_inference_steps=15, height=768, width=768).images
64
 
 
65
  images = []
66
  for image in results:
67
  img_byte_arr = io.BytesIO()
@@ -77,11 +119,10 @@ def generate_story_questions_images(request: StoryRequest):
77
  "questions": questions,
78
  "images": images
79
  })
80
-
81
  except Exception as e:
82
  print(f"❌ Error generating story/questions/images: {e}")
83
  raise HTTPException(status_code=500, detail=str(e))
84
 
85
  @app.get("/")
86
  def home():
87
- return {"message": "πŸŽ‰ Welcome to the Story, Question & Image API!"}
 
1
  import os
2
+ import time
3
  import torch
4
  from fastapi import FastAPI, HTTPException
5
  from fastapi.responses import JSONResponse
 
14
 
15
  # Set Hugging Face cache directories
16
  os.environ["HF_HOME"] = "/tmp/huggingface"
17
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" # Deprecated but still used for now
18
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
19
 
20
  # Enable GPU if available
 
32
  theme: str
33
  reading_level: str
34
 
35
+ def generate_images(prompts, max_retries=3, delay=2):
36
+ """
37
+ Attempts to generate images in batch. If an error related to
38
+ "index 16 is out of bounds" occurs, it retries for up to max_retries.
39
+ If all attempts fail, it falls back to generating images sequentially.
40
+ """
41
+ for attempt in range(max_retries):
42
+ try:
43
+ print(f"Batched image generation attempt {attempt+1}...")
44
+ results = pipeline(
45
+ prompt=prompts,
46
+ num_inference_steps=15,
47
+ height=768,
48
+ width=768
49
+ ).images
50
+ return results
51
+ except Exception as e:
52
+ if "index 16 is out of bounds" in str(e):
53
+ print(f"Attempt {attempt+1} failed with error: {e}")
54
+ time.sleep(delay)
55
+ else:
56
+ raise e
57
+ # Fallback to sequential generation
58
+ print("Falling back to sequential image generation...")
59
+ images = []
60
+ for i, prompt in enumerate(prompts):
61
+ try:
62
+ print(f"Sequential generation for prompt {i+1}...")
63
+ image = pipeline(
64
+ prompt=prompt,
65
+ num_inference_steps=15,
66
+ height=768,
67
+ width=768
68
+ ).images[0]
69
+ images.append(image)
70
+ except Exception as e:
71
+ print(f"Error in sequential generation for prompt {i+1}: {e}")
72
+ raise e
73
+ return images
74
+
75
  @app.post("/generate_story_questions_images")
76
  def generate_story_questions_images(request: StoryRequest):
77
+ """
78
+ Generates a story, dynamic questions, and cartoonish storybook images.
79
+ """
80
  try:
81
  print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
82
+ # Generate story and questions using the story_generator module
 
83
  story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
84
  story_text = story_result.get("story", "").strip()
85
  questions = story_result.get("questions", "").strip()
 
86
  if not story_text:
87
  raise HTTPException(status_code=500, detail="Story generation failed.")
88
 
89
  # Split the story into up to 6 paragraphs
90
  paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6]
91
 
92
+ # Build a list of prompts for batched image generation
93
  prompts = [
94
  (
95
  f"Children's storybook illustration of: {p}. "
 
100
  ]
101
  print(f"Generating images for {len(prompts)} paragraphs concurrently...")
102
 
103
+ # Use the retry mechanism for image generation
104
+ results = generate_images(prompts, max_retries=3, delay=2)
105
 
106
+ # Convert each generated image to Base64
107
  images = []
108
  for image in results:
109
  img_byte_arr = io.BytesIO()
 
119
  "questions": questions,
120
  "images": images
121
  })
 
122
  except Exception as e:
123
  print(f"❌ Error generating story/questions/images: {e}")
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
126
  @app.get("/")
127
  def home():
128
+ return {"message": "πŸŽ‰ Welcome to the Story, Question & Image API!"}