abdalraheemdmd commited on
Commit
7f33308
Β·
verified Β·
1 Parent(s): 8b89b10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -180
app.py CHANGED
@@ -1,190 +1,87 @@
 
1
  import torch
2
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
- import random
4
- import re
5
  import story_generator
 
 
 
 
6
 
7
- # Set Hugging Face cache directory
8
- CACHE_DIR = "/tmp/huggingface"
9
 
10
- # ------------------------
11
- # Load Story Generation Model
12
- # ------------------------
13
- STORY_MODEL_NAME = "abdalraheemdmd/story-api"
14
- story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
15
- story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
16
 
17
- # ------------------------
18
- # Load Question Generation Model
19
- # ------------------------
20
- QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene"
21
- question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
22
- question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
23
 
24
- # Ensure tokenizers have a pad token
25
- if story_tokenizer.pad_token_id is None:
26
- story_tokenizer.pad_token_id = story_tokenizer.eos_token_id
27
- if question_tokenizer.pad_token_id is None:
28
- question_tokenizer.pad_token_id = question_tokenizer.eos_token_id
 
29
 
30
- def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7):
31
- """Generates a story based on the provided theme and reading level."""
32
- prompt = f"A {reading_level} story about {theme}:"
33
- input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids
34
- with torch.no_grad():
35
- output = story_model.generate(
36
- input_ids,
37
- max_new_tokens=max_new_tokens,
38
- temperature=temperature,
39
- top_k=20,
40
- top_p=0.7,
41
- do_sample=True,
42
- early_stopping=True,
43
- pad_token_id=story_tokenizer.pad_token_id,
44
- eos_token_id=story_tokenizer.eos_token_id,
45
- attention_mask=input_ids.ne(story_tokenizer.pad_token_id)
46
- )
47
- return story_tokenizer.decode(output[0], skip_special_tokens=True)
48
 
49
- def extract_protagonist(story):
50
- """
51
- Attempts to extract the protagonist from the first sentence by searching for the pattern "named <Name>".
52
- Returns the first matched name, if available.
53
- """
54
- sentences = re.split(r'\.|\n', story)
55
- if sentences:
56
- m = re.search(r"named\s+([A-Z][a-z]+)", sentences[0])
57
- if m:
58
- return m.group(1)
59
- return None
60
-
61
- def extract_characters(story):
62
- """
63
- Extracts potential character names from the story using a frequency count on capitalized words.
64
- Filters out common stopwords so that the most frequently mentioned name is likely the main character.
65
- """
66
- words = re.findall(r'\b[A-Z][a-zA-Z]+\b', story)
67
- stopwords = {"The", "A", "An", "And", "But", "Suddenly", "Quickly", "However", "Well",
68
- "They", "I", "He", "She", "It", "When", "Where", "Dr", "Mr"}
69
- filtered = [w for w in words if w not in stopwords and len(w) > 2]
70
- if not filtered:
71
- return []
72
- freq = {}
73
- for word in filtered:
74
- freq[word] = freq.get(word, 0) + 1
75
- sorted_chars = sorted(freq.items(), key=lambda x: x[1], reverse=True)
76
- return [item[0] for item in sorted_chars]
77
-
78
- def extract_themes(story):
79
- """Extracts themes from the story based on keyword matching."""
80
- themes = []
81
- story_lower = story.lower()
82
- if "space" in story_lower:
83
- themes.append("space")
84
- if "adventure" in story_lower:
85
- themes.append("adventure")
86
- if "friend" in story_lower:
87
- themes.append("friendship")
88
- if "learn" in story_lower or "lesson" in story_lower:
89
- themes.append("learning")
90
- return themes
91
-
92
- def extract_lesson(story):
93
- """
94
- Attempts to extract a lesson or moral from the story by finding sentences
95
- containing keywords like "learn" or "lesson". Returns the last matching sentence.
96
- """
97
- sentences = re.split(r'\.|\n', story)
98
- lesson_sentences = [
99
- s.strip() for s in sentences
100
- if ("learn" in s.lower() or "lesson" in s.lower()) and len(s.strip()) > 20
101
- ]
102
- if lesson_sentences:
103
- return lesson_sentences[-1]
104
- else:
105
- return "No explicit lesson found."
106
-
107
- def format_question(question_prompt, correct_answer, distractors):
108
- """
109
- Combines the correct answer with three distractors, shuffles the options,
110
- and formats the question as a multiple-choice question.
111
- """
112
- # Ensure exactly 3 distractors are available
113
- if len(distractors) < 3:
114
- default_distractors = ["Option X", "Option Y", "Option Z"]
115
- while len(distractors) < 3:
116
- distractors.append(default_distractors[len(distractors) % len(default_distractors)])
117
- else:
118
- distractors = random.sample(distractors, 3)
119
- options = distractors + [correct_answer]
120
- random.shuffle(options)
121
- letters = ["A", "B", "C", "D"]
122
- correct_letter = letters[options.index(correct_answer)]
123
- options_text = "\n".join(f"{letters[i]}) {option}" for i, option in enumerate(options))
124
- question_text = f"{question_prompt}\n{options_text}\nCorrect Answer: {correct_letter}"
125
- return question_text
126
-
127
- def dynamic_fallback_questions(story):
128
- """
129
- Generates three multiple-choice questions based on dynamic story content.
130
- Each question uses a randomly chosen template and shuffles its options.
131
- """
132
- protagonist = extract_protagonist(story)
133
- characters = extract_characters(story)
134
- themes = extract_themes(story)
135
- lesson = extract_lesson(story)
136
 
137
- # --- Question 1: Theme ---
138
- theme_templates = [
139
- "What is the main theme of the story?",
140
- "Which theme best represents the narrative?",
141
- "What subject is central to the story?"
142
- ]
143
- q1_prompt = random.choice(theme_templates)
144
- correct_theme = " and ".join(themes) if themes else "learning"
145
- q1_distractors = ["sports and competition", "cooking and baking", "weather and seasons", "technology and innovation"]
146
- q1 = format_question(q1_prompt, correct_theme, q1_distractors)
147
-
148
- # --- Question 2: Primary Character ---
149
- character_templates = [
150
- "Who is the primary character in the story?",
151
- "Which character drives the main action in the narrative?",
152
- "Who is the central figure in the story?"
153
- ]
154
- q2_prompt = random.choice(character_templates)
155
- if protagonist:
156
- correct_character = protagonist
157
- elif characters:
158
- correct_character = characters[0]
159
- else:
160
- correct_character = "The main character"
161
- q2_distractors = ["a mysterious stranger", "an unknown visitor", "a supporting character", "a sidekick"]
162
- q2 = format_question(q2_prompt, correct_character, q2_distractors)
163
-
164
- # --- Question 3: Lesson/Moral ---
165
- lesson_templates = [
166
- "What lesson did the characters learn by the end of the story?",
167
- "What moral can be inferred from the narrative?",
168
- "What is the key takeaway from the story?"
169
- ]
170
- q3_prompt = random.choice(lesson_templates)
171
- if lesson and lesson != "No explicit lesson found.":
172
- correct_lesson = lesson # full sentence without truncation
173
- else:
174
- correct_lesson = "understanding and growth"
175
- q3_distractors = ["always be silent", "never try new things", "do nothing", "ignore opportunities"]
176
- q3 = format_question(q3_prompt, correct_lesson, q3_distractors)
177
-
178
- return f"{q1}\n\n{q2}\n\n{q3}"
179
-
180
- def generate_story_and_questions(theme, reading_level):
181
- """
182
- Generates a story using the story generation model and then creates dynamic,
183
- multiple-choice questions based on that story.
184
- """
185
- story = generate_story(theme, reading_level)
186
- questions = dynamic_fallback_questions(story)
187
- return {"story": story, "questions": questions}
188
 
189
- # Alias for backward compatibility
190
- create_fallback_questions = dynamic_fallback_questions
 
 
1
+ import os
2
  import torch
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
  import story_generator
7
+ from diffusers import DiffusionPipeline
8
+ from PIL import Image
9
+ import io
10
+ import base64
11
 
12
+ app = FastAPI()
 
13
 
14
+ # Set Hugging Face cache directories
15
+ os.environ["HF_HOME"] = "/tmp/huggingface"
16
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
17
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
 
 
18
 
19
+ # Enable GPU if available
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
21
 
22
+ # Load image generation model
23
+ IMAGE_MODEL = "lykon/dreamshaper-8"
24
+ pipeline = DiffusionPipeline.from_pretrained(
25
+ IMAGE_MODEL,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
+ ).to(device)
28
 
29
+ # Define request schema
30
+ class StoryRequest(BaseModel):
31
+ theme: str
32
+ reading_level: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ @app.post("/generate_story_questions_images")
35
+ def generate_story_questions_images(request: StoryRequest):
36
+ """Generates a story, dynamic questions, and cartoonish storybook images."""
37
+ try:
38
+ print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
39
+
40
+ # Generate story and questions
41
+ story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
42
+ story_text = story_result.get("story", "").strip()
43
+ questions = story_result.get("questions", "").strip()
44
+
45
+ if not story_text:
46
+ raise HTTPException(status_code=500, detail="Story generation failed.")
47
+
48
+ # Split the story into up to 6 paragraphs
49
+ paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6]
50
+
51
+ # Batch image generation:
52
+ prompts = [
53
+ (
54
+ f"Children's storybook illustration of: {p}. "
55
+ "Soft pastel colors, hand-drawn style, friendly characters, warm lighting, "
56
+ "fantasy setting, watercolor texture, storybook illustration, beautiful composition."
57
+ )
58
+ for p in paragraphs
59
+ ]
60
+ print(f"Generating images for {len(prompts)} paragraphs concurrently...")
61
+
62
+ # Single batched call to generate images concurrently
63
+ results = pipeline(prompt=prompts, num_inference_steps=15, height=768, width=768).images
64
+
65
+ images = []
66
+ for image in results:
67
+ img_byte_arr = io.BytesIO()
68
+ image.save(img_byte_arr, format="PNG")
69
+ img_byte_arr.seek(0)
70
+ base64_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
71
+ images.append(base64_image)
72
+
73
+ return JSONResponse(content={
74
+ "theme": request.theme,
75
+ "reading_level": request.reading_level,
76
+ "story": story_text,
77
+ "questions": questions,
78
+ "images": images
79
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ except Exception as e:
82
+ print(f"❌ Error generating story/questions/images: {e}")
83
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ @app.get("/")
86
+ def home():
87
+ return {"message": "πŸŽ‰ Welcome to the Story, Question & Image API!"}