abdalraheemdmd commited on
Commit
93355f1
Β·
verified Β·
1 Parent(s): 1f79b90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -36
app.py CHANGED
@@ -1,53 +1,70 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  import torch
4
  from diffusers import StableDiffusionPipeline
5
- import story_generator # Import Story Generator
6
 
7
  app = FastAPI()
8
 
9
- # βœ… Use a single request format
10
  class StoryRequest(BaseModel):
11
  theme: str
12
  reading_level: str
13
 
14
- # βœ… Load Image Model (Cartoon-Style)
15
  CACHE_DIR = "/tmp/huggingface"
16
- pipe = StableDiffusionPipeline.from_pretrained(
17
- "nitrosocke/Arcane-Diffusion",
18
- cache_dir=CACHE_DIR,
19
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
20
- )
21
- pipe.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
22
 
23
  @app.post("/generate_story_and_image")
24
  def generate_story_and_image(request: StoryRequest):
25
  """Generates a story and an image based on the given theme and reading level."""
26
- print(f"Generating story for theme: {request.theme} and level: {request.reading_level}")
27
-
28
- # βœ… Generate story
29
- story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
30
- story_text = story_result["story"]
31
-
32
- # βœ… Generate image
33
- cartoon_prompt = f"A colorful, cartoon-style illustration of: {story_text}, vibrant colors, highly detailed, storybook fantasy."
34
- print("Generating Image for:", cartoon_prompt[:100])
35
- image = pipe(cartoon_prompt, width=1024, height=1024).images[0]
36
-
37
- # βœ… Save image
38
- image_path = "generated_image.png"
39
- image.save(image_path)
40
- print("Image Saved!")
41
-
42
- return {
43
- "theme": request.theme,
44
- "reading_level": request.reading_level,
45
- "story": story_text,
46
- "questions": story_result["questions"],
47
- "image_path": image_path
48
- }
49
-
50
- # βœ… Fix the "Not Found" issue for the root URL
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.get("/")
52
  def home():
53
- return {"message": "Welcome to the Story & Image Generation API! Use /generate_story_and_image"}
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
4
  from diffusers import StableDiffusionPipeline
5
+ import story_generator # βœ… Correctly importing the Story Generator
6
 
7
  app = FastAPI()
8
 
9
+ # βœ… Define the input request format
10
  class StoryRequest(BaseModel):
11
  theme: str
12
  reading_level: str
13
 
14
+ # βœ… Load AI Image Model (Cartoon-Style)
15
  CACHE_DIR = "/tmp/huggingface"
16
+ try:
17
+ pipe = StableDiffusionPipeline.from_pretrained(
18
+ "nitrosocke/Arcane-Diffusion",
19
+ cache_dir=CACHE_DIR,
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
21
+ )
22
+ pipe.to("cuda" if torch.cuda.is_available() else "cpu")
23
+ print("βœ… Image generation model loaded successfully.")
24
+ except Exception as e:
25
+ print(f"❌ Failed to load image model: {e}")
26
+ pipe = None # Fallback if model fails
27
 
28
  @app.post("/generate_story_and_image")
29
  def generate_story_and_image(request: StoryRequest):
30
  """Generates a story and an image based on the given theme and reading level."""
31
+ try:
32
+ print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
33
+
34
+ # βœ… Generate the story
35
+ if not hasattr(story_generator, "generate_story_and_questions"):
36
+ raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py")
37
+
38
+ story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
39
+ story_text = story_result["story"]
40
+
41
+ # βœ… Generate the image
42
+ if not pipe:
43
+ raise HTTPException(status_code=500, detail="Image generation model failed to load.")
44
+
45
+ cartoon_prompt = f"A colorful, cartoon-style illustration of: {story_text}, vibrant colors, highly detailed, storybook fantasy."
46
+ print("πŸ–ŒοΈ Generating Image for:", cartoon_prompt[:100])
47
+
48
+ image = pipe(cartoon_prompt, width=768, height=768).images[0] # βœ… Slightly smaller for faster generation
49
+
50
+ # βœ… Save image
51
+ image_path = "/tmp/generated_image.png"
52
+ image.save(image_path)
53
+ print("βœ… Image Saved!")
54
+
55
+ return {
56
+ "theme": request.theme,
57
+ "reading_level": request.reading_level,
58
+ "story": story_text,
59
+ "questions": story_result["questions"],
60
+ "image_url": f"https://your-api-url.com/generated_image.png" # Replace with actual host
61
+ }
62
+
63
+ except Exception as e:
64
+ print(f"❌ Error generating story and image: {e}")
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+ # βœ… Welcome message at root
68
  @app.get("/")
69
  def home():
70
+ return {"message": "πŸŽ‰ Welcome to the Story & Image Generation API! Use /generate_story_and_image"}