Javedalam commited on
Commit
e99e68b
·
verified ·
1 Parent(s): b283a26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ MODEL_ID = "Shekswess/trlm-135m"
6
+
7
+ # Load tokenizer & model
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ MODEL_ID,
13
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
+ )
15
+ model.to(device)
16
+ model.eval()
17
+
18
+
19
+ def generate_reply(prompt, max_new_tokens, temperature, top_p):
20
+ if not prompt.strip():
21
+ return ""
22
+
23
+ # Use the model's chat template (as in the README)
24
+ messages = [{"role": "user", "content": prompt}]
25
+ text = tokenizer.apply_chat_template(
26
+ messages,
27
+ tokenize=False,
28
+ add_generation_prompt=True,
29
+ )
30
+
31
+ inputs = tokenizer(text, return_tensors="pt").to(device)
32
+
33
+ with torch.no_grad():
34
+ output_ids = model.generate(
35
+ **inputs,
36
+ max_new_tokens=int(max_new_tokens),
37
+ do_sample=True,
38
+ temperature=float(temperature),
39
+ top_p=float(top_p),
40
+ pad_token_id=tokenizer.eos_token_id,
41
+ )
42
+
43
+ # Drop the prompt tokens and decode only the completion
44
+ generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
45
+ decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
46
+
47
+ return decoded.strip()
48
+
49
+
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("# Tiny Reasoning LM (trlm-135m)\nSmall 135M reasoning model by **Shekswess**.")
52
+
53
+ with gr.Row():
54
+ with gr.Column(scale=3):
55
+ prompt = gr.Textbox(
56
+ lines=8,
57
+ label="Prompt",
58
+ placeholder="Ask a question or give an instruction…",
59
+ )
60
+ max_new_tokens = gr.Slider(
61
+ minimum=16,
62
+ maximum=256,
63
+ value=128,
64
+ step=8,
65
+ label="Max new tokens",
66
+ )
67
+ temperature = gr.Slider(
68
+ minimum=0.1,
69
+ maximum=1.5,
70
+ value=0.8,
71
+ step=0.05,
72
+ label="Temperature",
73
+ )
74
+ top_p = gr.Slider(
75
+ minimum=0.1,
76
+ maximum=1.0,
77
+ value=0.9,
78
+ step=0.05,
79
+ label="Top-p",
80
+ )
81
+ generate_btn = gr.Button("Generate")
82
+
83
+ with gr.Column(scale=4):
84
+ output = gr.Textbox(
85
+ lines=12,
86
+ label="Model Output",
87
+ )
88
+
89
+ generate_btn.click(
90
+ fn=generate_reply,
91
+ inputs=[prompt, max_new_tokens, temperature, top_p],
92
+ outputs=[output],
93
+ )
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch()