|
import torch
|
|
import time
|
|
from tokenizers import Tokenizer
|
|
from miniGPT import MiniGPT
|
|
|
|
|
|
tokenizer = Tokenizer.from_file("wordlevel.json")
|
|
vocab_size = tokenizer.get_vocab_size()
|
|
|
|
|
|
model = MiniGPT(
|
|
vocab_size=vocab_size,
|
|
embed_dim=128,
|
|
num_heads=4,
|
|
ff_dim=512,
|
|
num_layers=4,
|
|
max_seq_len=128
|
|
)
|
|
checkpoint_path = "model_checkpoint_step20000.pt"
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
model.eval()
|
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters())
|
|
print(f"Model parameters: {num_params:,}")
|
|
|
|
|
|
|
|
def top_k_logits(logits, k):
|
|
"""Keep only top-k tokens with highest probability."""
|
|
values, _ = torch.topk(logits, k)
|
|
min_values = values[:, -1].unsqueeze(1)
|
|
logits[logits < min_values] = -float('Inf')
|
|
return logits
|
|
|
|
def top_p_logits(logits, p=0.9):
|
|
"""Keep the smallest set of tokens with cumulative probability >= p."""
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
sorted_indices_to_remove = cumulative_probs > p
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
for batch in range(logits.size(0)):
|
|
remove_ids = sorted_indices[batch][sorted_indices_to_remove[batch]]
|
|
logits[batch, remove_ids] = -float('Inf')
|
|
|
|
return logits
|
|
|
|
|
|
def generate_stream(
|
|
model, tokenizer, prompt,
|
|
max_new_tokens=50,
|
|
temperature=1.0,
|
|
top_k=None,
|
|
top_p=None,
|
|
repetition_penalty=2.0
|
|
):
|
|
idx = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long)
|
|
generated = []
|
|
start_time = time.time()
|
|
|
|
with torch.no_grad():
|
|
for _ in range(max_new_tokens):
|
|
if idx.shape[1] >= model.max_seq_len:
|
|
break
|
|
|
|
logits = model(idx)
|
|
logits = logits[:, -1, :] / temperature
|
|
|
|
|
|
for token_id in set(generated):
|
|
logits[0, token_id] /= repetition_penalty
|
|
|
|
|
|
if top_k is not None:
|
|
logits = top_k_logits(logits, top_k)
|
|
if top_p is not None:
|
|
logits = top_p_logits(logits, top_p)
|
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
next_id = torch.multinomial(probs, num_samples=1)
|
|
idx = torch.cat([idx, next_id], dim=1)
|
|
generated.append(next_id.item())
|
|
print(tokenizer.decode([next_id.item()]), end=' ', flush=True)
|
|
|
|
elapsed = time.time() - start_time
|
|
tps = len(generated) / elapsed if elapsed > 0 else 0
|
|
print(f"\n[Generated {len(generated)} tokens in {elapsed:.2f} seconds | {tps:.2f} tokens/sec]")
|
|
return idx
|
|
|
|
|
|
while True:
|
|
prompt = input("\nEnter your prompt (or type 'exit' to quit): ")
|
|
if prompt.lower() == 'exit':
|
|
break
|
|
|
|
print("\nStreaming output:")
|
|
generate_stream(
|
|
model, tokenizer, prompt,
|
|
max_new_tokens=90,
|
|
temperature=2.0,
|
|
top_k=100,
|
|
top_p=0.9,
|
|
repetition_penalty=1.8
|
|
)
|
|
|