story-image-api / story_generator.py
abdalraheemdmd's picture
Update story_generator.py
9475b0d verified
raw
history blame
2.82 kB
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# βœ… Define cache directory
CACHE_DIR = "/tmp/huggingface"
# βœ… Load Story Generation Model
STORY_MODEL_NAME = "abdalraheemdmd/story-api"
story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
# βœ… Load Question Generation Model
QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene"
question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
# βœ… Ensure tokenizer has a pad token
if story_tokenizer.pad_token_id is None:
story_tokenizer.pad_token_id = story_tokenizer.eos_token_id
if question_tokenizer.pad_token_id is None:
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id
def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7):
"""⚑ Ultra-fast, high-quality story generation."""
prompt = f"A {reading_level} story about {theme}:"
input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
output = story_model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=20,
top_p=0.7,
do_sample=True,
early_stopping=True,
pad_token_id=story_tokenizer.pad_token_id,
eos_token_id=story_tokenizer.eos_token_id,
attention_mask=input_ids.ne(story_tokenizer.pad_token_id)
)
return story_tokenizer.decode(output[0], skip_special_tokens=True)
def generate_questions(story, max_new_tokens=150, temperature=0.7):
"""⚑ Ultra-fast, concise question generation."""
prompt = f"Generate 3 clear comprehension questions for this short story:\n\n{story}"
input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
output = question_model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=20,
top_p=0.7,
do_sample=True,
early_stopping=True,
pad_token_id=question_tokenizer.pad_token_id,
eos_token_id=question_tokenizer.eos_token_id,
attention_mask=input_ids.ne(question_tokenizer.pad_token_id)
)
return question_tokenizer.decode(output[0], skip_special_tokens=True)
def generate_story_and_questions(theme, reading_level):
"""⚑ Generates a story and 3 questions ultra-fast."""
story = generate_story(theme, reading_level)
questions = generate_questions(story)
return {"story": story, "questions": questions}