Boning c commited on
Commit
4d45982
Β·
verified Β·
1 Parent(s): 5283875

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -21
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re, json
5
  from html import escape
6
 
7
- # ─── Config ─────────────────────────────────────────────────────────
8
  PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
  USAGE_LIMIT = 5
@@ -14,7 +14,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
16
 
17
- # ─── Load Models ───────────────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
@@ -25,9 +25,9 @@ def load_models():
25
  fallback_model = AutoModelForCausalLM.from_pretrained(
26
  FALLBACK_MODEL, torch_dtype=torch.float16
27
  ).to(device).eval()
28
- return f"βœ… Loaded {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
29
 
30
- # ─── Build Qwen-style Prompt ──────────────────────────────────────────
31
  def build_chat_prompt(history, user_input, reasoning_enabled):
32
  system_flag = "/think" if reasoning_enabled else "/no_think"
33
  prompt = f"<|system|>\n{system_flag}\n"
@@ -36,20 +36,20 @@ def build_chat_prompt(history, user_input, reasoning_enabled):
36
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
37
  return prompt
38
 
39
- # ─── Collapse <think> Blocks ──────────────────────────────────────────
40
  def format_thinking(text):
41
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
42
  if not match:
43
  return escape(text)
44
  reasoning = escape(match.group(1).strip())
45
- visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
46
  return (
47
  escape(visible)
48
  + "<br><details><summary>🧠 Show reasoning</summary>"
49
  + f"<pre>{reasoning}</pre></details>"
50
  )
51
 
52
- # ─── Token‐by‐Token Streaming ─────────────────────────────────────────
53
  def generate_stream(prompt, use_fallback=False,
54
  max_length=100, temperature=0.2, top_p=0.9):
55
  model = fallback_model if use_fallback else primary_model
@@ -59,7 +59,7 @@ def generate_stream(prompt, use_fallback=False,
59
  assistant_text = ""
60
 
61
  for _ in range(max_length):
62
- # 1) Sample next token
63
  logits = model(generated).logits[:, -1, :] / temperature
64
  sorted_logits, idxs = torch.sort(logits, descending=True)
65
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
@@ -68,40 +68,64 @@ def generate_stream(prompt, use_fallback=False,
68
  mask[..., 0] = 0
69
  filtered = logits.clone()
70
  filtered[:, idxs[mask]] = -float("Inf")
 
 
71
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
72
  generated = torch.cat([generated, next_token], dim=-1)
73
-
74
- # 2) Decode and append
75
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
76
  assistant_text += new_text
77
 
78
- # 3) Strip starting <|assistant|> if present
79
  if assistant_text.startswith("<|assistant|>"):
80
  assistant_text = assistant_text[len("<|assistant|>"):]
81
 
82
- # 4) If the accumulated text contains the user tag, truncate and stop
83
  if "<|user|>" in assistant_text:
84
- # drop the user tag and anything after it
85
  assistant_text = assistant_text.split("<|user|>")[0]
86
  yield assistant_text
87
  break
88
 
89
- # 5) Otherwise stream the clean assistant text
90
  yield assistant_text
91
 
92
- # 6) Stop on EOS
93
  if next_token.item() == tokenizer.eos_token_id:
94
  break
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def clear_chat():
97
  return [], [], "🧠 A3 left: 5", "Send"
98
 
99
- # ─── Gradio UI ─────────────────────────────────────────────────────────
100
  with gr.Blocks() as demo:
101
  # Inject client-side JS + CSS
102
  gr.HTML(f"""
103
  <script>
104
- // bump/reset usage in localStorage and write to hidden textbox
105
  function updateUsageLimit() {{
106
  const key = "samai_limit";
107
  const now = Date.now();
@@ -114,7 +138,6 @@ with gr.Blocks() as demo:
114
  localStorage.setItem(key, JSON.stringify(rec));
115
  document.getElementById("limit_json").value = JSON.stringify(rec);
116
  }}
117
- // on Send click: update limit & flip button text
118
  document.addEventListener("DOMContentLoaded", () => {{
119
  const btn = document.getElementById("send_btn");
120
  btn.addEventListener("click", () => {{
@@ -133,11 +156,11 @@ with gr.Blocks() as demo:
133
  text-align: center;
134
  }}
135
  </style>
136
- """)
137
 
138
  gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Final)")
139
 
140
- # carry usage JSON from JS β†’ Python
141
  limit_json = gr.Textbox(visible=False, elem_id="limit_json")
142
  model_status = gr.Textbox(interactive=False, label="Model Status")
143
  usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
@@ -154,14 +177,16 @@ with gr.Blocks() as demo:
154
 
155
  model_status.value = load_models()
156
 
 
157
  send_btn.click(
158
  fn=respond,
159
  inputs=[user_input, chat_state, reason_toggle, limit_json],
160
  outputs=[chat_box, chat_state, usage_counter, send_btn]
161
  )
 
162
  clear_btn.click(
163
  fn=clear_chat,
164
- inputs=[],
165
  outputs=[chat_box, chat_state, usage_counter, send_btn]
166
  )
167
 
 
4
  import re, json
5
  from html import escape
6
 
7
+ # ─── Configuration ─────────────────────────────────────────────────────────
8
  PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
  USAGE_LIMIT = 5
 
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
16
 
17
+ # ─── Model Loading ─────────────────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
 
25
  fallback_model = AutoModelForCausalLM.from_pretrained(
26
  FALLBACK_MODEL, torch_dtype=torch.float16
27
  ).to(device).eval()
28
+ return f"βœ… Loaded {PRIMARY_MODEL} (fallback: {FALLBACK_MODEL})"
29
 
30
+ # ─── Build Chat Prompt ──────────────────────────────────────────────────────
31
  def build_chat_prompt(history, user_input, reasoning_enabled):
32
  system_flag = "/think" if reasoning_enabled else "/no_think"
33
  prompt = f"<|system|>\n{system_flag}\n"
 
36
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
37
  return prompt
38
 
39
+ # ─── Collapse <think> Blocks ────────────────────────────────────────────────
40
  def format_thinking(text):
41
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
42
  if not match:
43
  return escape(text)
44
  reasoning = escape(match.group(1).strip())
45
+ visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
46
  return (
47
  escape(visible)
48
  + "<br><details><summary>🧠 Show reasoning</summary>"
49
  + f"<pre>{reasoning}</pre></details>"
50
  )
51
 
52
+ # ─── Token-by-Token Streaming (Stops on <|user|>) ─────────────────────────
53
  def generate_stream(prompt, use_fallback=False,
54
  max_length=100, temperature=0.2, top_p=0.9):
55
  model = fallback_model if use_fallback else primary_model
 
59
  assistant_text = ""
60
 
61
  for _ in range(max_length):
62
+ # 1) Get next-token logits and apply top-p
63
  logits = model(generated).logits[:, -1, :] / temperature
64
  sorted_logits, idxs = torch.sort(logits, descending=True)
65
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
 
68
  mask[..., 0] = 0
69
  filtered = logits.clone()
70
  filtered[:, idxs[mask]] = -float("Inf")
71
+
72
+ # 2) Sample and append
73
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
74
  generated = torch.cat([generated, next_token], dim=-1)
 
 
75
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
76
  assistant_text += new_text
77
 
78
+ # 3) Remove any leading assistant tag
79
  if assistant_text.startswith("<|assistant|>"):
80
  assistant_text = assistant_text[len("<|assistant|>"):]
81
 
82
+ # 4) If we see a user‐turn tag, truncate and bail
83
  if "<|user|>" in assistant_text:
 
84
  assistant_text = assistant_text.split("<|user|>")[0]
85
  yield assistant_text
86
  break
87
 
88
+ # 5) Otherwise stream clean assistant text
89
  yield assistant_text
90
 
91
+ # 6) End if EOS
92
  if next_token.item() == tokenizer.eos_token_id:
93
  break
94
 
95
+ # ─── Main Chat Handler ──────────────────────────────────────────────────────
96
+ def respond(message, history, reasoning_enabled, limit_json):
97
+ # parse client-side usage info
98
+ info = json.loads(limit_json) if limit_json else {"count": 0}
99
+ count = info.get("count", 0)
100
+ use_fallback = count > USAGE_LIMIT
101
+ remaining = max(0, USAGE_LIMIT - count)
102
+ model_label = "A3" if not use_fallback else "Fallback A1"
103
+
104
+ # initial yield to set "Generating…"
105
+ prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
106
+ history = history + [[message, ""]]
107
+ yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
108
+
109
+ # stream assistant reply
110
+ for chunk in generate_stream(prompt, use_fallback):
111
+ formatted = format_thinking(chunk)
112
+ history[-1][1] = (
113
+ f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
114
+ )
115
+ yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
116
+
117
+ # final yield resets button text
118
+ yield history, history, f"🧠 A3 left: {remaining}", "Send"
119
+
120
+ # ─── Clear Chat ─────────────────────────────────────────────────────────────
121
  def clear_chat():
122
  return [], [], "🧠 A3 left: 5", "Send"
123
 
124
+ # ─── Gradio UI ──────────────────────────────────────────────────────────────
125
  with gr.Blocks() as demo:
126
  # Inject client-side JS + CSS
127
  gr.HTML(f"""
128
  <script>
 
129
  function updateUsageLimit() {{
130
  const key = "samai_limit";
131
  const now = Date.now();
 
138
  localStorage.setItem(key, JSON.stringify(rec));
139
  document.getElementById("limit_json").value = JSON.stringify(rec);
140
  }}
 
141
  document.addEventListener("DOMContentLoaded", () => {{
142
  const btn = document.getElementById("send_btn");
143
  btn.addEventListener("click", () => {{
 
156
  text-align: center;
157
  }}
158
  </style>
159
+ """)
160
 
161
  gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Final)")
162
 
163
+ # Hidden textbox ferrying usage JSON from JS β†’ Python
164
  limit_json = gr.Textbox(visible=False, elem_id="limit_json")
165
  model_status = gr.Textbox(interactive=False, label="Model Status")
166
  usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
 
177
 
178
  model_status.value = load_models()
179
 
180
+ # Bind Send button -> respond()
181
  send_btn.click(
182
  fn=respond,
183
  inputs=[user_input, chat_state, reason_toggle, limit_json],
184
  outputs=[chat_box, chat_state, usage_counter, send_btn]
185
  )
186
+
187
  clear_btn.click(
188
  fn=clear_chat,
189
+ inputs=[],
190
  outputs=[chat_box, chat_state, usage_counter, send_btn]
191
  )
192