import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import re, json from html import escape # ─── Configuration ───────────────────────────────────────────────────────── PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3" FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1" USAGE_LIMIT = 5 RESET_MS = 20 * 60 * 1000 # 20 minutes in milliseconds device = "cuda" if torch.cuda.is_available() else "cpu" primary_model = primary_tokenizer = None fallback_model = fallback_tokenizer = None # ─── Model Loading ───────────────────────────────────────────────────────── def load_models(): global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True) primary_model = AutoModelForCausalLM.from_pretrained( PRIMARY_MODEL, torch_dtype=torch.float16 ).to(device).eval() fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True) fallback_model = AutoModelForCausalLM.from_pretrained( FALLBACK_MODEL, torch_dtype=torch.float16 ).to(device).eval() return f"✅ Loaded {PRIMARY_MODEL} (fallback: {FALLBACK_MODEL})" # ─── Build Chat Prompt ────────────────────────────────────────────────────── def build_chat_prompt(history, user_input, reasoning_enabled): system_flag = "/think" if reasoning_enabled else "/no_think" prompt = f"<|system|>\n{system_flag}\n" for u, a in history: prompt += f"<|user|>\n{u}\n<|assistant|>\n{a}\n" prompt += f"<|user|>\n{user_input}\n<|assistant|>\n" return prompt # ─── Collapse Blocks ──────────────────────────────────────────────── def format_thinking(text): match = re.search(r"(.*?)", text, re.DOTALL) if not match: return escape(text) reasoning = escape(match.group(1).strip()) visible = re.sub(r".*?", "[thinking...]", text, flags=re.DOTALL).strip() return ( escape(visible) + "
🧠 Show reasoning" + f"
{reasoning}
" ) # ─── Token-by-Token Streaming (Stops on <|user|>) ───────────────────────── def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9): model = fallback_model if use_fallback else primary_model tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) generated = input_ids assistant_text = "" for _ in range(max_length): # 1) Get next-token logits and apply top-p logits = model(generated).logits[:, -1, :] / temperature sorted_logits, idxs = torch.sort(logits, descending=True) probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) mask = probs > top_p mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = 0 filtered = logits.clone() filtered[:, idxs[mask]] = -float("Inf") # 2) Sample and append next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) generated = torch.cat([generated, next_token], dim=-1) new_text = tokenizer.decode(next_token[0], skip_special_tokens=False) assistant_text += new_text # 3) Remove any leading assistant tag if assistant_text.startswith("<|assistant|>"): assistant_text = assistant_text[len("<|assistant|>"):] # 4) If we see a user‐turn tag, truncate and bail if "<|user|>" in assistant_text: assistant_text = assistant_text.split("<|user|>")[0] yield assistant_text break # 5) Otherwise stream clean assistant text yield assistant_text # 6) End if EOS if next_token.item() == tokenizer.eos_token_id: break # ─── Main Chat Handler ────────────────────────────────────────────────────── def respond(message, history, reasoning_enabled, limit_json): # parse client-side usage info info = json.loads(limit_json) if limit_json else {"count": 0} count = info.get("count", 0) use_fallback = count > USAGE_LIMIT remaining = max(0, USAGE_LIMIT - count) model_label = "A3" if not use_fallback else "Fallback A1" # initial yield to set "Generating…" prompt = build_chat_prompt(history, message.strip(), reasoning_enabled) history = history + [[message, ""]] yield history, history, f"🧠 A3 left: {remaining}", "Generating…" # stream assistant reply for chunk in generate_stream(prompt, use_fallback): formatted = format_thinking(chunk) history[-1][1] = ( f"{formatted}
({model_label})" ) yield history, history, f"🧠 A3 left: {remaining}", "Generating…" # final yield resets button text yield history, history, f"🧠 A3 left: {remaining}", "Send" # ─── Clear Chat ───────────────────────────────────────────────────────────── def clear_chat(): return [], [], "🧠 A3 left: 5", "Send" # ─── Gradio UI ────────────────────────────────────────────────────────────── with gr.Blocks() as demo: # Inject client-side JS + CSS gr.HTML(f""" """) gr.Markdown("# 🤖 SamAI – Chat Reasoning (Final)") # Hidden textbox ferrying usage JSON from JS → Python limit_json = gr.Textbox(visible=False, elem_id="limit_json") model_status = gr.Textbox(interactive=False, label="Model Status") usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False) chat_box = gr.Chatbot(type="tuples") chat_state = gr.State([]) with gr.Row(): user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6) reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1) send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1) clear_btn = gr.Button("Clear") model_status.value = load_models() # Bind Send button -> respond() send_btn.click( fn=respond, inputs=[user_input, chat_state, reason_toggle, limit_json], outputs=[chat_box, chat_state, usage_counter, send_btn] ) clear_btn.click( fn=clear_chat, inputs=[], outputs=[chat_box, chat_state, usage_counter, send_btn] ) demo.queue() demo.launch()