story-image-api / story_generator.py
abdalraheemdmd's picture
Update story_generator.py
70a8c2d verified
raw
history blame
4.27 kB
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# βœ… Define cache directory
CACHE_DIR = "/tmp/huggingface"
# βœ… Load Story Generation Model
STORY_MODEL_NAME = "abdalraheemdmd/story-api"
story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
# βœ… Load Question Generation Model
QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene"
question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
# βœ… Ensure tokenizers have a pad token
if story_tokenizer.pad_token_id is None:
story_tokenizer.pad_token_id = story_tokenizer.eos_token_id
if question_tokenizer.pad_token_id is None:
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id
def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7):
"""⚑ High-quality, fast story generation."""
prompt = f"A {reading_level} children's story about {theme}:"
input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids.to(story_model.device)
with torch.no_grad():
output = story_model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=20,
top_p=0.7,
do_sample=True,
early_stopping=True,
pad_token_id=story_tokenizer.pad_token_id,
eos_token_id=story_tokenizer.eos_token_id
)
return story_tokenizer.decode(output[0], skip_special_tokens=True)
import torch
def generate_questions(story, max_new_tokens=150, temperature=0.7):
"""⚑ Generates structured multiple-choice comprehension questions from the story."""
# βœ… Ensure tokenizer has a proper padding token
if question_tokenizer.pad_token_id is None:
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id # βœ… Fix attention issue
# βœ… Optimized instruction (prevents prompt from appearing in the output)
input_text = f"""
Read the following short story and generate exactly **3 multiple-choice comprehension questions**.
**Story:** {story}
**Now, generate 3 multiple-choice questions in this format:**
**1.** (Question here)
A) (Wrong choice)
B) (Wrong choice)
C) (Correct answer)
D) (Wrong choice)
**2.** (Question here)
A) (Wrong choice)
B) (Correct answer)
C) (Wrong choice)
D) (Wrong choice)
**3.** (Question here)
A) (Correct answer)
B) (Wrong choice)
C) (Wrong choice)
D) (Wrong choice)
"""
input_ids = question_tokenizer(input_text, return_tensors="pt").input_ids.to(question_model.device)
attention_mask = input_ids.ne(question_tokenizer.pad_token_id) # βœ… Proper masking
with torch.no_grad():
output = question_model.generate(
input_ids,
attention_mask=attention_mask, # βœ… FIXED: Ensuring valid mask
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=30,
top_p=0.85,
do_sample=True,
num_beams=3, # βœ… Ensures structured question format
early_stopping=True,
pad_token_id=question_tokenizer.pad_token_id, # βœ… FIXED
eos_token_id=question_tokenizer.eos_token_id
)
decoded_output = question_tokenizer.decode(output[0], skip_special_tokens=True)
# βœ… Extract only the generated questions
questions_start = decoded_output.find("1.")
if questions_start != -1:
decoded_output = decoded_output[questions_start:].strip()
# βœ… Ensure valid question output
questions = [q.strip() for q in decoded_output.split("\n") if q.strip() and "?" in q]
return questions[:3] # βœ… Return exactly 3 questions
def generate_story_and_questions(theme, reading_level):
"""⚑ Generates a story and 3 multiple-choice questions."""
story = generate_story(theme, reading_level)
questions = generate_questions(story)
return {"story": story, "questions": questions}