|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import tiktoken |
|
import torchvision.transforms as transforms |
|
from model import DecoderTransformer |
|
from config import Config |
|
from inference import predict |
|
from utils import get_device |
|
|
|
def generate_sequence(text): |
|
config = Config() |
|
device = get_device() |
|
|
|
model = DecoderTransformer(config) |
|
model.load_state_dict(torch.load(config.saved_model_path, weights_only=True)) |
|
model.to(device) |
|
model.eval() |
|
|
|
enc = tiktoken.get_encoding('gpt2') |
|
tokens = enc.encode(text) |
|
T = len(tokens) |
|
input_tensor = torch.tensor(tokens, device=device) |
|
input_tensor = input_tensor.view(1, T) |
|
|
|
max_output_len = 30 |
|
y = predict(input_tensor, model, max_output_len=max_output_len) |
|
output_tokens = y[0, :].tolist() |
|
return enc.decode(output_tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_sequence, |
|
inputs=gr.Textbox(), |
|
outputs=gr.Textbox(), |
|
title="Text Generation", |
|
description="Enter text to generate a continuation", |
|
allow_flagging=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |