|
""" |
|
DeepSeek Children's Stories Text Generation |
|
Generate children's stories using the trained DeepSeek model |
|
""" |
|
|
|
import os |
|
import sys |
|
import argparse |
|
import torch |
|
import tiktoken |
|
from typing import List, Optional |
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
|
|
|
from model.deepseek import DeepSeek, DeepSeekConfig |
|
|
|
|
|
torch.serialization.add_safe_globals([DeepSeekConfig]) |
|
|
|
class DeepSeekStoryGenerator: |
|
def __init__(self, model_path: str, device: str = 'auto'): |
|
"""Initialize the story generator""" |
|
self.device = self._get_device(device) |
|
self.model = self._load_model(model_path) |
|
self.tokenizer = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
self.special_tokens = { |
|
"story_start": "<|story|>", |
|
"story_end": "</|story|>", |
|
"prompt_start": "<|prompt|>", |
|
"prompt_end": "</|prompt|>", |
|
"moral_start": "<|moral|>", |
|
"moral_end": "</|moral|>", |
|
"character_start": "<|character|>", |
|
"character_end": "</|character|>" |
|
} |
|
|
|
def _get_device(self, device: str) -> str: |
|
"""Get the appropriate device""" |
|
if device == 'auto': |
|
return 'cuda' if torch.cuda.is_available() else 'cpu' |
|
return device |
|
|
|
def _load_model(self, model_path: str) -> DeepSeek: |
|
"""Load the trained model""" |
|
print(f"Loading model from {model_path}...") |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
|
|
|
|
|
config = checkpoint['config'] |
|
model = DeepSeek(config) |
|
|
|
|
|
state_dict = checkpoint['model'] |
|
if all(k.startswith('_orig_mod.') for k in state_dict.keys()): |
|
state_dict = {k[10:]: v for k, v in state_dict.items()} |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
model.to(self.device) |
|
model.eval() |
|
|
|
print(f"Model loaded successfully!") |
|
print(f"Model configuration: {config.n_layer}L/{config.n_head}H/{config.n_embd}D") |
|
print(f"Device: {self.device}") |
|
|
|
return model |
|
|
|
def encode_prompt(self, prompt: str, character: Optional[str] = None) -> torch.Tensor: |
|
"""Encode a prompt for generation""" |
|
|
|
full_prompt = f"{self.special_tokens['prompt_start']} {prompt.lower()} {self.special_tokens['prompt_end']}" |
|
|
|
if character: |
|
full_prompt += f" {self.special_tokens['character_start']} {character.lower()} {self.special_tokens['character_end']}" |
|
|
|
full_prompt += f" {self.special_tokens['story_start']}" |
|
|
|
|
|
token_ids = self.tokenizer.encode_ordinary(full_prompt) |
|
return torch.tensor([token_ids], dtype=torch.long, device=self.device) |
|
|
|
def generate_story(self, prompt: str, character: Optional[str] = None, |
|
max_tokens: int = 200, temperature: float = 0.8, |
|
top_k: int = 40, top_p: float = 0.9) -> str: |
|
"""Generate a children's story""" |
|
print(f"Generating story for prompt: '{prompt}'") |
|
if character: |
|
print(f"Character: {character}") |
|
|
|
|
|
input_ids = self.encode_prompt(prompt, character) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = self.model.generate( |
|
input_ids, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_k=top_k |
|
) |
|
|
|
|
|
generated_text = self.tokenizer.decode(generated_ids[0].tolist()) |
|
|
|
|
|
story = self._extract_story(generated_text) |
|
|
|
return story |
|
|
|
def _extract_story(self, text: str) -> str: |
|
"""Extract the story from the generated text""" |
|
|
|
story_start = text.find(self.special_tokens['story_start']) |
|
story_end = text.find(self.special_tokens['story_end']) |
|
|
|
if story_start != -1 and story_end != -1: |
|
|
|
story_content = text[story_start + len(self.special_tokens['story_start']):story_end].strip() |
|
return story_content |
|
else: |
|
|
|
prompt_end = text.find(self.special_tokens['prompt_end']) |
|
if prompt_end != -1: |
|
return text[prompt_end + len(self.special_tokens['prompt_end']):].strip() |
|
else: |
|
return text.strip() |
|
|
|
def generate_multiple_stories(self, prompts: List[str], num_stories: int = 3, |
|
**kwargs) -> List[str]: |
|
"""Generate multiple stories from a list of prompts""" |
|
stories = [] |
|
|
|
for i, prompt in enumerate(prompts): |
|
print(f"\nGenerating story {i+1}/{len(prompts)}...") |
|
story = self.generate_story(prompt, **kwargs) |
|
stories.append(story) |
|
|
|
return stories |
|
|
|
def interactive_generation(self): |
|
"""Interactive story generation mode""" |
|
print("DeepSeek Children's Stories - Interactive Mode") |
|
print("Type 'quit' to exit") |
|
print("-" * 50) |
|
|
|
while True: |
|
try: |
|
|
|
prompt = input("\nEnter a story prompt: ").strip() |
|
|
|
if prompt.lower() in ['quit', 'exit', 'q']: |
|
print("Goodbye!") |
|
break |
|
|
|
if not prompt: |
|
print("Please enter a valid prompt.") |
|
continue |
|
|
|
|
|
character = input("Enter a character name (optional): ").strip() |
|
if not character: |
|
character = None |
|
|
|
|
|
try: |
|
max_tokens = int(input("Max tokens (default 200): ") or "200") |
|
temperature = float(input("Temperature (default 0.8): ") or "0.8") |
|
except ValueError: |
|
max_tokens = 200 |
|
temperature = 0.8 |
|
|
|
|
|
story = self.generate_story( |
|
prompt, |
|
character=character, |
|
max_tokens=max_tokens, |
|
temperature=temperature |
|
) |
|
|
|
|
|
print("\n" + "="*50) |
|
print("GENERATED STORY:") |
|
print("="*50) |
|
print(story) |
|
print("="*50) |
|
|
|
except KeyboardInterrupt: |
|
print("\nGoodbye!") |
|
break |
|
except Exception as e: |
|
print(f"Error generating story: {e}") |
|
|
|
|
|
def main(): |
|
"""Main generation function""" |
|
parser = argparse.ArgumentParser(description='Generate children\'s stories with DeepSeek') |
|
|
|
|
|
parser.add_argument('--model-path', type=str, default='checkpoints/best_model.pt', |
|
help='Path to the trained model checkpoint') |
|
parser.add_argument('--device', type=str, default='auto', |
|
help='Device to use (auto, cuda, cpu)') |
|
|
|
|
|
parser.add_argument('--prompt', type=str, help='Story prompt') |
|
parser.add_argument('--character', type=str, help='Character name') |
|
parser.add_argument('--max-tokens', type=int, default=200, help='Maximum tokens to generate') |
|
parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature') |
|
parser.add_argument('--top-k', type=int, default=40, help='Top-k sampling') |
|
parser.add_argument('--top-p', type=float, default=0.9, help='Top-p sampling') |
|
|
|
|
|
parser.add_argument('--num-stories', type=int, default=1, help='Number of stories to generate') |
|
parser.add_argument('--interactive', action='store_true', help='Interactive mode') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.model_path): |
|
print(f"Error: Model file not found at {args.model_path}") |
|
print("Please train the model first or specify the correct path.") |
|
return |
|
|
|
|
|
generator = DeepSeekStoryGenerator(args.model_path, args.device) |
|
|
|
if args.interactive: |
|
|
|
generator.interactive_generation() |
|
else: |
|
|
|
if args.prompt: |
|
if args.num_stories == 1: |
|
|
|
story = generator.generate_story( |
|
args.prompt, |
|
character=args.character, |
|
max_tokens=args.max_tokens, |
|
temperature=args.temperature, |
|
top_k=args.top_k, |
|
top_p=args.top_p |
|
) |
|
|
|
print(f"\nPrompt: {args.prompt}") |
|
if args.character: |
|
print(f"Character: {args.character}") |
|
print("\n" + "="*50) |
|
print("GENERATED STORY:") |
|
print("="*50) |
|
print(story) |
|
print("="*50) |
|
else: |
|
|
|
prompts = [args.prompt] * args.num_stories |
|
stories = generator.generate_multiple_stories( |
|
prompts, |
|
num_stories=args.num_stories, |
|
character=args.character, |
|
max_tokens=args.max_tokens, |
|
temperature=args.temperature, |
|
top_k=args.top_k, |
|
top_p=args.top_p |
|
) |
|
|
|
for i, story in enumerate(stories): |
|
print(f"\nStory {i+1}:") |
|
print("="*50) |
|
print(story) |
|
print("="*50) |
|
else: |
|
print("Please provide a prompt or use --interactive mode.") |
|
print("Example: python generate.py --prompt 'A brave little mouse' --character 'Mickey'") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|