import gradio as gr | |
import torch | |
import torch.nn as nn | |
import tiktoken | |
import torchvision.transforms as transforms | |
from model import DecoderTransformer | |
from config import Config | |
from inference import predict | |
from utils import get_device | |
def generate_sequence(text): | |
config = Config() | |
device = get_device() | |
# Load model | |
model = DecoderTransformer(config) | |
# model.load_state_dict(torch.load(config.saved_model_path, weights_only=True)) | |
model.load_state_dict(torch.load(config.saved_model_path, map_location=torch.device("cpu"))) | |
model.to(device) | |
model.eval() | |
enc = tiktoken.get_encoding('gpt2') | |
tokens = enc.encode(text) | |
T = len(tokens) | |
input_tensor = torch.tensor(tokens, device=device) | |
input_tensor = input_tensor.view(1, T) | |
max_output_len = 30 | |
y = predict(input_tensor, model, max_output_len=max_output_len) | |
output_tokens = y[0, :].tolist() | |
return enc.decode(output_tokens) | |
# # Convert input text to tensor using tokenizer | |
# input_tensor = torch.tensor([config.tokenizer.encode(text)], device=config.device) | |
# Generate sequence | |
# with torch.no_grad(): | |
# # Initialize start token and empty sequence | |
# current_seq = torch.tensor([[config.start_token]], device=config.device) | |
# # Generate tokens one by one | |
# for _ in range(config.max_seq_length): | |
# # Get model predictions | |
# output = model(input_tensor, current_seq) | |
# next_token_logits = output[:, -1, :] | |
# next_token = torch.argmax(next_token_logits, dim=-1) | |
# # Add predicted token to sequence | |
# current_seq = torch.cat([current_seq, next_token.unsqueeze(0)], dim=1) | |
# # Stop if end token is generated | |
# if next_token.item() == config.end_token: | |
# break | |
# # Convert tokens to text | |
# generated_sequence = config.tokenizer.decode(current_seq[0].tolist()) | |
# return generated_sequence | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_sequence, | |
inputs=gr.Textbox(), | |
outputs=gr.Textbox(), | |
title="Text Generation", | |
description="Enter text to generate a continuation", | |
allow_flagging=False | |
) | |
if __name__ == "__main__": | |
iface.launch() |