abdalraheemdmd commited on
Commit
06e92c4
·
verified ·
1 Parent(s): 7f33308

Update story_generator.py

Browse files
Files changed (1) hide show
  1. story_generator.py +177 -107
story_generator.py CHANGED
@@ -1,119 +1,189 @@
1
- import os
2
  import torch
3
- from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import JSONResponse
5
- from pydantic import BaseModel
6
- import story_generator
7
- from diffusers import DiffusionPipeline
8
- from PIL import Image
9
- import io
10
- import base64
11
- import time
12
 
13
- app = FastAPI()
 
14
 
15
- # Set Hugging Face cache directories
16
- os.environ["HF_HOME"] = "/tmp/huggingface"
17
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
18
- os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
 
 
19
 
20
- # Enable GPU if available
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
22
 
23
- # Load image generation model
24
- IMAGE_MODEL = "lykon/dreamshaper-8"
25
- pipeline = DiffusionPipeline.from_pretrained(
26
- IMAGE_MODEL,
27
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
28
- ).to(device)
29
 
30
- # Define request schema
31
- class StoryRequest(BaseModel):
32
- theme: str
33
- reading_level: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- def generate_images(prompts, max_retries=3, delay=2):
36
  """
37
- Generates images in batch with a retry mechanism.
38
- If a specific indexing error occurs, it retries up to max_retries times.
39
- If all retries fail, it falls back to sequential generation.
40
  """
41
- for attempt in range(max_retries):
42
- try:
43
- print(f"Batch image generation attempt {attempt+1}...")
44
- results = pipeline(prompt=prompts, num_inference_steps=15, height=768, width=768).images
45
- return results
46
- except Exception as e:
47
- if "index 16 is out of bounds" in str(e):
48
- print(f"Encountered indexing error on attempt {attempt+1}: {e}")
49
- time.sleep(delay) # wait before retrying
50
- else:
51
- raise e
52
- # Fallback: Sequential generation
53
- print("Falling back to sequential image generation...")
54
- images = []
55
- for i, prompt in enumerate(prompts):
56
- try:
57
- print(f"Generating image for prompt {i+1} sequentially...")
58
- image = pipeline(prompt=prompt, num_inference_steps=15, height=768, width=768).images[0]
59
- images.append(image)
60
- except Exception as e:
61
- print(f"Sequential generation failed for prompt {i+1}: {e}")
62
- raise e
63
- return images
64
 
65
- @app.post("/generate_story_questions_images")
66
- def generate_story_questions_images(request: StoryRequest):
67
- """Generates a story, dynamic questions, and cartoonish storybook images."""
68
- try:
69
- print(f"🎭 Generating story for theme: {request.theme} and level: {request.reading_level}")
70
-
71
- # Generate story and questions
72
- story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
73
- story_text = story_result.get("story", "").strip()
74
- questions = story_result.get("questions", "").strip()
75
-
76
- if not story_text:
77
- raise HTTPException(status_code=500, detail="Story generation failed.")
78
-
79
- # Split the story into up to 6 paragraphs
80
- paragraphs = [p.strip() for p in story_text.split("\n") if p.strip()][:6]
81
-
82
- # Build a list of prompts for batched image generation
83
- prompts = [
84
- (
85
- f"Children's storybook illustration of: {p}. "
86
- "Soft pastel colors, hand-drawn style, friendly characters, warm lighting, "
87
- "fantasy setting, watercolor texture, storybook illustration, beautiful composition."
88
- )
89
- for p in paragraphs
90
- ]
91
- print(f"Generating images for {len(prompts)} paragraphs concurrently...")
92
-
93
- # Use the retry mechanism for image generation
94
- results = generate_images(prompts, max_retries=3, delay=2)
95
-
96
- # Convert each image to Base64
97
- images = []
98
- for image in results:
99
- img_byte_arr = io.BytesIO()
100
- image.save(img_byte_arr, format="PNG")
101
- img_byte_arr.seek(0)
102
- base64_image = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
103
- images.append(base64_image)
104
-
105
- return JSONResponse(content={
106
- "theme": request.theme,
107
- "reading_level": request.reading_level,
108
- "story": story_text,
109
- "questions": questions,
110
- "images": images
111
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- except Exception as e:
114
- print(f"❌ Error generating story/questions/images: {e}")
115
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- @app.get("/")
118
- def home():
119
- return {"message": "🎉 Welcome to the Story, Question & Image API!"}
 
 
1
  import torch
2
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
+ import random
4
+ import re
 
 
 
 
 
 
5
 
6
+ # Set Hugging Face cache directory
7
+ CACHE_DIR = "/tmp/huggingface"
8
 
9
+ # ------------------------
10
+ # Load Story Generation Model
11
+ # ------------------------
12
+ STORY_MODEL_NAME = "abdalraheemdmd/story-api"
13
+ story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
14
+ story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
15
 
16
+ # ------------------------
17
+ # Load Question Generation Model
18
+ # ------------------------
19
+ QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene"
20
+ question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
21
+ question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
22
 
23
+ # Ensure tokenizers have a pad token
24
+ if story_tokenizer.pad_token_id is None:
25
+ story_tokenizer.pad_token_id = story_tokenizer.eos_token_id
26
+ if question_tokenizer.pad_token_id is None:
27
+ question_tokenizer.pad_token_id = question_tokenizer.eos_token_id
 
28
 
29
+ def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7):
30
+ """Generates a story based on the provided theme and reading level."""
31
+ prompt = f"A {reading_level} story about {theme}:"
32
+ input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids
33
+ with torch.no_grad():
34
+ output = story_model.generate(
35
+ input_ids,
36
+ max_new_tokens=max_new_tokens,
37
+ temperature=temperature,
38
+ top_k=20,
39
+ top_p=0.7,
40
+ do_sample=True,
41
+ early_stopping=True,
42
+ pad_token_id=story_tokenizer.pad_token_id,
43
+ eos_token_id=story_tokenizer.eos_token_id,
44
+ attention_mask=input_ids.ne(story_tokenizer.pad_token_id)
45
+ )
46
+ return story_tokenizer.decode(output[0], skip_special_tokens=True)
47
 
48
+ def extract_protagonist(story):
49
  """
50
+ Attempts to extract the protagonist from the first sentence by searching for the pattern "named <Name>".
51
+ Returns the first matched name, if available.
 
52
  """
53
+ sentences = re.split(r'\.|\n', story)
54
+ if sentences:
55
+ m = re.search(r"named\s+([A-Z][a-z]+)", sentences[0])
56
+ if m:
57
+ return m.group(1)
58
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def extract_characters(story):
61
+ """
62
+ Extracts potential character names from the story using a frequency count on capitalized words.
63
+ Filters out common stopwords so that the most frequently mentioned name is likely the main character.
64
+ """
65
+ words = re.findall(r'\b[A-Z][a-zA-Z]+\b', story)
66
+ stopwords = {"The", "A", "An", "And", "But", "Suddenly", "Quickly", "However", "Well",
67
+ "They", "I", "He", "She", "It", "When", "Where", "Dr", "Mr"}
68
+ filtered = [w for w in words if w not in stopwords and len(w) > 2]
69
+ if not filtered:
70
+ return []
71
+ freq = {}
72
+ for word in filtered:
73
+ freq[word] = freq.get(word, 0) + 1
74
+ sorted_chars = sorted(freq.items(), key=lambda x: x[1], reverse=True)
75
+ return [item[0] for item in sorted_chars]
76
+
77
+ def extract_themes(story):
78
+ """Extracts themes from the story based on keyword matching."""
79
+ themes = []
80
+ story_lower = story.lower()
81
+ if "space" in story_lower:
82
+ themes.append("space")
83
+ if "adventure" in story_lower:
84
+ themes.append("adventure")
85
+ if "friend" in story_lower:
86
+ themes.append("friendship")
87
+ if "learn" in story_lower or "lesson" in story_lower:
88
+ themes.append("learning")
89
+ return themes
90
+
91
+ def extract_lesson(story):
92
+ """
93
+ Attempts to extract a lesson or moral from the story by finding sentences
94
+ containing keywords like "learn" or "lesson". Returns the last matching sentence.
95
+ """
96
+ sentences = re.split(r'\.|\n', story)
97
+ lesson_sentences = [
98
+ s.strip() for s in sentences
99
+ if ("learn" in s.lower() or "lesson" in s.lower()) and len(s.strip()) > 20
100
+ ]
101
+ if lesson_sentences:
102
+ return lesson_sentences[-1]
103
+ else:
104
+ return "No explicit lesson found."
105
+
106
+ def format_question(question_prompt, correct_answer, distractors):
107
+ """
108
+ Combines the correct answer with three distractors, shuffles the options,
109
+ and formats the question as a multiple-choice question.
110
+ """
111
+ # Ensure exactly 3 distractors are available
112
+ if len(distractors) < 3:
113
+ default_distractors = ["Option X", "Option Y", "Option Z"]
114
+ while len(distractors) < 3:
115
+ distractors.append(default_distractors[len(distractors) % len(default_distractors)])
116
+ else:
117
+ distractors = random.sample(distractors, 3)
118
+ options = distractors + [correct_answer]
119
+ random.shuffle(options)
120
+ letters = ["A", "B", "C", "D"]
121
+ correct_letter = letters[options.index(correct_answer)]
122
+ options_text = "\n".join(f"{letters[i]}) {option}" for i, option in enumerate(options))
123
+ question_text = f"{question_prompt}\n{options_text}\nCorrect Answer: {correct_letter}"
124
+ return question_text
125
+
126
+ def dynamic_fallback_questions(story):
127
+ """
128
+ Generates three multiple-choice questions based on dynamic story content.
129
+ Each question uses a randomly chosen template and shuffles its options.
130
+ """
131
+ protagonist = extract_protagonist(story)
132
+ characters = extract_characters(story)
133
+ themes = extract_themes(story)
134
+ lesson = extract_lesson(story)
135
+
136
+ # --- Question 1: Theme ---
137
+ theme_templates = [
138
+ "What is the main theme of the story?",
139
+ "Which theme best represents the narrative?",
140
+ "What subject is central to the story?"
141
+ ]
142
+ q1_prompt = random.choice(theme_templates)
143
+ correct_theme = " and ".join(themes) if themes else "learning"
144
+ q1_distractors = ["sports and competition", "cooking and baking", "weather and seasons", "technology and innovation"]
145
+ q1 = format_question(q1_prompt, correct_theme, q1_distractors)
146
+
147
+ # --- Question 2: Primary Character ---
148
+ character_templates = [
149
+ "Who is the primary character in the story?",
150
+ "Which character drives the main action in the narrative?",
151
+ "Who is the central figure in the story?"
152
+ ]
153
+ q2_prompt = random.choice(character_templates)
154
+ if protagonist:
155
+ correct_character = protagonist
156
+ elif characters:
157
+ correct_character = characters[0]
158
+ else:
159
+ correct_character = "The main character"
160
+ q2_distractors = ["a mysterious stranger", "an unknown visitor", "a supporting character", "a sidekick"]
161
+ q2 = format_question(q2_prompt, correct_character, q2_distractors)
162
 
163
+ # --- Question 3: Lesson/Moral ---
164
+ lesson_templates = [
165
+ "What lesson did the characters learn by the end of the story?",
166
+ "What moral can be inferred from the narrative?",
167
+ "What is the key takeaway from the story?"
168
+ ]
169
+ q3_prompt = random.choice(lesson_templates)
170
+ if lesson and lesson != "No explicit lesson found.":
171
+ correct_lesson = lesson # full sentence without truncation
172
+ else:
173
+ correct_lesson = "understanding and growth"
174
+ q3_distractors = ["always be silent", "never try new things", "do nothing", "ignore opportunities"]
175
+ q3 = format_question(q3_prompt, correct_lesson, q3_distractors)
176
+
177
+ return f"{q1}\n\n{q2}\n\n{q3}"
178
+
179
+ def generate_story_and_questions(theme, reading_level):
180
+ """
181
+ Generates a story using the story generation model and then creates dynamic,
182
+ multiple-choice questions based on that story.
183
+ """
184
+ story = generate_story(theme, reading_level)
185
+ questions = dynamic_fallback_questions(story)
186
+ return {"story": story, "questions": questions}
187
 
188
+ # Alias for backward compatibility
189
+ create_fallback_questions = dynamic_fallback_questions