File size: 3,549 Bytes
24a676b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Model definitions
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1"
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1"
USAGE_LIMIT = 10

device = "cuda" if torch.cuda.is_available() else "cpu"

# Globals for models and tokenizers
primary_model, primary_tokenizer = None, None
fallback_model, fallback_tokenizer = None, None

# IP-based usage tracking
usage_counts = {}

def load_models():
    global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
    primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL)
    primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval()
    fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
    fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval()
    return f"Models loaded: {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}"

def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9):
    model = fallback_model if use_fallback else primary_model
    tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    generated = input_ids
    output_text = tokenizer.decode(input_ids[0])

    for _ in range(max_length):
        outputs = model(generated)
        logits = outputs.logits[:, -1, :] / temperature
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
        mask = probs > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = 0
        filtered = logits.clone()
        filtered[:, sorted_indices[mask]] = -float("Inf")
        next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
        generated = torch.cat([generated, next_token], dim=-1)
        new_text = tokenizer.decode(next_token[0])
        output_text += new_text
        yield output_text
        if next_token.item() == tokenizer.eos_token_id:
            break

def respond(msg, history, reasoning_enabled, request: gr.Request):
    ip = request.client.host if request else "unknown"
    usage_counts[ip] = usage_counts.get(ip, 0) + 1
    use_fallback = usage_counts[ip] > USAGE_LIMIT
    model_used = "A1" if not use_fallback else "Fallback S2.1"
    prefix = "/think " if reasoning_enabled else "/no_think "
    prompt = prefix + msg.strip()
    history = history + [[msg, ""]]
    for output in generate_stream(prompt, use_fallback):
        history[-1][1] = output + f" ({model_used})"
        yield history, history

def clear_chat():
    return [], []

with gr.Blocks() as demo:
    gr.Markdown("# 🤖 SmilyAI Reasoning Chat • Token-by-Token + IP Usage Limits")

    model_status = gr.Textbox(label="Model Load Status", interactive=False)
    chat_box = gr.Chatbot(label="Chat", type="tuples")
    chat_state = gr.State([])

    with gr.Row():
        user_input = gr.Textbox(placeholder="Your message here...", show_label=False, scale=6)
        reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
        send_btn = gr.Button("Send", scale=1)

    clear_btn = gr.Button("Clear Chat")

    model_status.value = load_models()

    send_btn.click(
        respond,
        inputs=[user_input, chat_state, reason_toggle],
        outputs=[chat_box, chat_state]
    )

    clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state])

demo.queue()
demo.launch()