import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Model definitions PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1" FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1" USAGE_LIMIT = 10 device = "cuda" if torch.cuda.is_available() else "cpu" # Globals for models and tokenizers primary_model, primary_tokenizer = None, None fallback_model, fallback_tokenizer = None, None # IP-based usage tracking usage_counts = {} def load_models(): global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL) primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval() fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval() return f"Models loaded: {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}" def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9): model = fallback_model if use_fallback else primary_model tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) generated = input_ids output_text = tokenizer.decode(input_ids[0]) for _ in range(max_length): outputs = model(generated) logits = outputs.logits[:, -1, :] / temperature sorted_logits, sorted_indices = torch.sort(logits, descending=True) probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) mask = probs > top_p mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = 0 filtered = logits.clone() filtered[:, sorted_indices[mask]] = -float("Inf") next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) generated = torch.cat([generated, next_token], dim=-1) new_text = tokenizer.decode(next_token[0]) output_text += new_text yield output_text if next_token.item() == tokenizer.eos_token_id: break def respond(msg, history, reasoning_enabled, request: gr.Request): ip = request.client.host if request else "unknown" usage_counts[ip] = usage_counts.get(ip, 0) + 1 use_fallback = usage_counts[ip] > USAGE_LIMIT model_used = "A1" if not use_fallback else "Fallback S2.1" prefix = "/think " if reasoning_enabled else "/no_think " prompt = prefix + msg.strip() history = history + [[msg, ""]] for output in generate_stream(prompt, use_fallback): history[-1][1] = output + f" ({model_used})" yield history, history def clear_chat(): return [], [] with gr.Blocks() as demo: gr.Markdown("# 🤖 SmilyAI Reasoning Chat • Token-by-Token + IP Usage Limits") model_status = gr.Textbox(label="Model Load Status", interactive=False) chat_box = gr.Chatbot(label="Chat", type="tuples") chat_state = gr.State([]) with gr.Row(): user_input = gr.Textbox(placeholder="Your message here...", show_label=False, scale=6) reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1) send_btn = gr.Button("Send", scale=1) clear_btn = gr.Button("Clear Chat") model_status.value = load_models() send_btn.click( respond, inputs=[user_input, chat_state, reason_toggle], outputs=[chat_box, chat_state] ) clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state]) demo.queue() demo.launch()