Spaces:
Sleeping
Sleeping
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# β Define cache directory | |
CACHE_DIR = "/tmp/huggingface" | |
# β Load Story Generation Model | |
STORY_MODEL_NAME = "abdalraheemdmd/story-api" | |
story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR) | |
story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR) | |
# β Load Question Generation Model | |
QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene" | |
question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR) | |
question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR) | |
# β Ensure tokenizers have a pad token | |
if story_tokenizer.pad_token_id is None: | |
story_tokenizer.pad_token_id = story_tokenizer.eos_token_id | |
if question_tokenizer.pad_token_id is None: | |
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id | |
def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7): | |
"""β‘ High-quality, fast story generation.""" | |
prompt = f"A {reading_level} children's story about {theme}:" | |
input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids.to(story_model.device) | |
with torch.no_grad(): | |
output = story_model.generate( | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=20, | |
top_p=0.7, | |
do_sample=True, | |
early_stopping=True, | |
pad_token_id=story_tokenizer.pad_token_id, | |
eos_token_id=story_tokenizer.eos_token_id | |
) | |
return story_tokenizer.decode(output[0], skip_special_tokens=True) | |
import torch | |
def generate_questions(story, max_new_tokens=150, temperature=0.7): | |
"""β‘ Generates structured multiple-choice comprehension questions from the story.""" | |
# β Ensure tokenizer has a proper padding token | |
if question_tokenizer.pad_token_id is None: | |
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id # β Fix attention issue | |
# β Optimized instruction (prevents prompt from appearing in the output) | |
input_text = f""" | |
Read the following short story and generate exactly **3 multiple-choice comprehension questions**. | |
**Story:** {story} | |
**Now, generate 3 multiple-choice questions in this format:** | |
**1.** (Question here) | |
A) (Wrong choice) | |
B) (Wrong choice) | |
C) (Correct answer) | |
D) (Wrong choice) | |
**2.** (Question here) | |
A) (Wrong choice) | |
B) (Correct answer) | |
C) (Wrong choice) | |
D) (Wrong choice) | |
**3.** (Question here) | |
A) (Correct answer) | |
B) (Wrong choice) | |
C) (Wrong choice) | |
D) (Wrong choice) | |
""" | |
input_ids = question_tokenizer(input_text, return_tensors="pt").input_ids.to(question_model.device) | |
attention_mask = input_ids.ne(question_tokenizer.pad_token_id) # β Proper masking | |
with torch.no_grad(): | |
output = question_model.generate( | |
input_ids, | |
attention_mask=attention_mask, # β FIXED: Ensuring valid mask | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=30, | |
top_p=0.85, | |
do_sample=True, | |
num_beams=3, # β Ensures structured question format | |
early_stopping=True, | |
pad_token_id=question_tokenizer.pad_token_id, # β FIXED | |
eos_token_id=question_tokenizer.eos_token_id | |
) | |
decoded_output = question_tokenizer.decode(output[0], skip_special_tokens=True) | |
# β Extract only the generated questions | |
questions_start = decoded_output.find("1.") | |
if questions_start != -1: | |
decoded_output = decoded_output[questions_start:].strip() | |
# β Ensure valid question output | |
questions = [q.strip() for q in decoded_output.split("\n") if q.strip() and "?" in q] | |
return questions[:3] # β Return exactly 3 questions | |
def generate_story_and_questions(theme, reading_level): | |
"""β‘ Generates a story and 3 multiple-choice questions.""" | |
story = generate_story(theme, reading_level) | |
questions = generate_questions(story) | |
return {"story": story, "questions": questions} | |