File size: 3,495 Bytes
99f24ee
 
 
 
 
 
 
da0b79d
99f24ee
 
 
 
 
 
 
 
 
 
 
afe17cd
99f24ee
 
 
 
 
1046e61
99f24ee
1046e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f24ee
 
 
1046e61
99f24ee
 
1046e61
4aab56b
99f24ee
1046e61
 
 
a23a366
 
 
1046e61
 
a23a366
1046e61
99f24ee
 
 
 
1046e61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import tiktoken
import gradio as gr
from model import GPT, GPTConfig
from torch.nn import functional as F

# Set device
device = 'cpu'

# Load the tokenizer
TOKENIZER = tiktoken.get_encoding('gpt2')

# Load untrained model
UNTRAINED_MODEL = GPT(GPTConfig)
UNTRAINED_MODEL.to(device)
UNTRAINED_MODEL.eval()

# Load fine-tuned model 
TRAINED_MODEL = GPT(GPTConfig)
checkpoint = torch.load("model_19072.pt", weights_only=False, map_location=torch.device('cpu'))
TRAINED_MODEL.load_state_dict(checkpoint["model"])
TRAINED_MODEL.to(device)
TRAINED_MODEL.eval()


def generate_text(input, max_length=30, top_k=50):
    tokens = TOKENIZER.encode(input)
    x_untrained = torch.tensor([tokens], dtype=torch.long).to(device)
    x_trained   = torch.tensor([tokens], dtype=torch.long).to(device)

    # Iterate until one of the sequences reaches max_length
    while (x_untrained.size(1) < max_length) or (x_trained.size(1) < max_length):

        # --- Untrained Model Forward Pass ---
        if x_untrained.size(1) < max_length:
            with torch.no_grad():
                logits_u, _ = UNTRAINED_MODEL(x_untrained)
                logits_u = logits_u[:, -1, :]
                probs_u = F.softmax(logits_u, dim=-1)
                topk_probs_u, topk_indices_u = torch.topk(probs_u, top_k, dim=-1)
                ix_u = torch.multinomial(topk_probs_u, 1)
                next_token_u = torch.gather(topk_indices_u, -1, ix_u)
            x_untrained = torch.cat((x_untrained, next_token_u), dim=1)

        # --- Trained Model Forward Pass ---
        if x_trained.size(1) < max_length:
            with torch.no_grad():
                logits_t, _ = TRAINED_MODEL(x_trained)
                logits_t = logits_t[:, -1, :]
                probs_t = F.softmax(logits_t, dim=-1)
                topk_probs_t, topk_indices_t = torch.topk(probs_t, top_k, dim=-1)
                ix_t = torch.multinomial(topk_probs_t, 1)
                next_token_t = torch.gather(topk_indices_t, -1, ix_t)
            x_trained = torch.cat((x_trained, next_token_t), dim=1)

        # --- Decode the partial text for each model ---
        untrained_text = TOKENIZER.decode(x_untrained[0].tolist())
        trained_text   = TOKENIZER.decode(x_trained[0].tolist())

        yield (untrained_text, trained_text)


def streaming_fn(prompt, max_length=30, top_k=50):
    for untrained_text, trained_text in generate_text(prompt, max_length, top_k):
        output = (
            f"------------ (Untrained Model) ------------\n\n {untrained_text}\n\n\n"
            f"------------ (Trained Model)------------\n\n {trained_text}"
            )
        yield output
    

def main():
    interface = gr.Interface(
        fn=streaming_fn,
        inputs=[
            gr.Textbox(label="Enter your prompt here:"),
            gr.Slider(minimum=10, maximum=150, step=10, label="Max Length"),
            gr.Slider(minimum=1, maximum=50, step=1, label="Top-K Samples")
        ],
        outputs=gr.Textbox(label="Model Outputs"),
        title="GPT-2 Streaming Text Generator",
        description= (
            "Generate text using an untrained and a trained GPT-2 model. "
            "Use prompts that are short, simple and easy to generate coherent looking text. "
            "For example: \n"
            "- \"This is a summary of\" \n"
            "- \"In this article\" \n"
            "- \"Hello, my name is\" \n"
        )
    )
    interface.launch(share=True)

if __name__ == "__main__":
    main()