AseemD commited on
Commit
1046e61
·
1 Parent(s): da0b79d

Added output streaming support for the gradio app.

Browse files
Files changed (1) hide show
  1. app.py +59 -48
app.py CHANGED
@@ -17,67 +17,78 @@ UNTRAINED_MODEL.eval()
17
 
18
  # Load fine-tuned model
19
  TRAINED_MODEL = GPT(GPTConfig)
20
- checkpoint = torch.load("model_19072.pt", weights_only=False, map_location=torch.device('cpu'))
21
  TRAINED_MODEL.load_state_dict(checkpoint["model"])
22
  TRAINED_MODEL.to(device)
23
  TRAINED_MODEL.eval()
24
 
25
 
26
- def generate_text(input, model, num_sequences, max_length):
27
  tokens = TOKENIZER.encode(input)
28
- tokens = torch.tensor(tokens, dtype=torch.long)
29
- tokens = tokens.unsqueeze(0).repeat(num_sequences, 1)
30
- x = tokens.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- sentences = []
33
- while x.size(1) < max_length:
34
- with torch.no_grad():
35
- logits, loss = model(x)
36
- logits = logits[:, -1, :]
37
- probs = F.softmax(logits, dim=-1)
38
- topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
39
- ix = torch.multinomial(topk_probs, 1)
40
- xcol = torch.gather(topk_indices, -1, ix)
41
- x = torch.cat((x, xcol), dim=1)
42
-
43
- for i in range(num_sequences):
44
- tokens = x[i, :max_length].tolist()
45
- decoded = TOKENIZER.decode(tokens)
46
- sentences.append(decoded)
47
-
48
- return sentences
49
-
50
-
51
- def gradio_fn(prompt, num_sequences=1, max_length=30):
52
- """Generate text using both models."""
53
- untrained_texts = generate_text(prompt, UNTRAINED_MODEL, num_sequences, max_length)
54
- untrained_output = "\n\n".join(f"> {s}" for s in untrained_texts)
55
-
56
- trained_texts = generate_text(prompt, TRAINED_MODEL, num_sequences, max_length)
57
- trained_output = "\n\n".join(f"> {s}" for s in trained_texts)
58
-
59
- return untrained_output, trained_output
60
-
61
- # Gradio interface
62
  def main():
63
  interface = gr.Interface(
64
- fn=gradio_fn,
65
  inputs=[
66
  gr.Textbox(label="Enter your prompt here:"),
67
- gr.Slider(minimum=1, maximum=10, step=1, label="Number of Generations"),
68
- gr.Slider(minimum=10, maximum=100, step=10, label="Max Length"),
69
- ],
70
- outputs=[
71
- gr.Textbox(label="Generated Text (Untrained Model)"),
72
- gr.Textbox(label="Generated Text (Trained Model)"),
73
  ],
74
- title="GPT-2 Text Generator",
75
- description="Generate text an untrained and a trained GPT-2 model."
 
 
 
 
 
 
 
 
76
  )
77
-
78
  interface.launch(share=True)
79
 
80
  if __name__ == "__main__":
81
- main()
82
-
83
-
 
17
 
18
  # Load fine-tuned model
19
  TRAINED_MODEL = GPT(GPTConfig)
20
+ checkpoint = torch.load("log/model_19072.pt", weights_only=False)
21
  TRAINED_MODEL.load_state_dict(checkpoint["model"])
22
  TRAINED_MODEL.to(device)
23
  TRAINED_MODEL.eval()
24
 
25
 
26
+ def generate_text(input, max_length=30, top_k=50):
27
  tokens = TOKENIZER.encode(input)
28
+ x_untrained = torch.tensor([tokens], dtype=torch.long).to(device)
29
+ x_trained = torch.tensor([tokens], dtype=torch.long).to(device)
30
+
31
+ # Iterate until one of the sequences reaches max_length
32
+ while (x_untrained.size(1) < max_length) or (x_trained.size(1) < max_length):
33
+
34
+ # --- Untrained Model Forward Pass ---
35
+ if x_untrained.size(1) < max_length:
36
+ with torch.no_grad():
37
+ logits_u, _ = UNTRAINED_MODEL(x_untrained)
38
+ logits_u = logits_u[:, -1, :]
39
+ probs_u = F.softmax(logits_u, dim=-1)
40
+ topk_probs_u, topk_indices_u = torch.topk(probs_u, top_k, dim=-1)
41
+ ix_u = torch.multinomial(topk_probs_u, 1)
42
+ next_token_u = torch.gather(topk_indices_u, -1, ix_u)
43
+ x_untrained = torch.cat((x_untrained, next_token_u), dim=1)
44
+
45
+ # --- Trained Model Forward Pass ---
46
+ if x_trained.size(1) < max_length:
47
+ with torch.no_grad():
48
+ logits_t, _ = TRAINED_MODEL(x_trained)
49
+ logits_t = logits_t[:, -1, :]
50
+ probs_t = F.softmax(logits_t, dim=-1)
51
+ topk_probs_t, topk_indices_t = torch.topk(probs_t, top_k, dim=-1)
52
+ ix_t = torch.multinomial(topk_probs_t, 1)
53
+ next_token_t = torch.gather(topk_indices_t, -1, ix_t)
54
+ x_trained = torch.cat((x_trained, next_token_t), dim=1)
55
+
56
+ # --- Decode the partial text for each model ---
57
+ untrained_text = TOKENIZER.decode(x_untrained[0].tolist())
58
+ trained_text = TOKENIZER.decode(x_trained[0].tolist())
59
+
60
+ yield (untrained_text, trained_text)
61
+
62
+
63
+ def streaming_fn(prompt, max_length=30, top_k=50):
64
+ for untrained_text, trained_text in generate_text(prompt, max_length, top_k):
65
+ output = (
66
+ f"------------ (Untrained Model) ------------\n\n {untrained_text}\n\n\n"
67
+ f"------------ (Trained Model)------------\n\n {trained_text}"
68
+ )
69
+ yield output
70
+
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def main():
73
  interface = gr.Interface(
74
+ fn=streaming_fn,
75
  inputs=[
76
  gr.Textbox(label="Enter your prompt here:"),
77
+ gr.Slider(minimum=10, maximum=150, step=10, label="Max Length"),
78
+ gr.Slider(minimum=1, maximum=50, step=10, label="Top-K Samples")
 
 
 
 
79
  ],
80
+ outputs=gr.Textbox(label="Model Outputs"),
81
+ title="GPT-2 Streaming Text Generator",
82
+ description= (
83
+ "Generate text using an untrained and a trained GPT-2 model."
84
+ "Use prompts that are short, simple and easy to generate coherent looking text."
85
+ "For eg: \n"
86
+ "- \"Hello, my name is\" \n"
87
+ "- \"This is a summary of\" \n"
88
+ "- \"In this article\" \n"
89
+ )
90
  )
 
91
  interface.launch(share=True)
92
 
93
  if __name__ == "__main__":
94
+ main()