Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer | |
from burmese_gpt.config import ModelConfig | |
from burmese_gpt.models import BurmeseGPT | |
VOCAB_SIZE = 119547 | |
CHECKPOINT_PATH = "checkpoints/best_model.pth" | |
def download_pretrained_model(path: str): | |
pass | |
def load_model(path: str): | |
model_config = ModelConfig() | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model_config.vocab_size = VOCAB_SIZE | |
model = BurmeseGPT(model_config) | |
# Load checkpoint | |
checkpoint = torch.load(path, map_location="cpu") | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
model.eval() | |
# Move to device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
return model, tokenizer, device | |
def generate_sample(model, tokenizer, device, prompt="မြန်မာ", max_length=50): | |
"""Generate text from prompt""" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
for _ in range(max_length): | |
outputs = model(input_ids) | |
next_token = outputs[:, -1, :].argmax(dim=-1, keepdim=True) | |
input_ids = torch.cat((input_ids, next_token), dim=-1) | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
if __name__ == "__main__": | |
# Download the pretrained model | |
# download_pretrained_model(CHECKPOINT_PATH) | |
print("Loading model...") | |
model, tokenizer, device = load_model(CHECKPOINT_PATH) | |
while True: | |
prompt = input("\nEnter prompt (or 'quit' to exit): ") | |
if prompt.lower() == "quit": | |
break | |
print("\nGenerating...") | |
generated = generate_sample(model, tokenizer, device, prompt) | |
print(f"\nPrompt: {prompt}") | |
print(f"Generated: {generated}") | |