File size: 2,684 Bytes
e99e68b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "Shekswess/trlm-135m"

# Load tokenizer & model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
model.to(device)
model.eval()


def generate_reply(prompt, max_new_tokens, temperature, top_p):
    if not prompt.strip():
        return ""

    # Use the model's chat template (as in the README)
    messages = [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(text, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            do_sample=True,
            temperature=float(temperature),
            top_p=float(top_p),
            pad_token_id=tokenizer.eos_token_id,
        )

    # Drop the prompt tokens and decode only the completion
    generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
    decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)

    return decoded.strip()


with gr.Blocks() as demo:
    gr.Markdown("# Tiny Reasoning LM (trlm-135m)\nSmall 135M reasoning model by **Shekswess**.")

    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(
                lines=8,
                label="Prompt",
                placeholder="Ask a question or give an instruction…",
            )
            max_new_tokens = gr.Slider(
                minimum=16,
                maximum=256,
                value=128,
                step=8,
                label="Max new tokens",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=1.5,
                value=0.8,
                step=0.05,
                label="Temperature",
            )
            top_p = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.9,
                step=0.05,
                label="Top-p",
            )
            generate_btn = gr.Button("Generate")

        with gr.Column(scale=4):
            output = gr.Textbox(
                lines=12,
                label="Model Output",
            )

    generate_btn.click(
        fn=generate_reply,
        inputs=[prompt, max_new_tokens, temperature, top_p],
        outputs=[output],
    )

if __name__ == "__main__":
    demo.launch()