mlopez6132 commited on
Commit
1ca3221
·
verified ·
1 Parent(s): f59de10

Upload sample_nano_coder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sample_nano_coder.py +104 -0
sample_nano_coder.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sampling script for the nano-coder model.
3
+ This script loads a trained nano-coder model and generates Python code completions.
4
+ """
5
+
6
+ import os
7
+ import pickle
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from model import GPTConfig, GPT
11
+
12
+ # Configuration
13
+ out_dir = 'out-nano-coder'
14
+ start = "def fibonacci(n):\n " # or start with any Python code
15
+ num_samples = 5 # number of samples to generate
16
+ max_new_tokens = 500 # number of tokens generated in each sample
17
+ temperature = 0.8 # 1.0 = no change, lower values make output more focused
18
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
19
+ seed = 1337
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
22
+
23
+ # Load the model
24
+ def load_model():
25
+ """Load the trained nano-coder model."""
26
+ # Load the checkpoint
27
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
28
+ if not os.path.exists(ckpt_path):
29
+ raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}. Please train the model first.")
30
+
31
+ checkpoint = torch.load(ckpt_path, map_location=device)
32
+ gptconf = GPTConfig(**checkpoint['model_args'])
33
+ model = GPT(gptconf)
34
+ state_dict = checkpoint['model']
35
+ unwanted_prefix = '_orig_mod.'
36
+ for k,v in list(state_dict.items()):
37
+ if k.startswith(unwanted_prefix):
38
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
39
+ model.load_state_dict(state_dict)
40
+ model.eval()
41
+ model.to(device)
42
+
43
+ return model, checkpoint
44
+
45
+ def load_vocab():
46
+ """Load the vocabulary from the dataset."""
47
+ data_dir = os.path.join('data', 'python-codes-25k')
48
+ meta_path = os.path.join(data_dir, 'meta.pkl')
49
+
50
+ if not os.path.exists(meta_path):
51
+ raise FileNotFoundError(f"Vocabulary not found at {meta_path}. Please run prepare_code_dataset.py first.")
52
+
53
+ with open(meta_path, 'rb') as f:
54
+ meta = pickle.load(f)
55
+
56
+ return meta['stoi'], meta['itos']
57
+
58
+ def encode(text, stoi):
59
+ """Encode text to token ids."""
60
+ return [stoi[c] for c in text]
61
+
62
+ def decode(ids, itos):
63
+ """Decode token ids to text."""
64
+ return ''.join([itos[i] for i in ids])
65
+
66
+ def generate_code(model, stoi, itos, start_text, max_new_tokens, temperature, top_k):
67
+ """Generate code completion."""
68
+ # Encode the start text
69
+ start_ids = encode(start_text, stoi)
70
+ x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
71
+
72
+ # Generate tokens
73
+ with torch.no_grad():
74
+ with torch.amp.autocast(device_type='cuda' if device == 'cuda' else 'cpu', dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16):
75
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
76
+ completion = decode(y[0].tolist(), itos)
77
+
78
+ return completion
79
+
80
+ def main():
81
+ print("Loading nano-coder model...")
82
+ model, checkpoint = load_model()
83
+ stoi, itos = load_vocab()
84
+
85
+ print(f"Model loaded successfully!")
86
+ print(f"Vocabulary size: {len(stoi)}")
87
+ print(f"Model parameters: {model.get_num_params()/1e6:.2f}M")
88
+ print(f"Context length: {model.config.block_size}")
89
+ print(f"Generating {num_samples} samples...")
90
+ print(f"Start text: {repr(start)}")
91
+ print("-" * 80)
92
+
93
+ # Set random seed for reproducibility
94
+ torch.manual_seed(seed)
95
+
96
+ # Generate samples
97
+ for i in range(num_samples):
98
+ print(f"\n--- Sample {i+1} ---")
99
+ completion = generate_code(model, stoi, itos, start, max_new_tokens, temperature, top_k)
100
+ print(completion)
101
+ print("-" * 80)
102
+
103
+ if __name__ == '__main__':
104
+ main()