nano-coder-free / sample_nano_coder.py
mlopez6132's picture
Upload sample_nano_coder.py with huggingface_hub
1ca3221 verified
"""
Sampling script for the nano-coder model.
This script loads a trained nano-coder model and generates Python code completions.
"""
import os
import pickle
import torch
import torch.nn.functional as F
from model import GPTConfig, GPT
# Configuration
out_dir = 'out-nano-coder'
start = "def fibonacci(n):\n " # or start with any Python code
num_samples = 5 # number of samples to generate
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, lower values make output more focused
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
# Load the model
def load_model():
"""Load the trained nano-coder model."""
# Load the checkpoint
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}. Please train the model first.")
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
return model, checkpoint
def load_vocab():
"""Load the vocabulary from the dataset."""
data_dir = os.path.join('data', 'python-codes-25k')
meta_path = os.path.join(data_dir, 'meta.pkl')
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Vocabulary not found at {meta_path}. Please run prepare_code_dataset.py first.")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
return meta['stoi'], meta['itos']
def encode(text, stoi):
"""Encode text to token ids."""
return [stoi[c] for c in text]
def decode(ids, itos):
"""Decode token ids to text."""
return ''.join([itos[i] for i in ids])
def generate_code(model, stoi, itos, start_text, max_new_tokens, temperature, top_k):
"""Generate code completion."""
# Encode the start text
start_ids = encode(start_text, stoi)
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
# Generate tokens
with torch.no_grad():
with torch.amp.autocast(device_type='cuda' if device == 'cuda' else 'cpu', dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
completion = decode(y[0].tolist(), itos)
return completion
def main():
print("Loading nano-coder model...")
model, checkpoint = load_model()
stoi, itos = load_vocab()
print(f"Model loaded successfully!")
print(f"Vocabulary size: {len(stoi)}")
print(f"Model parameters: {model.get_num_params()/1e6:.2f}M")
print(f"Context length: {model.config.block_size}")
print(f"Generating {num_samples} samples...")
print(f"Start text: {repr(start)}")
print("-" * 80)
# Set random seed for reproducibility
torch.manual_seed(seed)
# Generate samples
for i in range(num_samples):
print(f"\n--- Sample {i+1} ---")
completion = generate_code(model, stoi, itos, start, max_new_tokens, temperature, top_k)
print(completion)
print("-" * 80)
if __name__ == '__main__':
main()