Spaces:
Sleeping
Sleeping
Upload sample_nano_coder.py with huggingface_hub
Browse files- sample_nano_coder.py +104 -0
sample_nano_coder.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Sampling script for the nano-coder model.
|
3 |
+
This script loads a trained nano-coder model and generates Python code completions.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from model import GPTConfig, GPT
|
11 |
+
|
12 |
+
# Configuration
|
13 |
+
out_dir = 'out-nano-coder'
|
14 |
+
start = "def fibonacci(n):\n " # or start with any Python code
|
15 |
+
num_samples = 5 # number of samples to generate
|
16 |
+
max_new_tokens = 500 # number of tokens generated in each sample
|
17 |
+
temperature = 0.8 # 1.0 = no change, lower values make output more focused
|
18 |
+
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
|
19 |
+
seed = 1337
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
22 |
+
|
23 |
+
# Load the model
|
24 |
+
def load_model():
|
25 |
+
"""Load the trained nano-coder model."""
|
26 |
+
# Load the checkpoint
|
27 |
+
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
28 |
+
if not os.path.exists(ckpt_path):
|
29 |
+
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}. Please train the model first.")
|
30 |
+
|
31 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
32 |
+
gptconf = GPTConfig(**checkpoint['model_args'])
|
33 |
+
model = GPT(gptconf)
|
34 |
+
state_dict = checkpoint['model']
|
35 |
+
unwanted_prefix = '_orig_mod.'
|
36 |
+
for k,v in list(state_dict.items()):
|
37 |
+
if k.startswith(unwanted_prefix):
|
38 |
+
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
39 |
+
model.load_state_dict(state_dict)
|
40 |
+
model.eval()
|
41 |
+
model.to(device)
|
42 |
+
|
43 |
+
return model, checkpoint
|
44 |
+
|
45 |
+
def load_vocab():
|
46 |
+
"""Load the vocabulary from the dataset."""
|
47 |
+
data_dir = os.path.join('data', 'python-codes-25k')
|
48 |
+
meta_path = os.path.join(data_dir, 'meta.pkl')
|
49 |
+
|
50 |
+
if not os.path.exists(meta_path):
|
51 |
+
raise FileNotFoundError(f"Vocabulary not found at {meta_path}. Please run prepare_code_dataset.py first.")
|
52 |
+
|
53 |
+
with open(meta_path, 'rb') as f:
|
54 |
+
meta = pickle.load(f)
|
55 |
+
|
56 |
+
return meta['stoi'], meta['itos']
|
57 |
+
|
58 |
+
def encode(text, stoi):
|
59 |
+
"""Encode text to token ids."""
|
60 |
+
return [stoi[c] for c in text]
|
61 |
+
|
62 |
+
def decode(ids, itos):
|
63 |
+
"""Decode token ids to text."""
|
64 |
+
return ''.join([itos[i] for i in ids])
|
65 |
+
|
66 |
+
def generate_code(model, stoi, itos, start_text, max_new_tokens, temperature, top_k):
|
67 |
+
"""Generate code completion."""
|
68 |
+
# Encode the start text
|
69 |
+
start_ids = encode(start_text, stoi)
|
70 |
+
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
|
71 |
+
|
72 |
+
# Generate tokens
|
73 |
+
with torch.no_grad():
|
74 |
+
with torch.amp.autocast(device_type='cuda' if device == 'cuda' else 'cpu', dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16):
|
75 |
+
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
|
76 |
+
completion = decode(y[0].tolist(), itos)
|
77 |
+
|
78 |
+
return completion
|
79 |
+
|
80 |
+
def main():
|
81 |
+
print("Loading nano-coder model...")
|
82 |
+
model, checkpoint = load_model()
|
83 |
+
stoi, itos = load_vocab()
|
84 |
+
|
85 |
+
print(f"Model loaded successfully!")
|
86 |
+
print(f"Vocabulary size: {len(stoi)}")
|
87 |
+
print(f"Model parameters: {model.get_num_params()/1e6:.2f}M")
|
88 |
+
print(f"Context length: {model.config.block_size}")
|
89 |
+
print(f"Generating {num_samples} samples...")
|
90 |
+
print(f"Start text: {repr(start)}")
|
91 |
+
print("-" * 80)
|
92 |
+
|
93 |
+
# Set random seed for reproducibility
|
94 |
+
torch.manual_seed(seed)
|
95 |
+
|
96 |
+
# Generate samples
|
97 |
+
for i in range(num_samples):
|
98 |
+
print(f"\n--- Sample {i+1} ---")
|
99 |
+
completion = generate_code(model, stoi, itos, start, max_new_tokens, temperature, top_k)
|
100 |
+
print(completion)
|
101 |
+
print("-" * 80)
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
main()
|