import torch from transformers import GPT2Tokenizer, GPT2LMHeadModel # ✅ Define cache directory CACHE_DIR = "/tmp/huggingface" # ✅ Load Story Generation Model STORY_MODEL_NAME = "abdalraheemdmd/story-api" story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR) story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR) # ✅ Load Question Generation Model QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene" question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR) question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR) # ✅ Ensure tokenizer has a pad token if story_tokenizer.pad_token_id is None: story_tokenizer.pad_token_id = story_tokenizer.eos_token_id if question_tokenizer.pad_token_id is None: question_tokenizer.pad_token_id = question_tokenizer.eos_token_id def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7): """⚡ Ultra-fast, high-quality story generation.""" prompt = f"A {reading_level} story about {theme}:" input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids with torch.no_grad(): output = story_model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=20, top_p=0.7, do_sample=True, early_stopping=True, pad_token_id=story_tokenizer.pad_token_id, eos_token_id=story_tokenizer.eos_token_id, attention_mask=input_ids.ne(story_tokenizer.pad_token_id) ) return story_tokenizer.decode(output[0], skip_special_tokens=True) def generate_questions(story, max_new_tokens=150, temperature=0.7): """⚡ Ultra-fast, concise question generation.""" prompt = f"Generate 3 clear comprehension questions for this short story:\n\n{story}" input_ids = question_tokenizer(prompt, return_tensors="pt").input_ids with torch.no_grad(): output = question_model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=20, top_p=0.7, do_sample=True, early_stopping=True, pad_token_id=question_tokenizer.pad_token_id, eos_token_id=question_tokenizer.eos_token_id, attention_mask=input_ids.ne(question_tokenizer.pad_token_id) ) return question_tokenizer.decode(output[0], skip_special_tokens=True) def generate_story_and_questions(theme, reading_level): """⚡ Generates a story and 3 questions ultra-fast.""" story = generate_story(theme, reading_level) questions = generate_questions(story) return {"story": story, "questions": questions}