Transformer-MiniGPT / inference.py
Austin207's picture
Upload folder using huggingface_hub
e24d6f1 verified
import torch
import time
from tokenizers import Tokenizer
from miniGPT import MiniGPT
# --- 1. Load tokenizer and model ---
tokenizer = Tokenizer.from_file("wordlevel.json")
vocab_size = tokenizer.get_vocab_size()
# Set model parameters to match your trained model
model = MiniGPT(
vocab_size=vocab_size,
embed_dim=128,
num_heads=4,
ff_dim=512,
num_layers=4,
max_seq_len=128
)
checkpoint_path = "model_checkpoint_step20000.pt"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 2. Show model parameter count ---
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
# --- 3. Sampling helpers ---
def top_k_logits(logits, k):
"""Keep only top-k tokens with highest probability."""
values, _ = torch.topk(logits, k)
min_values = values[:, -1].unsqueeze(1)
logits[logits < min_values] = -float('Inf')
return logits
def top_p_logits(logits, p=0.9):
"""Keep the smallest set of tokens with cumulative probability >= p."""
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for batch in range(logits.size(0)):
remove_ids = sorted_indices[batch][sorted_indices_to_remove[batch]]
logits[batch, remove_ids] = -float('Inf')
return logits
# --- 4. Streaming generation function ---
def generate_stream(
model, tokenizer, prompt,
max_new_tokens=50,
temperature=1.0,
top_k=None,
top_p=None,
repetition_penalty=2.0
):
idx = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long)
generated = []
start_time = time.time()
with torch.no_grad():
for _ in range(max_new_tokens):
if idx.shape[1] >= model.max_seq_len:
break
logits = model(idx)
logits = logits[:, -1, :] / temperature
# Apply repetition penalty
for token_id in set(generated):
logits[0, token_id] /= repetition_penalty
# Apply Top-K and/or Top-P filtering
if top_k is not None:
logits = top_k_logits(logits, top_k)
if top_p is not None:
logits = top_p_logits(logits, top_p)
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_id], dim=1)
generated.append(next_id.item())
print(tokenizer.decode([next_id.item()]), end=' ', flush=True)
elapsed = time.time() - start_time
tps = len(generated) / elapsed if elapsed > 0 else 0
print(f"\n[Generated {len(generated)} tokens in {elapsed:.2f} seconds | {tps:.2f} tokens/sec]")
return idx
# --- 5. Main input loop ---
while True:
prompt = input("\nEnter your prompt (or type 'exit' to quit): ")
if prompt.lower() == 'exit':
break
print("\nStreaming output:")
generate_stream(
model, tokenizer, prompt,
max_new_tokens=90,
temperature=2.0,
top_k=100,
top_p=0.9,
repetition_penalty=1.8
)