Spaces:
Running
Running
""" | |
Sampling script for the nano-coder model. | |
This script loads a trained nano-coder model and generates Python code completions. | |
""" | |
import os | |
import pickle | |
import torch | |
import torch.nn.functional as F | |
from model import GPTConfig, GPT | |
# Configuration | |
out_dir = 'out-nano-coder' | |
start = "def fibonacci(n):\n " # or start with any Python code | |
num_samples = 5 # number of samples to generate | |
max_new_tokens = 500 # number of tokens generated in each sample | |
temperature = 0.8 # 1.0 = no change, lower values make output more focused | |
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability | |
seed = 1337 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' | |
# Load the model | |
def load_model(): | |
"""Load the trained nano-coder model.""" | |
# Load the checkpoint | |
ckpt_path = os.path.join(out_dir, 'ckpt.pt') | |
if not os.path.exists(ckpt_path): | |
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}. Please train the model first.") | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
gptconf = GPTConfig(**checkpoint['model_args']) | |
model = GPT(gptconf) | |
state_dict = checkpoint['model'] | |
unwanted_prefix = '_orig_mod.' | |
for k,v in list(state_dict.items()): | |
if k.startswith(unwanted_prefix): | |
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
model.load_state_dict(state_dict) | |
model.eval() | |
model.to(device) | |
return model, checkpoint | |
def load_vocab(): | |
"""Load the vocabulary from the dataset.""" | |
data_dir = os.path.join('data', 'python-codes-25k') | |
meta_path = os.path.join(data_dir, 'meta.pkl') | |
if not os.path.exists(meta_path): | |
raise FileNotFoundError(f"Vocabulary not found at {meta_path}. Please run prepare_code_dataset.py first.") | |
with open(meta_path, 'rb') as f: | |
meta = pickle.load(f) | |
return meta['stoi'], meta['itos'] | |
def encode(text, stoi): | |
"""Encode text to token ids.""" | |
return [stoi[c] for c in text] | |
def decode(ids, itos): | |
"""Decode token ids to text.""" | |
return ''.join([itos[i] for i in ids]) | |
def generate_code(model, stoi, itos, start_text, max_new_tokens, temperature, top_k): | |
"""Generate code completion.""" | |
# Encode the start text | |
start_ids = encode(start_text, stoi) | |
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] | |
# Generate tokens | |
with torch.no_grad(): | |
with torch.amp.autocast(device_type='cuda' if device == 'cuda' else 'cpu', dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16): | |
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) | |
completion = decode(y[0].tolist(), itos) | |
return completion | |
def main(): | |
print("Loading nano-coder model...") | |
model, checkpoint = load_model() | |
stoi, itos = load_vocab() | |
print(f"Model loaded successfully!") | |
print(f"Vocabulary size: {len(stoi)}") | |
print(f"Model parameters: {model.get_num_params()/1e6:.2f}M") | |
print(f"Context length: {model.config.block_size}") | |
print(f"Generating {num_samples} samples...") | |
print(f"Start text: {repr(start)}") | |
print("-" * 80) | |
# Set random seed for reproducibility | |
torch.manual_seed(seed) | |
# Generate samples | |
for i in range(num_samples): | |
print(f"\n--- Sample {i+1} ---") | |
completion = generate_code(model, stoi, itos, start, max_new_tokens, temperature, top_k) | |
print(completion) | |
print("-" * 80) | |
if __name__ == '__main__': | |
main() |