abdalraheemdmd commited on
Commit
4607384
Β·
verified Β·
1 Parent(s): 70a8c2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
6
- import story_generator # Make sure story_generator.py is correctly uploaded
7
  from diffusers import DiffusionPipeline
8
  from PIL import Image
9
  import io
@@ -11,7 +11,7 @@ import base64
11
 
12
  app = FastAPI()
13
 
14
- # βœ… Hugging Face cache directory
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
16
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
17
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
@@ -21,47 +21,47 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  # βœ… Load image generation model
23
  IMAGE_MODEL = "lykon/dreamshaper-8"
24
- pipeline = DiffusionPipeline.from_pretrained(IMAGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
 
 
 
25
 
26
  # βœ… Define request format
27
  class StoryRequest(BaseModel):
28
  theme: str
29
  reading_level: str
30
 
 
31
  @app.post("/generate_story_questions_images")
32
  def generate_story_questions_images(request: StoryRequest):
33
- """Generates a story, questions, and cartoonish images."""
34
  try:
35
  print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
36
 
37
- # βœ… Generate the story and questions
38
  if not hasattr(story_generator, "generate_story_and_questions"):
39
- raise HTTPException(status_code=500, detail="Story generator function not found.")
40
 
41
  story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
42
- story_text = story_result["story"]
43
- questions = story_result["questions"]
44
-
45
- # βœ… Process Story (Ensure it’s correctly structured)
46
- if not isinstance(story_text, str) or len(story_text) < 50:
47
- raise HTTPException(status_code=500, detail="Generated story is too short or invalid.")
48
 
49
- # βœ… Process Questions (Ensure they are valid)
50
- if not isinstance(questions, list) or len(questions) < 3:
 
 
51
  raise HTTPException(status_code=500, detail="Question generation failed.")
52
 
53
- # βœ… Generate cartoon-style images (for each paragraph)
54
- paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()]
 
 
55
  images = []
56
-
57
  for i, paragraph in enumerate(paragraphs):
58
- if i >= 6: # βœ… Limit image generation to first 6 paragraphs (to save time)
59
- break
60
-
61
- prompt = f"Children's storybook illustration of: {paragraph}. Soft pastel colors, hand-drawn style, warm lighting, fantasy setting, watercolor texture."
62
- print(f"πŸ–ΌοΈ Generating better cartoon-style image for paragraph {i+1}...")
63
 
64
- # βœ… Reduce inference steps (from 40 β†’ 15 for speed)
65
  image = pipeline(prompt=prompt, num_inference_steps=15, height=768, width=768).images[0]
66
 
67
  # βœ… Convert image to Base64
@@ -81,10 +81,11 @@ def generate_story_questions_images(request: StoryRequest):
81
  })
82
 
83
  except Exception as e:
84
- print(f"❌ Error generating content: {e}")
85
  raise HTTPException(status_code=500, detail=str(e))
86
 
 
87
  # βœ… Welcome message
88
  @app.get("/")
89
  def home():
90
- return {"message": "πŸŽ‰ Welcome to the Story, Question & Image API!"}
 
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
6
+ import story_generator
7
  from diffusers import DiffusionPipeline
8
  from PIL import Image
9
  import io
 
11
 
12
  app = FastAPI()
13
 
14
+ # βœ… Set Hugging Face cache directory
15
  os.environ["HF_HOME"] = "/tmp/huggingface"
16
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
17
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
 
21
 
22
  # βœ… Load image generation model
23
  IMAGE_MODEL = "lykon/dreamshaper-8"
24
+ pipeline = DiffusionPipeline.from_pretrained(
25
+ IMAGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
26
+ ).to(device)
27
+
28
 
29
  # βœ… Define request format
30
  class StoryRequest(BaseModel):
31
  theme: str
32
  reading_level: str
33
 
34
+
35
  @app.post("/generate_story_questions_images")
36
  def generate_story_questions_images(request: StoryRequest):
37
+ """Generates a story, questions, and cartoonish storybook images."""
38
  try:
39
  print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
40
 
41
+ # βœ… Generate the story
42
  if not hasattr(story_generator, "generate_story_and_questions"):
43
+ raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py")
44
 
45
  story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
46
+ story_text = story_result.get("story", "").strip()
47
+ questions = story_result.get("questions", "").strip()
 
 
 
 
48
 
49
+ # 🚨 Error Handling: Ensure valid story & questions are generated
50
+ if not story_text:
51
+ raise HTTPException(status_code=500, detail="Story generation failed.")
52
+ if not questions or "?" not in questions:
53
  raise HTTPException(status_code=500, detail="Question generation failed.")
54
 
55
+ # βœ… Split the story into paragraphs (max 6)
56
+ paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6] # Only take up to 6 paragraphs
57
+
58
+ # βœ… Generate high-quality cartoonish images
59
  images = []
 
60
  for i, paragraph in enumerate(paragraphs):
61
+ prompt = f"Children's storybook illustration of: {paragraph}. Soft pastel colors, hand-drawn style, friendly characters, warm lighting, fantasy setting, watercolor texture, storybook illustration, beautiful composition."
62
+ print(f"πŸ–Ό Generating cartoon-style image for paragraph {i+1}...")
 
 
 
63
 
64
+ # βœ… Reduce inference steps to speed it up (from 40 β†’ 15)
65
  image = pipeline(prompt=prompt, num_inference_steps=15, height=768, width=768).images[0]
66
 
67
  # βœ… Convert image to Base64
 
81
  })
82
 
83
  except Exception as e:
84
+ print(f"❌ Error generating story/questions/images: {e}")
85
  raise HTTPException(status_code=500, detail=str(e))
86
 
87
+
88
  # βœ… Welcome message
89
  @app.get("/")
90
  def home():
91
+ return {"message": "πŸŽ‰ Welcome to the Story, Question & Image API!"}