abdalraheemdmd commited on
Commit
3dfe47c
Β·
verified Β·
1 Parent(s): 43cfc3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -15,12 +15,14 @@ os.environ["HF_HOME"] = "/tmp/huggingface"
15
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
16
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
17
 
18
- # βœ… Load Public Image Generation Model (No Token Needed)
19
- IMAGE_MODEL = "stabilityai/sdxl-turbo" # Fastest model for public access
 
 
 
20
  pipeline = DiffusionPipeline.from_pretrained(
21
- IMAGE_MODEL,
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
23
- ).to("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
  # βœ… Define the input request format
26
  class StoryRequest(BaseModel):
@@ -41,30 +43,21 @@ def generate_story_questions_images(request: StoryRequest):
41
  story_text = story_result["story"]
42
  questions = story_result["questions"]
43
 
44
- # βœ… Split the story into sentences for image generation
45
- story_sentences = story_text.strip().split(". ")
46
-
47
- # βœ… Generate an image for each sentence
48
- images = []
49
- for sentence in story_sentences:
50
- if len(sentence) > 5: # Avoid empty sentences
51
- print(f"πŸ–ΌοΈ Generating image for: {sentence}")
52
- image = pipeline(prompt=sentence, num_inference_steps=5).images[0]
53
-
54
- # Convert Image to Base64
55
- img_byte_arr = io.BytesIO()
56
- image.save(img_byte_arr, format="PNG")
57
- img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
58
 
59
- images.append({"sentence": sentence, "image": img_base64})
 
 
 
60
 
61
- # βœ… Return the full response
62
  return {
63
  "theme": request.theme,
64
  "reading_level": request.reading_level,
65
  "story": story_text,
66
  "questions": questions,
67
- "images": images,
68
  }
69
 
70
  except Exception as e:
 
15
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
16
  os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
17
 
18
+ # βœ… Enable GPU if available
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # βœ… Load Public Image Generation Model
22
+ IMAGE_MODEL = "runwayml/stable-diffusion-v1-5" # βœ… Optimized for GPU
23
  pipeline = DiffusionPipeline.from_pretrained(
24
+ IMAGE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
25
+ ).to(device)
 
26
 
27
  # βœ… Define the input request format
28
  class StoryRequest(BaseModel):
 
43
  story_text = story_result["story"]
44
  questions = story_result["questions"]
45
 
46
+ # βœ… Generate an image for the story theme
47
+ print(f"πŸ–ΌοΈ Generating image for: {request.theme}")
48
+ image = pipeline(prompt=request.theme, num_inference_steps=5).images[0]
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Convert Image to Base64
51
+ img_byte_arr = io.BytesIO()
52
+ image.save(img_byte_arr, format="PNG")
53
+ img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
54
 
 
55
  return {
56
  "theme": request.theme,
57
  "reading_level": request.reading_level,
58
  "story": story_text,
59
  "questions": questions,
60
+ "image": img_base64
61
  }
62
 
63
  except Exception as e: