Spaces:
Sleeping
Sleeping
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# β Define cache directory | |
CACHE_DIR = "/tmp/huggingface" | |
# β Load Story Generation Model | |
STORY_MODEL_NAME = "abdalraheemdmd/story-api" | |
story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR) | |
story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR) | |
# β Load Question Generation Model | |
QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene" | |
question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR) | |
question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR) | |
# β Ensure tokenizer has a pad token | |
if story_tokenizer.pad_token_id is None: | |
story_tokenizer.pad_token_id = story_tokenizer.eos_token_id | |
if question_tokenizer.pad_token_id is None: | |
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id | |
def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7): | |
"""β‘ Ultra-fast, high-quality story generation.""" | |
prompt = f"A {reading_level} story about {theme}:" | |
input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids | |
with torch.no_grad(): | |
output = story_model.generate( | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=20, | |
top_p=0.7, | |
do_sample=True, | |
early_stopping=True, | |
pad_token_id=story_tokenizer.pad_token_id, | |
eos_token_id=story_tokenizer.eos_token_id, | |
attention_mask=input_ids.ne(story_tokenizer.pad_token_id) | |
) | |
return story_tokenizer.decode(output[0], skip_special_tokens=True) | |
def generate_questions(story, max_new_tokens=150, temperature=0.7): | |
"""β‘ Ultra-fast, concise question generation.""" | |
prompt = f"Generate 3 clear comprehension questions for this short story:\n\n{story}" | |
input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids | |
with torch.no_grad(): | |
output = question_model.generate( | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=20, | |
top_p=0.7, | |
do_sample=True, | |
early_stopping=True, | |
pad_token_id=question_tokenizer.pad_token_id, | |
eos_token_id=question_tokenizer.eos_token_id, | |
attention_mask=input_ids.ne(question_tokenizer.pad_token_id) | |
) | |
return question_tokenizer.decode(output[0], skip_special_tokens=True) | |
def generate_story_and_questions(theme, reading_level): | |
"""β‘ Generates a story and 3 questions ultra-fast.""" | |
story = generate_story(theme, reading_level) | |
questions = generate_questions(story) | |
return {"story": story, "questions": questions} | |