Keeby-smilyai commited on
Commit
0b0ac5e
Β·
verified Β·
1 Parent(s): aae85c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -0
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------
2
+ # app.py β€” Sam-3.5: The Reasoning AI (Updated Architecture)
3
+ # -------------------------------
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from pathlib import Path
10
+ from safetensors.torch import load_file
11
+ from transformers import AutoTokenizer
12
+ from dataclasses import dataclass
13
+ from typing import Dict, List
14
+ import gradio as gr
15
+ import os
16
+ from huggingface_hub import hf_hub_download
17
+ import json
18
+
19
+ # -------------------------------
20
+ # 1) Configuration & Special Tokens
21
+ # -------------------------------
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ print(f"Using device: {device}")
24
+
25
+ SPECIAL_TOKENS = {
26
+ "bos": "<|bos|>",
27
+ "eot": "<|eot|>",
28
+ "user": "<|user|>",
29
+ "assistant": "<|assistant|>",
30
+ "system": "<|system|>",
31
+ "think": "<|think|>", # Keep this for reasoning display
32
+ }
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
35
+ if tokenizer.pad_token is None:
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+ tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
38
+
39
+ SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()}
40
+ EOT_ID = SPECIAL_IDS.get("eot", tokenizer.eos_token_id)
41
+ THINK_ID = SPECIAL_IDS.get("think")
42
+ assert THINK_ID is not None, "Tokenizer must include <|think|> token"
43
+
44
+ MAX_LENGTH = 1024
45
+
46
+ # -------------------------------
47
+ # 2) Model Architecture (Sam-3.5)
48
+ # -------------------------------
49
+ class RMSNorm(nn.Module):
50
+ def __init__(self, d, eps=1e-6):
51
+ super().__init__()
52
+ self.eps = eps
53
+ self.weight = nn.Parameter(torch.ones(d))
54
+ def forward(self, x):
55
+ return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
56
+
57
+ class MHA(nn.Module):
58
+ def __init__(self, d_model, n_heads, dropout=0.0):
59
+ super().__init__()
60
+ if d_model % n_heads != 0:
61
+ raise ValueError("d_model must be divisible by n_heads")
62
+ self.n_heads = n_heads
63
+ self.head_dim = d_model // n_heads
64
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
65
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
66
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
67
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
68
+ self.dropout = nn.Dropout(dropout)
69
+ def forward(self, x, attn_mask=None):
70
+ B, T, C = x.shape
71
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
72
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
73
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
74
+ out = F.scaled_dot_product_attention(
75
+ q, k, v,
76
+ is_causal=True,
77
+ dropout_p=self.dropout.p if self.training else 0.0
78
+ )
79
+ return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
80
+
81
+ class SwiGLU(nn.Module):
82
+ def __init__(self, d_model, d_ff, dropout=0.0):
83
+ super().__init__()
84
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
85
+ self.w2 = nn.Linear(d_model, d_ff, bias=False)
86
+ self.w3 = nn.Linear(d_ff, d_model, bias=False)
87
+ self.dropout = nn.Dropout(dropout)
88
+ def forward(self, x):
89
+ return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x)))
90
+
91
+ class Block(nn.Module):
92
+ def __init__(self, d_model, n_heads, ff_mult, dropout=0.0):
93
+ super().__init__()
94
+ self.norm1 = RMSNorm(d_model)
95
+ self.attn = MHA(d_model, n_heads, dropout=dropout)
96
+ self.norm2 = RMSNorm(d_model)
97
+ self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout)
98
+ self.drop = nn.Dropout(dropout)
99
+ def forward(self, x, attn_mask=None):
100
+ x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask))
101
+ x = x + self.drop(self.ff(self.norm2(x)))
102
+ return x
103
+
104
+ @dataclass
105
+ class Sam3Config:
106
+ vocab_size: int
107
+ d_model: int = 468
108
+ n_layers: int = 14
109
+ n_heads: int = 6
110
+ ff_mult: float = 4.0
111
+ dropout: float = 0.1
112
+ input_modality: str = "text"
113
+ head_type: str = "causal_lm"
114
+ version: str = "0.1"
115
+
116
+ class Sam3(nn.Module):
117
+ def __init__(self, config: Sam3Config):
118
+ super().__init__()
119
+ self.config = config
120
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
121
+ self.blocks = nn.ModuleList([
122
+ Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout)
123
+ for _ in range(config.n_layers)
124
+ ])
125
+ self.norm = RMSNorm(config.d_model)
126
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
127
+ self.lm_head.weight = self.embed.weight # Weight tying
128
+
129
+ def forward(self, input_ids, attention_mask=None):
130
+ x = self.embed(input_ids)
131
+ for blk in self.blocks:
132
+ x = blk(x, attn_mask=attention_mask)
133
+ x = self.norm(x)
134
+ return self.lm_head(x)
135
+
136
+ # -------------------------------
137
+ # 3) Load Model from Hugging Face Hub
138
+ # -------------------------------
139
+ def load_sam3_model_from_hf(repo_id: str, filename: str = "sam3-epoch1-best.safetensors"):
140
+ print(f"πŸ“₯ Loading config and weights from: {repo_id}")
141
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
142
+ weights_path = hf_hub_download(repo_id=repo_id, filename=filename)
143
+
144
+ with open(config_path, "r") as f:
145
+ config_dict = json.load(f)
146
+
147
+ # Ensure vocab_size matches tokenizer after adding special tokens
148
+ config_dict["vocab_size"] = len(tokenizer)
149
+ config = Sam3Config(**config_dict)
150
+
151
+ model = Sam3(config).to(device)
152
+ state_dict = load_file(weights_path)
153
+ model.load_state_dict(state_dict, strict=False)
154
+
155
+ model.eval()
156
+ print(f"βœ… Model loaded successfully from Hugging Face Hub: {repo_id}")
157
+ return model
158
+
159
+ # Load model
160
+ model = load_sam3_model_from_hf("Smilyai-labs/Sam-3.5-1")
161
+
162
+ # -------------------------------
163
+ # 4) Sampling Function (Enhanced from your original)
164
+ # -------------------------------
165
+ def sample_next_token(
166
+ logits,
167
+ past_tokens,
168
+ temperature=0.8,
169
+ top_k=60,
170
+ top_p=0.9,
171
+ repetition_penalty=1.1,
172
+ max_repeat=5,
173
+ no_repeat_ngram_size=3
174
+ ):
175
+ if logits.dim() == 3:
176
+ logits = logits[:, -1, :].clone()
177
+ else:
178
+ logits = logits.clone()
179
+ batch_size, vocab_size = logits.size(0), logits.size(1)
180
+ orig_logits = logits.clone()
181
+
182
+ if temperature != 1.0:
183
+ logits = logits / float(temperature)
184
+
185
+ past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens)
186
+
187
+ for token_id in set(past_list):
188
+ if 0 <= token_id < vocab_size:
189
+ logits[:, token_id] /= repetition_penalty
190
+
191
+ if len(past_list) >= max_repeat:
192
+ last_token = past_list[-1]
193
+ count = 1
194
+ for i in reversed(past_list[:-1]):
195
+ if i == last_token:
196
+ count += 1
197
+ else:
198
+ break
199
+ if count >= max_repeat:
200
+ if 0 <= last_token < vocab_size:
201
+ logits[:, last_token] = -float("inf")
202
+
203
+ if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size:
204
+ for i in range(len(past_list) - no_repeat_ngram_size + 1):
205
+ ngram = tuple(past_list[i : i + no_repeat_ngram_size])
206
+ if len(past_list) >= no_repeat_ngram_size - 1:
207
+ prefix = tuple(past_list[-(no_repeat_ngram_size - 1):])
208
+ for token_id in range(vocab_size):
209
+ if tuple(list(prefix) + [token_id]) == ngram and 0 <= token_id < vocab_size:
210
+ logits[:, token_id] = -float("inf")
211
+
212
+ if top_k is not None and top_k > 0:
213
+ tk = min(max(1, int(top_k)), vocab_size)
214
+ topk_vals, topk_indices = torch.topk(logits, tk, dim=-1)
215
+ min_topk = topk_vals[:, -1].unsqueeze(-1)
216
+ logits[logits < min_topk] = -float("inf")
217
+
218
+ if top_p is not None and 0.0 < top_p < 1.0:
219
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
220
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
221
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
222
+ for b in range(batch_size):
223
+ sorted_mask = cumulative_probs[b] > top_p
224
+ if sorted_mask.numel() > 0:
225
+ sorted_mask[0] = False
226
+ tokens_to_remove = sorted_indices[b][sorted_mask]
227
+ logits[b, tokens_to_remove] = -float("inf")
228
+
229
+ for b in range(batch_size):
230
+ if torch.isneginf(logits[b]).all():
231
+ logits[b] = orig_logits[b]
232
+
233
+ probs = F.softmax(logits, dim=-1)
234
+ if torch.isnan(probs).any():
235
+ probs = torch.ones_like(logits) / logits.size(1)
236
+
237
+ next_token = torch.multinomial(probs, num_samples=1)
238
+ return next_token.to(device)
239
+
240
+ # -------------------------------
241
+ # 5) Gradio Chat Interface β€” WITH STYLED THINKING STEPS
242
+ # -------------------------------
243
+ def predict(message, history):
244
+ # Build prompt
245
+ chat_history = []
246
+ for human, assistant in history:
247
+ chat_history.append(f"{SPECIAL_TOKENS['user']} {human} {SPECIAL_TOKENS['eot']}")
248
+ if assistant:
249
+ # Assistant responses may contain <|think|>...<|eot|> blocks β€” we don't reconstruct them here
250
+ chat_history.append(f"{SPECIAL_TOKENS['assistant']} {assistant} {SPECIAL_TOKENS['eot']}")
251
+
252
+ chat_history.append(f"{SPECIAL_TOKENS['user']} {message} {SPECIAL_TOKENS['eot']}")
253
+
254
+ system_prompt = "You are Sam-3.5, an advanced reasoning AI. You think step-by-step, analyze deeply, and respond with precision. You do not guess β€” you deduce. Avoid medical or legal advice."
255
+ prompt = f"{SPECIAL_TOKENS['system']} {system_prompt} {SPECIAL_TOKENS['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS['assistant']} {SPECIAL_TOKENS['think']}"
256
+
257
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(device)
258
+ input_ids = inputs["input_ids"]
259
+ attention_mask = inputs["attention_mask"]
260
+
261
+ generated_text = ""
262
+ thinking_mode = False
263
+ thinking_buffer = ""
264
+
265
+ for _ in range(256):
266
+ with torch.no_grad():
267
+ logits = model(input_ids, attention_mask=attention_mask)
268
+ next_token = sample_next_token(
269
+ logits,
270
+ input_ids[0],
271
+ temperature=0.4,
272
+ top_k=50,
273
+ top_p=0.9,
274
+ repetition_penalty=1.1
275
+ )
276
+ token_id = int(next_token.squeeze().item())
277
+ token_str = tokenizer.decode([token_id], skip_special_tokens=False)
278
+
279
+ # Append to sequence
280
+ input_ids = torch.cat([input_ids, next_token], dim=1)
281
+ attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device, dtype=attention_mask.dtype)], dim=1)
282
+
283
+ # Handle thinking mode
284
+ if not thinking_mode and token_str.strip() == "<|think|>":
285
+ thinking_mode = True
286
+ thinking_buffer = ""
287
+ continue
288
+
289
+ if thinking_mode:
290
+ if token_str.strip() == "<|eot|>":
291
+ # End thinking block β†’ yield styled output
292
+ thinking_buffer = thinking_buffer.strip()
293
+ if thinking_buffer:
294
+ yield f"<div style='background-color:#f8f9fa; padding:12px; border-left:4px solid #007bff; border-radius:0 8px 8px 0; margin:10px 0; font-style:italic; color:#495057; font-size:0.95em;'>πŸ’‘ <strong>Thinking:</strong> {thinking_buffer}</div>"
295
+ thinking_mode = False
296
+ continue
297
+ else:
298
+ thinking_buffer += token_str
299
+ continue
300
+
301
+ # Normal output
302
+ if not thinking_mode:
303
+ # Clean token for display (optional: handle GPT-2 space artifacts)
304
+ clean_token = token_str.replace('Ġ', ' ').replace('Ċ', '\n')
305
+ generated_text += clean_token
306
+ yield generated_text
307
+
308
+ # Stop if final EOT (outside thinking block)
309
+ if token_id == EOT_ID and not thinking_mode:
310
+ break
311
+
312
+ # -------------------------------
313
+ # 6) Launch Gradio Interface
314
+ # -------------------------------
315
+ CSS = """
316
+ .gradio-container .message-bubble {
317
+ border-radius: 12px !important;
318
+ padding: 10px 14px !important;
319
+ font-size: 16px !important;
320
+ }
321
+ .gradio-container .message-bubble.user {
322
+ background-color: #007bff !important;
323
+ color: white !important;
324
+ }
325
+ .gradio-container .message-bubble.assistant {
326
+ background-color: #f8f9fa !important;
327
+ color: #212529 !important;
328
+ border: 1px solid #e9ecef;
329
+ }
330
+ """
331
+
332
+ demo = gr.ChatInterface(
333
+ fn=predict,
334
+ title="🧠 Sam-3.5: The Reasoning AI",
335
+ description="""
336
+ Sam-3.5 doesn’t just answer β€” it **thinks first**.
337
+ Watch its internal reasoning unfold in real time β€” step by step, clearly shown.
338
+ No guessing. No fluff. Just pure deduction.
339
+
340
+ Try asking:
341
+ β†’ β€œWhy does a mirror reverse left and right but not up and down?”
342
+ β†’ β€œIf I have 3 apples and give away half, then buy 5 more, how many do I have?”
343
+ β†’ β€œExplain quantum entanglement like I’m 10.”
344
+ β†’ β€œWhat’s wrong with this argument: β€˜All birds fly; penguins are birds; therefore penguins can fly’?”
345
+ """,
346
+ theme=gr.themes.Soft(
347
+ primary_hue="indigo",
348
+ secondary_hue="blue"
349
+ ),
350
+ chatbot=gr.Chatbot(
351
+ label="Sam-3.5 πŸ€”",
352
+ bubble_full_width=False,
353
+ height=600,
354
+ avatar_images=(
355
+ "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-bot.jpg",
356
+ "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-user.jpg"
357
+ )
358
+ ),
359
+ examples=[
360
+ "What is the capital of France?",
361
+ "Explain why the sky is blue.",
362
+ "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?",
363
+ "What are the ethical implications of AI making medical diagnoses?"
364
+ ],
365
+ css=CSS,
366
+ cache_examples=False
367
+ )
368
+
369
+ if __name__ == "__main__":
370
+ demo.launch(show_api=True)